{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### TabularExplainer for income prediction (classification)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The class `TabularExplainer` is designed for tabular data, acting as a factory of the supported tabular explainers such as LIME, SHAP and MACE. `TabularExplainer` provides a unified easy-to-use interface for all the supported explainers. In practice, we recommend applying `TabularExplainer` to generate explanations instead of using a specific explainer in the package `omnixai.explainers.tabular`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": true }, "outputs": [], "source": [ "import os\n", "import sklearn\n", "import sklearn.datasets\n", "import sklearn.ensemble\n", "import xgboost\n", "import numpy as np\n", "import pandas as pd\n", "\n", "from omnixai.data.tabular import Tabular\n", "from omnixai.preprocessing.tabular import TabularTransform\n", "from omnixai.explainers.tabular import TabularExplainer\n", "from omnixai.visualization.dashboard import Dashboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dataset used in this example is for income prediction (https://archive.ics.uci.edu/ml/datasets/adult). We recommend using `Tabular` to represent a tabular dataset that can be constructed from a pandas dataframe or a numpy array. To create a `Tabular` instance given a pandas dataframe, one needs to specify the dataframe, the categorical feature names (if exists) and the target/label column name (if exists). The package `omnixai.preprocessing` provides several useful preprocessing functions for a `Tabular` data. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Age Workclass fnlwgt Education Education-Num \\\n", "0 39 State-gov 77516 Bachelors 13 \n", "1 50 Self-emp-not-inc 83311 Bachelors 13 \n", "2 38 Private 215646 HS-grad 9 \n", "3 53 Private 234721 11th 7 \n", "4 28 Private 338409 Bachelors 13 \n", "... .. ... ... ... ... \n", "32556 27 Private 257302 Assoc-acdm 12 \n", "32557 40 Private 154374 HS-grad 9 \n", "32558 58 Private 151910 HS-grad 9 \n", "32559 22 Private 201490 HS-grad 9 \n", "32560 52 Self-emp-inc 287927 HS-grad 9 \n", "\n", " Marital Status Occupation Relationship Race Sex \\\n", "0 Never-married Adm-clerical Not-in-family White Male \n", "1 Married-civ-spouse Exec-managerial Husband White Male \n", "2 Divorced Handlers-cleaners Not-in-family White Male \n", "3 Married-civ-spouse Handlers-cleaners Husband Black Male \n", "4 Married-civ-spouse Prof-specialty Wife Black Female \n", "... ... ... ... ... ... \n", "32556 Married-civ-spouse Tech-support Wife White Female \n", "32557 Married-civ-spouse Machine-op-inspct Husband White Male \n", "32558 Widowed Adm-clerical Unmarried White Female \n", "32559 Never-married Adm-clerical Own-child White Male \n", "32560 Married-civ-spouse Exec-managerial Wife White Female \n", "\n", " Capital Gain Capital Loss Hours per week Country label \n", "0 2174 0 40 United-States <=50K \n", "1 0 0 13 United-States <=50K \n", "2 0 0 40 United-States <=50K \n", "3 0 0 40 United-States <=50K \n", "4 0 0 40 Cuba <=50K \n", "... ... ... ... ... ... \n", "32556 0 0 38 United-States <=50K \n", "32557 0 0 40 United-States >50K \n", "32558 0 0 40 United-States <=50K \n", "32559 0 0 20 United-States <=50K \n", "32560 15024 0 40 United-States >50K \n", "\n", "[32561 rows x 15 columns]\n" ] } ], "source": [ "# Load the dataset\n", "feature_names = [\n", " \"Age\", \"Workclass\", \"fnlwgt\", \"Education\",\n", " \"Education-Num\", \"Marital Status\", \"Occupation\",\n", " \"Relationship\", \"Race\", \"Sex\", \"Capital Gain\",\n", " \"Capital Loss\", \"Hours per week\", \"Country\", \"label\"\n", "]\n", "df = pd.DataFrame(\n", " np.genfromtxt(os.path.join('data', 'adult.data'), delimiter=', ', dtype=str),\n", " columns=feature_names\n", ")\n", "tabular_data = Tabular(\n", " data=df,\n", " categorical_columns=[feature_names[i] for i in [1, 3, 5, 6, 7, 8, 9, 13]],\n", " target_column='label'\n", ")\n", "print(tabular_data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`TabularTransform` is a special transform designed for tabular data. By default, it converts categorical features into one-hot encoding, and keeps continuous-valued features (if one wants to normalize continuous-valued features, set the parameter `cont_transform` in `TabularTransform` to `Standard` or `MinMax`). The `transform` method of `TabularTransform` will transform a `Tabular` instance into a numpy array. If the `Tabular` instance has a target/label column, the last column of the transformed numpy array will be the target/label. \n", "\n", "If some other transformations that are not supported in the library are necessary, one can simply convert the `Tabular` instance into a pandas dataframe by calling `Tabular.to_pd()` and try different transformations with it.\n", "\n", "After data preprocessing, we can train a XGBoost classifier for this task (one may try other classifiers). " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training data shape: (26048, 108)\n", "Test data shape: (6513, 108)\n", "Test accuracy: 0.8668816213726394\n" ] } ], "source": [ "# Train an XGBoost model\n", "np.random.seed(1)\n", "transformer = TabularTransform().fit(tabular_data)\n", "class_names = transformer.class_names\n", "x = transformer.transform(tabular_data)\n", "train, test, train_labels, test_labels = \\\n", " sklearn.model_selection.train_test_split(x[:, :-1], x[:, -1], train_size=0.80)\n", "print('Training data shape: {}'.format(train.shape))\n", "print('Test data shape: {}'.format(test.shape))\n", "\n", "gbtree = xgboost.XGBClassifier(n_estimators=300, max_depth=5)\n", "gbtree.fit(train, train_labels)\n", "print('Test accuracy: {}'.format(\n", " sklearn.metrics.accuracy_score(test_labels, gbtree.predict(test))))\n", "\n", "# Convert the transformed data back to Tabular instances\n", "train_data = transformer.invert(train)\n", "test_data = transformer.invert(test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To initialize `TabularExplainer`, we need to set the following parameters:\n", "\n", " - `explainers`: The names of the explainers to apply, e.g., [\"lime\", \"shap\", \"mace\", \"pdp\"].\n", " - `data`: The data used to initialize explainers. ``data`` is the training dataset for training the machine learning model. If the training dataset is too large, ``data`` can be a subset of it by applying `omnixai.sampler.tabular.Sampler.subsample`.\n", " - `model`: The ML model to explain, e.g., a scikit-learn model, a tensorflow model, a pytorch model or a black-box prediction function.\n", " - `preprocess`: The preprocessing function converting the raw data (a `Tabular` instance) into the inputs of `model`.\n", " - `postprocess` (optional): The postprocessing function transforming the outputs of ``model`` to a user-specific form, e.g., the predicted probability for each class. The output of `postprocess` should be a numpy array.\n", " - `mode`: The task type, e.g., \"classification\" or \"regression\".\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The preprocessing function takes a `Tabular` instance as its input and outputs the processed features that the ML model consumes. In this example, we simply call `transformer.transform`. If one uses some special transforms on pandas dataframes, the preprocess function has this kind of format: `lambda z: some_transform(z.to_pd())`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "preprocess = lambda z: transformer.transform(z)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are now ready to create a `TabularExplainer`. `params` in `TabularExplainer` allows us to set parameters for each explainer applied here. For example, \"kernel_width\" for LIME is set to 3. \n", "\n", "In this example, LIME, SHAP and MACE generate local explanations while PDP (partial dependence plot) generates global explanations. `explainers.explain` returns the local explanations generated by the three methods given the test instances, and `explainers.explain_global` returns the global explanations generated by PDP. `TabularExplainer` hides all the details behind the explainers, so we can simply call these two methods to generate explanations." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9e03f8867c644333908987da5848c767", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/5 [00:00