{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### L2X (learning to explain) on MNIST" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an example of the L2X explainer on image classification. Different from gradient-based methods, L2X trains a separate explanation model. The advantage of L2X is that it generates explanations fast after the explanation model is trained. The disadvantage is that the quality of the explanations highly depend on the trained explanation model, which can be affected by multiple factors, e.g., the network structure of the explanation model, the training parameters.\n", "\n", "For image data, we implement the default explanation model in `omnixai.explainers.vision.agnostic.l2x`. One may implement other models by following the same interface. Please refer to the docs for more details. If using this explainer, please cite the original work: \"Learning to Explain: An Information-Theoretic Perspective on Model Interpretation, Jianbo Chen, Le Song, Martin J. Wainwright, Michael I. Jordan, https://arxiv.org/abs/1802.07814\"." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "from torch.utils.data import Dataset\n", "from torch.utils.data import DataLoader\n", "\n", "from omnixai.data.image import Image\n", "from omnixai.explainers.vision import L2XImage" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model is a simple convolutional neural network with two convolutional layers and one dense hidden layer." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class InputData(Dataset):\n", "\n", " def __init__(self, images, labels):\n", " self.images = images\n", " self.labels = labels\n", "\n", " def __len__(self):\n", " return self.images.shape[0]\n", "\n", " def __getitem__(self, index):\n", " return self.images[index], self.labels[index]\n", " \n", "\n", "class MNISTNet(nn.Module):\n", "\n", " def __init__(self):\n", " super().__init__()\n", " self.conv_layers = nn.Sequential(\n", " nn.Conv2d(1, 10, kernel_size=5),\n", " nn.MaxPool2d(2),\n", " nn.ReLU(),\n", " nn.Conv2d(10, 20, kernel_size=5),\n", " nn.Dropout(),\n", " nn.MaxPool2d(2),\n", " nn.ReLU(),\n", " )\n", " self.fc_layers = nn.Sequential(\n", " nn.Linear(320, 50),\n", " nn.ReLU(),\n", " nn.Dropout(),\n", " nn.Linear(50, 10)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.conv_layers(x)\n", " x = torch.flatten(x, 1)\n", " x = self.fc_layers(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following code loads the training and test datasets. We recommend using `Image` to represent a batch of images. `Image` can be constructed from a numpy array or a Pillow image. In this example, `Image` is constructed from a numpy array containing a batch of digit images." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Load the training and test datasets\n", "train_data = torchvision.datasets.MNIST(root='../data', train=True, download=True)\n", "test_data = torchvision.datasets.MNIST(root='../data', train=False, download=True)\n", "train_data.data = train_data.data.numpy()\n", "test_data.data = test_data.data.numpy()\n", "\n", "class_names = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)\n", "# Use `Image` objects to represent the training and test datasets\n", "train_imgs, train_labels = Image(train_data.data, batched=True), train_data.targets\n", "test_imgs, test_labels = Image(test_data.data, batched=True), test_data.targets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The predictions function takes an `Image` instance as its input and outputs the class probabilities or logits for classification tasks." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "# The CNN model\n", "model = MNISTNet().to(device)\n", "# The preprocessing function\n", "transform = transforms.Compose([transforms.ToTensor()])\n", "preprocess = lambda ims: torch.stack([transform(im.to_pil()) for im in ims])\n", "# The prediction function\n", "predict_function = lambda ims: model(preprocess(ims).to(device)).detach().cpu().numpy()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now train the CNN model defined above and evaluate its performance." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy for class 0 is: 99.6 %\n", "Accuracy for class 1 is: 99.7 %\n", "Accuracy for class 2 is: 98.8 %\n", "Accuracy for class 3 is: 98.5 %\n", "Accuracy for class 4 is: 99.4 %\n", "Accuracy for class 5 is: 99.1 %\n", "Accuracy for class 6 is: 98.4 %\n", "Accuracy for class 7 is: 99.4 %\n", "Accuracy for class 8 is: 98.7 %\n", "Accuracy for class 9 is: 96.4 %\n" ] } ], "source": [ "learning_rate=1e-3\n", "batch_size=32\n", "num_epochs=5\n", "\n", "train_loader = DataLoader(\n", " dataset=InputData(preprocess(train_imgs), train_labels),\n", " batch_size=batch_size,\n", " shuffle=True\n", ")\n", "test_loader = DataLoader(\n", " dataset=InputData(preprocess(test_imgs), test_labels),\n", " batch_size=batch_size,\n", " shuffle=False\n", ")\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "loss_func = nn.CrossEntropyLoss()\n", "\n", "model.train()\n", "for epoch in range(num_epochs):\n", " for i, (x, y) in enumerate(train_loader):\n", " x, y = x.to(device), y.to(device)\n", " loss = loss_func(model(x), y)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", "correct_pred = {name: 0 for name in class_names}\n", "total_pred = {name: 0 for name in class_names}\n", "\n", "model.eval()\n", "for x, y in test_loader:\n", " images, labels = x.to(device), y.to(device)\n", " outputs = model(images)\n", " _, predictions = torch.max(outputs, 1)\n", " for label, prediction in zip(labels, predictions):\n", " if label == prediction:\n", " correct_pred[class_names[label]] += 1\n", " total_pred[class_names[label]] += 1\n", "\n", "for name, correct_count in correct_pred.items():\n", " accuracy = 100 * float(correct_count) / total_pred[name]\n", " print(\"Accuracy for class {} is: {:.1f} %\".format(name, accuracy))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To initialize `L2XImage`, we need to set the following parameters:\n", " \n", " - `training_data`: The data used to train the explainer. `training_data` should be the training dataset for training the machine learning model.\n", " - `predict_function`: The prediction function corresponding to the model to explain. When the model is for classification, the outputs of the `predict_function` are the class probabilities. When the model is for regression, the outputs of the `predict_function` are the estimated values.\n", " - `mode`: The task type, e.g., `classification` or `regression`.\n", " - `selection_model`: A pytorch model class for estimating P(S|X) in L2X. If `selection_model = None`, a default model `DefaultSelectionModel` will be used.\n", " - `prediction_model`: A pytorch model class for estimating Q(X_S) in L2X. If `prediction_model = None`, a default model `DefaultPredictionModel` will be used." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " |████████████████████████████████████████| 100.0% Complete, Loss 0.2665\n", "L2X prediction model accuracy: 0.8901166666666667\n" ] } ], "source": [ "explainer = L2XImage(\n", " training_data=train_imgs,\n", " predict_function=predict_function,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We call `explainer.explain` to generate explanations for this classification task. `ipython_plot` plots the generated explanations in IPython. Parameter `index` indicates which instance to plot, e.g., `index = 0` means plotting the first instance in `test_imgs[0:5]`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n" ] }, { "data": { "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "explanations = explainer.explain(test_imgs[0:5])\n", "explanations.ipython_plot(index=1)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }