{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Contrastive explanation on MNIST (Tensorflow)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an example of `ContrastiveExplainer` on MNIST with a Tensorflow model. `ContrastiveExplainer` is an optimization based method for generating explanations (pertinent negatives and pertinent positives), supporting classification tasks only. If using this explainer, please cite the original work: https://arxiv.org/abs/1802.07623." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "import matplotlib.pyplot as plt\n", "\n", "from omnixai.data.image import Image\n", "from omnixai.explainers.vision import ContrastiveExplainer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following code loads the training and test datasets. We recommend using `Image` to represent a batch of images. `Image` can be constructed from a numpy array or a Pillow image. In this example, `Image` is constructed from a numpy array containing a batch of digit images." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Load the MNIST dataset\n", "img_rows, img_cols = 28, 28\n", "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n", "\n", "if tf.keras.backend.image_data_format() == 'channels_first':\n", " x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)\n", " x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)\n", " input_shape = (1, img_rows, img_cols)\n", "else:\n", " x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n", " x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n", " input_shape = (img_rows, img_cols, 1)\n", "\n", "# Use `Image` objects to represent the training and test datasets\n", "train_imgs, train_labels = Image(x_train.astype('float32'), batched=True), y_train\n", "test_imgs, test_labels = Image(x_test.astype('float32'), batched=True), y_test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The preprocessing function takes an `Image` instance as its input and outputs the processed features that the ML model consumes. In this example, the pixel values are normalized to [0, 1]." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "preprocess_func = lambda x: np.expand_dims(x.to_numpy() / 255, axis=-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We train a simple convolutional neural network for this task. The network has two convolutional layers and one dense hidden layer. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.1712 - accuracy: 0.9493 - val_loss: 0.0509 - val_accuracy: 0.9837\n", "Epoch 2/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0467 - accuracy: 0.9857 - val_loss: 0.0364 - val_accuracy: 0.9880\n", "Epoch 3/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0331 - accuracy: 0.9896 - val_loss: 0.0323 - val_accuracy: 0.9884\n", "Epoch 4/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0226 - accuracy: 0.9927 - val_loss: 0.0345 - val_accuracy: 0.9890\n", "Epoch 5/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0171 - accuracy: 0.9942 - val_loss: 0.0371 - val_accuracy: 0.9880\n", "Epoch 6/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0150 - accuracy: 0.9949 - val_loss: 0.0297 - val_accuracy: 0.9906\n", "Epoch 7/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0109 - accuracy: 0.9966 - val_loss: 0.0428 - val_accuracy: 0.9887\n", "Epoch 8/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0101 - accuracy: 0.9967 - val_loss: 0.0356 - val_accuracy: 0.9895\n", "Epoch 9/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0086 - accuracy: 0.9969 - val_loss: 0.0393 - val_accuracy: 0.9892\n", "Epoch 10/10\n", "469/469 [==============================] - 2s 5ms/step - loss: 0.0065 - accuracy: 0.9977 - val_loss: 0.0399 - val_accuracy: 0.9898\n", "Test loss: 0.03988948091864586\n", "Test accuracy: 0.989799976348877\n" ] } ], "source": [ "batch_size = 128\n", "num_classes = 10\n", "epochs = 10\n", "\n", "# Preprocess the training and test data\n", "x_train = preprocess_func(train_imgs)\n", "x_test = preprocess_func(test_imgs)\n", "y_train = tf.keras.utils.to_categorical(y_train, num_classes)\n", "y_test = tf.keras.utils.to_categorical(y_test, num_classes)\n", "\n", "# Model structure\n", "model = tf.keras.models.Sequential()\n", "model.add(tf.keras.layers.Conv2D(\n", " 32, kernel_size=(3, 3), activation='relu', input_shape=input_shape))\n", "model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))\n", "model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))\n", "model.add(tf.keras.layers.Dropout(0.1))\n", "model.add(tf.keras.layers.Flatten())\n", "model.add(tf.keras.layers.Dense(128, activation='relu'))\n", "model.add(tf.keras.layers.Dropout(0.1))\n", "model.add(tf.keras.layers.Dense(num_classes))\n", "\n", "# Train the model\n", "model.compile(\n", " loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),\n", " optimizer=tf.keras.optimizers.Adam(),\n", " metrics=['accuracy']\n", ")\n", "model.fit(\n", " x_train, y_train,\n", " batch_size=batch_size,\n", " epochs=epochs,\n", " verbose=1,\n", " validation_data=(x_test, y_test)\n", ")\n", "score = model.evaluate(x_test, y_test, verbose=0)\n", "print('Test loss:', score[0])\n", "print('Test accuracy:', score[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To initialize `ContrastiveExplainer`, we need to set the following parameters:\n", " \n", " - `model`: The ML model to explain, e.g., `torch.nn.Module` or `tf.keras.Model`.\n", " - `preprocess_function`: The preprocessing function that converts the raw data (a `Image` instance) into the inputs of `model`.\n", " - \"optimization parameters\": e.g., `binary_search_steps`, `num_iterations`. Please refer to the docs for more details." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "explainer = ContrastiveExplainer(\n", " model=model,\n", " preprocess_function=preprocess_func\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can simply call `explainer.explain` to generate explanations for this classification task. `ipython_plot` plots the generated explanations in IPython. Parameter `index` indicates which instance to plot, e.g., `index = 0` means plotting the first instance in `test_imgs[0:5]`." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Binary step: 5 |----------------------------------------| 0.6% " ] }, { "data": { "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "explanations = explainer.explain(test_imgs[0:5])\n", "explanations.ipython_plot(index=4)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }