Integrated-gradient for visual language tasks

This is an example of integrated-gradient on vision language tasks. The python library lavis will be released soon.

import as pio
pio.renderers.default = "png"
import os
import torch
import unittest
import numpy as np
from PIL import Image as PilImage
from import Text
from import Image
from import MultiInputs
from omnixai.preprocessing.image import Resize
from omnixai.explainers.vision_language.specific.ig import IntegratedGradient

from lavis.models import BlipITM
from lavis.processors import load_processor
The data class Image represents a batch of images, which can be constructed from a numpy array or a Pillow image. The data class Text represents a batch of texts/sentences. For vision language tasks, we use MultiInputs as the input by setting the attributes image and text.

image = Resize(size=480).transform(
text = Text("A girl playing with her dog on the beach")
inputs = MultiInputs(image=image, text=text)

We load a BLIP model as an example:

pretrained_path = \
model = BlipITM(pretrained=pretrained_path, vit="base")
image_processor = load_processor("blip_image_eval").build(image_size=384)
text_processor = load_processor("blip_caption")
tokenizer = BlipITM.init_tokenizer()

We then define the preprocessing function that converts a MultiInputs instance into the inputs of the BLIP model:

def preprocess(x: MultiInputs):
    images = torch.stack([image_processor(z.to_pil()) for z in x.image])
    texts = [text_processor(z) for z in x.text.values]
    return images, texts

To initialize IntegratedGradient for vision language tasks, we need to set the following parameters:

  • model: The ML model to explain, e.g., torch.nn.Module.

  • preprocess_function: The preprocessing function converting the raw data (a MultiInputs instance) into the inputs of model.

  • target_layer: The target layer for explanation, e.g., torch.nn.Module.

  • tokenizer: The tokenizer for processing text inputs.

  • loss_function: The loss function used to compute gradients w.r.t the target layer.

explainer = IntegratedGradient(
    loss_function=lambda outputs: outputs[:, 1].sum()
explanations = explainer.explain(inputs)
Instance 0
[CLS] a girl playing with her dog on the beach [SEP]