{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### L2X (learning to explain) for text classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an example of the L2X explainer on text classification. Different from gradient-based methods, L2X trains a separate explanation model. The advantage of L2X is that it generates explanations fast after the explanation model is trained. The disadvantage is that the quality of the explanations highly depend on the trained explanation model, which can be affected by multiple factors, e.g., the network structure of the explanation model, the training hyperparameters.\n", "\n", "For text classification, we implement the default CNN-based explanation model in `omnixai.explainers.nlp.agnostic.l2x`. One may implement other models by following the same interface (please refer to the docs for more details). If using this explainer, please cite the original work: \"Learning to Explain: An Information-Theoretic Perspective on Model Interpretation, Jianbo Chen, Le Song, Martin J. Wainwright, Michael I. Jordan, https://arxiv.org/abs/1802.07814\"." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import sklearn.ensemble\n", "from sklearn.datasets import fetch_20newsgroups\n", "\n", "from omnixai.data.text import Text\n", "from omnixai.preprocessing.text import Tfidf\n", "from omnixai.explainers.nlp.agnostic.l2x import L2XText" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use a `Text` object to represent a batch of texts/sentences. The package `omnixai.preprocessing.text` provides some transforms related to text data such as `Tfidf`." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Load the training and text datasets\n", "categories = ['alt.atheism', 'soc.religion.christian']\n", "newsgroups_train = fetch_20newsgroups(subset='train', categories=categories)\n", "newsgroups_test = fetch_20newsgroups(subset='test', categories=categories)\n", "\n", "x_train = Text(newsgroups_train.data)\n", "y_train = newsgroups_train.target\n", "x_test = Text(newsgroups_test.data)\n", "y_test = newsgroups_test.target\n", "class_names = ['atheism', 'christian']\n", "# A TFDIF transform\n", "transform = Tfidf().fit(x_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this classification task, we train a random forest classifier with TF-IDF feature vectors." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy: 0.925233644859813\n" ] } ], "source": [ "train_vectors = transform.transform(x_train)\n", "test_vectors = transform.transform(x_test)\n", "model = sklearn.ensemble.RandomForestClassifier(n_estimators=500)\n", "model.fit(train_vectors, y_train)\n", "predict_function = lambda x: model.predict_proba(transform.transform(x))\n", "\n", "predictions = model.predict(test_vectors)\n", "print('Test accuracy: {}'.format(\n", " sklearn.metrics.f1_score(y_test, predictions, average='binary')))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To initialize `L2XText`, we need to set the following parameters:\n", "\n", " - `training_data`: The data used to train the explainer. `training_data` should be the training dataset for training the machine learning model.\n", " - `predict_function`: The prediction function corresponding to the model to explain. When the model is for classification, the outputs of the `predict_function` are the class probabilities. When the model is for regression, the outputs of the `predict_function` are the estimated values.\n", " - `mode`: The task type, e.g., `classification` or `regression`.\n", " - `selection_model`: A pytorch model class for estimating P(S|X) in L2X. If `selection_model = None`, a default model `DefaultSelectionModel` will be used.\n", " - `prediction_model`: A pytorch model class for estimating Q(X_S) in L2X. If `prediction_model = None`, a default model `DefaultPredictionModel` will be used." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " |████████████████████████████████████████| 100.0% Complete, Loss 0.0039\n", "L2X prediction model accuracy: 0.8674698795180723\n" ] }, { "data": { "text/html": [ "