{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### NLPExplainer for sentiment analysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The class `NLPExplainer` is designed for NLP tasks, acting as a factory of the supported NLP explainers such as integrated-gradient and LIME. `NLPExplainer` provides a unified easy-to-use interface for all the supported explainers. Because the supported NLP explainers in the current version are limited, one can either use `NLPExplainer` or a specific explainer in the package `omnixai.explainers.nlp` to generate explanations." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# This default renderer is used for sphinx docs only. Please delete this cell in IPython.\n", "import plotly.io as pio\n", "pio.renderers.default = \"png\"" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import transformers\n", "from omnixai.data.text import Text\n", "from omnixai.explainers.nlp import NLPExplainer\n", "from omnixai.visualization.dashboard import Dashboard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we consider a sentiment analysis task. The test input is an instance of `Text`." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "x = Text([\n", " \"What a great movie!\",\n", " \"The Interview was neither that funny nor that witty. \"\n", " \"Even if there are words like funny and witty, the overall structure is a negative type.\"\n", "])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model considered here is a transformer model. Similar to `TabularExplainer`, to initialize `NLPExplainer`, we need to set the following parameters:\n", "\n", " - `explainers`: The names of the explainers to apply, e.g., [\"shap\", \"lime\"].\n", " - `model`: The ML model to explain, e.g., a scikit-learn model, a tensorflow model, a pytorch model or a black-box prediction function.\n", " - `preprocess`: The preprocessing function converting the raw data (a `Text` instance) into the inputs of model.\n", " - `postprocess`: The postprocessing function transforming the outputs of model to a user-specific form, e.g., the predicted probability for each class.\n", " - `mode`: The task type, e.g., \"classification\", \"regression\" or \"qa\".\n", " \n", "The preprocessing function takes a `Text` instance as its input and outputs the processed features that the ML model consumes. In this example, the `Text` object is converted into a batch of strings.\n", "\n", "The postprocessing function converts the outputs into class probabilities." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# The preprocessing function\n", "preprocess = lambda x: x.values\n", "# A transformer model for sentiment analysis\n", "model = transformers.pipeline(\n", " 'sentiment-analysis',\n", " model='distilbert-base-uncased-finetuned-sst-2-english',\n", " return_all_scores=True\n", ")\n", "# The postprocessing function\n", "postprocess = lambda outputs: np.array([[s[\"score\"] for s in ss] for ss in outputs])\n", "\n", "# Initialize a NLPExplainer\n", "explainer = NLPExplainer(\n", " explainers=[\"shap\", \"lime\", \"polyjuice\"],\n", " mode=\"classification\",\n", " model=model,\n", " preprocess=preprocess,\n", " postprocess=postprocess\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There is no \"global explanation\" for `NLPExplainer` currently. One can simply call explainer.explain to generate local explanations for NLP tasks." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/248 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Partition explainer: 3it [00:12, 12.70s/it] \n", "INFO:polyjuice.polyjuice_wrapper:Setup Polyjuice.\n", "INFO:polyjuice.polyjuice_wrapper:Setup SpaCy processor.\n", "INFO:polyjuice.polyjuice_wrapper:Setup perplexity scorer.\n", "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "SHAP results:\n" ] }, { "data": { "text/html": [ "