{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### SHAP for income prediction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SHAP is a model-agnostic explainer for tabular data. If using this explainer, please cite the original work: https://github.com/slundberg/shap." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "import sklearn\n", "import xgboost\n", "import numpy as np\n", "import pandas as pd\n", "from omnixai.data.tabular import Tabular\n", "from omnixai.preprocessing.tabular import TabularTransform\n", "from omnixai.explainers.tabular import ShapTabular" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dataset used in this example is for income prediction (https://archive.ics.uci.edu/ml/datasets/adult). We recommend using `Tabular` to represent a tabular dataset, which can be constructed from a pandas dataframe or a numpy array. To create a `Tabular` instance given a numpy array, one needs to specify the data, the feature names, the categorical feature names (if exists) and the target/label column name (if exists)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Age Workclass fnlwgt Education Education-Num \\\n", "0 39 State-gov 77516 Bachelors 13 \n", "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", "2 38 Private 215646 HS-grad 9 \n", "3 53 Private 234721 11th 7 \n", "4 28 Private 338409 Bachelors 13 \n", "... .. ... ... ... ... \n", "32556 27 Private 257302 Assoc-acdm 12 \n", "32557 40 Private 154374 HS-grad 9 \n", "32558 58 Private 151910 HS-grad 9 \n", "32559 22 Private 201490 HS-grad 9 \n", "32560 52 Self-emp-inc 287927 HS-grad 9 \n", "\n", " Marital Status Occupation Relationship Race Sex \\\n", "0 Never-married Adm-clerical Not-in-family White Male \n", "1 Married-civ-spouse Exec-managerial Husband White Male \n", "2 Divorced Handlers-cleaners Not-in-family White Male \n", "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", "4 Married-civ-spouse Prof-specialty Wife Black Female \n", "... ... ... ... ... ... \n", "32556 Married-civ-spouse Tech-support Wife White Female \n", "32557 Married-civ-spouse Machine-op-inspct Husband White Male \n", "32558 Widowed Adm-clerical Unmarried White Female \n", "32559 Never-married Adm-clerical Own-child White Male \n", "32560 Married-civ-spouse Exec-managerial Wife White Female \n", "\n", " Capital Gain Capital Loss Hours per week Country label \n", "0 2174 0 40 United-States <=50K \n", "1 0 0 13 United-States <=50K \n", "2 0 0 40 United-States <=50K \n", "3 0 0 40 United-States <=50K \n", "4 0 0 40 Cuba <=50K \n", "... ... ... ... ... ... \n", "32556 0 0 38 United-States <=50K \n", "32557 0 0 40 United-States >50K \n", "32558 0 0 40 United-States <=50K \n", "32559 0 0 20 United-States <=50K \n", "32560 15024 0 40 United-States >50K \n", "\n", "[32561 rows x 15 columns]\n" ] } ], "source": [ "feature_names = [\n", " \"Age\", \"Workclass\", \"fnlwgt\", \"Education\",\n", " \"Education-Num\", \"Marital Status\", \"Occupation\",\n", " \"Relationship\", \"Race\", \"Sex\", \"Capital Gain\",\n", " \"Capital Loss\", \"Hours per week\", \"Country\", \"label\"\n", "]\n", "data = np.genfromtxt(os.path.join('../data', 'adult.data'), delimiter=', ', dtype=str)\n", "tabular_data = Tabular(\n", " data,\n", " feature_columns=feature_names,\n", " categorical_columns=[feature_names[i] for i in [1, 3, 5, 6, 7, 8, 9, 13]],\n", " target_column='label'\n", ")\n", "print(tabular_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`TabularTransform` is a special transform designed for tabular data. By default, it converts categorical features into one-hot encoding, and keeps continuous-valued features (if one wants to normalize continuous-valued features, set the parameter `cont_transform` in `TabularTransform` to `Standard` or `MinMax`). The `transform` method of `TabularTransform` will transform a `Tabular` instance into a numpy array. If the `Tabular` instance has a target/label column, the last column of the transformed numpy array will be the target/label. \n", "\n", "If one wants some other transformations that are not supported in the library, one can simply convert the `Tabular` instance into a pandas dataframe by calling `Tabular.to_pd()` and try different transformations with it.\n", "\n", "After data preprocessing, we can train a XGBoost classifier for this task (one may try other classifiers). " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training data shape: (26048, 108)\n", "Test data shape: (6513, 108)\n", "Test accuracy: 0.8668816213726394\n" ] } ], "source": [ "np.random.seed(1)\n", "transformer = TabularTransform().fit(tabular_data)\n", "class_names = transformer.class_names\n", "x = transformer.transform(tabular_data)\n", "train, test, labels_train, labels_test = \\\n", " sklearn.model_selection.train_test_split(x[:, :-1], x[:, -1], train_size=0.80)\n", "print('Training data shape: {}'.format(train.shape))\n", "print('Test data shape: {}'.format(test.shape))\n", "\n", "gbtree = xgboost.XGBClassifier(n_estimators=300, max_depth=5)\n", "gbtree.fit(train, labels_train)\n", "print('Test accuracy: {}'.format(\n", " sklearn.metrics.accuracy_score(labels_test, gbtree.predict(test))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The prediction function takes a `Tabular` instance as its inputs, and outputs the class probabilities for classification tasks or the estimated values for regression tasks. In this example, we simply call `transformer.transform` to do data preprocessing followed by the prediction function of `gbtree`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "predict_function=lambda z: gbtree.predict_proba(transformer.transform(z))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To initialize a SHAP explainer, we need to set:\n", " \n", " - `training_data`: The data used to initialize a SHAP explainer. ``training_data`` can be the training dataset for training the machine learning model. If the training dataset is too large, ``training_data`` can be a subset of it by applying `omnixai.sampler.tabular.Sampler.subsample`.\n", " - `predict_function`: The prediction function corresponding to the model.\n", " - `mode`: The task type, e.g., \"classification\" or \"regression\"." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "explainer = ShapTabular(\n", " training_data=tabular_data,\n", " predict_function=predict_function,\n", " nsamples=100\n", ")\n", "# Apply an inverse transform, i.e., converting the numpy array back to `Tabular`\n", "test_instances = transformer.invert(test)\n", "test_x = test_instances[1653:1655]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SHAP generates local explanations, e.g. `explainer.explain` is called given the test instances. `ipython_plot` plots the generated explanations in IPython. Parameter `index` indicates which instance in `test_x` to plot, e.g., `index = 0` means plotting the first instance in `test_x`." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "16e23bc7be4445b5ade249aa6647a32d", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/2 [00:00