Logistic regression for income prediction

The library contains several special classes for linear models, e.g., LogisticRegression for classification tasks and LinearRegression for regression tasks. These classes provide a special method explain for generating both global explanations (linear coefficients) and local explanations (feature importance scores). One may apply these classes directly if using a linear model.

[1]:
# This default renderer is used for sphinx docs only. Please delete this cell in IPython.
import plotly.io as pio
pio.renderers.default = "png"
[2]:
import os
import unittest
import pprint
import sklearn
import sklearn.datasets
import sklearn.ensemble

import numpy as np
import pandas as pd
from omnixai.data.tabular import Tabular
from omnixai.explainers.tabular import LogisticRegression

The dataset used in this example is for income prediction (https://archive.ics.uci.edu/ml/datasets/adult). For linear models, we use Tabular to represent a tabular dataset, which can be constructed from a pandas dataframe or a numpy array. To create a Tabular instance given a numpy array, one needs to specify the data, the feature names, the categorical feature names (if exists) and the target/label column name (if exists).

[3]:
feature_names = [
    "Age", "Workclass", "fnlwgt", "Education",
    "Education-Num", "Marital Status", "Occupation",
    "Relationship", "Race", "Sex", "Capital Gain",
    "Capital Loss", "Hours per week", "Country", "label"
]
data = np.genfromtxt(os.path.join('../data', 'adult.data'), delimiter=', ', dtype=str)
tabular_data = Tabular(
    data,
    feature_columns=feature_names,
    categorical_columns=[feature_names[i] for i in [1, 3, 5, 6, 7, 8, 9, 13]],
    target_column='label'
)

To train the linear model, the method fit is called with the training dataset (a Tabular instance).

[4]:
np.random.seed(1)
model = LogisticRegression()
model.fit(tabular_data)
Validation accuracy: 0.8518347919545525

The linear model has both global explanations and local explanations generated by calling the method explain.

[5]:
test_x = tabular_data[0:5]
explanations = model.explain(test_x)
explanations.ipython_plot(index=0)
../../_images/tutorials_tabular_linear_9_0.png