{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Tutorial to invoke SHAP explainers via aix360\n", "\n", "There are two ways to use [SHAP](https://github.com/slundberg/shap) explainers after installing aix360:\n", "- [Approach 1 (aix360 style)](#approach1): SHAP explainers can be invoked in a manner similar to other explainer algorithms in aix360 via the implemented wrapper classes.\n", "- [Approach 2 (original style)](#approach2): Since SHAP comes pre-installed in aix360, the explainers can simply be invoked directly.\n", "\n", "This notebook showcases both these approaches to invoke SHAP. The notebook is based on the following example from the original SHAP tutorial: \n", "https://slundberg.github.io/shap/notebooks/Iris%20classification%20with%20scikit-learn.html\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Approach 1 (aix360 style)\n", "\n", "- Note the import statement related to KernelExplainer" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from __future__ import print_function\n", "import sklearn\n", "from sklearn.model_selection import train_test_split\n", "import sklearn.datasets\n", "import sklearn.ensemble\n", "import numpy as np\n", "import time\n", "np.random.seed(1)\n", "\n", "# Importing shap KernelExplainer (aix360 style)\n", "from aix360.algorithms.shap import KernelExplainer\n", "\n", "# the following import is required for access to shap plotting functions and datasets\n", "import shap" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Supress jupyter warnings if required for cleaner output\n", "import warnings\n", "warnings.simplefilter('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### K-nearest neighbors" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Accuracy = 96.66666666666667%\n" ] } ], "source": [ "X_train,X_test,Y_train,Y_test = train_test_split(*shap.datasets.iris(), test_size=0.2, random_state=0)\n", "\n", "# rather than use the whole training set to estimate expected values, we could summarize with\n", "# a set of weighted kmeans, each weighted by the number of points they represent. But this dataset\n", "# is so small we don't worry about it\n", "#X_train_summary = shap.kmeans(X_train, 50)\n", "\n", "def print_accuracy(f):\n", " print(\"Accuracy = {0}%\".format(100*np.sum(f(X_test) == Y_test)/len(Y_test)))\n", " time.sleep(0.5) # to let the print get out before any progress bars\n", "\n", "shap.initjs()\n", "\n", "knn = sklearn.neighbors.KNeighborsClassifier()\n", "knn.fit(X_train, Y_train)\n", "\n", "print_accuracy(knn.predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explain a single prediction from the test set" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using 120 background data samples could cause slower run times. Consider using shap.kmeans(data, K) to summarize the background as K weighted samples.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "shapexplainer = KernelExplainer(knn.predict_proba, X_train)\n", "print(type(shapexplainer))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# aix360 style for explaining input instances\n", "shap_values = shapexplainer.explain_instance(X_test.iloc[0,:])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "shap.force_plot(shapexplainer.explainer.expected_value[0], shap_values[0], X_test.iloc[0,:])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explain all the predictions in the test set" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "53079d7d468a4a568d975ae624d1d419", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/html": [ "\n", "
\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# aix360 style for explaining input instances\n", "shap_values = shapexplainer.explain_instance(X_test)\n", "shap.force_plot(shapexplainer.explainer.expected_value[0], shap_values[0], X_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Approach 2 (original style)\n", "\n", "- Note the last import statement related to KernelExplainer" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "from __future__ import print_function\n", "import sklearn\n", "from sklearn.model_selection import train_test_split\n", "import sklearn.datasets\n", "import sklearn.ensemble\n", "import numpy as np\n", "import time\n", "np.random.seed(1)\n", "\n", "# Importing shap KernelExplainer (original style)\n", "import shap\n", "from shap import KernelExplainer" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Supress jupyter warnings if required for cleaner output\n", "import warnings\n", "warnings.simplefilter('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### K-nearest neighbors" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Accuracy = 96.66666666666667%\n" ] } ], "source": [ "X_train,X_test,Y_train,Y_test = train_test_split(*shap.datasets.iris(), test_size=0.2, random_state=0)\n", "\n", "# rather than use the whole training set to estimate expected values, we could summarize with\n", "# a set of weighted kmeans, each weighted by the number of points they represent. But this dataset\n", "# is so small we don't worry about it\n", "#X_train_summary = shap.kmeans(X_train, 50)\n", "\n", "def print_accuracy(f):\n", " print(\"Accuracy = {0}%\".format(100*np.sum(f(X_test) == Y_test)/len(Y_test)))\n", " time.sleep(0.5) # to let the print get out before any progress bars\n", "\n", "shap.initjs()\n", "\n", "knn = sklearn.neighbors.KNeighborsClassifier()\n", "knn.fit(X_train, Y_train)\n", "\n", "print_accuracy(knn.predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explain a single prediction from the test set" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using 120 background data samples could cause slower run times. Consider using shap.kmeans(data, K) to summarize the background as K weighted samples.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "explainer = KernelExplainer(knn.predict_proba, X_train)\n", "print(type(explainer))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Shap original style for explaining input instances\n", "shap_values = explainer.shap_values(X_test.iloc[0,:])\n", "shap.force_plot(explainer.expected_value[0], shap_values[0], X_test.iloc[0,:])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Explain all the predictions in the test set" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "236c4a11af794e15941943bac6ec73bf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=30), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/html": [ "\n", "
\n", "
\n", " Visualization omitted, Javascript library not loaded!
\n", " Have you run `initjs()` in this notebook? If this notebook was from another\n", " user you must also trust this notebook (File -> Trust notebook). If you are viewing\n", " this notebook on github the Javascript has been stripped for security. If you are using\n", " JupyterLab this error is because a JupyterLab extension has not yet been written.\n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Shap original style for explaining input instances\n", "shap_values = explainer.shap_values(X_test)\n", "shap.force_plot(explainer.expected_value[0], shap_values[0], X_test)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }