{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Counterfactual explanation on MNIST (Tensorflow)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an example of `CounterfactualExplainer` on MNIST with a Tensorflow model. `CounterfactualExplainer` is an optimization based method for generating counterfactual examples, supporting classification tasks only. If using this explainer, please cite the paper \"Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR, Sandra Wachter, Brent Mittelstadt, Chris Russell, https://arxiv.org/abs/1711.00399\"." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "import matplotlib.pyplot as plt\n", "\n", "from omnixai.data.image import Image\n", "from omnixai.explainers.vision import CounterfactualExplainer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following code loads the training and test datasets. We recommend using `Image` to represent a batch of images. `Image` can be constructed from a numpy array or a Pillow image. In this example, `Image` is constructed from a numpy array containing a batch of digit images." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Load the MNIST dataset\n", "img_rows, img_cols = 28, 28\n", "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n", "\n", "if tf.keras.backend.image_data_format() == 'channels_first':\n", " x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)\n", " x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)\n", " input_shape = (1, img_rows, img_cols)\n", "else:\n", " x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n", " x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n", " input_shape = (img_rows, img_cols, 1)\n", "\n", "# Use `Image` objects to represent the training and test datasets\n", "train_imgs, train_labels = Image(x_train.astype('float32'), batched=True), y_train\n", "test_imgs, test_labels = Image(x_test.astype('float32'), batched=True), y_test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The preprocessing function takes an `Image` instance as its input and outputs the processed features that the ML model consumes. In this example, the pixel values are normalized to [0, 1]." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "preprocess_func = lambda x: np.expand_dims(x.to_numpy() / 255, axis=-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We train a simple convolutional neural network for this task. The network has two convolutional layers and one dense hidden layer. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.1696 - accuracy: 0.9492 - val_loss: 0.0436 - val_accuracy: 0.9855\n", "Epoch 2/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0478 - accuracy: 0.9856 - val_loss: 0.0352 - val_accuracy: 0.9882\n", "Epoch 3/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0324 - accuracy: 0.9896 - val_loss: 0.0315 - val_accuracy: 0.9892\n", "Epoch 4/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0223 - accuracy: 0.9929 - val_loss: 0.0320 - val_accuracy: 0.9887\n", "Epoch 5/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0179 - accuracy: 0.9940 - val_loss: 0.0314 - val_accuracy: 0.9901\n", "Epoch 6/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0141 - accuracy: 0.9952 - val_loss: 0.0365 - val_accuracy: 0.9888\n", "Epoch 7/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0113 - accuracy: 0.9960 - val_loss: 0.0324 - val_accuracy: 0.9903\n", "Epoch 8/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0109 - accuracy: 0.9965 - val_loss: 0.0297 - val_accuracy: 0.9918\n", "Epoch 9/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0083 - accuracy: 0.9972 - val_loss: 0.0337 - val_accuracy: 0.9918\n", "Epoch 10/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0072 - accuracy: 0.9976 - val_loss: 0.0382 - val_accuracy: 0.9895\n", "Test loss: 0.03824701905250549\n", "Test accuracy: 0.9894999861717224\n" ] } ], "source": [ "batch_size = 128\n", "num_classes = 10\n", "epochs = 10\n", "\n", "# Preprocess the training and test data\n", "x_train = preprocess_func(train_imgs)\n", "x_test = preprocess_func(test_imgs)\n", "y_train = tf.keras.utils.to_categorical(y_train, num_classes)\n", "y_test = tf.keras.utils.to_categorical(y_test, num_classes)\n", "\n", "# Model structure\n", "model = tf.keras.models.Sequential()\n", "model.add(tf.keras.layers.Conv2D(\n", " 32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))\n", "model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))\n", "model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))\n", "model.add(tf.keras.layers.Dropout(0.1))\n", "model.add(tf.keras.layers.Flatten())\n", "model.add(tf.keras.layers.Dense(128, activation='relu'))\n", "model.add(tf.keras.layers.Dropout(0.1))\n", "model.add(tf.keras.layers.Dense(num_classes))\n", "\n", "# Train the model\n", "model.compile(\n", " loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),\n", " optimizer=tf.keras.optimizers.Adam(),\n", " metrics=['accuracy']\n", ")\n", "model.fit(\n", " x_train, y_train,\n", " batch_size=batch_size,\n", " epochs=epochs,\n", " verbose=1,\n", " validation_data=(x_test, y_test)\n", ")\n", "score = model.evaluate(x_test, y_test, verbose=0)\n", "print('Test loss:', score[0])\n", "print('Test accuracy:', score[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To initialize `CounterfactualExplainer`, we need to set the following parameters:\n", " \n", " - `model`: The ML model to explain, e.g., `torch.nn.Module` or `tf.keras.Model`.\n", " - `preprocess_function`: The preprocessing function that converts the raw data (a `Image` instance) into the inputs of `model`.\n", " - \"optimization parameters\": e.g., `binary_search_steps`, `num_iterations`. Please refer to the docs for more details." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "explainer = CounterfactualExplainer(\n", " model=model,\n", " preprocess_function=preprocess_func\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can simply call `explainer.explain` to generate counterfactual examples for this classification task. `ipython_plot` plots the generated explanations in IPython. Parameter `index` indicates which instance to plot, e.g., `index = 0` means plotting the first instance in `test_imgs[0:5]`." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Binary step: 5 |███████████████████████████████████████-| 99.9% " ] }, { "data": { "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "explanations = explainer.explain(test_imgs[0:5])\n", "explanations.ipython_plot(index=4)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }