{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### NLPExplainer on IMDB dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The class `NLPExplainer` is designed for NLP tasks, acting as a factory of the supported NLP explainers such as integrated-gradient and LIME. `NLPExplainer` provides a unified easy-to-use interface for all the supported explainers. Because the supported NLP explainers in the current version are limited, one can either use `NLPExplainer` or a specific explainer in the package `omnixai.explainers.nlp` to generate explanations." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import torch\n", "import torch.nn as nn\n", "import sklearn\n", "from sklearn.datasets import fetch_20newsgroups\n", "\n", "from omnixai.data.text import Text\n", "from omnixai.preprocessing.text import Word2Id\n", "from omnixai.explainers.tabular.agnostic.L2X.utils import Trainer, InputData, DataLoader\n", "from omnixai.explainers.nlp import NLPExplainer\n", "from omnixai.visualization.dashboard import Dashboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We apply a simple CNN model for this text classification task. Note that the method `forward` has two inputs `inputs` (token ids) and `masks` (the sentence masks). Note that the first input of the model must be the token ids." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class TextModel(nn.Module):\n", "\n", " def __init__(self, num_embeddings, num_classes, **kwargs):\n", " super().__init__()\n", " self.num_embeddings = num_embeddings\n", " self.embedding_size = kwargs.get(\"embedding_size\", 50)\n", " self.embedding = nn.Embedding(self.num_embeddings, self.embedding_size)\n", " self.embedding.weight.data.normal_(mean=0.0, std=0.01)\n", " \n", " hidden_size = kwargs.get(\"hidden_size\", 100)\n", " kernel_sizes = kwargs.get(\"kernel_sizes\", [3, 4, 5])\n", " if type(kernel_sizes) == int:\n", " kernel_sizes = [kernel_sizes]\n", "\n", " self.activation = nn.ReLU()\n", " self.conv_layers = nn.ModuleList([\n", " nn.Conv1d(self.embedding_size, hidden_size, k, padding=k // 2) for k in kernel_sizes])\n", " self.dropout = nn.Dropout(0.2)\n", " self.output_layer = nn.Linear(len(kernel_sizes) * hidden_size, num_classes)\n", "\n", " def forward(self, inputs, masks):\n", " embeddings = self.embedding(inputs)\n", " x = embeddings * masks.unsqueeze(dim=-1)\n", " x = x.permute(0, 2, 1)\n", " x = [self.activation(layer(x).max(2)[0]) for layer in self.conv_layers]\n", " outputs = self.output_layer(self.dropout(torch.cat(x, dim=1)))\n", " if outputs.shape[1] == 1:\n", " outputs = outputs.squeeze(dim=1)\n", " return outputs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `Text` object is used to represent a batch of texts/sentences. The package `omnixai.preprocessing.text` provides some transforms related to text data such as `Tfidf` and `Word2Id`." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Load the training and test datasets\n", "train_data = pd.read_csv('/home/ywz/data/imdb/labeledTrainData.tsv', sep='\\t')\n", "n = int(0.8 * len(train_data))\n", "x_train = Text(train_data[\"review\"].values[:n])\n", "y_train = train_data[\"sentiment\"].values[:n].astype(int)\n", "x_test = Text(train_data[\"review\"].values[n:])\n", "y_test = train_data[\"sentiment\"].values[n:].astype(int)\n", "class_names = [\"negative\", \"positive\"]\n", "# The transform for converting words/tokens to IDs\n", "transform = Word2Id().fit(x_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The preprocessing function converts a batch of texts into token IDs and the masks. The outputs of the preprocessing function must fit the inputs of the model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "max_length = 256\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "def preprocess(X: Text):\n", " samples = transform.transform(X)\n", " max_len = 0\n", " for i in range(len(samples)):\n", " max_len = max(max_len, len(samples[i]))\n", " max_len = min(max_len, max_length)\n", " inputs = np.zeros((len(samples), max_len), dtype=int)\n", " masks = np.zeros((len(samples), max_len), dtype=np.float32)\n", " for i in range(len(samples)):\n", " x = samples[i][:max_len]\n", " inputs[i, :len(x)] = x\n", " masks[i, :len(x)] = 1\n", " return inputs, masks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now train the CNN model and evaluate its performance." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " |████████████████████████████████████████| 100.0% Complete, Loss 0.0010\n" ] } ], "source": [ "model = TextModel(\n", " num_embeddings=transform.vocab_size,\n", " num_classes=len(class_names)\n", ").to(device)\n", "\n", "Trainer(\n", " optimizer_class=torch.optim.AdamW,\n", " learning_rate=1e-3,\n", " batch_size=128,\n", " num_epochs=10,\n", ").train(\n", " model=model,\n", " loss_func=nn.CrossEntropyLoss(),\n", " train_x=transform.transform(x_train),\n", " train_y=y_train,\n", " padding=True,\n", " max_length=max_length,\n", " verbose=True\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test accuracy: 0.8492442322991249\n" ] } ], "source": [ "model.eval()\n", "data = transform.transform(x_test)\n", "data_loader = DataLoader(\n", " dataset=InputData(data, [0] * len(data), max_length),\n", " batch_size=32,\n", " collate_fn=InputData.collate_func,\n", " shuffle=False\n", ")\n", "outputs = []\n", "for inputs in data_loader:\n", " value, mask, target = inputs\n", " y = model(value.to(device), mask.to(device))\n", " outputs.append(y.detach().cpu().numpy())\n", "outputs = np.concatenate(outputs, axis=0)\n", "predictions = np.argmax(outputs, axis=1)\n", "print('Test accuracy: {}'.format(\n", " sklearn.metrics.f1_score(y_test, predictions, average='binary')))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Similar to `TabularExplainer`, to initialize `NLPExplainer`, we need to set the following parameters:\n", "\n", " - `explainers`: The names of the explainers to apply, e.g., [\"ig\", \"lime\"].\n", " - `model`: The ML model to explain, e.g., a scikit-learn model, a tensorflow model or a pytorch model.\n", " - `preprocess`: The preprocessing function converting the raw data (a `Text` instance) into the inputs of model.\n", " - `postprocess`: The postprocessing function transforming the outputs of model to a user-specific form, e.g., the predicted probability for each class.\n", " - `mode`: The task type, e.g., \"classification\" or \"regression\".\n", "\n", "The preprocessing function takes a `Text` instance as its input and outputs the processed features that the ML model consumes, e.g., the `Text` object is converted into pytorch tensors in this example.\n", "\n", "The postprocessing function converts the outputs into class probabilities." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# The preprocessing function\n", "preprocess_func = lambda x: tuple(torch.tensor(y).to(device) for y in preprocess(x))\n", "# The postprocessing function\n", "postprocess_func = lambda logits: torch.nn.functional.softmax(logits, dim=1)\n", "# Initialize a NLPExplainer\n", "explainer = NLPExplainer(\n", " explainers=[\"ig\", \"lime\", \"polyjuice\"],\n", " mode=\"classification\",\n", " model=model,\n", " preprocess=preprocess_func,\n", " postprocess=postprocess_func,\n", " params={\"ig\": {\"embedding_layer\": model.embedding, \n", " \"id2token\": transform.id_to_word}}\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There is no \"global explanation\" for `NLPExplainer` currently. One can simply call explainer.explain to generate local explanations for NLP tasks." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:polyjuice.polyjuice_wrapper:Setup Polyjuice.\n", "INFO:polyjuice.polyjuice_wrapper:Setup SpaCy processor.\n", "INFO:polyjuice.polyjuice_wrapper:Setup perplexity scorer.\n", "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Integrated gradient results:\n" ] }, { "data": { "text/html": [ "