Grad-CAM for image classification (PyTorch)

This is an example of Grad-CAM on image classification with a PyTorch model. If using this explainer, please cite “Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization, Selvaraju et al., https://arxiv.org/abs/1610.02391”.

[1]:
# This default renderer is used for sphinx docs only. Please delete this cell in IPython.
import plotly.io as pio
pio.renderers.default = "png"
[2]:
import json
import torch
from torchvision import models, transforms
from PIL import Image as PilImage

from omnixai.data.image import Image
from omnixai.explainers.vision.specific.gradcam.pytorch.gradcam import GradCAM

We recommend using Image to represent a batch of images. Image can be constructed from a numpy array or a Pillow image. The following code loads one test image and the class names on ImageNet.

[3]:
# Load the test image
img = Image(PilImage.open('../data/images/camera.jpg').convert('RGB'))
# Load the class names
with open('../data/images/imagenet_class_index.json', 'r') as read_file:
    class_idx = json.load(read_file)
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]

The model considered here is a ResNet model pretrained on ImageNet. The preprocessing function takes an Image instance as its input and outputs the processed features that the ML model consumes. In this example, the Image object is converted into a torch tensor via the defined transform.

[4]:
# A ResNet Model
model = models.resnet50(pretrained=True)
# The preprocessing model
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
preprocess = lambda ims: torch.stack([transform(im.to_pil()) for im in ims])

To initialize GradCAM, we need to set the following parameters:

  • model: The ML model to explain, e.g., tf.keras.Model or torch.nn.Module.

  • preprocess: The preprocessing function converting the raw data (a Image instance) into the inputs of model.

  • target_layer: The target convolutional layer for explanation, which can be tf.keras.layers.Layer or torch.nn.Module.

  • mode: The task type, e.g., “classification” or “regression”.

[5]:
explainer = GradCAM(
    model=model,
    target_layer=model.layer4[-1],
    preprocess_function=preprocess
)
# Explain the top label
explanations = explainer.explain(img)
explanations.ipython_plot(index=0, class_names=idx2label)
../../_images/tutorials_vision_gradcam_torch_9_0.png