{ "cells": [ { "cell_type": "markdown", "id": "45daa2a3", "metadata": {}, "source": [ "### Integrated-gradient for visual language tasks" ] }, { "cell_type": "markdown", "id": "c11a3120", "metadata": {}, "source": [ "This is an example of integrated-gradient on vision language tasks. The python library ``lavis`` will be released soon." ] }, { "cell_type": "code", "execution_count": 1, "id": "b8cc791d", "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "c7530beb", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ywz/anaconda3/envs/conda_env_py3.9.12/lib/python3.9/site-packages/torch/cuda/__init__.py:83: UserWarning:\n", "\n", "CUDA initialization: The NVIDIA driver on your system is too old (found version 10010). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.)\n", "\n" ] } ], "source": [ "import os\n", "import torch\n", "import unittest\n", "import numpy as np\n", "from PIL import Image as PilImage\n", "from omnixai.data.text import Text\n", "from omnixai.data.image import Image\n", "from omnixai.data.multi_inputs import MultiInputs\n", "from omnixai.preprocessing.image import Resize\n", "from omnixai.explainers.vision_language.specific.ig import IntegratedGradient\n", "\n", "from lavis.models import BlipITM\n", "from lavis.processors import load_processor" ] }, { "cell_type": "markdown", "id": "660e0e32", "metadata": {}, "source": [ "The data class `Image` represents a batch of images, which can be constructed from a numpy array or a Pillow image. The data class `Text` represents a batch of texts/sentences. For vision language tasks, we use `MultiInputs` as the input by setting the attributes `image` and `text`." ] }, { "cell_type": "code", "execution_count": 3, "id": "9b7b95e0", "metadata": {}, "outputs": [], "source": [ "image = Resize(size=480).transform(\n", " Image(PilImage.open(\"../data/images/girl_dog.jpg\").convert(\"RGB\")))\n", "text = Text(\"A girl playing with her dog on the beach\")\n", "inputs = MultiInputs(image=image, text=text)" ] }, { "cell_type": "markdown", "id": "16e2f8ba", "metadata": {}, "source": [ "We load a BLIP model as an example:" ] }, { "cell_type": "code", "execution_count": 4, "id": "1755116d", "metadata": {}, "outputs": [], "source": [ "pretrained_path = \\\n", " \"https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth\"\n", "model = BlipITM(pretrained=pretrained_path, vit=\"base\")\n", "image_processor = load_processor(\"blip_image_eval\").build(image_size=384)\n", "text_processor = load_processor(\"blip_caption\")\n", "tokenizer = BlipITM.init_tokenizer()" ] }, { "cell_type": "markdown", "id": "2720ad73", "metadata": {}, "source": [ "We then define the preprocessing function that converts a `MultiInputs` instance into the inputs of the BLIP model:" ] }, { "cell_type": "code", "execution_count": 5, "id": "9adaaa87", "metadata": {}, "outputs": [], "source": [ "def preprocess(x: MultiInputs):\n", " images = torch.stack([image_processor(z.to_pil()) for z in x.image])\n", " texts = [text_processor(z) for z in x.text.values]\n", " return images, texts" ] }, { "cell_type": "markdown", "id": "5f4a9400", "metadata": {}, "source": [ "To initialize `IntegratedGradient` for vision language tasks, we need to set the following parameters:\n", "\n", " - `model`: The ML model to explain, e.g., `torch.nn.Module`.\n", " - `preprocess_function`: The preprocessing function converting the raw data (a `MultiInputs` instance) into the inputs of `model`.\n", " - `target_layer`: The target layer for explanation, e.g., `torch.nn.Module`.\n", " - `tokenizer`: The tokenizer for processing text inputs.\n", " - `loss_function`: The loss function used to compute gradients w.r.t the target layer." ] }, { "cell_type": "code", "execution_count": 6, "id": "074b0e51", "metadata": {}, "outputs": [], "source": [ "explainer = IntegratedGradient(\n", " model=model,\n", " embedding_layer=model.text_encoder.embeddings.word_embeddings,\n", " preprocess_function=preprocess,\n", " tokenizer=tokenizer,\n", " loss_function=lambda outputs: outputs[:, 1].sum()\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "id": "44f7b410", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
Instance 0
\n", "
[CLS] a girl playing with her dog on the beach [SEP]

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "explanations = explainer.explain(inputs)\n", "explanations.ipython_plot()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 5 }