Feature map visualization (PyTorch)

OmniXAI provides a simple explainer for visualizing feature maps of a specified layer.

# This default renderer is used for sphinx docs only. Please delete this cell in IPython.
import plotly.io as pio
pio.renderers.default = "png"
import os
import torch
from torchvision import models, transforms
from PIL import Image as PilImage

from omnixai.data.image import Image
from omnixai.explainers.vision.specific.feature_visualization.visualizer import \

Let’s load a test image:

img = Image(PilImage.open("../data/images/dog_cat.png").convert("RGB"))

Here we choose the ResNet50 model for demonstration (you may test other CNN models, e.g., VGG). The preprocessing function converts the image input (an Image instance) into the model input.

model = models.resnet50(pretrained=True)
transform = transforms.Compose(
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
preprocess = lambda ims: torch.stack([transform(im.to_pil()) for im in ims])
The parameter 'pretrained' is deprecated since 0.13 and will be removed in 0.15, please use 'weights' instead.
Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and will be removed in 0.15. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.

We then create a FeatureMapVisualizer explainer by setting the model, the target layer to analyze and the preprocessing function.

explainer = FeatureMapVisualizer(
explanations = explainer.explain(img)