{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# TextExplainer: debugging black-box text classifiers\n", "\n", "While eli5 supports many classifiers and preprocessing methods, it can't support them all. \n", "\n", "If a library is not supported by eli5 directly, or the text processing pipeline is too complex for eli5, eli5 can still help - it provides an implementation of [LIME](http://arxiv.org/abs/1602.04938) (Ribeiro et al., 2016) algorithm which allows to explain predictions of arbitrary classifiers, including text classifiers. `eli5.lime` can also help when it is hard to get exact mapping between model coefficients and text features, e.g. if there is dimension reduction involved.\n", "\n", "## Example problem: LSA+SVM for 20 Newsgroups dataset\n", "\n", "Let's load \"20 Newsgroups\" dataset and create a text processing pipeline which is hard to debug using conventional methods: SVM with RBF kernel trained on [LSA](https://en.wikipedia.org/wiki/Latent_semantic_analysis) features." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from sklearn.datasets import fetch_20newsgroups\n", "\n", "categories = ['alt.atheism', 'soc.religion.christian', \n", " 'comp.graphics', 'sci.med']\n", "twenty_train = fetch_20newsgroups(\n", " subset='train',\n", " categories=categories,\n", " shuffle=True,\n", " random_state=42,\n", " remove=('headers', 'footers'),\n", ")\n", "twenty_test = fetch_20newsgroups(\n", " subset='test',\n", " categories=categories,\n", " shuffle=True,\n", " random_state=42,\n", " remove=('headers', 'footers'),\n", ")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "0.89014647137150471" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.feature_extraction.text import TfidfVectorizer\n", "from sklearn.svm import SVC\n", "from sklearn.decomposition import TruncatedSVD\n", "from sklearn.pipeline import Pipeline, make_pipeline\n", "\n", "vec = TfidfVectorizer(min_df=3, stop_words='english',\n", " ngram_range=(1, 2))\n", "svd = TruncatedSVD(n_components=100, n_iter=7, random_state=42)\n", "lsa = make_pipeline(vec, svd)\n", "\n", "clf = SVC(C=150, gamma=2e-2, probability=True)\n", "pipe = make_pipeline(lsa, clf)\n", "pipe.fit(twenty_train.data, twenty_train.target)\n", "pipe.score(twenty_test.data, twenty_test.target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dimension of the input documents is reduced to 100, and then a kernel SVM is used to classify the documents. \n", "\n", "This is what the pipeline returns for a document - it is pretty sure the first message in test data belongs to sci.med:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.001 alt.atheism\n", "0.001 comp.graphics\n", "0.995 sci.med\n", "0.004 soc.religion.christian\n" ] } ], "source": [ "def print_prediction(doc):\n", " y_pred = pipe.predict_proba([doc])[0]\n", " for target, prob in zip(twenty_train.target_names, y_pred):\n", " print(\"{:.3f} {}\".format(prob, target)) \n", "\n", "doc = twenty_test.data[0]\n", "print_prediction(doc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TextExplainer\n", "Such pipelines are not supported by eli5 directly, but one can use `eli5.lime.TextExplainer` to debug the prediction - to check what was important in the document to make this decision.\n", "\n", "Create a `TextExplainer` instance, then pass the document to explain and a black-box classifier (a function which returns probabilities) to the `TextExplainer.fit` method, then check the explanation:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=alt.atheism\n", " \n", "\n", "\n", " \n", " (probability 0.000, score -9.583)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " -0.360\n", " \n", " <BIAS>\n", "
\n", " -9.223\n", " \n", " Highlighted in text (sum)\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " as i recall from my bout with kidney stones, there isn't any\n", "medication that can do anything about them except relieve the pain.\n", "\n", "either they pass, or they have to be broken up with sound, or they have\n", "to be extracted surgically.\n", "\n", "when i was in, the x-ray tech happened to mention that she'd had kidney\n", "stones and children, and the childbirth hurt less.\n", "

\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=comp.graphics\n", " \n", "\n", "\n", " \n", " (probability 0.000, score -8.285)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " -0.213\n", " \n", " <BIAS>\n", "
\n", " -8.073\n", " \n", " Highlighted in text (sum)\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " as i recall from my bout with kidney stones, there isn't any\n", "medication that can do anything about them except relieve the pain.\n", "\n", "either they pass, or they have to be broken up with sound, or they have\n", "to be extracted surgically.\n", "\n", "when i was in, the x-ray tech happened to mention that she'd had kidney\n", "stones and children, and the childbirth hurt less.\n", "

\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=sci.med\n", " \n", "\n", "\n", " \n", " (probability 0.996, score 5.846)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " +5.959\n", " \n", " Highlighted in text (sum)\n", "
\n", " -0.113\n", " \n", " <BIAS>\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " as i recall from my bout with kidney stones, there isn't any\n", "medication that can do anything about them except relieve the pain.\n", "\n", "either they pass, or they have to be broken up with sound, or they have\n", "to be extracted surgically.\n", "\n", "when i was in, the x-ray tech happened to mention that she'd had kidney\n", "stones and children, and the childbirth hurt less.\n", "

\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=soc.religion.christian\n", " \n", "\n", "\n", " \n", " (probability 0.004, score -5.484)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " -0.346\n", " \n", " <BIAS>\n", "
\n", " -5.137\n", " \n", " Highlighted in text (sum)\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " as i recall from my bout with kidney stones, there isn't any\n", "medication that can do anything about them except relieve the pain.\n", "\n", "either they pass, or they have to be broken up with sound, or they have\n", "to be extracted surgically.\n", "\n", "when i was in, the x-ray tech happened to mention that she'd had kidney\n", "stones and children, and the childbirth hurt less.\n", "

\n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import eli5\n", "from eli5.lime import TextExplainer\n", "\n", "te = TextExplainer(random_state=42)\n", "te.fit(doc, pipe.predict_proba)\n", "te.show_prediction(target_names=twenty_train.target_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Why it works\n", "\n", "Explanation makes sense - we expect reasonable classifier to take highlighted words in account. But how can we be sure this is how the pipeline works, not just a nice-looking lie? A simple sanity check is to remove or change the highlighted words, to confirm that they change the outcome:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.068 alt.atheism\n", "0.149 comp.graphics\n", "0.369 sci.med\n", "0.414 soc.religion.christian\n" ] } ], "source": [ "import re\n", "doc2 = re.sub(r'(recall|kidney|stones|medication|pain|tech)', '', doc, flags=re.I)\n", "print_prediction(doc2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predicted probabilities changed a lot indeed. \n", "\n", "And in fact, `TextExplainer` did something similar to get the explanation. `TextExplainer` generated a lot of texts similar to the document (by removing some of the words), and then trained a white-box classifier which predicts the output of the black-box classifier (not the true labels!). The explanation we saw is for this white-box classifier.\n", "\n", "This approach follows the LIME algorithm; for text data the algorithm is actually pretty straightforward:\n", "\n", "1. generate distorted versions of the text;\n", "2. predict probabilities for these distorted texts \n", " using the black-box classifier;\n", "3. train another classifier (one of those eli5 supports) which \n", " tries to predict output of a black-box classifier on these texts.\n", "\n", "The algorithm works because even though it could be hard or impossible to approximate a black-box classifier globally (for every possible text), approximating it in a small neighbourhood near a given text often works well, even with simple white-box classifiers.\n", "\n", "Generated samples (distorted texts) are available in `samples_` attribute:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "As my kidney , isn' any\n", " can .\n", "\n", "Either they , be , \n", "to .\n", "\n", " , - tech to mention ' had kidney\n", " and , .\n" ] } ], "source": [ "print(te.samples_[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default `TextExplainer` generates 5000 distorted texts (use `n_samples` argument to change the amount):" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "5000" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(te.samples_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Trained white-box classifier and vectorizer are available as `vec_` and `clf_` attributes:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "(CountVectorizer(analyzer='word', binary=False, decode_error='strict',\n", " dtype=, encoding='utf-8', input='content',\n", " lowercase=True, max_df=1.0, max_features=None, min_df=1,\n", " ngram_range=(1, 2), preprocessor=None, stop_words=None,\n", " strip_accents=None, token_pattern='(?u)\\\\b\\\\w+\\\\b', tokenizer=None,\n", " vocabulary=None),\n", " SGDClassifier(alpha=0.001, average=False, class_weight=None, epsilon=0.1,\n", " eta0=0.0, fit_intercept=True, l1_ratio=0.15,\n", " learning_rate='optimal', loss='log', n_iter=5, n_jobs=1,\n", " penalty='elasticnet', power_t=0.5,\n", " random_state=,\n", " shuffle=True, verbose=0, warm_start=False))" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "te.vec_, te.clf_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Should we trust the explanation?\n", "\n", "Ok, this sounds fine, but how can we be sure that this simple text classification pipeline approximated the black-box classifier well?\n", "\n", "One way to do that is to check the quality on a held-out dataset (which is also generated). `TextExplainer` does that by default and stores metrics in `metrics_` attribute:" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "{'mean_KL_divergence': 0.020277596015756863, 'score': 0.98684669657535129}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "te.metrics_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "* 'score' is an accuracy score weighted by cosine distance between generated sample and the original document (i.e. texts which are closer to the example are more important). Accuracy shows how good are 'top 1' predictions. \n", "* 'mean_KL_divergence' is a mean [Kullback–Leibler divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) for all target classes; it is also weighted by distance. KL divergence shows how well are probabilities approximated; 0.0 means a perfect match.\n", "\n", "In this example both accuracy and KL divergence are good; it means our white-box classifier usually assigns the same labels as the black-box classifier on the dataset we generated, and its predicted probabilities are close to those predicted by our LSA+SVM pipeline. So it is likely (though not guaranteed, we'll discuss it later) that the explanation is correct and can be trusted.\n", "\n", "When working with LIME (e.g. via `TextExplainer`) it is always a good idea to check these scores. If they are not good then you can tell that something is not right." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Let's make it fail\n", "\n", "By default `TextExplainer` uses a very basic text processing pipeline: Logistic Regression trained on bag-of-words and bag-of-bigrams features (see `te.clf_` and `te.vec_` attributes). It limits a set of black-box classifiers it can explain: because the text is seen as \"bag of words/ngrams\", the default white-box pipeline can't distinguish e.g. between the same word in the beginning of the document and in the end of the document. Bigrams help to alleviate the problem in practice, but not completely. \n", "\n", "Black-box classifiers which use features like \"text length\" (not directly related to tokens) can be also hard to approximate using the default bag-of-words/ngrams model. \n", "\n", "This kind of failure is usually detectable though - scores (accuracy and KL divergence) will be low. Let's check it on a completely synthetic example - a black-box classifier which assigns a class based on oddity of document length and on a presence of 'medication' word." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=sci.med\n", " \n", "\n", "\n", " \n", " (probability 0.998, score 6.380)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " +6.445\n", " \n", " Highlighted in text (sum)\n", "
\n", " -0.065\n", " \n", " <BIAS>\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " as i recall from my bout with kidney stones, there isn't any\n", "medication that can do anything about them except relieve the pain.\n", "\n", "either they pass, or they have to be broken up with sound, or they have\n", "to be extracted surgically.\n", "\n", "when i was in, the x-ray tech happened to mention that she'd had kidney\n", "stones and children, and the childbirth hurt less.\n", "

\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "\n", "def predict_proba_len(docs):\n", " # nasty predict_proba - the result is based on document length,\n", " # and also on a presence of \"medication\"\n", " proba = [\n", " [0, 0, 1.0, 0] if len(doc) % 2 or 'medication' in doc else [1.0, 0, 0, 0] \n", " for doc in docs\n", " ]\n", " return np.array(proba) \n", "\n", "te3 = TextExplainer().fit(doc, predict_proba_len)\n", "te3.show_prediction(target_names=twenty_train.target_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`TextExplainer` correctly figured out that 'medication' is important, but failed to account for \"len(doc) % 2\" condition, so the explanation is incomplete. We can detect this failure by looking at metrics - they are low:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "{'mean_KL_divergence': 0.29813769123006623, 'score': 0.80148602213214504}" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "te3.metrics_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If (a big if...) we suspect that the fact document length is even or odd is important, it is possible to customize `TextExplainer` to check this hypothesis. \n", "\n", "To do that, we need to create a vectorizer which returns both \"is odd\" feature and bag-of-words features, and pass this vectorizer to `TextExplainer`. This vectorizer should follow scikit-learn API. The easiest way is to use `FeatureUnion` - just make sure all transformers joined by `FeatureUnion` have `get_feature_names()` methods." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'score': 1.0, 'mean_KL_divergence': 0.0247695693408547}\n" ] }, { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=sci.med\n", " \n", "\n", "\n", " \n", " (probability 0.997, score 5.654)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " +8.864\n", " \n", " countvectorizer: Highlighted in text (sum)\n", "
\n", " -0.083\n", " \n", " <BIAS>\n", "
\n", " -3.128\n", " \n", " doclength__is_even\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " countvectorizer: as i recall from my bout with kidney stones, there isn't any\n", "medication that can do anything about them except relieve the pain.\n", "\n", "either they pass, or they have to be broken up with sound, or they have\n", "to be extracted surgically.\n", "\n", "when i was in, the x-ray tech happened to mention that she'd had kidney\n", "stones and children, and the childbirth hurt less\n", "

\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "Explanation(estimator=\"SGDClassifier(alpha=0.001, average=False, class_weight=None, epsilon=0.1,\\n eta0=0.0, fit_intercept=True, l1_ratio=0.15,\\n learning_rate='optimal', loss='log', n_iter=5, n_jobs=1,\\n penalty='elasticnet', power_t=0.5,\\n random_state=,\\n shuffle=True, verbose=0, warm_start=False)\", description=None, error=None, method='linear model', is_regression=False, targets=[TargetExplanation(target='sci.med', feature_weights=FeatureWeights(pos=[FeatureWeight(feature='countvectorizer__medication', weight=4.9077070275088301, std=None, value=1.0), FeatureWeight(feature='countvectorizer__any medication', weight=1.409764951748643, std=None, value=1.0), FeatureWeight(feature='countvectorizer__medication that', weight=1.4075922954882905, std=None, value=1.0), FeatureWeight(feature='countvectorizer__with', weight=0.23348147145315884, std=None, value=2.0), FeatureWeight(feature='countvectorizer__stones there', weight=0.16161551402103283, std=None, value=1.0), FeatureWeight(feature='countvectorizer__the', weight=0.13270154381365193, std=None, value=3.0), FeatureWeight(feature='countvectorizer__have to', weight=0.12980991179263227, std=None, value=2.0), FeatureWeight(feature='countvectorizer__or', weight=0.11985977768061165, std=None, value=2.0), FeatureWeight(feature='countvectorizer__them', weight=0.11240293217075215, std=None, value=1.0), FeatureWeight(feature='countvectorizer__to', weight=0.09242458490228754, std=None, value=3.0), FeatureWeight(feature='countvectorizer__children', weight=0.084783452059869369, std=None, value=1.0), FeatureWeight(feature='countvectorizer__relieve', weight=0.070809994007045157, std=None, value=1.0), FeatureWeight(feature='countvectorizer__as recall', weight=0.062726317704725823, std=None, value=1.0), FeatureWeight(feature='countvectorizer__and', weight=0.060329695557079885, std=None, value=2.0), FeatureWeight(feature='countvectorizer__isn', weight=0.055731230297561586, std=None, value=1.0), FeatureWeight(feature='countvectorizer__was', weight=0.055574937183372843, std=None, value=1.0), FeatureWeight(feature='countvectorizer__pass', weight=0.051479578351954741, std=None, value=1.0), FeatureWeight(feature='countvectorizer__anything', weight=0.045478923958651947, std=None, value=1.0), FeatureWeight(feature='countvectorizer__from', weight=0.042921373696465967, std=None, value=1.0), FeatureWeight(feature='countvectorizer__anything about', weight=0.03140271572138572, std=None, value=1.0), FeatureWeight(feature='countvectorizer__can', weight=0.025449656672696049, std=None, value=1.0), FeatureWeight(feature='countvectorizer__except relieve', weight=0.023126864736359332, std=None, value=1.0), FeatureWeight(feature='countvectorizer__relieve the', weight=0.020249690687074631, std=None, value=1.0), FeatureWeight(feature='countvectorizer__up', weight=0.017726204502091284, std=None, value=1.0), FeatureWeight(feature='countvectorizer__bout with', weight=0.012887647445101694, std=None, value=1.0), FeatureWeight(feature='countvectorizer__from my', weight=0.0072005020063512009, std=None, value=1.0)], neg=[FeatureWeight(feature='doclength__is_even', weight=-3.1275898258396633, std=None, value=1.0), FeatureWeight(feature='countvectorizer__that', weight=-0.22019037154050045, std=None, value=2.0), FeatureWeight(feature='countvectorizer__any', weight=-0.16921983637837387, std=None, value=1.0), FeatureWeight(feature='', weight=-0.082964516417960252, std=None, value=1.0), FeatureWeight(feature='countvectorizer__that can', weight=-0.067580541098713517, std=None, value=1.0), FeatureWeight(feature='countvectorizer__was in', weight=-0.029424751811322588, std=None, value=1.0), FeatureWeight(feature='countvectorizer__extracted', weight=-0.024172514413087477, std=None, value=1.0), FeatureWeight(feature='countvectorizer__either they', weight=-0.00058697237704910565, std=None, value=1.0)], pos_remaining=0, neg_remaining=0), proba=0.99650704965986392, score=5.6535094652910081, weighted_spans=WeightedSpans(docs_weighted_spans=[DocWeightedSpans(document=\"as i recall from my bout with kidney stones, there isn't any\\nmedication that can do anything about them except relieve the pain.\\n\\neither they pass, or they have to be broken up with sound, or they have\\nto be extracted surgically.\\n\\nwhen i was in, the x-ray tech happened to mention that she'd had kidney\\nstones and children, and the childbirth hurt less\", spans=[('from', [(12, 16)], 0.042921373696465967), ('with', [(25, 29)], 0.23348147145315884), ('isn', [(51, 54)], 0.055731230297561586), ('any', [(57, 60)], -0.16921983637837387), ('medication', [(61, 71)], 4.9077070275088301), ('that', [(72, 76)], -0.22019037154050045), ('can', [(77, 80)], 0.025449656672696049), ('anything', [(84, 92)], 0.045478923958651947), ('them', [(99, 103)], 0.11240293217075215), ('relieve', [(111, 118)], 0.070809994007045157), ('the', [(119, 122)], 0.13270154381365193), ('pass', [(142, 146)], 0.051479578351954741), ('or', [(148, 150)], 0.11985977768061165), ('to', [(161, 163)], 0.09242458490228754), ('up', [(174, 176)], 0.017726204502091284), ('with', [(177, 181)], 0.23348147145315884), ('or', [(189, 191)], 0.11985977768061165), ('to', [(202, 204)], 0.09242458490228754), ('extracted', [(208, 217)], -0.024172514413087477), ('was', [(238, 241)], 0.055574937183372843), ('the', [(246, 249)], 0.13270154381365193), ('to', [(270, 272)], 0.09242458490228754), ('that', [(281, 285)], -0.22019037154050045), ('and', [(310, 313)], 0.060329695557079885), ('children', [(314, 322)], 0.084783452059869369), ('and', [(324, 327)], 0.060329695557079885), ('the', [(328, 331)], 0.13270154381365193), ('as recall', [(0, 2), (5, 11)], 0.062726317704725823), ('from my', [(12, 16), (17, 19)], 0.0072005020063512009), ('bout with', [(20, 24), (25, 29)], 0.012887647445101694), ('stones there', [(37, 43), (45, 50)], 0.16161551402103283), ('any medication', [(57, 60), (61, 71)], 1.409764951748643), ('medication that', [(61, 71), (72, 76)], 1.4075922954882905), ('that can', [(72, 76), (77, 80)], -0.067580541098713517), ('anything about', [(84, 92), (93, 98)], 0.03140271572138572), ('except relieve', [(104, 110), (111, 118)], 0.023126864736359332), ('relieve the', [(111, 118), (119, 122)], 0.020249690687074631), ('either they', [(130, 136), (137, 141)], -0.00058697237704910565), ('have to', [(156, 160), (161, 163)], 0.12980991179263227), ('have to', [(197, 201), (202, 204)], 0.12980991179263227), ('was in', [(238, 241), (242, 244)], -0.029424751811322588)], preserve_density=False, vec_name='countvectorizer')], other=FeatureWeights(pos=[FeatureWeight(feature=, weight=8.8640638075486287, std=None, value=None)], neg=[FeatureWeight(feature='doclength__is_even', weight=-3.1275898258396633, std=None, value=1.0), FeatureWeight(feature='', weight=-0.082964516417960252, std=None, value=1.0)], pos_remaining=0, neg_remaining=0)))], feature_importances=None, decision_tree=None, highlight_spaces=None, transition_features=None)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.pipeline import make_union\n", "from sklearn.feature_extraction.text import CountVectorizer\n", "from sklearn.base import TransformerMixin\n", "\n", "class DocLength(TransformerMixin):\n", " def fit(self, X, y=None): # some boilerplate\n", " return self\n", " \n", " def transform(self, X):\n", " return [\n", " # note that we needed both positive and negative \n", " # feature - otherwise for linear model there won't \n", " # be a feature to show in a half of the cases\n", " [len(doc) % 2, not len(doc) % 2] \n", " for doc in X\n", " ]\n", " \n", " def get_feature_names(self):\n", " return ['is_odd', 'is_even']\n", "\n", "vec = make_union(DocLength(), CountVectorizer(ngram_range=(1,2)))\n", "te4 = TextExplainer(vec=vec).fit(doc[:-1], predict_proba_len)\n", "\n", "print(te4.metrics_)\n", "te4.explain_prediction(target_names=twenty_train.target_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Much better! It was a toy example, but the idea stands - if you think something could be important, add it to the mix as a feature for `TextExplainer`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Let's make it fail, again\n", "\n", "Another possible issue is the dataset generation method. Not only feature extraction should be powerful enough, but auto-generated texts also should be diverse enough. \n", "\n", "`TextExplainer` removes random words by default, so by default it can't e.g. provide a good explanation for a black-box classifier which works on character level. Let's try to use `TextExplainer` to explain a classifier which uses char ngrams as features:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "0.87017310252996005" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.feature_extraction.text import HashingVectorizer\n", "from sklearn.linear_model import SGDClassifier\n", "\n", "vec_char = HashingVectorizer(analyzer='char_wb', ngram_range=(4,5))\n", "clf_char = SGDClassifier(loss='log')\n", "\n", "pipe_char = make_pipeline(vec_char, clf_char)\n", "pipe_char.fit(twenty_train.data, twenty_train.target)\n", "pipe_char.score(twenty_test.data, twenty_test.target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This pipeline is supported by eli5 directly, so in practice there is no need to use `TextExplainer` for it. We're using this pipeline as an example - it is possible check the \"true\" explanation first, without using `TextExplainer`, and then compare the results with `TextExplainer` results." ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=sci.med\n", " \n", "\n", "\n", " \n", " (probability 0.572, score -0.116)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " +0.880\n", " \n", " Highlighted in text (sum)\n", "
\n", " -0.995\n", " \n", " <BIAS>\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " as i recall from my bout with kidney stones, there isn't any\n", "medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have\n", "to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney\n", "stones and children, and the childbirth hurt less.\n", "

\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "eli5.show_prediction(clf_char, doc, vec=vec_char,\n", " targets=['sci.med'], target_names=twenty_train.target_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`TextExplainer` produces a different result:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'score': 0.93454054240068041, 'mean_KL_divergence': 0.014021429806131684}\n" ] }, { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=sci.med\n", " \n", "\n", "\n", " \n", " (probability 0.564, score 0.602)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " +0.982\n", " \n", " Highlighted in text (sum)\n", "
\n", " -0.380\n", " \n", " <BIAS>\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " as i recall from my bout with kidney stones, there isn't any\n", "medication that can do anything about them except relieve the pain.\n", "\n", "either they pass, or they have to be broken up with sound, or they have\n", "to be extracted surgically.\n", "\n", "when i was in, the x-ray tech happened to mention that she'd had kidney\n", "stones and children, and the childbirth hurt less.\n", "

\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "te = TextExplainer(random_state=42).fit(doc, pipe_char.predict_proba)\n", "print(te.metrics_)\n", "te.show_prediction(targets=['sci.med'], target_names=twenty_train.target_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Scores look OK but not great; the explanation kind of makes sense on a first sight, but we know that the classifier works in a different way. \n", "\n", "To explain such black-box classifiers we need to change both dataset generation method (change/remove individual characters, not only words) and feature extraction method (e.g. use char ngrams instead of words and word ngrams).\n", "\n", "`TextExplainer` has an option (`char_based=True`) to use char-based sampling and char-based classifier. If this makes a more powerful explanation engine why not always use it?" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'score': 0.52104555439486744, 'mean_KL_divergence': 0.19554815684055157}\n" ] }, { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=sci.med\n", " \n", "\n", "\n", " \n", " (probability 0.360, score 0.043)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " +0.241\n", " \n", " Highlighted in text (sum)\n", "
\n", " -0.198\n", " \n", " <BIAS>\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " as i recall from my bout with kidney stones, there isn't any\n", "medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have\n", "to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney\n", "stones and children, and the childbirth hurt less.\n", "

\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "te = TextExplainer(char_based=True, random_state=42)\n", "te.fit(doc, pipe_char.predict_proba)\n", "print(te.metrics_)\n", "te.show_prediction(targets=['sci.med'], target_names=twenty_train.target_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hm, the result look worse. `TextExplainer` detected correctly that only the first part of word \"medication\" is important, but the result is noisy overall, and scores are bad. Let's try it with more samples:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'score': 0.85575209964207921, 'mean_KL_divergence': 0.071035516578501337}\n" ] }, { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=sci.med\n", " \n", "\n", "\n", " \n", " (probability 0.648, score 0.749)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " +0.962\n", " \n", " Highlighted in text (sum)\n", "
\n", " -0.213\n", " \n", " <BIAS>\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " as i recall from my bout with kidney stones, there isn't any\n", "medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have\n", "to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney\n", "stones and children, and the childbirth hurt less.\n", "

\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "te = TextExplainer(char_based=True, n_samples=50000, random_state=42)\n", "te.fit(doc, pipe_char.predict_proba)\n", "print(te.metrics_)\n", "te.show_prediction(targets=['sci.med'], target_names=twenty_train.target_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is getting closer, but still not there yet. The problem is that it is much more resource intensive - you need a lot more samples to get non-noisy results. Here explaining a single example took more time than training the original pipeline.\n", "\n", "Generally speaking, to do an efficient explanation we should make some assumptions about black-box classifier, such as:\n", "\n", "1. it uses words as features and doesn't take word position in account;\n", "2. it uses words as features and takes word positions in account;\n", "3. it uses words ngrams as features;\n", "4. it uses char ngrams as features, positions don't matter (i.e. an ngram means the same everywhere);\n", "5. it uses arbitrary attention over the text characters, i.e. every part of text could be potentionally important for a classifier on its own;\n", "6. it is important to have a particular token at a particular position, e.g. \"third token is X\", and if we delete 2nd token then prediction changes not because 2nd token changed, but because 3rd token is shifted.\n", "\n", "Depending on assumptions we should choose both dataset generation method and a white-box classifier. There is a tradeoff between generality and speed. \n", "\n", "Simple bag-of-words assumptions allow for fast sample generation, and just a few hundreds of samples could be required to get an OK quality if the assumption is correct. But such generation methods / models will fail to explain a more complex classifier properly (they could still provide an explanation which is useful in practice though). \n", "\n", "On the other hand, allowing for each character to be important is a more powerful method, but it can require a lot of samples (maybe hundreds thousands) and a lot of CPU time to get non-noisy results.\n", "\n", "What's bad about this kind of failure (wrong assumption about the black-box pipeline) is that it could be impossible to detect the failure by looking at the scores. Scores could be high because generated dataset is not diverse enough, not because our approximation is good.\n", "\n", "The takeaway is that it is important to understand the \"lenses\" you're looking through when using LIME to explain a prediction." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Customizing TextExplainer: sampling\n", "\n", "`TextExplainer` uses `MaskingTextSampler` or `MaskingTextSamplers` instances to generate texts to train on. `MaskingTextSampler` is the main text generation class; `MaskingTextSamplers` provides a way to combine multiple samplers in a single object with the same interface.\n", "\n", "A custom sampler instance can be passed to `TextExplainer` if we want to experiment with sampling. For example, let's try a sampler which replaces no more than 3 characters in the text (default is to replace a random number of characters):" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "As I recall from my bout with kidney stones, there isn't any\n", "medication that can do anything about them except relieve the pain\n", "\n", "Either thy pass, or they have to be broken up with sound, or they have\n", "to be extracted surgically.\n", "\n", "When I was in, the X-ray tech happened to mention that she'd had kidney\n", "stones and children, and the childbirth hurt less.\n" ] } ], "source": [ "from eli5.lime.samplers import MaskingTextSampler\n", "sampler = MaskingTextSampler(\n", " # Regex to split text into tokens.\n", " # \".\" means any single character is a token, i.e.\n", " # we work on chars.\n", " token_pattern='.',\n", "\n", " # replace no more than 3 tokens\n", " max_replace=3,\n", "\n", " # by default all tokens are replaced;\n", " # replace only a token at a given position.\n", " bow=False,\n", ")\n", "samples, similarity = sampler.sample_near(doc)\n", "print(samples[0])" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'score': 1.0, 'mean_KL_divergence': 1.0004596970275623}\n" ] }, { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "

\n", " \n", " \n", " y=sci.med\n", " \n", "\n", "\n", " \n", " (probability 0.970, score 4.522)\n", "\n", "top features\n", "

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", "
\n", " Contribution?\n", " Feature
\n", " +4.512\n", " \n", " Highlighted in text (sum)\n", "
\n", " +0.010\n", " \n", " <BIAS>\n", "
\n", "\n", " \n", "\n", "\n", "\n", "

\n", " as i recall from my bout with kidney stones, there isn't any\n", "medication that can do anything about them except relieve the pain. either they pass, or they have to be broken up with sound, or they have\n", "to be extracted surgically. when i was in, the x-ray tech happened to mention that she'd had kidney\n", "stones and children, and the childbirth hurt less.\n", "

\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "te = TextExplainer(char_based=True, sampler=sampler, random_state=42)\n", "te.fit(doc, pipe_char.predict_proba)\n", "print(te.metrics_)\n", "te.show_prediction(targets=['sci.med'], target_names=twenty_train.target_names)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that accuracy score is perfect, but KL divergence is bad. It means this sampler was not very useful: most generated texts were \"easy\" in sense that most (or all?) of them should be still classified as `sci.med`, so it was easy to get a good accuracy. But because generated texts were not diverse enough classifier haven't learned anything useful; it's having a hard time predicting the probability output of the black-box pipeline on a held-out dataset.\n", "\n", "By default `TextExplainer` uses a mix of several sampling strategies which seems to work OK for token-based explanations. But a good sampling strategy which works for many real-world tasks could be a research topic on itself. If you've got some experience with it we'd love to hear from you - please share your findings in eli5 issue tracker ( https://github.com/TeamHG-Memex/eli5/issues )!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Customizing TextExplainer: classifier\n", "\n", "In one of the previous examples we already changed the vectorizer TextExplainer uses (to take additional features in account). It is also possible to change the white-box classifier - for example, use a small decision tree:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'score': 0.9838155527960798, 'mean_KL_divergence': 0.03812615869329402}\n" ] }, { "data": { "text/html": [ "\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
WeightFeature
\n", " 0.5449\n", " \n", " \n", " kidney\n", "
\n", " 0.4551\n", " \n", " \n", " pain\n", "
\n", " \n", "\n", " \n", "\n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", "\n", " \n", " \n", "
\n", "
\n",
       "\n",
       "Tree\n",
       "\n",
       "\n",
       "0\n",
       "\n",
       "kidney <= 0.5\n",
       "gini = 0.1595\n",
       "samples = 100.0%\n",
       "value = [0.01, 0.03, 0.92, 0.04]\n",
       "\n",
       "\n",
       "1\n",
       "\n",
       "pain <= 0.5\n",
       "gini = 0.3893\n",
       "samples = 38.9%\n",
       "value = [0.03, 0.1, 0.77, 0.11]\n",
       "\n",
       "\n",
       "0->1\n",
       "\n",
       "\n",
       "True\n",
       "\n",
       "\n",
       "4\n",
       "\n",
       "pain <= 0.5\n",
       "gini = 0.0474\n",
       "samples = 61.1%\n",
       "value = [0.0, 0.01, 0.98, 0.01]\n",
       "\n",
       "\n",
       "0->4\n",
       "\n",
       "\n",
       "False\n",
       "\n",
       "\n",
       "2\n",
       "\n",
       "gini = 0.5253\n",
       "samples = 28.4%\n",
       "value = [0.04, 0.14, 0.65, 0.16]\n",
       "\n",
       "\n",
       "1->2\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "3\n",
       "\n",
       "gini = 0.0445\n",
       "samples = 10.6%\n",
       "value = [0.0, 0.0, 0.98, 0.02]\n",
       "\n",
       "\n",
       "1->3\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "5\n",
       "\n",
       "gini = 0.1194\n",
       "samples = 22.8%\n",
       "value = [0.01, 0.02, 0.94, 0.04]\n",
       "\n",
       "\n",
       "4->5\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "6\n",
       "\n",
       "gini = 0.0121\n",
       "samples = 38.2%\n",
       "value = [0.0, 0.0, 0.99, 0.0]\n",
       "\n",
       "\n",
       "4->6\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "\n",
       "
\n", " \n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.tree import DecisionTreeClassifier\n", "\n", "te5 = TextExplainer(clf=DecisionTreeClassifier(max_depth=2), random_state=0)\n", "te5.fit(doc, pipe.predict_proba)\n", "print(te5.metrics_)\n", "te5.show_weights()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How to read it: \"kidney <= 0.5\" means \"word 'kidney' is not in the document\" (we're explaining the orginal LDA+SVM pipeline again).\n", "\n", "So according to this tree if \"kidney\" is not in the document and \"pain\" is not in the document then the probability of a document belonging to `sci.med` drops to `0.65`. If at least one of these words remain `sci.med` probability stays `0.9+`." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "both words removed::\n", "0.014 alt.atheism\n", "0.024 comp.graphics\n", "0.891 sci.med\n", "0.072 soc.religion.christian\n", "\n", "only 'pain' removed:\n", "0.002 alt.atheism\n", "0.004 comp.graphics\n", "0.978 sci.med\n", "0.015 soc.religion.christian\n" ] } ], "source": [ "print(\"both words removed::\")\n", "print_prediction(re.sub(r\"(kidney|pain)\", \"\", doc, flags=re.I))\n", "print(\"\\nonly 'pain' removed:\")\n", "print_prediction(re.sub(r\"pain\", \"\", doc, flags=re.I))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As expected, after removing both words probability of `sci.med` decreased, though not as much as our simple decision tree predicted (to 0.9 instead of 0.64). Removing `pain` provided exactly the same effect as predicted - probability of `sci.med` became `0.98`." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" } }, "nbformat": 4, "nbformat_minor": 2 }