{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "TODO:\n", "\n", "* implement Algorithm 1 \n", "* apply it to the example from [here](https://medium.com/analytics-vidhya/shap-part-3-tree-shap-3af9bcd7cd9b#:~:text=SHAP%20(SHapley%20Additive%20exPlanation)%20is,from%20it's%20individual%20feature%20values.&text=f%E2%82%9B()%20represents%20the%20prediction,model%20for%20the%20subset%20S.)\n", "* use shap to reproduce the manually calculated outputs" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import shap\n", "import matplotlib.pyplot as plt\n", "from sklearn.tree import DecisionTreeRegressor, plot_tree" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import test_tree_shap" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# for auto-reloading external modules\n", "# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython\n", "get_ipython().magic(\"load_ext autoreload\")\n", "get_ipython().magic(\"autoreload 2\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "data, tree = test_tree_shap.toy_tree()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import tree_shap" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "assert tree_shap.naive_tree_shap(tree, current_node=0, features={0: 150}) == 20\n", "assert tree_shap.naive_tree_shap(tree, current_node=0, features={1: 75}) == 27\n", "\n", "# given feature 0, having feature 1 doesn't make a difference\n", "assert (\n", " tree_shap.naive_tree_shap(tree, current_node=0, features={0: 150, 1: 75})\n", " == 20\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import itertools" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "features_tuple = ((0, 150), (1, 75), (2, 200))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "23.0" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "basis = data['y'].mean()\n", "basis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Need to rename the variables here to make it more readable**" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "all_phis = []\n", "for permutation in itertools.permutations(features_tuple):\n", " phis = {'basis': basis}\n", " for i in range(len(permutation)):\n", " phi_raw = tree_shap.naive_tree_shap(tree, current_node=0, features=dict(permutation[:i + 1]))\n", " phi = phi_raw - sum(phis.values())\n", " phis[permutation[i][0]] = phi\n", " all_phis.append(phis)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/html": [ "[{'basis': 23.0, 0: -3.0, 1: 0.0, 2: 0.0}, {'basis': 23.0, 0: -3.0, 2: 0.0, 1: 0.0}, {'basis': 23.0, 1: 4.0, 0: -7.0, 2: 0.0}, {'basis': 23.0, 1: 4.0, 2: 0.0, 0: -7.0}, {'basis': 23.0, 2: 0.0, 0: -3.0, 1: 0.0}, {'basis': 23.0, 2: 0.0, 1: 4.0, 0: -7.0}]" ], "text/plain": [ "[{'basis': 23.0, 0: -3.0, 1: 0.0, 2: 0.0},\n", " {'basis': 23.0, 0: -3.0, 2: 0.0, 1: 0.0},\n", " {'basis': 23.0, 1: 4.0, 0: -7.0, 2: 0.0},\n", " {'basis': 23.0, 1: 4.0, 2: 0.0, 0: -7.0},\n", " {'basis': 23.0, 2: 0.0, 0: -3.0, 1: 0.0},\n", " {'basis': 23.0, 2: 0.0, 1: 4.0, 0: -7.0}]" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "all_phis" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "basis 23.0\n", "0 -5.0\n", "1 2.0\n", "2 0.0\n", "dtype: float64" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(all_phis).mean(axis=0)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "from collections import OrderedDict" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "unhashable type: 'slice'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mOrderedDict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures_tuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mTypeError\u001b[0m: unhashable type: 'slice'" ] } ], "source": [ "OrderedDict(features_tuple)[:2]" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/html": [ "[(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)]" ], "text/plain": [ "[(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)]" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(itertools.permutations(OrderedDict(features_tuple)))" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((2, 200), (1, 75), (0, 150))" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "permutation" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "[23.0, 0.0, 4.0, -7.0]" ], "text/plain": [ "[23.0, 0.0, 4.0, -7.0]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "phis" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((2, 200), (1, 75), (0, 150))" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "permutation" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-3., 0., 0.],\n", " [-3., 0., 0.],\n", " [ 4., -7., 0.],\n", " [ 4., 0., -7.],\n", " [ 0., -3., 0.],\n", " [ 0., 4., -7.]])" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.array(all_phis)[:, 1:]" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0.33333333, -1. , -2.33333333])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.array(all_phis)[:, 1:].mean(axis=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0, 1, 2)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "itertools.subpermutation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0, 1, -2, -2, 0, -2, -2])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.feature" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "array([[[15.]],\n", "\n", " [[40.]],\n", "\n", " [[50.]],\n", "\n", " [[30.]],\n", "\n", " [[10.]],\n", "\n", " [[20.]],\n", "\n", " [[10.]]])" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.tree_.value" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "matrix([[1, 0, 0, 0, 1, 1, 0]])" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.decision_path([[150, 75, 200]]).todense()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 1, 2, -1, -1, 5, -1, -1])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.tree_.children_left" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 4, 3, -1, -1, 6, -1, -1])" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.tree_.children_right" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.tree_." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([100., 300., -2., -2., 200., -2., -2.])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.tree_.threshold" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_ = plot_tree(tree, filled=True, proportion=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Explain for the example [x=150, y=75, z=200]" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "import sklearn" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([100., 300., -2., -2., 200., -2., -2.])" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.tree_.threshold" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0, 1, -2, -2, 0, -2, -2])" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.tree_.feature" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 1, 2, -1, -1, 5, -1, -1])" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.tree_.children_left" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 4, 3, -1, -1, 6, -1, -1])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.tree_.children_right" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 1, 2, -1, -1, 5, -1, -1])" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.tree_.children_left" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "from collections import OrderedDict" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "features = [[1, 75], [2, 200], [0, 150]]" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/html": [ "[[1, 75], [2, 200]]" ], "text/plain": [ "[[1, 75], [2, 200]]" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "features[:2]" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([10, 4, 2, 2, 6, 1, 5])" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.tree_.n_node_samples" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def decision_path(i):\n", " \"\"\"\n", " i: index of features\n", " \"\"\"\n", " if i == tree.tree_.feature[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "x > y > z" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "phi_x = 20 - phi_null" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "phi_y = 20 - phi_x - phi_null" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "phi_z = 20 - phi_y - phi_x - phi_null" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.0" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "phi_z" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "y > z > x" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "phi_y = (4 / 10) * 50 + (6 / 10) * (1 / 6 * 20 + 5 / 6 * 10) - phi_null" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4.0" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "phi_y" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "phi_z = 0" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "phi_x = 20 - phi_y - phi_null" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-7.0" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "phi_x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([20.])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tree.predict([[150, 75, 200]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "X_test = pd.DataFrame({'x': [150], 'y': [75], 'z': [200]})" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting feature_perturbation = \"tree_path_dependent\" because no background data was given.\n" ] }, { "data": { "text/plain": [ "array([[-5., 2., 0.]])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "explainer = shap.TreeExplainer(tree)\n", "shap_values = explainer.shap_values(X_test)\n", "shap_values" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 4 }