{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### VisionExplainer for image classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The class `VisionExplainer` is designed for vision tasks, acting as a factory of the supported vision explainers such as integrated-gradient and Grad-CAM. `VisionExplainer` provides a unified easy-to-use interface for all the supported explainers. In practice, we recommend applying `VisionExplainer` to generate explanations instead of using a specific explainer in the package `omnixai.explainers.vision`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import json\n", "import torch\n", "import numpy as np\n", "from torchvision import models, transforms\n", "from PIL import Image as PilImage\n", "\n", "from omnixai.preprocessing.image import Resize\n", "from omnixai.data.image import Image\n", "from omnixai.explainers.vision import VisionExplainer\n", "from omnixai.visualization.dashboard import Dashboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we consider an image classification task. We recommend using `Image` to represent a batch of images. `Image` can be constructed from a numpy array or a Pillow image. The following code loads three images and resizes them to (256, 256). It then constructs an `Image` object to store the three images. By default, the shape of an `Image` object has the format (batch_size, height, width, channel), e.g., (3, 256, 256, 3)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(3, 256, 256, 3)\n" ] } ], "source": [ "# Load images for testing\n", "img_1 = Resize((256, 256)).transform(Image(PilImage.open('data/images/dog_cat.png').convert('RGB')))\n", "img_2 = Resize((256, 256)).transform(Image(PilImage.open('data/images/dog.jpg').convert('RGB')))\n", "img_3 = Resize((256, 256)).transform(Image(PilImage.open('data/images/camera.jpg').convert('RGB')))\n", "img = Image(\n", " data=np.concatenate([\n", " img_1.to_numpy(), img_2.to_numpy(), img_3.to_numpy()]),\n", " batched=True\n", ")\n", "print(img.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For visulization, the class names corresponding to the labels are also loaded." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "with open('data/images/imagenet_class_index.json', 'r') as read_file:\n", " class_idx = json.load(read_file)\n", " idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model considered here is a ResNet model pretrained on ImageNet. Similar to `TabularExplainer`, to initialize `VisionExplainer`, we need to set the following parameters:\n", "\n", " - `explainers`: The names of the explainers to apply, e.g., [\"gradcam\", \"lime\", \"ig\", \"ce\"].\n", " - `model`: The ML model to explain, e.g., a scikit-learn model, a tensorflow model, a pytorch model or a black-box prediction function.\n", " - `preprocess`: The preprocessing function converting the raw data (a `Image` instance) into the inputs of `model`.\n", " - `postprocess` (optional): The postprocessing function transforming the outputs of ``model`` to a user-specific form, e.g., the predicted probability for each class.\n", " - `mode`: The task type, e.g., \"classification\" or \"regression\".\n", " \n", "The preprocessing function takes an `Image` instance as its input and outputs the processed features that the ML model consumes. In this example, the `Image` object is first converted into a torch tensor via the defined `transform` and sent to particular device.\n", "\n", "The postprocessing function is a simple softmax function transforming the output logits into class probabilities." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "# The preprocessing function\n", "transform = transforms.Compose([\n", " transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", "])\n", "preprocess = lambda ims: torch.stack([transform(im.to_pil()) for im in ims]).to(device)\n", "# A ResNet model to explain\n", "model = models.resnet50(pretrained=True).to(device)\n", "# The postprocessing function\n", "postprocess = lambda logits: torch.nn.functional.softmax(logits, dim=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now create a `VisionExplainer`, e.g., the selected explainers include Grad-CAM, LIME, integreated-gradient and counterfactual. `params` in `VisionExplainer` allows setting parameters for each explainer applied here. For example, \"target_layer\" (a convolutional layer for analysis) for Grad-CAM is set to the last layer of `model.layer4`.\n", "\n", "There is no \"global explanation\" for `VisionExplainer` currently. One can simply call `explainer.explain` to generate local explanations for vision tasks." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0cdf2d333b21470fb0f3661a404fee57", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1000 [00:00