Decision tree for income prediction

Similar to linear models, the library provides several special classes for decision trees, e.g., TreeClassifier for classification tasks and TreeRegressor for regression tasks. These classes have a special method explain for generating both global explanations (the tree structures) and local explanations (the decision paths).

[1]:
import os
import sklearn
import numpy as np
import pandas as pd
from omnixai.data.tabular import Tabular
from omnixai.explainers.tabular.specific.decision_tree import TreeClassifier

The dataset used in this example is for income prediction (https://archive.ics.uci.edu/ml/datasets/adult). For tree-based models, we use Tabular to represent a tabular dataset, which can be constructed from a pandas dataframe or a numpy array. To create a Tabular instance given a numpy array, one needs to specify the data, the feature names, the categorical feature names (if exists) and the target/label column name (if exists).

[2]:
feature_names = [
    "Age", "Workclass", "fnlwgt", "Education",
    "Education-Num", "Marital Status", "Occupation",
    "Relationship", "Race", "Sex", "Capital Gain",
    "Capital Loss", "Hours per week", "Country", "label"
]
data = np.genfromtxt(os.path.join('../data', 'adult.data'), delimiter=', ', dtype=str)
tabular_data = Tabular(
    data,
    feature_columns=feature_names,
    categorical_columns=[feature_names[i] for i in [1, 3, 5, 6, 7, 8, 9, 13]],
    target_column='label'
)

To train a tree-based model, the method fit is called with the training dataset (a Tabular instance).

[3]:
np.random.seed(1)
model = TreeClassifier()
model.fit(tabular_data, max_depth=4)
Validation accuracy: 0.8446184553968985

The tree-based model has both global explanations and local explanations, which can be generated by calling the method explain. For example, calling explain with test instances will return the decision paths.

[4]:
i = 1653
test_x = tabular_data[i:i + 5]
explanations = model.explain(test_x)
explanations.ipython_plot(index=0)
../../_images/tutorials_tabular_tree_8_0.png