{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Counterfactual explanation on MNIST (PyTorch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an example of `CounterfactualExplainer` on MNIST with a PyTorch model. `CounterfactualExplainer` is an optimization based method for generating counterfactual examples, supporting classification tasks only. If using this explainer, please cite the paper \"Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR, Sandra Wachter, Brent Mittelstadt, Chris Russell, https://arxiv.org/abs/1711.00399\"." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "from torch.utils.data import Dataset\n", "from torch.utils.data import DataLoader\n", "import matplotlib.pyplot as plt\n", "from omnixai.data.image import Image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model is a simple convolutional neural network with two convolutional layers and one dense hidden layer." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class InputData(Dataset):\n", "\n", " def __init__(self, images, labels):\n", " self.images = images\n", " self.labels = labels\n", "\n", " def __len__(self):\n", " return self.images.shape[0]\n", "\n", " def __getitem__(self, index):\n", " return self.images[index], self.labels[index]\n", "\n", "\n", "class MNISTNet(nn.Module):\n", "\n", " def __init__(self):\n", " super().__init__()\n", " self.conv_layers = nn.Sequential(\n", " nn.Conv2d(1, 10, kernel_size=5),\n", " nn.MaxPool2d(2),\n", " nn.ReLU(),\n", " nn.Conv2d(10, 20, kernel_size=5),\n", " nn.Dropout(),\n", " nn.MaxPool2d(2),\n", " nn.ReLU(),\n", " )\n", " self.fc_layers = nn.Sequential(\n", " nn.Linear(320, 50),\n", " nn.ReLU(),\n", " nn.Dropout(),\n", " nn.Linear(50, 10)\n", " )\n", "\n", " def forward(self, x):\n", " x = self.conv_layers(x)\n", " x = torch.flatten(x, 1)\n", " x = self.fc_layers(x)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following code loads the training and test datasets. We recommend using `Image` to represent a batch of images. `Image` can be constructed from a numpy array or a Pillow image. In this example, `Image` is constructed from a numpy array containing a batch of digit images." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Load the training and test datasets\n", "train_data = torchvision.datasets.MNIST(root='../data', train=True, download=True)\n", "test_data = torchvision.datasets.MNIST(root='../data', train=False, download=True)\n", "train_data.data = train_data.data.numpy()\n", "test_data.data = test_data.data.numpy()\n", "\n", "class_names = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)\n", "# Use `Image` objects to represent the training and test datasets\n", "x_train, y_train = Image(train_data.data, batched=True), train_data.targets\n", "x_test, y_test = Image(test_data.data, batched=True), test_data.targets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The preprocessing function takes an `Image` instance as its input and outputs the processed features that the ML model consumes. In this example, the `Image` object is converted into a torch tensor via the defined `transform`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "# Build the CNN model\n", "model = MNISTNet().to(device)\n", "# The preprocessing function\n", "transform = transforms.Compose([transforms.ToTensor()])\n", "preprocess = lambda ims: torch.stack([transform(im.to_pil()) for im in ims])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now train the CNN model defined above and evaluate its performance." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy for class 0 is: 99.7 %\n", "Accuracy for class 1 is: 99.8 %\n", "Accuracy for class 2 is: 99.3 %\n", "Accuracy for class 3 is: 99.2 %\n", "Accuracy for class 4 is: 99.5 %\n", "Accuracy for class 5 is: 99.2 %\n", "Accuracy for class 6 is: 98.6 %\n", "Accuracy for class 7 is: 98.4 %\n", "Accuracy for class 8 is: 99.2 %\n", "Accuracy for class 9 is: 97.7 %\n" ] } ], "source": [ "learning_rate=1e-3\n", "batch_size=128\n", "num_epochs=10\n", "\n", "train_loader = DataLoader(\n", " dataset=InputData(preprocess(x_train), y_train),\n", " batch_size=batch_size,\n", " shuffle=True\n", ")\n", "test_loader = DataLoader(\n", " dataset=InputData(preprocess(x_test), y_test),\n", " batch_size=batch_size,\n", " shuffle=False\n", ")\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", "loss_func = nn.CrossEntropyLoss()\n", "\n", "model.train()\n", "for epoch in range(num_epochs):\n", " for i, (x, y) in enumerate(train_loader):\n", " x, y = x.to(device), y.to(device)\n", " loss = loss_func(model(x), y)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", "correct_pred = {name: 0 for name in class_names}\n", "total_pred = {name: 0 for name in class_names}\n", "\n", "model.eval()\n", "for x, y in test_loader:\n", " images, labels = x.to(device), y.to(device)\n", " outputs = model(images)\n", " _, predictions = torch.max(outputs, 1)\n", " for label, prediction in zip(labels, predictions):\n", " if label == prediction:\n", " correct_pred[class_names[label]] += 1\n", " total_pred[class_names[label]] += 1\n", "\n", "for name, correct_count in correct_pred.items():\n", " accuracy = 100 * float(correct_count) / total_pred[name]\n", " print(\"Accuracy for class {} is: {:.1f} %\".format(name, accuracy))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To initialize `CounterfactualExplainer`, we need to set the following parameters:\n", " \n", " - `model`: The ML model to explain, e.g., `torch.nn.Module` or `tf.keras.Model`.\n", " - `preprocess_function`: The preprocessing function that converts the raw data (a `Image` instance) into the inputs of `model`.\n", " - \"optimization parameters\": e.g., `binary_search_steps`, `num_iterations`. Please refer to the docs for more details." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "from omnixai.explainers.vision import CounterfactualExplainer\n", "\n", "explainer = CounterfactualExplainer(\n", " model=model,\n", " preprocess_function=preprocess\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can simply call `explainer.explain` to generate counterfactual examples for this classification task. `ipython_plot` plots the generated explanations in IPython. Parameter `index` indicates which instance to plot, e.g., `index = 0` means plotting the first instance in `x_test[0:5]`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Binary step: 5 |███████████████████████████████████████-| 99.9% " ] }, { "data": { "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "explanations = explainer.explain(x_test[0:5])\n", "explanations.ipython_plot(index=4)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }