L2X (learning to explain) for income prediction
This is an example of the L2X explainer. Different from LIME, SHAP or MACE, L2X trains a separate model for explanation. The advantage of L2X is that it generates explanations fast after the explanation model is trained. The disadvantage is that the quality of the explanations highly depend on the trained explanation model, which can be affected by multiple factors, e.g., the network structure of the explanation model, the training hyperparameters.
For tabular data, we implement the default explanation model in omnixai.explainers.tabular.agnostic.L2X.l2x
. One may implement other models by following the same interface (please refer to the docs for more details). If using this explainer, please cite the original work: “Learning to Explain: An Information-Theoretic Perspective on Model Interpretation, Jianbo Chen, Le Song, Martin J. Wainwright, Michael I. Jordan, https://arxiv.org/abs/1802.07814”.
[1]:
# This default renderer is used for sphinx docs only. Please delete this cell in IPython.
import plotly.io as pio
pio.renderers.default = "png"
[2]:
import os
import sklearn
import xgboost
import numpy as np
import pandas as pd
from omnixai.data.tabular import Tabular
from omnixai.preprocessing.tabular import TabularTransform
The dataset used in this example is for income prediction (https://archive.ics.uci.edu/ml/datasets/adult). We recommend using Tabular
to represent a tabular dataset, which can be constructed from a pandas dataframe or a numpy array. To create a Tabular
instance given a numpy array, one needs to specify the data, the feature names, the categorical feature names (if exists) and the target/label column name (if exists).
[3]:
feature_names = [
"Age", "Workclass", "fnlwgt", "Education",
"Education-Num", "Marital Status", "Occupation",
"Relationship", "Race", "Sex", "Capital Gain",
"Capital Loss", "Hours per week", "Country", "label"
]
data = np.genfromtxt(os.path.join('../data', 'adult.data'), delimiter=', ', dtype=str)
tabular_data = Tabular(
data,
feature_columns=feature_names,
categorical_columns=[feature_names[i] for i in [1, 3, 5, 6, 7, 8, 9, 13]],
target_column='label'
)
TabularTransform
is a special transform designed for tabular data. By default, it converts categorical features into one-hot encoding, and keeps continuous-valued features (if one wants to normalize continuous-valued features, set the parameter cont_transform
in TabularTransform
to Standard
or MinMax
). The transform
method of TabularTransform
will transform a Tabular
instance into a numpy array. If the Tabular
instance has a target/label column, the last
column of the transformed numpy array will be the target/label.
If some other transformations that are not supported in the library are necessary, one can simply convert the Tabular
instance into a pandas dataframe by calling Tabular.to_pd()
and try different transformations with it.
After data preprocessing, we can train a XGBoost classifier for this task (one may try other classifiers).
[4]:
np.random.seed(1)
transformer = TabularTransform().fit(tabular_data)
class_names = transformer.class_names
x = transformer.transform(tabular_data)
train, test, labels_train, labels_test = \
sklearn.model_selection.train_test_split(x[:, :-1], x[:, -1], train_size=0.80)
print('Training data shape: {}'.format(train.shape))
print('Test data shape: {}'.format(test.shape))
gbtree = xgboost.XGBClassifier(n_estimators=300, max_depth=5)
gbtree.fit(train, labels_train)
print('Test accuracy: {}'.format(
sklearn.metrics.accuracy_score(labels_test, gbtree.predict(test))))
Training data shape: (26048, 108)
Test data shape: (6513, 108)
Test accuracy: 0.8668816213726394
The prediction function takes a Tabular
instance as its inputs, and outputs the class probabilities or logits for classification tasks or the estimated values for regression tasks. In this example, we simply call transformer.transform
to do data preprocessing followed by the prediction function of gbtree
.
[5]:
predict_function=lambda z: gbtree.predict_proba(transformer.transform(z))
To initialize a L2X explainer, we need to set:
training_data
: The data used to train the explainer.training_data
should be the training dataset for training the machine learning model.predict_function
: The prediction function corresponding to the model.mode
: The task type, e.g., “classification” or “regression”.selection_model
: A pytorch model class for estimating P(S|X) in L2X. Ifselection_model = None
, a default modelDefaultSelectionModel
will be used.prediction_model
: A pytorch model class for estimating Q(X_S) in L2X. Ifprediction_model = None
, a default modelDefaultPredictionModel
will be used.
[6]:
from omnixai.explainers.tabular.agnostic.L2X.l2x import L2XTabular
explainer = L2XTabular(
training_data=tabular_data,
predict_function=predict_function,
)
|████████████████████████████████████████| 100.0% Complete, Loss 0.1953
L2X prediction model accuracy: 0.8647768803169436
L2X generates local explanations, e.g. explainer.explain
is called given the test instances. ipython_plot
plots the generated explanations in IPython. Parameter index
indicates which instance in test_x
to plot, e.g., index = 0
means plotting the first instance in test_x
.
[7]:
i = 1653
test_x = tabular_data[i:i + 5]
explanations = explainer.explain(test_x)
explanations.ipython_plot(index=0, class_names=class_names)