{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial to invoke LIME explainers via aix360\n", "\n", "There are two ways to use [LIME](https://github.com/marcotcr/lime) explainers after installing aix360:\n", "- [Approach 1 (aix360 style)](#approach1): LIME explainers can be invoked in a manner similar to other explainer algorithms in aix360 via the implemented wrapper classes.\n", "- [Approach 2 (original style)](#approach2): Since LIME comes pre-installed in aix360, the explainers can simply be invoked directly.\n", "\n", "This notebook showcases both these approaches to invoke LIME. The notebook is based on the following example from the original LIME tutorial: https://marcotcr.github.io/lime/tutorials/Lime%20-%20multiclass.html" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Approach 1 (aix360 style)\n", "\n", "- Note the import statement related to LimeTextExplainer" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from __future__ import print_function\n", "import sklearn\n", "import numpy as np\n", "import sklearn\n", "import sklearn.ensemble\n", "import sklearn.metrics\n", "\n", "# Importing LimeTextExplainer (aix360 sytle)\n", "from aix360.algorithms.lime import LimeTextExplainer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Supress jupyter warnings if required for cleaner output\n", "import warnings\n", "warnings.simplefilter('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fetching data, training a classifier" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import fetch_20newsgroups\n", "newsgroups_train = fetch_20newsgroups(subset='train')\n", "newsgroups_test = fetch_20newsgroups(subset='test')\n", "# making class names shorter\n", "class_names = [x.split('.')[-1] if 'misc' not in x else '.'.join(x.split('.')[-2:]) for x in newsgroups_train.target_names]\n", "class_names[3] = 'pc.hardware'\n", "class_names[4] = 'mac.hardware'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "atheism,graphics,ms-windows.misc,pc.hardware,mac.hardware,x,misc.forsale,autos,motorcycles,baseball,hockey,crypt,electronics,med,space,christian,guns,mideast,politics.misc,religion.misc\n" ] } ], "source": [ "print(','.join(class_names))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "vectorizer = sklearn.feature_extraction.text.TfidfVectorizer(lowercase=False)\n", "train_vectors = vectorizer.fit_transform(newsgroups_train.data)\n", "test_vectors = vectorizer.transform(newsgroups_test.data)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MultinomialNB(alpha=0.01)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.naive_bayes import MultinomialNB\n", "nb = MultinomialNB(alpha=.01)\n", "nb.fit(train_vectors, newsgroups_train.target)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8350184193998174" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred = nb.predict(test_vectors)\n", "sklearn.metrics.f1_score(newsgroups_test.target, pred, average='weighted')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explaining predictions using lime" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import make_pipeline\n", "c = make_pipeline(vectorizer, nb)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.001 0.01 0.003 0.047 0.006 0.002 0.003 0.521 0.022 0.008 0.025 0.\n", " 0.331 0.003 0.006 0. 0.003 0. 0.001 0.009]]\n" ] } ], "source": [ "print(c.predict_proba([newsgroups_test.data[0]]).round(3))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "limeexplainer = LimeTextExplainer(class_names=class_names)\n", "print(type(limeexplainer))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Document id: 1340\n", "Predicted class = atheism\n", "True class: atheism\n" ] } ], "source": [ "idx = 1340\n", "# aix360 style for explaining input instances\n", "exp = limeexplainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6, labels=[0, 17])\n", "print('Document id: %d' % idx)\n", "print('Predicted class =', class_names[nb.predict(test_vectors[idx]).reshape(1,-1)[0,0]])\n", "print('True class: %s' % class_names[newsgroups_test.target[idx]])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Explanation for class atheism\n", "('Caused', 0.26238077066590293)\n", "('Rice', 0.13518306047536863)\n", "('Genocide', 0.13057135539747292)\n", "('scri', -0.08935717895685585)\n", "('owlnet', -0.0866086512211189)\n", "('Semitic', -0.08450085710232705)\n", "\n", "Explanation for class mideast\n", "('fsu', -0.05515369453392734)\n", "('Theism', -0.054747456372004545)\n", "('Luther', -0.0470171558437435)\n", "('Caused', -0.03726710714147818)\n", "('jews', 0.03390695559844592)\n", "('PBS', 0.032592761622772665)\n" ] } ], "source": [ "print ('Explanation for class %s' % class_names[0])\n", "print ('\\n'.join(map(str, exp.as_list(label=0))))\n", "print ()\n", "print ('Explanation for class %s' % class_names[17])\n", "print ('\\n'.join(map(str, exp.as_list(label=17))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Approach 2 (original style)\n", "\n", "- Note the import statement related to LimeTextExplainer" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "from __future__ import print_function\n", "import sklearn\n", "import numpy as np\n", "import sklearn\n", "import sklearn.ensemble\n", "import sklearn.metrics\n", "\n", "# Importing LimeTextExplainer (original style)\n", "from lime.lime_text import LimeTextExplainer" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Supress jupyter warnings if required for cleaner output\n", "import warnings\n", "warnings.simplefilter('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fetching data, training a classifier" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets import fetch_20newsgroups\n", "newsgroups_train = fetch_20newsgroups(subset='train')\n", "newsgroups_test = fetch_20newsgroups(subset='test')\n", "# making class names shorter\n", "class_names = [x.split('.')[-1] if 'misc' not in x else '.'.join(x.split('.')[-2:]) for x in newsgroups_train.target_names]\n", "class_names[3] = 'pc.hardware'\n", "class_names[4] = 'mac.hardware'" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "atheism,graphics,ms-windows.misc,pc.hardware,mac.hardware,x,misc.forsale,autos,motorcycles,baseball,hockey,crypt,electronics,med,space,christian,guns,mideast,politics.misc,religion.misc\n" ] } ], "source": [ "print(','.join(class_names))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "vectorizer = sklearn.feature_extraction.text.TfidfVectorizer(lowercase=False)\n", "train_vectors = vectorizer.fit_transform(newsgroups_train.data)\n", "test_vectors = vectorizer.transform(newsgroups_test.data)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MultinomialNB(alpha=0.01)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.naive_bayes import MultinomialNB\n", "nb = MultinomialNB(alpha=.01)\n", "nb.fit(train_vectors, newsgroups_train.target)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8350184193998174" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred = nb.predict(test_vectors)\n", "sklearn.metrics.f1_score(newsgroups_test.target, pred, average='weighted')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explaining predictions using lime" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "from sklearn.pipeline import make_pipeline\n", "c = make_pipeline(vectorizer, nb)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[0.001 0.01 0.003 0.047 0.006 0.002 0.003 0.521 0.022 0.008 0.025 0.\n", " 0.331 0.003 0.006 0. 0.003 0. 0.001 0.009]]\n" ] } ], "source": [ "print(c.predict_proba([newsgroups_test.data[0]]).round(3))" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "explainer = LimeTextExplainer(class_names=class_names)\n", "print(type(explainer))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Document id: 1340\n", "Predicted class = atheism\n", "True class: atheism\n" ] } ], "source": [ "idx = 1340\n", "# LIME original style for explaining input instances\n", "exp = explainer.explain_instance(newsgroups_test.data[idx], c.predict_proba, num_features=6, labels=[0, 17])\n", "print('Document id: %d' % idx)\n", "print('Predicted class =', class_names[nb.predict(test_vectors[idx]).reshape(1,-1)[0,0]])\n", "print('True class: %s' % class_names[newsgroups_test.target[idx]])" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Explanation for class atheism\n", "('Caused', 0.25632751113248026)\n", "('Rice', 0.14626237282926857)\n", "('Genocide', 0.13133950204310213)\n", "('owlnet', -0.0905864608741577)\n", "('scri', -0.08526752642948768)\n", "('Semitic', -0.08287079245253995)\n", "\n", "Explanation for class mideast\n", "('fsu', -0.05714575197470851)\n", "('Theism', -0.05248892887178565)\n", "('Luther', -0.05034586274638904)\n", "('Rice', -0.03470048019899408)\n", "('jews', 0.03372175158157365)\n", "('PBS', 0.030816731955572742)\n" ] } ], "source": [ "print ('Explanation for class %s' % class_names[0])\n", "print ('\\n'.join(map(str, exp.as_list(label=0))))\n", "print ()\n", "print ('Explanation for class %s' % class_names[17])\n", "print ('\\n'.join(map(str, exp.as_list(label=17))))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" } }, "nbformat": 4, "nbformat_minor": 4 }