{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "*Licensed under the MIT License.*\n", "\n", "# Text Classification of SST-2 Sentences using a 3-Player Introspective Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append(\"../..\")\n", "import os\n", "import json\n", "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import scrapbook as sb\n", "\n", "from interpret_text.experimental.introspective_rationale import IntrospectiveRationaleExplainer\n", "from interpret_text.experimental.common.preprocessor.glove_preprocessor import GlovePreprocessor\n", "from interpret_text.experimental.common.preprocessor.bert_preprocessor import BertPreprocessor\n", "from interpret_text.experimental.common.model_config.introspective_rationale_model_config import IntrospectiveRationaleModelConfig\n", "from interpret_text.experimental.widget import ExplanationDashboard\n", "\n", "from notebooks.test_utils.utils_sst2 import load_sst2_pandas_df\n", "from notebooks.test_utils.utils_data_shared import load_glove_embeddings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction\n", "In this notebook, we train and evaluate a [three-player explainer](http://people.csail.mit.edu/tommi/papers/YCZJ_EMNLP2019.pdf) model on a subset of the [SST-2](https://nlp.stanford.edu/sentiment/index.html/) dataset. To run this notebook, we used the SST-2 data files provided [here](https://github.com/AcademiaSinicaNLPLab/sentiment_dataset)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Set parameters\n", "Here we set some parameters that we use for our modeling task." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "# if quick run true, skips over embedding, most of model training, and model evaluation; used to quickly test pipeline\n", "QUICK_RUN = True\n", "MODEL_TYPE = \"RNN\" # currently support either RNN, BERT, or a combination of RNN and BERT\n", "CUDA = False\n", "\n", "# data processing parameters\n", "DATA_FOLDER = \"../../../data/sst2\"\n", "LABEL_COL = \"labels\" \n", "TEXT_COL = \"sentences\"\n", "token_count_thresh = 1\n", "max_sentence_token_count = 70\n", "\n", "# training procedure parameters\n", "load_pretrained_model = False\n", "pretrained_model_path = \"../models/rnn.pth\"\n", "MODEL_SAVE_DIR = os.path.join(\"..\", \"models\")\n", "model_prefix = \"sst2rnpmodel\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_config = {\n", " \"cuda\": CUDA,\n", " \"model_save_dir\": MODEL_SAVE_DIR, \n", " \"model_prefix\": model_prefix,\n", " \"lr\": 2e-4\n", "}\n", "\n", "if QUICK_RUN:\n", " model_config[\"save_best_model\"] = False\n", " model_config[\"pretrain_cls\"] = True\n", " model_config[\"num_epochs\"] = 1\n", "\n", "if MODEL_TYPE == \"RNN\":\n", " # (i.e. not using BERT), load pretrained glove embeddings\n", " if not QUICK_RUN:\n", " model_config[\"embedding_path\"] = load_glove_embeddings(DATA_FOLDER)\n", " else:\n", " model_config[\"embedding_path\"] = os.path.join(DATA_FOLDER, \"\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Read Dataset\n", "We start by loading a subset of the data for training and testing." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_data = load_sst2_pandas_df('train')\n", "test_data = load_sst2_pandas_df('test')\n", "all_data = pd.concat([train_data, test_data])\n", "if QUICK_RUN:\n", " batch_size = 50\n", " train_data = train_data.head(batch_size)\n", " test_data = test_data.head(batch_size)\n", "X_train = train_data[TEXT_COL]\n", "X_test = test_data[TEXT_COL]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# get all unique labels\n", "y_labels = all_data[LABEL_COL].unique()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_config[\"labels\"] = np.array(sorted(y_labels))\n", "model_config[\"num_labels\"] = len(y_labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tokenization and embedding\n", "The data is then tokenized and embedded using glove embeddings." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if MODEL_TYPE == \"RNN\":\n", " preprocessor = GlovePreprocessor(token_count_thresh, max_sentence_token_count)\n", " preprocessor.build_vocab(all_data[TEXT_COL])\n", "if MODEL_TYPE == \"BERT\":\n", " preprocessor = BertPreprocessor()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# append labels to tokenizer output\n", "df_train = pd.concat([train_data[LABEL_COL], preprocessor.preprocess(X_train)], axis=1)\n", "df_test = pd.concat([test_data[LABEL_COL], preprocessor.preprocess(X_test)], axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Explainer\n", "Then, we create the explainer and train it (or load a pretrained model).\n", "The steps involved to set up the explainer: \n", "- Initialize explainer\n", "- Setup preprocessor for the explainer\n", "- Supply necessary model configurations to the explainer\n", "- Load the explainer once all necessary modules are setup\n", "- Fit/Train the explainer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "explainer = IntrospectiveRationaleExplainer(classifier_type=MODEL_TYPE, cuda=CUDA)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "explainer.set_preprocessor(preprocessor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "explainer.build_model_config(model_config)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "explainer.load()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "explainer.fit(df_train, df_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can test the explainer and measure its performance:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if not QUICK_RUN:\n", " explainer.score(df_test)\n", " sparsity = explainer.model.avg_sparsity\n", " accuracy = explainer.model.avg_accuracy\n", " anti_accuracy = explainer.model.avg_anti_accuracy\n", " print(\"Test sparsity: \", sparsity)\n", " print(\"Test accuracy: \", accuracy, \"% Anti-accuracy: \", anti_accuracy)\n", " \n", " # for testing\n", " sb.glue(\"sparsity\", sparsity)\n", " sb.glue(\"accuracy\", accuracy)\n", " sb.glue(\"anti_accuracy\", anti_accuracy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Local importances\n", "We can display the found local importances (the most and least important words for a given sentence):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Enter a sentence that needs to be interpreted\n", "sentence = \"Beautiful movie ; really good , the popcorn was bad\"\n", "s2 = \"a beautiful and haunting examination of the stories we tell ourselves to make sense of the mundane horrors of the world.\"\n", "s3 = \"the premise is in extremely bad taste , and the film's supposed insights are so poorly executed and done that even a high school dropout taking his or her first psychology class could dismiss them .\"\n", "s4= \"This is a super amazing movie with bad acting\"\n", "\n", "local_explanation = explainer.explain_local(s4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize explanations\n", "We can visualize local feature importances as a heatmap over words in the document and view importance values of individual words." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "explainer.visualize(local_explanation)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ExplanationDashboard(local_explanation)" ] } ], "metadata": { "kernelspec": { "display_name": "Python (interpret_cpu)", "language": "python", "name": "interpret_cpu" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }