{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import lime\n", "import sklearn\n", "import numpy as np\n", "import sklearn\n", "import sklearn.ensemble\n", "import sklearn.metrics\n", "from __future__ import print_function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fetching data, training a classifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the [previous tutorial](http://marcotcr.github.io/lime-ml/tutorials/Lime%20-%20basic%20usage%2C%20two%20class%20case.html), we looked at lime in the two class case. In this tutorial, we will use the [20 newsgroups dataset](http://scikit-learn.org/stable/datasets/#the-20-newsgroups-text-dataset) again, but this time using all of the classes." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from sklearn.datasets import fetch_20newsgroups\n", "newsgroups_train = fetch_20newsgroups(subset='train')\n", "newsgroups_test = fetch_20newsgroups(subset='test')\n", "# making class names shorter\n", "class_names = [x.split('.')[-1] if 'misc' not in x else '.'.join(x.split('.')[-2:]) for x in newsgroups_train.target_names]\n", "class_names[3] = 'pc.hardware'\n", "class_names[4] = 'mac.hardware'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "atheism,graphics,ms-windows.misc,pc.hardware,mac.hardware,x,misc.forsale,autos,motorcycles,baseball,hockey,crypt,electronics,med,space,christian,guns,mideast,politics.misc,religion.misc\n" ] } ], "source": [ "print(','.join(class_names))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again, let's use the tfidf vectorizer, commonly used for text." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [], "source": [ "vectorizer = sklearn.feature_extraction.text.TfidfVectorizer(lowercase=False)\n", "train_vectors = vectorizer.fit_transform(newsgroups_train.data)\n", "test_vectors = vectorizer.transform(newsgroups_test.data)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This time we will use Multinomial Naive Bayes for classification, so that we can make reference to [this document](http://scikit-learn.org/stable/datasets/#filtering-text-for-more-realistic-training)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "MultinomialNB(alpha=0.01, class_prior=None, fit_prior=True)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.naive_bayes import MultinomialNB\n", "nb = MultinomialNB(alpha=.01)\n", "nb.fit(train_vectors, newsgroups_train.target)\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.83501841939981736" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred = nb.predict(test_vectors)\n", "sklearn.metrics.f1_score(newsgroups_test.target, pred, average='weighted')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that this classifier achieves a very high F score. [The sklearn guide to 20 newsgroups](http://scikit-learn.org/stable/datasets/#filtering-text-for-more-realistic-training) indicates that Multinomial Naive Bayes overfits this dataset by learning irrelevant stuff, such as headers, by looking at the features with highest coefficients for the model in general. We now use lime to explain individual predictions instead." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Explaining predictions using lime" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from lime import lime_text\n", "from sklearn.pipeline import make_pipeline\n", "c = make_pipeline(vectorizer, nb)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ 0.001 0.01 0.003 0.047 0.006 0.002 0.003 0.521 0.022 0.008\n", " 0.025 0. 0.331 0.003 0.006 0. 0.003 0. 0.001 0.009]]\n" ] } ], "source": [ "print(c.predict_proba([newsgroups_test.data[0]]).round(3))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from lime.lime_text import LimeTextExplainer\n", "explainer = LimeTextExplainer(class_names=class_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Previously, we used the default parameter for label when generating explanation, which works well in the binary case. \n", "For the multiclass case, we have to determine for which labels we will get explanations, via the 'labels' parameter. \n", "Below, we generate explanations for labels 0 and 17." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Document id: 1340\n", "Predicted class = atheism\n", "True class: atheism\n" ] } ], "source": [ "idx = 1340\n", "exp = explainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6, labels=[0, 17])\n", "print('Document id: %d' % idx)\n", "print('Predicted class =', class_names[nb.predict(test_vectors[idx]).reshape(1,-1)[0,0]])\n", "print('True class: %s' % class_names[newsgroups_test.target[idx]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we can see the explanations for different labels. Notice that the positive and negative signs are with respect to a particular label - so that words that are negative towards class 0 may be positive towards class 15, and vice versa." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Explanation for class atheism\n", "(u'Caused', 0.26601045845449439)\n", "(u'Rice', 0.13658462676578711)\n", "(u'Genocide', 0.13519771973974065)\n", "(u'owlnet', -0.092585962025604027)\n", "(u'certainty', -0.088001257903975866)\n", "(u'Semitic', -0.085734572813977061)\n", "\n", "Explanation for class mideast\n", "(u'fsu', -0.056622666266163954)\n", "(u'Luther', -0.051897643027619206)\n", "(u'Theism', -0.049640283384298073)\n", "(u'jews', 0.037819200983811155)\n", "(u'Caused', -0.037513316854574666)\n", "(u'Genocide', 0.027357407069969929)\n" ] } ], "source": [ "print ('Explanation for class %s' % class_names[0])\n", "print ('\\n'.join(map(str, exp.as_list(label=0))))\n", "print ()\n", "print ('Explanation for class %s' % class_names[17])\n", "print ('\\n'.join(map(str, exp.as_list(label=17))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another alternative is to ask LIME to generate labels for the top K classes. This is shown below with K=2. \n", "To see which labels have explanations, use the available_labels function." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0, 15]\n" ] } ], "source": [ "exp = explainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6, top_labels=2)\n", "print(exp.available_labels())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's see some the visualization of the explanations. Notice that for each class, the words on the right side on the line are positive, and the words on the left side are negative. Thus, 'Caused' is positive for atheism, but negative for christian." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "exp.show_in_notebook(text=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We notice that the classifier is using reasonable words (such as 'Genocide', 'Luther', 'Semitic', etc), as well as unreasonable ones ('Rice', 'owlnet'). Let's zoom in and just look at the explanations for class 'Atheism'." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "exp.show_in_notebook(text=newsgroups_test.data[idx], labels=(0,))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Looking at this example demonstrates that there can be useful signal in the header or quotes that would generalize - i.e., the Subject line. There is also signal that would not generalize (e.g. email addresses and institution names)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Explaining predictions without headers, quotes and footers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we follow the [suggestion of removing headers, footers and quotes](http://scikit-learn.org/stable/datasets/#filtering-text-for-more-realistic-training), and explain the same example with the new data." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [], "source": [ "newsgroups_train = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))\n", "newsgroups_test = fetch_20newsgroups(subset='test',remove=('headers', 'footers', 'quotes'))\n", "train_vectors = vectorizer.fit_transform(newsgroups_train.data)\n", "test_vectors = vectorizer.transform(newsgroups_test.data)\n", "nb = MultinomialNB(alpha=.01)\n", "nb.fit(train_vectors, newsgroups_train.target)\n", "c = make_pipeline(vectorizer, nb)\n", "explainer = LimeTextExplainer(class_names=class_names)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[15, 17]\n" ] } ], "source": [ "exp = explainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6, top_labels=2)\n", "print(exp.available_labels())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice how different the explanations are for the classifier without headers, footers and quotes. The prediction changes, but so do the reasons." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "exp.show_in_notebook(text=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's see the explanation with the text for the top class (christian):" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "exp.show_in_notebook(text=newsgroups_test.data[idx], labels=(15,))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice how short the text became after removing all of that information. One begins to wonder if this version of the dataset is still useful, or if it is better to find another dataset altogether. Could a reasonable classifier detect that this document belongs to the class atheism?\n", "\n", "Anyway, I hope this illustrated how to use LIME to explain arbitrary classifiers in the multiclass case!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.11" } }, "nbformat": 4, "nbformat_minor": 0 }