{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Counterfactual explanation on Diabetes dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an example of the basic counterfactual explainer `CounterfactualExplainer` for tabular data. It only supports continuous-valued features. If using this explainer, please cite the paper \"Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR, Sandra Wachter, Brent Mittelstadt, Chris Russell, https://arxiv.org/abs/1711.00399\"." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "from omnixai.data.tabular import Tabular\n", "from omnixai.explainers.tabular import CounterfactualExplainer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dataset considered here is the Diabetes dataset (https://archive.ics.uci.edu/ml/datasets/diabetes). We convert all the features into continuous-valued features." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def diabetes_data(file_path):\n", " data = pd.read_csv(file_path)\n", " data = data.replace(\n", " to_replace=['Yes', 'No', 'Positive', 'Negative', 'Male', 'Female'],\n", " value=[1, 0, 1, 0, 1, 0]\n", " )\n", " features = [\n", " 'Age', 'Gender', 'Polyuria', 'Polydipsia', 'sudden weight loss',\n", " 'weakness', 'Polyphagia', 'Genital thrush', 'visual blurring',\n", " 'Itching', 'Irritability', 'delayed healing', 'partial paresis',\n", " 'muscle stiffness', 'Alopecia', 'Obesity']\n", "\n", " y = data['class']\n", " data = data.drop(['class'], axis=1)\n", " x_train_un, x_test_un, y_train, y_test = \\\n", " train_test_split(data, y, test_size=0.2, random_state=2, stratify=y)\n", "\n", " sc = StandardScaler()\n", " x_train = sc.fit_transform(x_train_un)\n", " x_test = sc.transform(x_test_un)\n", "\n", " x_train = x_train.astype(np.float32)\n", " y_train = y_train.to_numpy()\n", " x_test = x_test.astype(np.float32)\n", " y_test = y_test.to_numpy()\n", "\n", " return x_train, y_train, x_test, y_test, features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we apply a tensorflow model for this diabetes prediction task. The model is a feedforward network with two hidden layers." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def train_tf_model(x_train, y_train, x_test, y_test):\n", " y_train = tf.keras.utils.to_categorical(y_train, 2)\n", " y_test = tf.keras.utils.to_categorical(y_test, 2)\n", "\n", " model = tf.keras.models.Sequential()\n", " model.add(tf.keras.layers.Input(shape=(16,)))\n", " model.add(tf.keras.layers.Dense(units=128, activation=tf.keras.activations.softplus))\n", " model.add(tf.keras.layers.Dense(units=64, activation=tf.keras.activations.softplus))\n", " model.add(tf.keras.layers.Dense(units=2, activation=None))\n", "\n", " learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(\n", " initial_learning_rate=0.1,\n", " decay_steps=1,\n", " decay_rate=0.99,\n", " staircase=True\n", " )\n", " optimizer = tf.keras.optimizers.SGD(\n", " learning_rate=learning_rate, \n", " momentum=0.9, \n", " nesterov=True\n", " )\n", " loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)\n", " model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])\n", " model.fit(x_train, y_train, batch_size=256, epochs=200, verbose=0)\n", " train_loss, train_accuracy = model.evaluate(x_train, y_train, batch_size=51, verbose=0)\n", " test_loss, test_accuracy = model.evaluate(x_test, y_test, batch_size=51, verbose=0)\n", "\n", " print('Train loss: {:.4f}, train accuracy: {:.4f}'.format(train_loss, train_accuracy))\n", " print('Test loss: {:.4f}, test accuracy: {:.4f}'.format(test_loss, test_accuracy))\n", " return model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We then load the dataset and train the tensorflow model defined above. Similar to other tabular explainers, we use `Tabular` to represent a tabular dataset used for initializing the explainer." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x_train shape: (416, 16)\n", "x_test shape: (104, 16)\n", "Train loss: 0.0631, train accuracy: 0.9856\n", "Test loss: 0.0568, test accuracy: 0.9808\n" ] } ], "source": [ "file_path = '../data/diabetes.csv'\n", "x_train, y_train, x_test, y_test, feature_names = diabetes_data(file_path)\n", "print('x_train shape: {}'.format(x_train.shape))\n", "print('x_test shape: {}'.format(x_test.shape))\n", "\n", "model = train_tf_model(x_train, y_train, x_test, y_test)\n", "# Used for initializing the explainer\n", "tabular_data = Tabular(\n", " x_train,\n", " feature_columns=feature_names,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To initialize a `CounterfactualExplainer` explainer, we need to set:\n", " \n", " - `training_data`: The data used to extract information such as medians of continuous-valued features. ``training_data`` can be the training dataset for training the machine learning model. If the training dataset is large, ``training_data`` can be a subset of it by applying `omnixai.sampler.tabular.Sampler.subsample`.\n", " - `predict_function`: The prediction function corresponding to the model.\n", " - `mode`: The task type, e.g., \"classification\" or \"regression\".\n", " \n", "In this example, the prediction function is a tensorflow model which is callable, so we can set `predict_function=model`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Binary step: 5 |███████████████████████████████████████-| 99.9% " ] }, { "data": { "image/png": "" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "explainer = CounterfactualExplainer(\n", " training_data=tabular_data,\n", " predict_function=model\n", ")\n", "explanations = explainer.explain(x_test[:1])\n", "explanations.ipython_plot()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }