{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### LIME for image classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an example of the LIME explainer for image data. This explainer only supports image classification tasks. If using this explainer, please cite the original work: https://github.com/marcotcr/lime." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import json\n", "import unittest\n", "import torch\n", "from torchvision import models, transforms\n", "from PIL import Image as PilImage\n", "from omnixai.data.image import Image\n", "from omnixai.explainers.vision import LimeImage" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We recommend using `Image` to represent a batch of images. `Image` can be constructed from a numpy array or a Pillow image. The following code loads one test image and the class names on ImageNet." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Load the test image\n", "img = Image(PilImage.open('../data/images/dog_cat.png').convert('RGB'))\n", "# Load the class names\n", "with open('../data/images/imagenet_class_index.json', 'r') as read_file:\n", " class_idx = json.load(read_file)\n", " idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model applied here is a Inception model pretrained on ImageNet." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "model = models.inception_v3(pretrained=True)\n", "transform = transforms.Compose([\n", " transforms.Resize((256, 256)),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", "])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check the predicted labels of the test image." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "((0.93592954, 239, 'Bernese_mountain_dog'), (0.038448237, 241, 'EntleBucher'), (0.023756476, 240, 'Appenzeller'), (0.0018181928, 238, 'Greater_Swiss_Mountain_dog'), (9.113302e-06, 214, 'Gordon_setter'))\n" ] } ], "source": [ "model.eval()\n", "input_img = transform(img.to_pil()).unsqueeze(dim=0)\n", "probs_top_5 = torch.nn.functional.softmax(model(input_img), dim=1).topk(5)\n", "r = tuple((p, c, idx2label[c]) for p, c in\n", " zip(probs_top_5[0][0].detach().numpy(), probs_top_5[1][0].detach().numpy()))\n", "print(r)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To initialize `LimeImage`, we need to set the following parameters:\n", "\n", " - `predict_function`: The prediction function corresponding to the machine learning model to explain. For classification, the outputs of the ``predict_function`` are the class probabilities." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f390f60809a5438ea0e11eebec78982f", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1000 [00:00