{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Comparing the result with Shap\n", "\n", "Shap is a method developed by Scott Lundberg et al that also tries to estimate Shapley value. \n", "The estimated Shapley values agree on linear systems. " ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "%matplotlib inline\n", "\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from keras.models import Sequential\n", "from keras.layers import Flatten, Dense, Dropout\n", "from keras.layers.core import Activation\n", "\n", "import sys\n", "sys.path.append(\"../\")\n", "from IntegratedGradients import *\n", "\n", "from shap import KernelExplainer, DenseData, visualize, initjs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loading data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [], "source": [ "X = np.array([[float(j) for j in i.rstrip().split(\",\")[:-1]] for i in open(\"iris.data\").readlines()][:-1])\n", "Y = np.array([0 for i in range(100)] + [1 for i in range(50)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training models" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = Sequential([\n", " Dense(1, input_dim=4),\n", " #Activation('sigmoid'),\n", "])\n", "model.compile(optimizer='sgd', loss='mean_squared_error')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [], "source": [ "history = model.fit(X, Y,\n", " epochs=300, batch_size=10,\n", " validation_split=0.1, verbose=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predict" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [], "source": [ "predictions = model.predict(X)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAABPCAYAAABs67OLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAAPYQAAD2EBqD+naQAAE1VJREFUeJzt3X9UVWW+x/HPJjBQhICQFBnSSg+GIJpFSQMXBTTHUfMW\nOUWNtzLGzLFJTDQTzTWjY2QW2hr8Oa30Ni2v1Dh3rKycESUzx4KU1EosUUEEBSERhHP/aHWux+do\naluy5v1ay7XYn/Ps/TzP5nDg6372OZbT6XQKAAAAAGzk9UMPAAAAAMBPD4UGAAAAANtRaAAAAACw\nHYUGAAAAANtRaAAAAACwHYUGAAAAANtRaAAAAACwHYUGAAAAANtRaAAAAACwHYUGAFxmsrOz5XA4\nFBUVJYfDIYfDoejoaKWlpSkvL09NTU2XpN/CwkI5HA59+OGHkqS8vDxFRUV9r/4yMjJ0zz332DVE\nAMCPiPcPPQAAgCkkJERr166V0+mUJNXV1amoqEjPPvusysrKlJube0n6tSzL9fWDDz6o0aNHq127\ndue9f3JysubOnav+/ftLkhYuXGj7GAEAPw4UGgBwGbIsS8HBwa7tkJAQdevWTTU1NVq0aJEmT56s\nsLCwSzoGPz8/+fn5nXf7yspKHTx40C0LCAiwe1gAgB8Jlk4BwI+Iw+GQJB06dEgZGRl69NFHtWDB\nAvXt21crV66UJNXX1+uZZ55RWlqaYmJilJKSosWLF7sdp76+XpMmTVK/fv3Uv39/TZo0SXV1da4r\nKJL04osvyuFwuC2dKigo0LBhwxQbG6uUlBS98MILamlp0datW5WYmCjLspSRkaGBAwdKMpdONTU1\nKTc3V8nJyYqOjtaAAQOUnZ2tmpoaV5vs7GyNGDFCW7du1Z133qk+ffooNTVVr7/+uttx5syZo+Tk\nZMXExCghIUFTpkzRsWPHbDzbAIDvgysaAPAjUlZWJknq3LmzJGnPnj3y9fVVQUGBQkJCJEnjx4/X\n7t27lZOTo169emnLli2aPXu2mpubNW7cOEnSrFmztGHDBs2ePVs33nijNm3apOeee85t6ZRlWW7b\na9eu1VNPPaUpU6YoKSlJe/bsUVZWlpqamjRx4kTl5uZq0qRJysvLU79+/TyO/6mnntKGDRs0ffp0\nxcXFad++fZoxY4bGjh2r1atXu9rV1NRo4cKFmjFjhq666ir94Q9/0PTp03XrrbcqLCxMixYt0rp1\n6zRv3jxFRkaqvLxcM2fO1OTJk5Wfn2/vSQcAXBQKDQD4ETh16pS2bNmi5cuXKzU11bVsqqKiQgUF\nBfL395cklZSUaMuWLZozZ47S0tIkSREREfrss8+0bNkyPfzww2ppadG6des0ZswYDRkyRJL0q1/9\nSnv37nVdFfEkPz9fycnJysjIcB33ySef1L59++Tt7e1aJhUYGKigoCBj/8rKSq1du1ZZWVn65S9/\n6TrGlClT9Nvf/lbbt29X3759JUlVVVVavny5rrvuOknSQw89pH/+858qLS1VWFiYSktL1bNnT918\n882SpLCwMC1evFh1dXXf70QDAGxDoQEAl6Hq6mrFxcW5tpuamuTj46Phw4drypQprjwiIsJVZEhS\ncXGxLMvSbbfd5na8+Ph4vfzyy/ryyy916tQpNTc3q1evXm5tTl9+daaTJ0/qs88+07Bhw9zy9PT0\n857Tzp07Jcm42hEXFyen06nS0lJXoeHn5+cqMiQpKChITqdTtbW1kqSBAwcqJydHEyZM0ODBgxUf\nH6+wsLBLft8KAOD8UWgAwGUoKChIf/nLX1zb3t7eCg0Nlbe3+8v2mTdb19fXy+l0avDgwW73Wzid\nTlmWpaqqKvn4+EiS2rdv77Zvhw4dzjqeb68UnKvNd6mvr5ckt8Lo9O2GhgZXdubYTl/CJX1T4Fxz\nzTVatWqVpk2bppMnTyo+Pl7Tpk1zK1AAAD8cCg0AuAx5eXkpIiLigvcLCAiQZVl6+eWXFRgYaDwe\nGhqqvXv3SpIaGxvdHjvXsqOgoCB5eXl9r6VJ3xZF3xYc3zp+/LgkqWPHjhd0vMTERCUmJqq5uVlF\nRUXKzc3V2LFj9e677170GAEA9uFdpwDgJ6RPnz5yOp06fPiwIiIiXP86duwoX19f+fr6KjIyUt7e\n3iouLnbb99sP6vPE29tb3bp1M9qsWrVKjzzyiGvb6XS6XUk5XXR0tCzLMo6xbds2WZalmJiY85qj\n0+nU+vXrVVFRIUny8fFRYmKiJkyYoIMHD3KfBgBcJriiAQA/ITfeeKMSEhL0zDPPqLW1VVFRUTpw\n4ID++Mc/SpJWr16tDh06KDk5Wa+99pr69Omjnj17qrCwUO+///45j/3www8rOztbf/rTnzRs2DDt\n2rVLCxYs0KhRoyTJdQVl06ZN6tixo6Kiotz2v/rqqzVy5Ejl5+erc+fOiomJ0Z49ezRnzhzFx8cr\nOjr6vOZoWZaWLFkiy7I0adIkde3aVdXV1Xr11VfVo0cPPrsDAC4TFBoAcBk6856EC5GXl6f58+dr\n9uzZOnLkiAIDAzVo0CA9/vjjrjazZs1STk6OnnzySXl5eSkxMVHTp093uzpx5jhGjBih1tZWLVu2\nTIsWLVKnTp2UkZGh3/zmN5Kk3r17a9CgQVqxYoXWrFmjwsJCY2wzZ85USEiIcnNzdfjwYQUHBys1\nNdVtbGeb/+nZokWLNHfuXE2cOFG1tbUKDg7WzTffrFmzZl3cSQMA2M5ynu0aNwAAAABcJO7RAAAA\nAGA7Cg0AAAAAtqPQAAAAAGA7Cg0AAAAAtqPQAAAAAGA7Cg0AAAAAtjvvz9H4srreyBqaW43Mz9t8\n73MvD++H3u4KMwv2MY8nSY0ehnmsscXIWjy8U299k3nMq3yvMDIfL8/vWX/weLORdWhn1mdXephP\nuPcJc4y+5gdJHT9pzkWSPJ0Nfx+z70P1p4wswt9sd/yUOcaOPp7nXefhvB05YY7zZwE+Rlb1tTme\nFg9vouzpOSBJp1rNxp7m7Ym3h+9juYfv4dn69qTJw+A7engOeOLfznyuHTje5LFtWAfzXF7dXG1k\nrVf6G1lFczsPfZtj9PQzK3n+H4cmD9+HLtv+28gav9pnZO2CrjIy767Xeey7pfIrDwMyz9uRrcVG\nFj7mESM7dXi/OcZd5r6S1H5whpHV/k++kXW8L8vs562lRubV0Zy3p7lIUs22j4ws7D5zPi37dxmZ\nd+drjax40tNGdmPWQx77bqmuMLIrQsPNhq3mz3zr8WNG5tXBfF2r3mR+hoYk+YUEGpn/UPP74NV4\n3Ozbt6N5wIq9RlSz4W2PfZ88Zv4eazreYGTXxJsfHOjXf5CRNX600ciuCLnGY9+engenDpUZ2ZWO\nm4zMCu5sZM07i86rD0lyNppz9AoIMTLvLt3MfvbuMLKJQ+caWeadPT323dRgvv4GdDWfL2XvmOfi\nmj5hRuZ3dXsj8w8P9dj3Fb7m6+KLOW8a2bABXY0s+Hrz/NR8br4e+/h6/jNq86ZyIwv08Hvs1lFR\nRvbJG7uNLCKmk5HFPn63x77X/VeekYVEmj934fE/MzKf9n5Gtvv1EiPrnnqDx74tL3OOnn7uaj47\namS9H/wPI6v94oCRlRd97rHvhsqvjeyWJ4ca2RV+5nPoowVrjex/PzxoZPel9/LYtye15XVGVvpx\npZF19fCcDuttPqdLN5q/2yTp0+Mnjey2a8y/Exz/GWNk1aXmHDd46Cf1Ds+/v2NfW+cxPx1XNAAA\nAADYjkIDAAAAgO0oNAAAAADYjkIDAAAAgO0oNAAAAADYjkIDAAAAgO0oNAAAAADYjkIDAAAAgO0o\nNAAAAADYjkIDAAAAgO0oNAAAAADYjkIDAAAAgO0oNAAAAADYjkIDAAAAgO0oNAAAAADYjkIDAAAA\ngO0oNAAAAADYjkIDAAAAgO0sp9Pp/KEHAQAAAOCnhSsaAAAAAGxHoQEAAADAdhQaAAAAAGxHoQEA\nAADAdhQaAAAAAGxHoQEAAADAdhQaAAAlJyfrueeeu6R9nDx5Unfeeafy8/MvaT/fOnDggG699VZ9\n8MEHbdIfAMAdhQYAoE3k5OQoKChIY8eOdWXl5eW6//775XA4VFZW5ta+oKBADodDsbGxio2NVUxM\njGJjYzVmzBhXm8bGRuXk5GjgwIG66aabdM8996iwsFCSFB4erpkzZ2rixImqqqpqm0kCAFwoNAAA\nl9wnn3yi119/XVlZWa7snXfeUXp6urp27SrLsjzuFx4eruLiYhUXF6ukpETFxcVavny56/GZM2fq\nX//6l5YsWaKioiKNHDlS48aN0969eyVJqampioyM1IsvvnhpJwgAMFBoAEAb+/vf/67hw4erb9++\nuuWWW/TYY4/p8OHDrsdXrFih1NRUxcTE6Pbbb9fTTz+tEydOuB53OBxavXq1MjMzFRcXp0GDBmnT\npk166623lJaWpri4OI0bN05ff/21pG+uDERHR2vz5s264447FBMTo7S0tHMuKVq/fr3uvvtu9evX\nT/Hx8Zo8ebJqamrOew5nWrZsmfr37y+Hw+HKamtrtXLlSo0YMeKizmNdXZ3Wrl2r8ePHq1u3bmrX\nrp3S09N1/fXXa9WqVa52v/71r7VmzRrV1tZeVD8AgItDoQEAbaiyslJZWVnKysrS9u3b9fbbb8uy\nLM2bN0+S9Pbbb2vevHmaO3euSkpKtGrVKr333nt66aWX3I6zfPlyTZw4UVu3blX37t2VnZ2tjRs3\n6q9//asKCgq0efNmFRQUuNqfOnVKr7zyiv785z/rgw8+0IABA5SZmamGhgZjjO+//76eeOIJjRkz\nRtu2bdMbb7yhqqoqTZgw4bzmcKbW1lYVFRXp5z//uVs+atQoXXvttec8X/X19ZowYYIGDBig22+/\nXU888YRrGdTOnTvV0tKi2NhYt3169+6t4uJi1/Ztt92m1tZWbd68+Zx9AQDsRaEBAG2ooaFBra2t\n8vX1lSQFBgbqhRdecP2RnpKSoqKiIsXFxUmSIiIidMstt+jjjz92O05ycrIcDod8fHyUlJSkI0eO\n6NFHH9WVV16pa6+9Vj169NDnn3/uam9ZljIzMxUaGio/Pz899thjOnHihDZu3GiMceXKlUpKStKQ\nIUNkWZbCwsL0u9/9Ttu2bVN5efl3zuFMBw8eVG1trXr16nVB5yooKEg33HCDRo8ercLCQq1YsUJf\nfvmlMjMz5XQ6XVdYAgMDjf2qq6td2wEBAercubN27dp1Qf0DAL4f7x96AADw76R79+7KyMjQAw88\noB49eig+Pl5DhgxRTEyMJKm5uVl5eXl69913VVNTo9bWVrW0tKh3795ux+nSpYvraz8/PyPz9fVV\nY2Oj2z7XXXed6+ugoCD5+/uroqLCGOPevXv11VdfuV0pcDqd8vb2Vnl5ueLj4885hzPV1NTIsiwF\nBQWd72mSJCUlJSkpKclt/E8//bTS09NVUlJyQccKDg52W/oFALj0uKIBAG1s6tSp+sc//qH7779f\nFRUVuvfee/X8889L+ubm5jfffFO5ubnavn27SkpKNHToUOMYXl4X/vLd0tLitu10Oj3ehO3r66v0\n9HTXTdjf3oi9Y8cOxcfHf+ccLqXIyEg5nU5VVlYqJCREknTs2DG3NkePHlVoaOglHwsA4NwoNACg\nDTmdTtXW1io0NFQjR47U/PnzNWPGDL3yyiuSpO3btyslJUVxcXHy8vJSS0uLPvnkE1v63rdvn+vr\nmpoaNTQ0KDw83GjXrVs37dy50y1rbGx03RvxXXM4U3BwsJxOp44ePXpB43311Ve1Zs0at2z37t2y\nLEuRkZGKjo6Wj4+Psazso48+ci09O32+F3pFBQDw/VBoAEAb+tvf/qZf/OIXrqU/DQ0N2rFjh2tZ\nU2RkpEpLS9XQ0KDKykrl5OQoICBAVVVVxhWJC+F0OpWfn6+qqio1NDRowYIF8vf3V0JCgtH2gQce\nUElJiVasWKETJ07o6NGjmjZtmuvzK842h+7du3vsu0uXLgoMDNSnn3561rE5nU4jb25u1uzZs1VY\nWKiWlhZ98cUXmjNnjvr06aOePXvK399fo0aN0sKFC1VWVqbGxkYtXbpU+/fv17333us6Tl1dnQ4d\nOqSoqKgLPm8AgIvHPRoA0IaGDRumgwcP6vHHH1d1dbXat2+vfv366dlnn5UkTZ48WVOnTlVCQoI6\ndeqk8ePH66677lJmZqZSUlL03nvvnfUzJ850ejvLsnTXXXdpzJgx2r9/v8LDw7VkyRLX/R2nt42J\nidHzzz+vl156SfPnz5ePj48SEhK0ePHic84hNzfX4zi8vLw0YMAAbdy4UQ8++KArHzx4sA4dOqTW\n1lZZlqXhw4fLsizddNNNWrp0qTIyMtTa2qrf//73qqioUIcOHTR06FDXu19J3yzhmjdvnu677z7V\n19crKipKS5cuVUREhKvN5s2bXWMAALQdy+npv5EAAD8ZBQUFmjp1qoqLi9WuXbsfZAw7duzQ3Xff\nrYKCAvXs2bNN+x49erRuuOEGzZo1q037BYB/dyydAgBcctHR0Ro+fPhZ3wL3Ulm/fr3Kyso0fvz4\nNu0XAEChAQBoIzk5OaqpqVF+fn6b9HfgwAHl5ORowYIF6tSpU5v0CQD4fyydAgAAAGA7rmgAAAAA\nsB2FBgAAAADbUWgAAAAAsB2FBgAAAADbUWgAAAAAsB2FBgAAAADbUWgAAAAAsB2FBgAAAADbUWgA\nAAAAsN3/ATSLOdutoFAtAAAAAElFTkSuQmCC\n", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(10,0.25))\n", "\n", "ax = sns.heatmap(np.transpose(predictions), cbar=False)\n", "plt.xticks([],[])\n", "plt.yticks([],[])\n", "plt.xlabel(\"samples (150)\")\n", "plt.title(\"Predictions\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Explaning with integrated gradients" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Evaluated output channel (0-based index): All\n", "Building gradient functions\n", "Progress: 100.0%\n", "Done.\n" ] } ], "source": [ "ig = integrated_gradients(model)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": false }, "outputs": [], "source": [ "exp1 = ig.explain(X[0], num_steps=10000)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": false }, "outputs": [], "source": [ "f = lambda x: model.predict(x)[:,0]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# use Shap to explain a single prediction\n", "x = X[0:1,:]\n", "background = DenseData(np.zeros((1,4)), range(4)) \n", "explainer = KernelExplainer(f, background, nsamples=10000)\n", "exp2=explainer.explain(x).effects" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Integrated Gradients: [-0.57897955 0.56409896 0.06661686 0.09188381]\n", "Shap: [-0.57898097 0.56416369 0.06662203 0.09187283]\n" ] } ], "source": [ "print \"Integrated Gradients:\", exp1\n", "print \"Shap:\", exp2" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.13" } }, "nbformat": 4, "nbformat_minor": 2 }