{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Counterfactual explanation on Diabetes dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an example of the basic counterfactual explainer `CounterfactualExplainer` for tabular data. It only supports continuous-valued features. If using this explainer, please cite the paper \"Counterfactual Explanations without Opening the Black Box: Automated Decisions and the GDPR, Sandra Wachter, Brent Mittelstadt, Chris Russell, https://arxiv.org/abs/1711.00399\"." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import StandardScaler\n", "\n", "from omnixai.data.tabular import Tabular\n", "from omnixai.explainers.tabular import CounterfactualExplainer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dataset considered here is the Diabetes dataset (https://archive.ics.uci.edu/ml/datasets/diabetes). We convert all the features into continuous-valued features." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def diabetes_data(file_path):\n", " data = pd.read_csv(file_path)\n", " data = data.replace(\n", " to_replace=['Yes', 'No', 'Positive', 'Negative', 'Male', 'Female'],\n", " value=[1, 0, 1, 0, 1, 0]\n", " )\n", " features = [\n", " 'Age', 'Gender', 'Polyuria', 'Polydipsia', 'sudden weight loss',\n", " 'weakness', 'Polyphagia', 'Genital thrush', 'visual blurring',\n", " 'Itching', 'Irritability', 'delayed healing', 'partial paresis',\n", " 'muscle stiffness', 'Alopecia', 'Obesity']\n", "\n", " y = data['class']\n", " data = data.drop(['class'], axis=1)\n", " x_train_un, x_test_un, y_train, y_test = \\\n", " train_test_split(data, y, test_size=0.2, random_state=2, stratify=y)\n", "\n", " sc = StandardScaler()\n", " x_train = sc.fit_transform(x_train_un)\n", " x_test = sc.transform(x_test_un)\n", "\n", " x_train = x_train.astype(np.float32)\n", " y_train = y_train.to_numpy()\n", " x_test = x_test.astype(np.float32)\n", " y_test = y_test.to_numpy()\n", "\n", " return x_train, y_train, x_test, y_test, features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we apply a tensorflow model for this diabetes prediction task. The model is a feedforward network with two hidden layers." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def train_tf_model(x_train, y_train, x_test, y_test):\n", " y_train = tf.keras.utils.to_categorical(y_train, 2)\n", " y_test = tf.keras.utils.to_categorical(y_test, 2)\n", "\n", " model = tf.keras.models.Sequential()\n", " model.add(tf.keras.layers.Input(shape=(16,)))\n", " model.add(tf.keras.layers.Dense(units=128, activation=tf.keras.activations.softplus))\n", " model.add(tf.keras.layers.Dense(units=64, activation=tf.keras.activations.softplus))\n", " model.add(tf.keras.layers.Dense(units=2, activation=None))\n", "\n", " learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(\n", " initial_learning_rate=0.1,\n", " decay_steps=1,\n", " decay_rate=0.99,\n", " staircase=True\n", " )\n", " optimizer = tf.keras.optimizers.SGD(\n", " learning_rate=learning_rate, \n", " momentum=0.9, \n", " nesterov=True\n", " )\n", " loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)\n", " model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])\n", " model.fit(x_train, y_train, batch_size=256, epochs=200, verbose=0)\n", " train_loss, train_accuracy = model.evaluate(x_train, y_train, batch_size=51, verbose=0)\n", " test_loss, test_accuracy = model.evaluate(x_test, y_test, batch_size=51, verbose=0)\n", "\n", " print('Train loss: {:.4f}, train accuracy: {:.4f}'.format(train_loss, train_accuracy))\n", " print('Test loss: {:.4f}, test accuracy: {:.4f}'.format(test_loss, test_accuracy))\n", " return model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We then load the dataset and train the tensorflow model defined above. Similar to other tabular explainers, we use `Tabular` to represent a tabular dataset used for initializing the explainer." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x_train shape: (416, 16)\n", "x_test shape: (104, 16)\n", "Train loss: 0.0631, train accuracy: 0.9856\n", "Test loss: 0.0568, test accuracy: 0.9808\n" ] } ], "source": [ "file_path = '../data/diabetes.csv'\n", "x_train, y_train, x_test, y_test, feature_names = diabetes_data(file_path)\n", "print('x_train shape: {}'.format(x_train.shape))\n", "print('x_test shape: {}'.format(x_test.shape))\n", "\n", "model = train_tf_model(x_train, y_train, x_test, y_test)\n", "# Used for initializing the explainer\n", "tabular_data = Tabular(\n", " x_train,\n", " feature_columns=feature_names,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To initialize a `CounterfactualExplainer` explainer, we need to set:\n", " \n", " - `training_data`: The data used to extract information such as medians of continuous-valued features. ``training_data`` can be the training dataset for training the machine learning model. If the training dataset is large, ``training_data`` can be a subset of it by applying `omnixai.sampler.tabular.Sampler.subsample`.\n", " - `predict_function`: The prediction function corresponding to the model.\n", " - `mode`: The task type, e.g., \"classification\" or \"regression\".\n", " \n", "In this example, the prediction function is a tensorflow model which is callable, so we can set `predict_function=model`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Binary step: 5 |███████████████████████████████████████-| 99.9% " ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAArwAAACMCAYAAACeX3rHAAAgAElEQVR4Xu1dB1RUV7vd2LsxGntLTOyaWFHRxN5QsVdQo9hb7L2DYBcbgr0i9m4Uu8Yudk1i19hjT6Ji4a1z/sU8GHWYC+Mw5949a731frnt+/bebPac+90bp9IVuoaDHyKgGAKJEsSTFYe9fa9Y5SzX6AhQu0ZXgLr9U7vqcsfKAScGXspARQRovCqyxpoFAtQudaAqAtSuqsyxboEAAy91oCQCNF4laWPRDLzUgMII0HcVJo+lM/BSA2oiQONVkzdWzRVeakBdBOi76nLHyrnCSw0oigCNV1HiWDZHGqgBZRGg7ypLHQvnSAM1oCoCNF5VmWPd1C41oCoC1K6qzLFugQBneKkDJRGg8SpJG4vmDC81oDAC9F2FyWPpDLzUgJoI0HjV5I1Vc4aXGlAXAfquutyxcq7wUgOKIkDjVZQ4ls0ZXmpAWQTou8pSx8I50kANqIoAjVdV5lg3tUsNqIoAtasqc6xbIMAZXupASQRovErSxqI5w0sNKIwAfVdh8lg6Ay81oCYCNF41eWPVnOGlBtRFgL6rLnesnCu81ICiCNB4FSWOZXOGlxpQFgH6rrLUsXCONFADqiJA41WVOdZN7VIDqiJA7arKHOsWCHCGlzpQEgEar5K0sWjO8FIDCiNA31WYPJbOwEsNqIkAjVdN3lg1Z3ipAXURoO+qyx0r5wovNaAoAjReRYlj2ZzhpQaURYC+qyx1LJwjDdSAqgjQeFVljnVTu9SAqghQu6oyx7oFAg49w1uzmjOG9HfHuMnBWLfxABkjAiYE7Gm8KZInxfaN43Dk2EX07D/T5iwUyJcT48d0wNSZa/BryDGbn58ndCwEVNGu8F7hwc1aeeHeg8cIWjAEh45cwPgpwTEG1Gt4G+TMngEenr4IDw+P8Xl4YNwgoFW7WrxTy77m3ffo0gBNGpRHuy4Tcf7i9bgBh1d1eARiFHhLFs8L92ZVkD9PdsSPHx+37zzE0uCd2Lr9qE0bZuC1KZy6OplW4xXNb13ni9SpkkscHj95IQPsNP+1ePrsH4vYxMaIrQE9e7b0GNinORYs3oYjxy9acwj3URgBVbQbOfD+dfshfL3a4+y5K1i0LCTG6LdvUws5c2TEoOFzYnwOHhh3CGjVrhbv1LIvA2/caUDlK2sOvFUqFceIQS3x/n04Dh+7iGfP/kG+PDkwd+EW7Np70qZYVK9SAsMGtuQKr01R1cfJtBpvROBNmDABNm09hNzfZsMPhXNJzQ4ZOS9OA68+GGEX1iKginYH9W2OWjVKyxXeG7fuW9se99MxAlq1qyXERuwr7iL0HuivCUWu8GqCy7A7awq88ePHw4YVXvjiixTo3mc6Tpz886PAVSxfBG1b1kSGDGlw6fJfGD9lBa5eu4MfXQrDd3Q7TPRbgdquZZAta3qcO38NI7wX4MnT/62ydWhbC261XPD06T/Ye+A0WrWoZgq8YnWuV/dGcC6eD6/DwrBpy2HMWbhF3hpbOHsArl+/i2Ohf6Bbx3qYEbgeGzYfNCyxem9cq/FGBN6wsDdwazwU8eI5YdPqMYgfLx6qufWX/7+1R3XUrFYSab5IibPnr2GCXzBu/fUQkU175JhF2LDSC/fuP0Zjj1ES5rataqJtqxoYOykIeb7Lhrq1y6JVO19cunIblSsUxaihP2P6rHVYtmLnR3UqtC5+L/xnb8DioBBkzfIVenVriHx5cyBhgvg4c+4qfCYE4eHfT/VOqyH6c1TthuwKxdD+7iheNA8u/nkTDx8+RY2qJWXgffT4eZSxnggvFz5bpUIxZMqUFnv2nZLjDm/fvpO3l5s2qog0X6TA3XuPMW/xrwjZeVzqP0f2DChfrafkunmTSnBzdUGG9F/IayxZvgNrN3B8zVF/EbRq1zzwJkgQH/17NZV/w1OmTIYbN+9jyoxVOHXmislnjx7/HcKnhQ7FeMII74VSG8Kj27SqIcdskiVNLBfcJvitwIsX/4GB11EV41h1aQq8eXJnw/xZ/SAE+Uu/GR/tROwzz7+vDAybth5Gm5Y1EB7+Hk1beaF0yfzyD/vr12+wYMk2fPN1JlSpWEwGAREIXEoXxHjvDrh77xEOHDyHKpWK4YvUKUyBV2wr5Zwfs+ZsRMYMX6KBWzkZMtZvOiiNNEumtEiUKCEuX70t9xF18qNPBLQar3ngFTrZuMobr1+HoU6jIWjeuBK6dqyL02ev4Oq1u6hVoxQePHyKpq1GI2mSxFH+2PuOaocfyxaGe5sxuHr9LhYE9Jdadm0wCJ08a0cbeM11miRxoiiBN3nyJJg6oRsOHDqLVCmTy/AQsusEhnst0CeZBuvKUbXb2r0amjWqKL9giSBSs7qzDBmWAq8IG/OX/CoDjHOJfDLwHjx8HmuXj8LlK7fl34C8ebJj245jcoTIPPCKa2bPlgG//3ETDer+iKxZ0qF5a2+uKDvo74RW7X5shVf45527j/DoyXO0bVkD//73CrUbDjYFXtG6+HKULFkSmQl27A7FsNHzpTa7daon79BduXYXndvVkV+yhnsvZOB1UL04WlmaAq8wtMljO8sHyMSDZN4j2qLCjz/Ino6H/iFXfXt0ro8mDSuga6+pcvW2nls5+TPPLhOQ7svU8g+7EOyY8cvwZZqUcpXt8NEL6DXAH31/aYJ6dcqiW+9pcvXYrVYZ9O/VTF4rZNdxhGwcj+Mn/0SfAf6AkxPWrxiNi7/fkMcKI/0uVxb0HRyA3w6dczScWY+NEdBqvBGBV4w0bN12BIUL5ZJ6Wb1+v7zjIP8QZ0svV3vFF7LePRrLL1Tde0/D73/eihJ4y5YphHFe7TF7/mZs/vUw1gWPxv6DZ9F/SCD69WwSbeA112nEalnECq85VDs2T8C9e4/g3tbHxijydHGBgKNqVzyUJhYSatYbgJevwjBycCuIETZLgTd41W74zVyDTBnTYvWyEdJ7x4xfig2rvOXKW+DcTdiy/ahcsRMf88AbGf+IQCPme/fsPx0X1PCa0SCgVbvRjTSIRSwRaqvX7Y/378Klz167fg8t2njDyclJ3k1LmjQxKrv2wYLA/siRLQNc6w+UdxHEnTORSSrU6MXAS+VahYCmwCueJp89ozf27j+NgcPnyPmu/PlyoG4tF1Pg9RrWBmKkwfwj9g9/Hy4Db+C8TXKFN2mSRNi5ZaLp2DEjPVG+3Peo13QY7j94Im9pTJ3QVQbeU2cuY9n8wR+c98rVO/Dw9JFG+k3OTPix6i98+tcq6tXeSavxRgReMRYjRmDE6q1YHQiYtwmvXoXJB9pevgxD/WbDJDAtmlZGl/Zu8naaWLGK/JYGseoljFicQ3x5E+F46Kj52Lkn9IPAK24JDx3gEWWkwVyn5oG3UIGv0e5nV+TMnhFJkiZC8mRJcPOvBzJ48KM+Ao6q3V/XjcU//75EwxYjJMjiATO56mthpCHCy8Udkz2/TsKFizfk4obQfdcOdZEmTUr8/egZ+gychT8v/xUl8Irfow6etVHOpRDSpE6JxIkTyv8Tq3liVY8fx0NAq3bNA2+6tKnlKm3B/DmRIkUymQHEmEOtBoMQFvZW+uyxE3+gR9/psvm5M/vI0S4xArM2eLQckTH/iEUKcSeZb2lwPL04WkWaAq8woy1rfRHPyQnNWv9vjtE8tEas8I72XYybtx6Y+hV/sIsU/jbKrVvzYyNWxzr1mCJvLbtWL4XB/VpEWeEVsz4zAtaZzvvqdRhE6LW0cuBooLOe2COg1XgjAm/EDK95BUvmDULWzOlMK7x9ejRGfbdy6PzLFFy+cueD15IJ027asIK8TVu44DeoWV8Y9hs5e9uw3k/o3NMPp05flsH1Z4/qUQJv5BlGUYd54BW3g1OlSo6Bw2bLGeLFcwfK+V0G3tjrxhHO4KjaDVo4BOm/SoMabv0R9uat/KIW3QxvxN26r3NmxNJ5g3HwyHkZbsVHhODmjSvK4LzvwBkMGDY7ik+LsaFBfVvIV/EtXLoNVSsVl78rDLyOoNKP16BVu+aBN0JTfjNW48Chc+jXsylKFMsTJfCKN4KI5yPEcxYbV3pLHVWp3Veu8AqPFq+GFA/NR3zEOEzXTvUYeB1XNg5TmabAK6r2bF1Tfpt6/vw/+Qol8e1MjDVEjDTk/i4r5vn3k3O04nZv4kQJ8d23WeX8ofkfdvPAG3GrWPyR3//bGWmA6dKlNs3witkflzIFIW6jiQchsmRKhwu/X5erAQy8DqMpuxSi1XijC7zi4RmxIhV5hvfe/Sdo/rPXBzO84lxilVaEZGG8Yj5RfMETHzEGIVZ8//jzFk6dvQI31zJIkiSRpsC7eY0PxByvGHEQq7x1XEtzhdcuqrLPRRxVu9071UfTRhXk2JiY4a1TszTECJClFV5xa1mE3oL5v8a3ubLIGV4xA9yutStOnb0sZ9DFKrF4ZaX4HYns0+ILpfhiKUbaDh65IANLlszpGHjtI8MYXUWrds0Db8Qd4BVr9sgH08VD6WIhLfIKrxgpEwsJiRIlQKmS+aW/ioeFG9cvj1+6NpDjLidC/5B3D5zghNkLNnOkIUZsGu8gzYFXQCRMsW6tssicKa2cdxTfyDZvO4JVa/dKBF1KFYBna1f5vkWx/Vjo7/KWb3SBVxwrbiPXcS2D5y/+Q8DcjXJOJ+I/PJEyRVJ06VAXZUoVQKqUyXD3/mNMmrpS3gJh4DWWeLUab3SBV8yLebaqiTq1ykA8RHbu4jVM9Fsptf2pObR5s/oib+7s6Nlvpun9ucK8hw9qJWfLnjx5gcnTV/3vrkbgBtNbGqJb4RVf9MSdksRJEsl5Y/EpXiwPV3h1InFH1a7Q+bCBHnKUTKyaiRU48SCnpcC7ccsh+VCaeLPIjl0nMG7Scnz11Rfy4eNsWb+Ss5biAWafCcvkCFBknxYPJfmO8pTz9Hfu/C3v3Il3/Y70XsiRBgfVulbtmnuneH5h5JDWyJw5nbwDJsbK+vVqGiXwiplw8Y7/ci6F5ZenkT6LpJeKFV/xcLHIBxnSp8GzZ/9i+ard0lf5lgYHFYyDlRWjwOtgPbAcAyKg1XhtDZGYP1y+aKh8PU7tRoOj3GKz9bV4Pn0hoAftRixezAxcL18lxo8xEIhr7RoDZXb5uRBg4P1cyPK8nxWBuDRe8UCbc/G8ciWMf/A/K826PLketBvdm0V0SRybQlxql/ATgdgiwMAbWwR5fJwgEJfGO9e/r3x4YuPWQ5gRsJ5vBYkTBah7UT1ol4FXXf3FpvK41G5s6uaxREAgwMBLHSiJAI1XSdpYtHh7QYJ4Eoewt++JBxFQCgFqVym6WKwZAgy8lISSCNB4laSNRTPwUgMKI0DfVZg8ls4VXmpATQRovGryxqq5wksNqIsAfVdd7lg5RxqoAUURoPEqShzL5kgDNaAsAvRdZalj4ZzhpQZURYDGqypzrJvapQZURYDaVZU51i0Q4AwvdaAkAjReJWlj0ZzhpQYURoC+qzB5LJ2BlxpQEwEar5q8sWrO8FID6iJA31WXO1bOFV5qQFEEaLyKEseyOcNLDSiLAH1XWepYOEcaqAFVEaDxqsoc66Z2qQFVEaB2VWWOdQsEOMNLHSiJAI1XSdpYNGd4qQGFEaDvKkweS2fgpQbURIDGqyZvrJozvNSAugjQd9XljpVzhZcaUBQBGq+ixLFszvBSA8oiQN9VljoWzpEGakBVBGi8qjLHuqldakBVBKhdVZlj3QIBzvBSB0oiQONVkjYWzRleakBhBOi7CpPH0uHkUrFrOHEgAqohkCB+PFny23fvVSud9RocAWrX4AJQuH1qV2HyWDqcfqrSnYGXQiACRIAIEAEiQASIABHQLQJOT54+Y+DVLb36bSyek5Ns7n045atflvXZGbWrT16N0BW1awSW9duj08O/nzIx6Jdf3XaWKOH/RhrC3nCkQbck67QxalenxBqgLWrXACTruEUGXh2Tq+fWaLx6ZlffvVG7+uZXz91Ru3pmV/+9MfDqn2Nddkjj1SWthmiK2jUEzbpsktrVJa2GaYqB1zBU66tRGq+++DRSN9SukdjWV6/Urr74NFo3DLxGY1wn/dJ4dUKkAdugdg1Iuk5apnZ1QqRB22DgNSjxqrdN41WdQePWT+0al3vVO6d2VWfQ2PUz8Bqbf2W7p/EqS53hC6d2DS8BZQGgdpWljoWL/7QwX0tGHaiIAI1XRdZYs0CA2qUOVEWA2lWVOdYtEGDgpQ6URIDGqyRtLJqBlxpQGAH6rsLksXQGXmpATQRovGryxqq5wksNqIsAfVdd7lg5V3ipAUURoPEqShzL5kgDNaAsAvRdZalj4RxpoAZURYDGqypzrJvapQZURYDaVZU51i0Q4AwvdaAkAjReJWlj0ZzhpQYURoC+qzB5LJ2BlxpQEwEar5q8sWrO8FID6iJA31WXO1bOFV5qQFEEaLyKEseyOcNLDSiLAH1XWepYOEcaqAFVEaDxqsoc66Z2qQFVEaB2VWWOdQsEOMNLHSiJAI1XSdpYNGd4qQGFEaDvKkweS2fgpQbURIDGqyZvrJozvNSAugjQd9XljpXHYIX35MlQ9OjeFffv30OBAgXh7x+IDBkzRsFyduAseHuPjvKzly9fYt26jShdxgUh27fBy2sUXr56idSpUsN37HgUK1YcDx48wKiRw7F7904kSJAQlSpXxtixE5AwYUJ5rn///Rdeo0di69YtiBcvHry8fVCzpit5NCACtjDeoKBl8PXxRljYa9SsWQvjxk9E/PjxP0DzwIH9GDxoAB4/fozChb/H1GkzkDZtWrlfYIA/AgNnISzsDerVq49hw0eazpElc3qTduW+gXNRtVp1nD51CqNGDcfFixeQLFkydOrcFW3btjMgi8Zs2Z7a/ZTXRqfdT+l6z57dGD/OF9euXUXSpEnR+ue26NathzGJNGDXMdGuNZlBQGntfgaEnS3bCAFNIw3v3r1DKedi8B07AZUqVYYItnv37sWSpUEWy3n69CkqlC+Hg4eOIkGCBChcKB82b9mOb775Bvv378PAAf1w4LfDOHbsKK5cuYyGDRtDXKtVyxYyILRp4ynP7+HeHAULFUKvXn2iBAkbYcHTKIRATIw3cntXLl9G/fpu2LR5KzJmzIROndqjWNHi6NS5SxQUhHbL/1QWwStWIU+evBg3zhfnzp7FosVL8dtvBzBs2GCsWbMBiRMnRkuP5qhUqQo6dOyEJ0+eoHatGlLX5h8RtHPlyoWSJZ1x/949VKlSEStWrkbevPkUYoClxhQBe2n3zZs3n/RaS9q1tG3lymAUKlRYalV8AaxevTJmzgxA8eIlYgoHj1MIAa3atTYzWLufQlCxVAdEQFPgDQ09gSGDB2LL1u2ylffv36NggTw4fOQ4UqVK/cn2pk6dgmdPn2LosBH4558XcHYujjNnLsiVsL///huVKv6E02fOf3B8wCx/3Lh5A2PG+OL8+XPo1bMHtm3f6YAwsiR7I6DVeM3rmzbND8+ePcOQIcPkpnPnzuKXHt2wY+eeKLvuCNmO4BXLMXv2PJPm8+X9FkePncSUyRORI2dOtG7dRm47e+YMfunZDTt37sWlS3+if78+WLN2Q7TQtPRogSZNm8HVtVa0+3IH9RGwl3Ytee3IEcM+qV1L28zRb9umFWrXdkPdevXVJ4YdRIuAVu1amxms3S/aArkDEbCAgKbAu2rVCuzbuwdTp800nbJ6tcrw8R2HIkWKfvQyb9++hXPJoli/YTOyZs0m9xGrYn/duoW2nu3hN2USmjVrgXr1G3xwvKfnz6hevYZc8Z09OwBnz57BPy9e4NKlS3J1WIxCZMqUmQQbEAGtxmsOUc+e3eHsXApNmzaXm169eoU8ub/BjZt3ouy6adNGbPt1K6ZN/3/NixXfadP9sWzZEjni0KzZ/84hVoOLFf0eV67ewKlTJ9Ggfl1kyJABb9+9ReXKVTF48FAkT548yvnDwsJQyrk4Nmz8/98PA9JpqJbtpV1LXjtwYP9PatfStgiixGLH3r170L9/H2zdGmIa8TEUkQZsVqt2rc0M1u5nQMjZsg0R0BR4Fy9aKEOnmHWM+NR1q4U+ffujbNlyHy1r7ZrV2LBhHeYvWGzafvnyJTRt0hDx4sdH1ixZETh7HtKlSxfleDEr5jPGC5u3bJNjEKNHjcD69WuxcNFSOTs8d+5sOQu8PHiVDeHgqVRBQKvxmvfVsUM71KhRE25165k2pf8qDe4/eAwnJyfTz+7evYMa1ati46YtyJw5C4KXB6FPn57YHrIL165dw9w5gVgWtELuL1bGVq9eictXbsh/v3jxHClTppIryX379kLatOng4zM2Sik+Pt74559/4O3towr0rDOWCNhLu6LMT3ntxo0bPqldS9vEOQf074ugoKXyOQvfsePQqFGTWCLCw1VBQKt2rc0M1u6nCk6s0zER0BR4xR/zHSEh8J8VaOqmYsUfMWHCZBQtWuyjHVarWgnDh49EGZeycvvDhw/h6loNCxcuQb58+bF8+TLMmD4NO3ftRaJEieQ+Yni9a5dOcq4xS5as8mdjxnjJB9UGDBgk/y1mfr779mu5mhY5oDgmzKzK1ghoNV4xfjBlyiRZxoSJk3Hw4G8oWqQY3D1amsJpgfx5cPPW3Q9K3bJlsxxfEDORzZu3wMyZM+Tog3hwbeKEcRABIUmSJPDwaIWAAH/s23/wg3OImeFmzRvj6NFQ07YFC+Zh86aNWLos2KR9W+PE8zkeAvbSbnRea0m71uhafOHr1q0z3N09THdKHA9tVmRLBLRq19rMYO1+tuyF5zIeApoC75nTp9GrVw/TnKMYV8ibJxeOHT+FNGnSfIDe0aNH0K9vb+zZe8C0bc3qVdi9e1eUW8Q1a1TF2HETUahQITmr26G9J+YvWITvvsttOm7pksUQcz4TJ02RPxO3ggvkz41Ll68bjzV2HOv/WpX/zBm4d+8uRo7yMn3JEjPiu/fst4jujRs34N6iKfYfOPTBfuJhtJOhJ6LcAYnY6Y8/fpe6jvhdCA4OwpIlixEUtAIpUqQgowZCQGtoMIfGWu1G57WRz2tJu5a2iZW50JMnMHnyVAMxaNxWtWrX2sxg7X7GRZ6d2wIBTYFXzG25lHGG9xhfVKxYSb6lQbwiTDyYI4LAnDkBGD16jKku8UBDhQqVTKtoYoN4JVPHju3kg28iJItVArc6rjjw2yHcv38fnm1/xuw585A7d54o/Yn5SBcXZyxfvkoGY/Eg3OnTpzB37gJb4MBzKIaAVuM1b+/mzZtSd2J2VrylQWgyf7786N2n30e1LI4XAblDh3bw8Ggp58ojf06cOI4unTti2bJgfJMrF4SBp02XVt6hEK/T696ts7yjIcZ/1q9fhzlzAhEUFIwUKVIqhjzLjS0C9tKuJa+N/JCxuXYt6frQwd9Q0rmUfOBYvKVBeHyDBo2ieHxs8eHxjouAVu1amxks7ee4aLAy1RDQFHhFc2IFtmvXTrj9118ylM6YGYAcOXJArOaKnx85ckKOGNy6dRNinOFE6Bn5vsbIH/HHfv68uQgPD0eSpEkwaNBQVK5cRa6ArV27Wo4uRHzEreLrN27Lf+7cuUO+JULMPH7//Q+YNNkP6dOnVw1z1msDBLQa78cuKebLR44cjpcv/5OvE5s8Zap8vZi5lsUf9dDQUPnAWZeu3U0PqYlzfl+4gNSxeFvDsGEjUKJESZNWBw3qL7WaLGlSuLnVQ7/+A+XoQsECefHw4YMoozgVKlaSq7386B8Be2r3U15rSbuWtnXu1AH79u2VgVf4epMmTfFLz94cK9O/bGWHMdGutZnhU/sZBFq2aQcENAdeO9TESxCBaBGIifFGe1LuQATsgAC1aweQeYnPggC1+1lg5UnthAADr52A5mVsiwCN17Z48mz2Q4DatR/WvJJtEaB2bYsnz2ZfBBh47Ys3r2YjBGi8NgKSp7E7AtSu3SHnBW2EALVrIyB5mjhBgIE3TmDnRWOLAI03tgjy+LhCgNqNK+R53dgiQO3GFkEeH5cIMPDGJfq8dowRoPHGGDoeGMcIULtxTAAvH2MEqN0YQ8cDHQABBl4HIIElaEeAxqsdMx7hGAhQu47BA6vQjgC1qx0zHuE4CDDwOg4XrEQDAjReDWBxV4dCgNp1KDpYjAYEqF0NYHFXh0OAgdfhKGFB1iBA47UGJe7jiAhQu47ICmuyBgFq1xqUuI+jIsDA66jMsC6LCNB4KRBVEaB2VWWOdVO71IDKCDDwqsyegWun8RqYfMVbp3YVJ9DA5VO7BiZfB60z8OqARCO2QOM1Iuv66Jna1QePRuyC2jUi6/rpmYFXP1waqhMar6Ho1lWz1K6u6DRUM9SuoejWXbMMvLqj1BgN0XiNwbMeu6R29ciqMXqido3Bs167ZODVK7M674vGq3OCddwetatjcnXeGrWrc4J13h4Dr84J1mt7NF69Mqv/vqhd/XOs1w6pXb0ya4y+GHiNwbPuuqTx6o5SwzRE7RqGat01Su3qjlJDNcTAayi69dMsjVc/XBqtE2rXaIzrp19qVz9cGrETpydPn4UbsXH2rDYC8ZycZAPvwylftZk0XvXUrvE410vH1K5emDRmH07h4UwMxqSeXRMBIkAEiAARIAJEwBgIOIW9ecclMmNwrasuEyaIJ/t58/a9rvpiM/pHgNrVP8d67ZDa1SuzxujL6VXYWwZeY3Ctqy4TJ4wv+3n95p2u+mIz+keA2tU/x3rtkNrVK7PG6IuB1xg8665LGq/uKDVMQ9SuYajWXaPUru4oNVRDDLyGols/zdJ49cOl0Tqhdo3GuH76pXb1w6URO2HgNSLrOuiZxqsDEg3aArVrUOJ10Da1qwMSDdwCA6+ByVe5dRqvyuwZu3Zq19j8q9w9tasye6ydgZcaUBIBGq+StLFoANQuZaAqAtSuqsyxboEAAy91oCQCNF4laWPRDLzUgMII0HcVJo+lM/BSA2oiQONVkzdWzRVeakBdBOi76pc61u0AAAuFSURBVHLHyrnCSw0oigCNV1HiWDZHGqgBZRGg7ypLHQvnSAM1oCoCNF5VmWPd1C41oCoC1K6qzLFugQBneKkDJRGg8SpJG4vmDC81oDAC9F2FyWPpDLzUgJoI0HjV5I1Vc4aXGlAXAfquutyxcq7wUgOKIkDjVZQ4ls0ZXmpAWQTou8pSx8I50kANqIoAjVdV5lg3tUsNqIoAtasqc6xbIMAZXupASQRovErSxqI5w0sNKIwAfVdh8lg6Ay81oCYCNF41eWPVnOGlBtRFgL6rLnesnCu81ICiCNB4FSWOZXOGlxpQFgH6rrLUsfCYjDQcP34MHdq1w717d1GocGHMX7AImTJl+iiYM6ZPQ0DALLx+9QoNGjbCGB9fvHnzBmN9fbBo4QK8ffsW+QsUgP+sQGTLlk2eY/GihRg5Yjhev36NOm51MXXadMSPH19uCw8Ph88Yb0yb6oe79x+SQAMjYAvjtaS1CGiFhocNHRIF6ZcvX2JbyA7kzp0HgwcOQEjIdiRMmBBVq1WH39Rp8n/fv3//k9v+++8/dOnUEUeOHsH79+9Rq1ZtjBs/AfHixTMwo8Zp3RG0W67cj5/0WkvaFdrv3asnQrZvQ8JEiTBo8BC4u3sYhzyDdxpT7W7dsgWtWrojZOcufP/9Dx9FUUu2MDgNbD+GCGia4X337h0KFsgHP79pqFqtGkQY2LVrJ1avWffB5f2mTMbevXsQOHsu0qVLZ9r++PFjBAbMQucuXZEqVSp4e43GhfPnsTRoOS79+SdqVK+Knbv3InPmzPi5VUuUKFkSPX7pKYNyS48WyJQpM4KXB+H23fsxbJmH6QGBmBpvRO+WtGYJnydPnsC5RDGcPnsep0+fwqVLl9CsWXOI341GDevD1bUWOnTshMOHD31y27ixvrhx4wamz5gpde1WpxY8PdujQcOGeqCGPUSDgCNo969btz7ptZa06zV6FG7fvo0ZM/1x584d/FTOBavWrEWRIkXJuwEQiIl2p0yehC2bN+Pff//FzFmzPhp4tWQLA8DMFj8TApoC77FjR9G3d2/s2bdfliNWp3Jmz4qz5y8iderUphKFePPnzY2Dh48ibdq0FksXocGzTRscOxGKiRPG49mzZxg12ksec+bMaXRs3x4HDx+R/96+bRsqV6mCTBm+wv2Hjz4TJDytCgjExHgj9xWd1j6FwYTx4/D06VN4eY/5YBdx5+H69euYOGmyxW2DBg5AlixZ0KVrN7lfj+7dUKRIEbT+uY0K0LPGWCLgCNrVov/Iui5bpjTmzpuPPHnzShSm+k2RAXjsuPGxRIWHq4BATLS7d88elCpdGq41qmPi5MkfDbzWZgsVMGKNjouApsAbFLQMu3buxOw5c00d/VjWBZOmTEHx4iVMPzt37izae3qifIUK2PbrrzIMj/b2hotL2Q+QEKu9J0+ehP+sAHTq2AFlXFzg4dFS7vfq1Stkzpgej58+Nx0nxiCyZMrAwOu4mrJLZTEx3siFWaM180aE9grkz4uQHbuQPXv2D/p0b94MrrVryxVf80/kbVevXEHdunUwYOAgvHj+HGvWrMaateuRPHlyu2DHi8QtAo6gXS36j6zd4kWLYPWatciRM6cEcf26dViyZBFWrloTt6Dy6nZBIDbarVj+J0z28/to4LU2W9ilSV5EtwhoCrzz5s7BqVOn5FxtxKdqlUoYPHgofipf3vQzsRLbpHFDzJk7H/Xq18epUyfRqEF9nDl3IcofdbEyUL1aFWzavBU5cuRA65YeqFWnDho2bGQ6V9LECfHfqzA4OTnJnzHw6laLmhqLjfGKC1mjNfOCVqwIxprVq7A8eOUHte7csQMjhg/D7r37kCBBgijbzbeJWfTBgwbKoPvo778xw38WGjduoql/7qwuAo6gXWv1b67dnr/0kB4+2ssb9+7dQ/OmTfBV+q+wYuVqdQlh5VYjEBvtWgq81mYLqwvljkTgIwhoCrzLlwdh29atmL9wkelUpUqWwLQZM1CiREnTz8Qq8BhvL+zYtdv0MxFsR44aDWfnUvJnjx49Qs3q1eA9xkeOKYhP504dUaJECfzcpq389/Pnz5EjWxY8efbCdB4GXupYIKDVeMWDkmJ2Vnymz/TH/n37otWaOdLlXMrA28cHP/74U5RN4mELz7ZtsHHTFtPDlxE7fGybCMZC/1P8puLBgwdo0awp2rZrhxYt3EmuARBwBO1a47Uf064Y5+nbuxdCQ0ORKXMmlCpVWmp58hQ/AzDHFrVqNzJilgKvtdmCDBCB2CCgKfCePBmKLp06mWZqRfjMmjkjLvz+J7788ktTHeKWrZtbbZw9d8H0MyF2Mfrwww9F5Jxubdea8mG0yA/qiAfd7t69C9+x4+RxwnA7d+yIo8dPMPDGhmUdHhsb4xVwWKO1yLAdOnQQ3bt2lbPmkT9nz55BS3d3LA9eYZprjNj+qW2FCubHpk1bTLeFd4SEICDAn7eFdajTj7XkCNqNTv+WdB25p3Zt26BajRpR7soZhEZDthkb7VoKvNZmC0OCzqZthoCmwCseUvvh+0KYOHEyqlStKt/SsHHjBvy6LQTXr13DzJkz5OuVxKdShfJo4e6ONm09IcJC61YtcebseflUet06tdGxc+cPTPLG9euoUrkiQnbulm9pELfdChQsKF99E/HhCq/NuFf6RLExXtG4Ja2Za1ns37xZE1SpUtV090H87I/ff0eL5s2wZOky5M2XLwqelrbVr+eGatWqy7c5iN+pfn37IEWKFBgxcpTSnLB46xBwBO1a0r8l7UZ0KHQr5i6nTJokF0DEq/j40T8CsdGueeCN7LOWsoX+UWWH9kJAU+AVRYlv/uKtCrdu3UTevPkwb/4C5Pz6axlqxW3dc+cvynlbscrbzrMtrl69Il8l5jdtmhx7mDtnNrp07mR6t25Eo2L2sWRJZ4g5SfFuU/GuUhEK/AMCkThxYgZeeylCkevExngjWvyU1sy1LF4hVs6lNP64dAVJkyY1IdTKw13qNfL7c5MkSYJHT57B0jZxvu7duuDatWvy3dLitrB4f2+yZMkUQZ9lxgYBR9CuqP9T+rek3SNHDkM8xCbexCMeQh47foJcnODHGAjERrvmgdfcZz+VLYyBLLu0BwKaA689iuI1iEB0CMTGeKM7N7cTgc+JALX7OdHluT8nAtTu50SX5/7cCDDwfm6Eef7PggCN97PAypPaAQFq1w4g8xKfBQFq97PAypPaCQEGXjsBzcvYFgEar23x5NnshwC1az+seSXbIkDt2hZPns2+CDDw2hdvXs1GCNB4bQQkT2N3BKhdu0POC9oIAWrXRkDyNHGCAANvnMDOi8YWARpvbBHk8XGFALUbV8jzurFFgNqNLYI8Pi4RYOCNS/R57RgjQOONMXQ8MI4RoHbjmABePsYIULsxho4HOgACDLwOQAJL0I4AjVc7ZjzCMRCgdh2DB1ahHQFqVztmPMJxEGDgdRwuWIkGBGi8GsDirg6FALXrUHSwGA0IULsawOKuDocAA6/DUcKCrEGAxmsNStzHERGgdh2RFdZkDQLUrjUocR9HRYCB11GZYV0WEaDxUiCqIkDtqsoc66Z2qQGVEWDgVZk9A9dO4zUw+Yq3Tu0qTqCBy6d2DUy+Dlpn4NUBiUZsgcZrRNb10TO1qw8ejdgFtWtE1vXTMwOvfrg0VCc0XkPRratmqV1d0WmoZqhdQ9Gtu2YZeHVHqTEaovEag2c9dknt6pFVY/RE7RqDZ712ycCrV2Z13heNV+cE67g9alfH5Oq8NWpX5wTrvD0GXp0TrNf2aLx6ZVb/fVG7+udYrx1Su3pl1hh9MfAag2fddUnj1R2lhmmI2jUM1bprlNrVHaWGaoiB11B066dZGq9+uDRaJ9Su0RjXT7/Urn64NGInDLxGZF0HPdN4dUCiQVugdg1KvA7apnZ1QKKBW/g/YPMTWjsthOgAAAAASUVORK5CYII=" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "explainer = CounterfactualExplainer(\n", " training_data=tabular_data,\n", " predict_function=model\n", ")\n", "explanations = explainer.explain(x_test[:1])\n", "explanations.ipython_plot()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }