{ "cells": [ { "cell_type": "markdown", "id": "45daa2a3", "metadata": {}, "source": [ "### Integrated-gradient for visual language tasks" ] }, { "cell_type": "markdown", "id": "c11a3120", "metadata": {}, "source": [ "This is an example of integrated-gradient on vision language tasks. The python library ``lavis`` will be released soon." ] }, { "cell_type": "code", "execution_count": 1, "id": "b8cc791d", "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "c7530beb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ywz/anaconda3/envs/conda_env_py3.9.12/lib/python3.9/site-packages/torch/cuda/__init__.py:83: UserWarning:\n", "\n", "CUDA initialization: The NVIDIA driver on your system is too old (found version 10010). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.)\n", "\n" ] } ], "source": [ "import os\n", "import torch\n", "import unittest\n", "import numpy as np\n", "from PIL import Image as PilImage\n", "from omnixai.data.text import Text\n", "from omnixai.data.image import Image\n", "from omnixai.data.multi_inputs import MultiInputs\n", "from omnixai.preprocessing.image import Resize\n", "from omnixai.explainers.vision_language.specific.ig import IntegratedGradient\n", "\n", "from lavis.models import BlipITM\n", "from lavis.processors import load_processor" ] }, { "cell_type": "markdown", "id": "660e0e32", "metadata": {}, "source": [ "The data class `Image` represents a batch of images, which can be constructed from a numpy array or a Pillow image. The data class `Text` represents a batch of texts/sentences. For vision language tasks, we use `MultiInputs` as the input by setting the attributes `image` and `text`." ] }, { "cell_type": "code", "execution_count": 3, "id": "9b7b95e0", "metadata": {}, "outputs": [], "source": [ "image = Resize(size=480).transform(\n", " Image(PilImage.open(\"../data/images/girl_dog.jpg\").convert(\"RGB\")))\n", "text = Text(\"A girl playing with her dog on the beach\")\n", "inputs = MultiInputs(image=image, text=text)" ] }, { "cell_type": "markdown", "id": "16e2f8ba", "metadata": {}, "source": [ "We load a BLIP model as an example:" ] }, { "cell_type": "code", "execution_count": 4, "id": "1755116d", "metadata": {}, "outputs": [], "source": [ "pretrained_path = \\\n", " \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth\"\n", "model = BlipITM(pretrained=pretrained_path, vit=\"base\")\n", "image_processor = load_processor(\"blip_image_eval\").build(image_size=384)\n", "text_processor = load_processor(\"blip_caption\")\n", "tokenizer = BlipITM.init_tokenizer()" ] }, { "cell_type": "markdown", "id": "2720ad73", "metadata": {}, "source": [ "We then define the preprocessing function that converts a `MultiInputs` instance into the inputs of the BLIP model:" ] }, { "cell_type": "code", "execution_count": 5, "id": "9adaaa87", "metadata": {}, "outputs": [], "source": [ "def preprocess(x: MultiInputs):\n", " images = torch.stack([image_processor(z.to_pil()) for z in x.image])\n", " texts = [text_processor(z) for z in x.text.values]\n", " return images, texts" ] }, { "cell_type": "markdown", "id": "5f4a9400", "metadata": {}, "source": [ "To initialize `IntegratedGradient` for vision language tasks, we need to set the following parameters:\n", "\n", " - `model`: The ML model to explain, e.g., `torch.nn.Module`.\n", " - `preprocess_function`: The preprocessing function converting the raw data (a `MultiInputs` instance) into the inputs of `model`.\n", " - `target_layer`: The target layer for explanation, e.g., `torch.nn.Module`.\n", " - `tokenizer`: The tokenizer for processing text inputs.\n", " - `loss_function`: The loss function used to compute gradients w.r.t the target layer." ] }, { "cell_type": "code", "execution_count": 6, "id": "074b0e51", "metadata": {}, "outputs": [], "source": [ "explainer = IntegratedGradient(\n", " model=model,\n", " embedding_layer=model.text_encoder.embeddings.word_embeddings,\n", " preprocess_function=preprocess,\n", " tokenizer=tokenizer,\n", " loss_function=lambda outputs: outputs[:, 1].sum()\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "44f7b410", "metadata": {}, "outputs": [ { "data": { "text/html": [ "