{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "*Licensed under the MIT License.*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Interpreting Classical Text Classification models\n", "\n", "_**This notebook showcases how to use the interpret-text repo to implement an interpretable module using feature importances and bag of words representation.**_\n", "\n", "\n", "## Contents\n", "1. [Introduction](#Introduction)\n", "2. [Setup](#Setup)\n", "3. [Training](#Training)\n", "4. [Results](#Results)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append(\"../..\")\n", "import os\n", "\n", "# sklearn\n", "from sklearn.metrics import precision_recall_fscore_support\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder\n", "\n", "from interpret_text.experimental.classical import ClassicalTextExplainer\n", "\n", "from notebooks.test_utils.utils_mnli import load_mnli_pandas_df\n", "\n", "# for testing\n", "from scrapbook.api import glue\n", "\n", "working_dir = os.getcwd()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 1. Introduction\n", "This notebook illustrates how to locally use interpret-text to help interpret text classification using a logisitic regression baseline and bag of words encoding. It demonstrates the API calls needed to obtain the feature importances along with a visualization dashbard.\n", "\n", "###### Note:\n", "* *Although we use logistic regression, any model that follows sklearn's classifier API should be supported natively or with minimal tweaking.*\n", "* *The interpreter supports interpretations using either coefficients associated with linear models or feature importances associated with ensemble models.*\n", "* *The classifier relies on scipy's sparse representations to keep the dataset in memory.*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Setup\n", "\n", "The notebook is built on features made available by [scikit-learn](https://scikit-learn.org/stable/) and [spacy](https://spacy.io/) for easier compatibiltiy with popular tookits." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Configuration parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "DATA_FOLDER = './temp'\n", "TRAIN_SIZE = 0.7\n", "TEST_SIZE = 0.3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data loading" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = load_mnli_pandas_df(DATA_FOLDER, \"train\")\n", "df = df[df[\"gold_label\"] == \"neutral\"] # get unique sentences\n", "\n", "# fetch documents and labels from data frame\n", "X_str = df['sentence1'] # the document we want to analyze\n", "ylabels = df['genre'] # the labels, or answers, we want to test against" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create explainer " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Create explainer object that contains default glassbox classifier and explanation methods\n", "explainer = ClassicalTextExplainer(n_jobs=-1, tol=0.1)\n", "label_encoder = LabelEncoder()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training\n", "\n", "###### Note: Vocabulary\n", "\n", "* *The vocabulary is compiled from the training set. Any word that does not appear in the training data split, will not appear in the vocabulary.*\n", "* *The word must appear one or more times to be considered part of the vocabulary.*\n", "* *However, the sklearn countvectorizer allows the addition of a custom vocabulary as an input parameter.*\n", "\n", "### Configure training setup\n", "This step will cast the training data and labels into the correct format\n", "\n", "1. Split data into train and test using a random shuffle\n", "2. Load desired classifier. In this case, Logistic Regression is set as default.\n", "3. Setup grid search for hyperparameter optimization and train model. Edit the hyper parameter range to search over as per your model.\n", "4. Fit models to train set" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(X_str, ylabels, train_size=0.8, test_size=0.2)\n", "y_train = label_encoder.fit_transform(y_train)\n", "y_test = label_encoder.transform(y_test)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(\"X_train shape =\" + str(X_train.shape))\n", "print(\"y_train shape =\" + str(y_train.shape))\n", "print(\"X_train data structure = \" + str(type(X_train)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Model Overview\n", "\n", "The 1-gram [Bag of Words](https://en.wikipedia.org/wiki/Bag-of-words_model) allows a 1:1 mapping from individual words to their respective frequencies in the [document-term matrix](https://en.wikipedia.org/wiki/Document-term_matrix). " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "classifier, best_params = explainer.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Results\n", "\n", "###### Notes for default Logistic Regression classifier:\n", "* *The parameters are set using cross-validation*\n", "* *Below listed hyperparamters are selected by searching over a larger space.*\n", "* *These apply specifically to this instance of the logistic regression model and mnli dataset.*\n", "* *'Multinomial' setup was found to be better than 'one-vs-all' across the board*\n", "* *Default 'liblinear' solver is not supported for 'multinomial' model setup*\n", "* *For a different model or dataset, set the range as appropriate using the hyperparam_range argument in the train method* " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# obtain best classifier and hyper params\n", "print(\"best classifier: \" + str(best_params))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Performance Metrics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mean_accuracy = classifier.score(X_test, y_test, sample_weight=None)\n", "print(\"accuracy = \" + str(mean_accuracy * 100) + \"%\")\n", "y_pred = classifier.predict(X_test)\n", "[precision, recall, fscore, support] = precision_recall_fscore_support(y_test, y_pred,average='macro')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Capture metrics for integration testing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "glue(\"accuracy\", mean_accuracy)\n", "glue(\"precision\", precision)\n", "glue(\"recall\", recall)\n", "glue(\"f1\", fscore)\n", "print(\"[precision, recall, fscore, support] = \" + str([precision, recall, fscore, support]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Local Importances\n", "\n", "Local importances are the most and least important words for a single document." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Enter any document or a document and label pair that needs to be interpreted\n", "document = \"I travelled to the beach. I took the train. I saw fairies, dragons and elves\"\n", "document1 = \"The term construction means fabrication, erection, or installation of an affected unit.\"\n", "document2 = \"Demonstrating Product Reliability Indicates the Product Is Ready for Production\"\n", "document3 = \"and see there\\'s no secrecy to that because the bill always comes in and we know how much they pay for it\"\n", "document4 = \"Had that piquant gipsy face been at the bottom of the crime, or was it 73 the baser mainspring of money?\"\n", "document5 = \"No, the boy trusted me, and I shan\\'t let him down.\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Obtain the top feature ids for the selected class label\n", "explainer.preprocessor.labelEncoder = label_encoder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "local_explanation = explainer.explain_local(document)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alternatively, you can pass the predicted label with the document" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "y = classifier.predict(document1)\n", "predicted_label = label_encoder.inverse_transform(y)\n", "local_explanation = explainer.explain_local(document1, predicted_label)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from interpret_text.experimental.widget import ExplanationDashboard" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ExplanationDashboard(local_explanation)" ] } ], "metadata": { "kernelspec": { "display_name": "Python (interpret_cpu)", "language": "python", "name": "interpret_cpu" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 2 }