{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "#
Model Interpretability on Random Forest using LIME
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Table of Contents\n", "\n", "1. [Problem Statement](#section1)

\n", "2. [Importing Packages](#section2)

\n", "3. [Loading Data](#section3)\n", " - 3.1 [Description of the Dataset](#section301)

\n", "4. [Data Preprocessing](#section4)

\n", "5. [Data train/test split](#section5)

\n", "6. [Random Forest Model](#section6)\n", " - 6.1 [Random Forest in scikit-learn](#section601)

\n", " - 6.2 [Using the Model for Prediction](#section602)

\n", "7. [Model Evaluation](#section7)\n", " - 7.1 [Accuracy Score](#section701)

\n", "8. [Model Interpretability using LIME](#section8) \n", " - 8.1 [Setup LIME Algorithm](#section801)

\n", " - 8.2 [Explore Key Features in Instance-by-Instance Predictions](#section802)
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 1. Problem Statement\n", "\n", "- We have often found that **Machine Learning (ML)** algorithms capable of capturing **structural non-linearities** in training data - models that are sometimes referred to as **'black box' (e.g. Random Forests, Deep Neural Networks, etc.)** - perform far **better at prediction** than their **linear counterparts (e.g. Generalized Linear Models)**. \n", "\n", "\n", "- They are, however, much **harder to interpret** - in fact, quite often it is **not possible to gain any insight into why a particular prediction has been produced**, when given an **instance of input data (i.e. the model features)**. \n", "\n", "\n", "- Consequently, it has **not been possible to use 'black box' ML algorithms** in situations where clients have sought **cause-and-effect explanations for model predictions**, with end-results being that sub-optimal predictive models have been used in their place, as their explanatory power has been more valuable, in relative terms.\n", "\n", "\n", "- The **problem with model explainability** is that it’s **very hard to define a model’s decision boundary in human understandable manner**. \n", "\n", "\n", "- **LIME** is a **python library** which tries to **solve for model interpretability by producing locally faithful explanations**. \n", "\n", "
\n", "

\n", "\n", "\n", "- We will use **LIME** to **interpret** our **RandomForest model**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 2. Importing Packages" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install LIME using the following command.\n", "\n", "!pip install lime" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.set_printoptions(precision=4) # To display values only upto four decimal places. \n", "\n", "import pandas as pd\n", "pd.set_option('mode.chained_assignment', None) # To suppress pandas warnings.\n", "pd.set_option('display.max_colwidth', -1) # To display all the data in the columns.\n", "pd.options.display.max_columns = 40 # To display all the columns. (Set the value to a high number)\n", "\n", "import matplotlib.pyplot as plt\n", "plt.style.use('seaborn-whitegrid') # To apply seaborn whitegrid style to the plots.\n", "plt.rc('figure', figsize=(10, 8)) # Set the default figure size of plots.\n", "%matplotlib inline\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore') # To suppress all the warnings in the notebook.\n", "\n", "from sklearn.preprocessing import LabelEncoder\n", "from sklearn.preprocessing import OneHotEncoder\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 3. Loading Data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
classcap-shapecap-surfacecap-colorbruisesodorgill-attachmentgill-spacinggill-sizegill-colorstalk-shapestalk-rootstalk-surface-above-ringstalk-surface-below-ringstalk-color-above-ringstalk-color-below-ringveil-typeveil-colorring-numberring-typespore-print-colorpopulationhabitat
0pxsntpfcnkeesswwpwopksu
1exsytafcbkecsswwpwopnng
2ebswtlfcbnecsswwpwopnnm
3pxywtpfcnneesswwpwopksu
4exsgfnfwbktesswwpwoenag
\n", "
" ], "text/plain": [ " class cap-shape cap-surface cap-color bruises odor gill-attachment \\\n", "0 p x s n t p f \n", "1 e x s y t a f \n", "2 e b s w t l f \n", "3 p x y w t p f \n", "4 e x s g f n f \n", "\n", " gill-spacing gill-size gill-color stalk-shape stalk-root \\\n", "0 c n k e e \n", "1 c b k e c \n", "2 c b n e c \n", "3 c n n e e \n", "4 w b k t e \n", "\n", " stalk-surface-above-ring stalk-surface-below-ring stalk-color-above-ring \\\n", "0 s s w \n", "1 s s w \n", "2 s s w \n", "3 s s w \n", "4 s s w \n", "\n", " stalk-color-below-ring veil-type veil-color ring-number ring-type \\\n", "0 w p w o p \n", "1 w p w o p \n", "2 w p w o p \n", "3 w p w o p \n", "4 w p w o e \n", "\n", " spore-print-color population habitat \n", "0 k s u \n", "1 n n g \n", "2 n n m \n", "3 k s u \n", "4 n a g " ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('../../data/mushrooms.csv')\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 3.1 Description of the Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- This **dataset includes descriptions** of hypothetical samples corresponding to **23 species** of **gilled mushrooms** in the **Agaricus and Lepiota Family Mushroom** drawn from **The Audubon Society Field Guid**e to **North American Mushrooms (1981)**. \n", "\n", "\n", "- **Each species** is **identified as definitely edible**, **definitely poisonous**, or **of unknown edibility** and **not recommended**. This **latter class was combined with** the **poisonous one**. \n", "\n", "\n", "- The **Guide clearly states** that there is no **simple rule for determining** the **edibility of a mushroom**; no rule like **\"leaflets three, let it be''** for **Poisonous Oak and Ivy**." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['class', 'cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor',\n", " 'gill-attachment', 'gill-spacing', 'gill-size', 'gill-color',\n", " 'stalk-shape', 'stalk-root', 'stalk-surface-above-ring',\n", " 'stalk-surface-below-ring', 'stalk-color-above-ring',\n", " 'stalk-color-below-ring', 'veil-type', 'veil-color', 'ring-number',\n", " 'ring-type', 'spore-print-color', 'population', 'habitat'],\n", " dtype='object')" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.columns" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "| **Column Name** | **Description** |\n", "| ---------------------------------|:----------------------------------------------------------------------------------------:| \n", "| class | classes: edible=e, poisonous=p. |\n", "| cap-shape | bell=b,conical=c, convex=x, flat=f, knobbed=k, sunken=s. |\n", "| cap-surface | fibrous=f, grooves=g, scaly=y, smooth=s. |\n", "| cap-color | brown=n, buff=b, cinnamon=c, gray=g, green=r, pink=p, purple=u, red=e, white=w, yellow=y.|\n", "| bruises | bruises=t, no=f. |\n", "| odor | almond=a, anise=l, creosote=c, fishy=y, foul=f, musty=m ,none=n, pungent=p, spicy=s. |\n", "| gill-attachment | attached=a, descending=d, free=f, notched=n. |\n", "| gill-spacing | close=c, crowded=w, distant=d. |\n", "| gill-size | broad=b, narrow=n. |\n", "| gill-color | black=k, brown=n ,buff=b, chocolate=h, gray=g, green=r, orange=o, pink=p, purple=u, red=e, white=w, yellow=y. |\n", "| stalk-shape | enlarging=e, tapering=t. |\n", "| stalk-root | bulbous=b, club=c, cup=u, equal=e, rhizomorphs=z, rooted=r, missing=?. |\n", "| stalk-surface-above-ring | fibrous=f, scaly=y, silky=k, smooth=s. |\n", "| stalk-surface-below-ring | fibrous=f, scaly=y, silky=k, smooth=s. |\n", "| stalk-color-above-ring | brown=n, buff=b, cinnamon=c, gray=g, orange=o, pink=p, red=e, white=w, yellow=y. |\n", "| stalk-color-below-ring | brown=n, buff=b, cinnamon=c, gray=g, orange=o, pink=p, red=e, white=w, yellow=y. |\n", "| veil-type | partial=p ,universal=u. |\n", "| veil-color | brown=n, orange=o, white=w, yellow=y. |\n", "| ring-number | none=n, one=o, two=t. |\n", "| ring-type | cobwebby=c, evanescent=e, flaring=f, large=l, none=n, pendant=p, sheathing=s, zone=z. |\n", "| spore-print-color | black=k, brown=n, buff=b, chocolate=h, green=r, orange=o, purple=u, white=w, yellow=y. |\n", "| population | abundant=a, clustered=c, numerous=n, scattered=s, several=v, solitary=y. |\n", "| habitat | grasses=g, leaves=l, meadows=m, paths=p, urban=u, waste=w, woods=d. |" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 8124 entries, 0 to 8123\n", "Data columns (total 23 columns):\n", "class 8124 non-null object\n", "cap-shape 8124 non-null object\n", "cap-surface 8124 non-null object\n", "cap-color 8124 non-null object\n", "bruises 8124 non-null object\n", "odor 8124 non-null object\n", "gill-attachment 8124 non-null object\n", "gill-spacing 8124 non-null object\n", "gill-size 8124 non-null object\n", "gill-color 8124 non-null object\n", "stalk-shape 8124 non-null object\n", "stalk-root 8124 non-null object\n", "stalk-surface-above-ring 8124 non-null object\n", "stalk-surface-below-ring 8124 non-null object\n", "stalk-color-above-ring 8124 non-null object\n", "stalk-color-below-ring 8124 non-null object\n", "veil-type 8124 non-null object\n", "veil-color 8124 non-null object\n", "ring-number 8124 non-null object\n", "ring-type 8124 non-null object\n", "spore-print-color 8124 non-null object\n", "population 8124 non-null object\n", "habitat 8124 non-null object\n", "dtypes: object(23)\n", "memory usage: 1.4+ MB\n" ] } ], "source": [ "df.info()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
classcap-shapecap-surfacecap-colorbruisesodorgill-attachmentgill-spacinggill-sizegill-colorstalk-shapestalk-rootstalk-surface-above-ringstalk-surface-below-ringstalk-color-above-ringstalk-color-below-ringveil-typeveil-colorring-numberring-typespore-print-colorpopulationhabitat
count81248124812481248124812481248124812481248124812481248124812481248124812481248124812481248124
unique2641029222122544991435967
topexynfnfcbbtbsswwpwopwvd
freq42083656324422844748352879146812561217284608377651764936446443848124792474883968238840403148
\n", "
" ], "text/plain": [ " class cap-shape cap-surface cap-color bruises odor gill-attachment \\\n", "count 8124 8124 8124 8124 8124 8124 8124 \n", "unique 2 6 4 10 2 9 2 \n", "top e x y n f n f \n", "freq 4208 3656 3244 2284 4748 3528 7914 \n", "\n", " gill-spacing gill-size gill-color stalk-shape stalk-root \\\n", "count 8124 8124 8124 8124 8124 \n", "unique 2 2 12 2 5 \n", "top c b b t b \n", "freq 6812 5612 1728 4608 3776 \n", "\n", " stalk-surface-above-ring stalk-surface-below-ring \\\n", "count 8124 8124 \n", "unique 4 4 \n", "top s s \n", "freq 5176 4936 \n", "\n", " stalk-color-above-ring stalk-color-below-ring veil-type veil-color \\\n", "count 8124 8124 8124 8124 \n", "unique 9 9 1 4 \n", "top w w p w \n", "freq 4464 4384 8124 7924 \n", "\n", " ring-number ring-type spore-print-color population habitat \n", "count 8124 8124 8124 8124 8124 \n", "unique 3 5 9 6 7 \n", "top o p w v d \n", "freq 7488 3968 2388 4040 3148 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.describe()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 4 Data Preprocessing" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
classcap-shapecap-surfacecap-colorbruisesodorgill-attachmentgill-spacinggill-sizegill-colorstalk-shapestalk-rootstalk-surface-above-ringstalk-surface-below-ringstalk-color-above-ringstalk-color-below-ringveil-typeveil-colorring-numberring-typespore-print-colorpopulationhabitat
0pxsntpfcnkeesswwpwopksu
1exsytafcbkecsswwpwopnng
2ebswtlfcbnecsswwpwopnnm
3pxywtpfcnneesswwpwopksu
4exsgfnfwbktesswwpwoenag
\n", "
" ], "text/plain": [ " class cap-shape cap-surface cap-color bruises odor gill-attachment \\\n", "0 p x s n t p f \n", "1 e x s y t a f \n", "2 e b s w t l f \n", "3 p x y w t p f \n", "4 e x s g f n f \n", "\n", " gill-spacing gill-size gill-color stalk-shape stalk-root \\\n", "0 c n k e e \n", "1 c b k e c \n", "2 c b n e c \n", "3 c n n e e \n", "4 w b k t e \n", "\n", " stalk-surface-above-ring stalk-surface-below-ring stalk-color-above-ring \\\n", "0 s s w \n", "1 s s w \n", "2 s s w \n", "3 s s w \n", "4 s s w \n", "\n", " stalk-color-below-ring veil-type veil-color ring-number ring-type \\\n", "0 w p w o p \n", "1 w p w o p \n", "2 w p w o p \n", "3 w p w o p \n", "4 w p w o e \n", "\n", " spore-print-color population habitat \n", "0 k s u \n", "1 n n g \n", "2 n n m \n", "3 k s u \n", "4 n a g " ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['p', 'e', 'e', ..., 'e', 'p', 'e'], dtype=object)" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Creating labels array from the class column.\n", "\n", "labels = df.iloc[:, 0].values\n", "labels" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LabelEncoder()" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Creating a LabelEncoder object le and fitting labels array into it.\n", "\n", "le = LabelEncoder()\n", "le.fit(labels)" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 0, 0, ..., 0, 1, 0])" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Transforming the labels array to have numerical values.\n", "\n", "labels = le.transform(labels)\n", "labels" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['e', 'p'], dtype=object)" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Storing the different classes found by LabelEncoder in labels array into class_names.\n", "\n", "class_names = le.classes_\n", "class_names" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "# Dropping the class column from the df dataframe.\n", "\n", "df.drop(['class'], axis=1, inplace=True)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
cap-shapecap-surfacecap-colorbruisesodorgill-attachmentgill-spacinggill-sizegill-colorstalk-shapestalk-rootstalk-surface-above-ringstalk-surface-below-ringstalk-color-above-ringstalk-color-below-ringveil-typeveil-colorring-numberring-typespore-print-colorpopulationhabitat
0xsntpfcnkeesswwpwopksu
1xsytafcbkecsswwpwopnng
2bswtlfcbnecsswwpwopnnm
3xywtpfcnneesswwpwopksu
4xsgfnfwbktesswwpwoenag
\n", "
" ], "text/plain": [ " cap-shape cap-surface cap-color bruises odor gill-attachment gill-spacing \\\n", "0 x s n t p f c \n", "1 x s y t a f c \n", "2 b s w t l f c \n", "3 x y w t p f c \n", "4 x s g f n f w \n", "\n", " gill-size gill-color stalk-shape stalk-root stalk-surface-above-ring \\\n", "0 n k e e s \n", "1 b k e c s \n", "2 b n e c s \n", "3 n n e e s \n", "4 b k t e s \n", "\n", " stalk-surface-below-ring stalk-color-above-ring stalk-color-below-ring \\\n", "0 s w w \n", "1 s w w \n", "2 s w w \n", "3 s w w \n", "4 s w w \n", "\n", " veil-type veil-color ring-number ring-type spore-print-color population \\\n", "0 p w o p k s \n", "1 p w o p n n \n", "2 p w o p n n \n", "3 p w o p k s \n", "4 p w o e n a \n", "\n", " habitat \n", "0 u \n", "1 g \n", "2 m \n", "3 u \n", "4 g " ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "range(0, 22)" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Creating a range form 0 upto the number of categorical features. Since all the features in df are categorical using len(). \n", "\n", "categorical_features = range(len(df.columns))\n", "categorical_features" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['cap-shape', 'cap-surface', 'cap-color', 'bruises', 'odor',\n", " 'gill-attachment', 'gill-spacing', 'gill-size', 'gill-color',\n", " 'stalk-shape', 'stalk-root', 'stalk-surface-above-ring',\n", " 'stalk-surface-below-ring', 'stalk-color-above-ring',\n", " 'stalk-color-below-ring', 'veil-type', 'veil-color', 'ring-number',\n", " 'ring-type', 'spore-print-color', 'population', 'habitat'],\n", " dtype=object)" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Creating an array of feature names.\n", "\n", "feature_names = df.columns.values\n", "feature_names" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'bell=b,conical=c,convex=x,flat=f,knobbed=k,sunken=s'" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We expand the characters into words, using the dataset description provided in the beginning.\n", "\n", "categorical_names = '''bell=b,conical=c,convex=x,flat=f,knobbed=k,sunken=s\n", "fibrous=f,grooves=g,scaly=y,smooth=s\n", "brown=n,buff=b,cinnamon=c,gray=g,green=r,pink=p,purple=u,red=e,white=w,yellow=y\n", "bruises=t,no=f\n", "almond=a,anise=l,creosote=c,fishy=y,foul=f,musty=m,none=n,pungent=p,spicy=s\n", "attached=a,descending=d,free=f,notched=n\n", "close=c,crowded=w,distant=d\n", "broad=b,narrow=n\n", "black=k,brown=n,buff=b,chocolate=h,gray=g,green=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y\n", "enlarging=e,tapering=t\n", "bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r,missing=?\n", "fibrous=f,scaly=y,silky=k,smooth=s\n", "fibrous=f,scaly=y,silky=k,smooth=s\n", "brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y\n", "brown=n,buff=b,cinnamon=c,gray=g,orange=o,pink=p,red=e,white=w,yellow=y\n", "partial=p,universal=u\n", "brown=n,orange=o,white=w,yellow=y\n", "none=n,one=o,two=t\n", "cobwebby=c,evanescent=e,flaring=f,large=l,none=n,pendant=p,sheathing=s,zone=z\n", "black=k,brown=n,buff=b,chocolate=h,green=r,orange=o,purple=u,white=w,yellow=y\n", "abundant=a,clustered=c,numerous=n,scattered=s,several=v,solitary=y\n", "grasses=g,leaves=l,meadows=m,paths=p,urban=u,waste=w,woods=d'''.split('\\n')\n", "\n", "categorical_names[0]" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "for j, names in enumerate(categorical_names):\n", " values = names.split(',')\n", " values = dict([(x.split('=')[1], x.split('=')[0]) for x in values])\n", " df.iloc[:, j] = df.iloc[:, j].map(values)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
cap-shapecap-surfacecap-colorbruisesodorgill-attachmentgill-spacinggill-sizegill-colorstalk-shapestalk-rootstalk-surface-above-ringstalk-surface-below-ringstalk-color-above-ringstalk-color-below-ringveil-typeveil-colorring-numberring-typespore-print-colorpopulationhabitat
0convexsmoothbrownbruisespungentfreeclosenarrowblackenlargingequalsmoothsmoothwhitewhitepartialwhiteonependantblackscatteredurban
1convexsmoothyellowbruisesalmondfreeclosebroadblackenlargingclubsmoothsmoothwhitewhitepartialwhiteonependantbrownnumerousgrasses
2bellsmoothwhitebruisesanisefreeclosebroadbrownenlargingclubsmoothsmoothwhitewhitepartialwhiteonependantbrownnumerousmeadows
3convexscalywhitebruisespungentfreeclosenarrowbrownenlargingequalsmoothsmoothwhitewhitepartialwhiteonependantblackscatteredurban
4convexsmoothgraynononefreecrowdedbroadblacktaperingequalsmoothsmoothwhitewhitepartialwhiteoneevanescentbrownabundantgrasses
\n", "
" ], "text/plain": [ " cap-shape cap-surface cap-color bruises odor gill-attachment \\\n", "0 convex smooth brown bruises pungent free \n", "1 convex smooth yellow bruises almond free \n", "2 bell smooth white bruises anise free \n", "3 convex scaly white bruises pungent free \n", "4 convex smooth gray no none free \n", "\n", " gill-spacing gill-size gill-color stalk-shape stalk-root \\\n", "0 close narrow black enlarging equal \n", "1 close broad black enlarging club \n", "2 close broad brown enlarging club \n", "3 close narrow brown enlarging equal \n", "4 crowded broad black tapering equal \n", "\n", " stalk-surface-above-ring stalk-surface-below-ring stalk-color-above-ring \\\n", "0 smooth smooth white \n", "1 smooth smooth white \n", "2 smooth smooth white \n", "3 smooth smooth white \n", "4 smooth smooth white \n", "\n", " stalk-color-below-ring veil-type veil-color ring-number ring-type \\\n", "0 white partial white one pendant \n", "1 white partial white one pendant \n", "2 white partial white one pendant \n", "3 white partial white one pendant \n", "4 white partial white one evanescent \n", "\n", " spore-print-color population habitat \n", "0 black scattered urban \n", "1 brown numerous grasses \n", "2 brown numerous meadows \n", "3 black scattered urban \n", "4 brown abundant grasses " ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "# LabelEncoding all the features. Capturing the different class values for each feature in the categorical_names dictionary.\n", "\n", "categorical_names = {}\n", "\n", "for feature in categorical_features:\n", " le = LabelEncoder()\n", " le.fit(df.iloc[:, feature])\n", " df.iloc[:, feature] = le.transform(df.iloc[:, feature])\n", " categorical_names[feature] = le.classes_" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['bell', 'conical', 'convex', 'flat', 'knobbed', 'sunken'],\n", " dtype=object)" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "categorical_names[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 5. Data train/test split\n", "\n", "- Now that the entire **data** is of **numeric datatype**, lets begin our modelling process.\n", "\n", "\n", "- Firstly, **splitting** the complete **dataset** into **training** and **testing** datasets." ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
cap-shapecap-surfacecap-colorbruisesodorgill-attachmentgill-spacinggill-sizegill-colorstalk-shapestalk-rootstalk-surface-above-ringstalk-surface-below-ringstalk-color-above-ringstalk-color-below-ringveil-typeveil-colorring-numberring-typespore-print-colorpopulationhabitat
02300710100233770214034
12390010000133770214120
20380110010133770214122
32280710110233770214034
42331611001233770210100
\n", "
" ], "text/plain": [ " cap-shape cap-surface cap-color bruises odor gill-attachment \\\n", "0 2 3 0 0 7 1 \n", "1 2 3 9 0 0 1 \n", "2 0 3 8 0 1 1 \n", "3 2 2 8 0 7 1 \n", "4 2 3 3 1 6 1 \n", "\n", " gill-spacing gill-size gill-color stalk-shape stalk-root \\\n", "0 0 1 0 0 2 \n", "1 0 0 0 0 1 \n", "2 0 0 1 0 1 \n", "3 0 1 1 0 2 \n", "4 1 0 0 1 2 \n", "\n", " stalk-surface-above-ring stalk-surface-below-ring stalk-color-above-ring \\\n", "0 3 3 7 \n", "1 3 3 7 \n", "2 3 3 7 \n", "3 3 3 7 \n", "4 3 3 7 \n", "\n", " stalk-color-below-ring veil-type veil-color ring-number ring-type \\\n", "0 7 0 2 1 4 \n", "1 7 0 2 1 4 \n", "2 7 0 2 1 4 \n", "3 7 0 2 1 4 \n", "4 7 0 2 1 0 \n", "\n", " spore-print-color population habitat \n", "0 0 3 4 \n", "1 1 2 0 \n", "2 1 2 2 \n", "3 0 3 4 \n", "4 1 0 0 " ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
cap-shapecap-surfacecap-colorbruisesodorgill-attachmentgill-spacinggill-sizegill-colorstalk-shapestalk-rootstalk-surface-above-ringstalk-surface-below-ringstalk-color-above-ringstalk-color-below-ringveil-typeveil-colorring-numberring-typespore-print-colorpopulationhabitat
02300710100233770214034
12390010000133770214120
20380110010133770214122
32280710110233770214034
42331611001233770210100
\n", "
" ], "text/plain": [ " cap-shape cap-surface cap-color bruises odor gill-attachment \\\n", "0 2 3 0 0 7 1 \n", "1 2 3 9 0 0 1 \n", "2 0 3 8 0 1 1 \n", "3 2 2 8 0 7 1 \n", "4 2 3 3 1 6 1 \n", "\n", " gill-spacing gill-size gill-color stalk-shape stalk-root \\\n", "0 0 1 0 0 2 \n", "1 0 0 0 0 1 \n", "2 0 0 1 0 1 \n", "3 0 1 1 0 2 \n", "4 1 0 0 1 2 \n", "\n", " stalk-surface-above-ring stalk-surface-below-ring stalk-color-above-ring \\\n", "0 3 3 7 \n", "1 3 3 7 \n", "2 3 3 7 \n", "3 3 3 7 \n", "4 3 3 7 \n", "\n", " stalk-color-below-ring veil-type veil-color ring-number ring-type \\\n", "0 7 0 2 1 4 \n", "1 7 0 2 1 4 \n", "2 7 0 2 1 4 \n", "3 7 0 2 1 4 \n", "4 7 0 2 1 0 \n", "\n", " spore-print-color population habitat \n", "0 0 3 4 \n", "1 1 2 0 \n", "2 1 2 2 \n", "3 0 3 4 \n", "4 1 0 0 " ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = df.iloc[:, :]\n", "X.head()" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 0, 0, 1, 0, 0, 0, 0, 1, 0])" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = labels[:]\n", "y[:10]" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "# Using scikit-learn's train_test_split function to split the dataset into train and test sets.\n", "# 80% of the data will be in the train set and 20% in the test set, as specified by test_size=0.2\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(6499, 22)\n", "(6499,)\n", "(1625, 22)\n", "(1625,)\n" ] } ], "source": [ "# Checking the shapes of all the training and test sets for the dependent and independent features.\n", "\n", "print(X_train.shape)\n", "print(y_train.shape)\n", "print(X_test.shape)\n", "print(y_test.shape)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "OneHotEncoder(categorical_features=range(0, 22), categories=None, drop=None,\n", " dtype=, handle_unknown='error',\n", " n_values=None, sparse=True)" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Finally, we use a One-hot encoder, so that the classifier does not take our categorical features as continuous features. \n", "# We will use this encoder only for the classifier, not for the explainer - \n", "# and the reason is that the explainer must make sure that a categorical feature only has one value.\n", "\n", "ohe = OneHotEncoder(categorical_features=categorical_features)\n", "ohe.fit(df)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "X_train_encoded = ohe.transform(X_train)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "X_test_encoded = ohe.transform(X_test)" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(6499, 117)\n", "(1625, 117)\n" ] } ], "source": [ "print(X_train_encoded.shape)\n", "print(X_test_encoded.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 6. Random Forest Model\n", "\n", "\n", "### 6.1 Random Forest with Scikit-Learn" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [], "source": [ "# Creating a Random Forest Classifier.\n", "\n", "classifier_rf = RandomForestClassifier(n_estimators=500, random_state=0, oob_score=True, n_jobs=-1)" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='auto', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=500,\n", " n_jobs=-1, oob_score=True, random_state=0, verbose=0,\n", " warm_start=False)" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Fitting the model on the dataset.\n", "\n", "classifier_rf.fit(X_train_encoded, y_train)" ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.0" ] }, "execution_count": 76, "metadata": {}, "output_type": "execute_result" } ], "source": [ "classifier_rf.oob_score_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- From the **OOB score** we can see how our model's gonna perform against the **test set or new** samples." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 6.2 Using the Model for Prediction" ] }, { "cell_type": "code", "execution_count": 124, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 1, 1, 0, 0, 1, 0, 0, 1, 1])" ] }, "execution_count": 124, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Making predictions on the training set.\n", "\n", "y_pred_train = classifier_rf.predict(X_train_encoded)\n", "y_pred_train[:10]" ] }, { "cell_type": "code", "execution_count": 125, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0, 1, 1, 0, 1, 1, 1, 1, 0, 0])" ] }, "execution_count": 125, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Making predictions on test set.\n", "\n", "y_pred_test = classifier_rf.predict(X_test_encoded)\n", "y_pred_test[:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 7. Model Evaluation\n", "\n", "**Error** is the deviation of the values predicted by the model with the true values." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 7.1 Accuracy Score" ] }, { "cell_type": "code", "execution_count": 126, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy score for train data is: 1.0\n" ] } ], "source": [ "# Accuracy score on the training set.\n", "\n", "print('Accuracy score for train data is:', accuracy_score(y_train, y_pred_train))" ] }, { "cell_type": "code", "execution_count": 127, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy score for test data is: 1.0\n" ] } ], "source": [ "# Accuracy score on the test set.\n", "\n", "print('Accuracy score for test data is:', accuracy_score(y_test, y_pred_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We get an **accuracy** of **100%** on our train set and an **accuracy** of **100%** on our test set.\n", "\n", "\n", "- We can notice that the **accuracy** obtained on the **test set (1.0)** is similar to the one obtained using the **oob_score_ (1.0)**, so we can use the **oob_score_** as a **validation** before testing our model on the **test set**. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 8. Model Interpretability using LIME\n", "\n", "\n", "- **LIME** stands for **Local Interpretable Model-Agnostic Explanations** is a technique to **explain the predictions of any machine learning classifier**, and **evaluate its usefulness** in various **tasks** related to **trust**. " ] }, { "cell_type": "code", "execution_count": 132, "metadata": {}, "outputs": [], "source": [ "# Our predict function first transforms the data into the one-hot representation. \n", "# Then it calculates the prediction probability for each class of target variable.\n", "\n", "predict_fn = lambda x: classifier_rf.predict_proba(ohe.transform(x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 8.1 Setup LIME Algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We now **create** our **explainer**. \n", "\n", "\n", "- The **categorical_features parameter** lets it know **which features** are **categorica**l (in **this case**, **all of them**). \n", "\n", "\n", "- The **categorical names parameter** gives a **string representation** of **each categorical feature's numerical value**." ] }, { "cell_type": "code", "execution_count": 128, "metadata": {}, "outputs": [], "source": [ "from lime.lime_tabular import LimeTabularExplainer" ] }, { "cell_type": "code", "execution_count": 129, "metadata": {}, "outputs": [], "source": [ "# Creating the LIME explainer object.\n", "\n", "explainer = LimeTabularExplainer(X_train.values, mode='classification', class_names=['edible', 'poisonous'], \n", " feature_names = feature_names, categorical_features=categorical_features, \n", " categorical_names=categorical_names, kernel_width=3, verbose=True, random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 8.2 Explore Key Features in Instance-by-Instance Predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **Start by choosing an instance** from the **test dataset**.\n", "\n", "\n", "- Use **LIME** to **estimate a local model** to use for **explaining our model's predictions**. The **outputs** will be:\n", "\n", " 1. The **intercept** estimated for the local model.\n", " 2. The **local model's estimate** for the **Regression Forest's prediction**.\n", " 3. The **Regression Forest's actual prediction**.\n", "\n", "\n", "- Note, that the **actual value from the data does not enter into this** - the **idea of LIME** is to **gain insight** into **why the chosen model** - in our case the Random Forest regressor - **is predicting whatever it has been asked to predict**. Whether or not this prediction is actually any good, is a separate issue." ] }, { "cell_type": "code", "execution_count": 253, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "i = 1075\n" ] } ], "source": [ "# Selecting a random instance from the test dataset.\n", "\n", "i = np.random.randint(0, X_test.shape[0])\n", "print('i =', i)" ] }, { "cell_type": "code", "execution_count": 254, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Intercept 0.7144173526963122\n", "Prediction_local [0.2041]\n", "Right: 0.0\n" ] } ], "source": [ "# Using LIME to estimate a local model. Using only 6 features to explain our model's predictions.\n", "\n", "exp = explainer.explain_instance(X_test.values[i], predict_fn, num_features=6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **Printing** the **DataFrame row** for the **chosen test instance**." ] }, { "cell_type": "code", "execution_count": 255, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
indexcap-shapecap-surfacecap-colorbruisesodorgill-attachmentgill-spacinggill-sizegill-colorstalk-shapestalk-rootstalk-surface-above-ringstalk-surface-below-ringstalk-color-above-ringstalk-color-below-ringveil-typeveil-colorring-numberring-typespore-print-colorpopulationhabitat
107522802000610071033530214056
\n", "
" ], "text/plain": [ " index cap-shape cap-surface cap-color bruises odor \\\n", "1075 2280 2 0 0 0 6 \n", "\n", " gill-attachment gill-spacing gill-size gill-color stalk-shape \\\n", "1075 1 0 0 7 1 \n", "\n", " stalk-root stalk-surface-above-ring stalk-surface-below-ring \\\n", "1075 0 3 3 \n", "\n", " stalk-color-above-ring stalk-color-below-ring veil-type veil-color \\\n", "1075 5 3 0 2 \n", "\n", " ring-number ring-type spore-print-color population habitat \n", "1075 1 4 0 5 6 " ] }, "execution_count": 255, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Here the index column is the original index as per the df dataframe and the number at the beginning the index after reset.\n", "\n", "X_test.reset_index().loc[[i]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **LIME's interpretation** of our **Random Forest's prediction**." ] }, { "cell_type": "code", "execution_count": 256, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "exp.show_in_notebook(show_table=True, show_all=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- First, note that the **row** we **explained** is **displayed** on the **right side**, in **table** format. Since we had the **show_all parameter** set to **false**, only the **features used in the explanation are displayed**.\n", "\n", "\n", "- The **value column** displays the **original value for each feature**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- To get the **output generated above** in the **form of a list**." ] }, { "cell_type": "code", "execution_count": 257, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('odor=none', -0.25505675405753075),\n", " ('gill-size=broad', -0.1277562900735544),\n", " ('stalk-surface-above-ring=smooth', -0.09797362037586274),\n", " ('gill-spacing=close', 0.06957793771888027),\n", " ('bruises=bruises', -0.05072352183511148),\n", " ('ring-type=pendant', -0.0483729727557791)]" ] }, "execution_count": 257, "metadata": {}, "output_type": "execute_result" } ], "source": [ "exp.as_list()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Obesrvations obtained from LIME's interpretation of our Random Forest's prediction**:\n", "\n", "- The **values** shown after the condition is the **amount** by which the value is **shifted** from the **intercept** estimated for the local model. \n", " \n", " \n", "- When all these values are **added** to the **intercept**, it gives us the **Prediction_local** (local model's estimate for the Regression Forest's prediction) calculated by **LIME**." ] }, { "cell_type": "code", "execution_count": 258, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Intercept = 0.7144173526963122\n", "Prediction_local = 0.20411213131735406\n" ] } ], "source": [ "print('Intercept =', exp.intercept[1])\n", "print('Prediction_local =', exp.local_pred[0])" ] }, { "cell_type": "code", "execution_count": 259, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction_local = 0.20411213131735403\n" ] } ], "source": [ "# Calculating the Prediction_local by adding all the values obtained above for each condition into the intercept.\n", "# The intercept can be obtained from the exp.intercept using the index 0.\n", "\n", "intercept = exp.intercept[1]\n", "prediction_local = intercept\n", "\n", "for j in range(len(exp.as_list())):\n", " prediction_local += exp.as_list()[j][1]\n", "\n", "print('Prediction_local =', prediction_local)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " \n", " \n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Choosing **another instance** from the **test dataset**." ] }, { "cell_type": "code", "execution_count": 264, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "i = 515 \n", "\n", "Intercept 0.4517135442299755\n", "Prediction_local [0.6529]\n", "Right: 1.0\n" ] } ], "source": [ "# This time specifying a particular value of i in order to explain the working of LIME.\n", "\n", "i = 515\n", "print('i =', i, '\\n')\n", "\n", "exp = explainer.explain_instance(X_test.values[i], predict_fn, num_features=6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **Printing** the **DataFrame row** for the **chosen test instance**." ] }, { "cell_type": "code", "execution_count": 277, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
indexcap-shapecap-surfacecap-colorbruisesodorgill-attachmentgill-spacinggill-sizegill-colorstalk-shapestalk-rootstalk-surface-above-ringstalk-surface-below-ringstalk-color-above-ringstalk-color-below-ringveil-typeveil-colorring-numberring-typespore-print-colorpopulationhabitat
515535723804100101003770214334
\n", "
" ], "text/plain": [ " index cap-shape cap-surface cap-color bruises odor gill-attachment \\\n", "515 5357 2 3 8 0 4 1 \n", "\n", " gill-spacing gill-size gill-color stalk-shape stalk-root \\\n", "515 0 0 10 1 0 \n", "\n", " stalk-surface-above-ring stalk-surface-below-ring \\\n", "515 0 3 \n", "\n", " stalk-color-above-ring stalk-color-below-ring veil-type veil-color \\\n", "515 7 7 0 2 \n", "\n", " ring-number ring-type spore-print-color population habitat \n", "515 1 4 3 3 4 " ] }, "execution_count": 277, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test.reset_index().loc[[i]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **LIME's interpretation** of our **Random Forest's prediction**." ] }, { "cell_type": "code", "execution_count": 266, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "exp.show_in_notebook(show_table=True, show_all=False)" ] }, { "cell_type": "code", "execution_count": 267, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('odor=foul', 0.274586091806193),\n", " ('gill-size=broad', -0.1292596563485946),\n", " ('spore-print-color=chocolate', 0.08298611784301334),\n", " ('gill-spacing=close', 0.07639601609745206),\n", " ('ring-type=pendant', -0.053535847394535555),\n", " ('bruises=bruises', -0.049955819676421445)]" ] }, "execution_count": 267, "metadata": {}, "output_type": "execute_result" } ], "source": [ "exp.as_list()" ] }, { "cell_type": "code", "execution_count": 268, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Intercept = 0.4517135442299755\n", "Prediction_local = 0.6529304465570823\n" ] } ], "source": [ "print('Intercept =', exp.intercept[1])\n", "print('Prediction_local =', exp.local_pred[0])" ] }, { "cell_type": "code", "execution_count": 269, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction_local = 0.6529304465570822\n" ] } ], "source": [ "intercept = exp.intercept[1]\n", "prediction_local = intercept\n", "\n", "for j in range(len(exp.as_list())):\n", " prediction_local += exp.as_list()[j][1]\n", "\n", "print('Prediction_local =', prediction_local)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- By **changing** the chosen **i**, we observe that the **narrative provided by LIME** also **changes, in response to changes in the model** in the **local region** of the **feature space** in which it is working to **generate a given prediction**. \n", "\n", "\n", "- This is clearly an **improvement on relying purely** on the **Regression Forest's (static) expected relative feature importance** and of **great benefit to models that provice no insight whatsoever**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Now note that the **explanations** are **based not only on features**, **but** on **feature-value pairs**. \n", "\n", "\n", "- **For example**, we are saying that **odor = foul** is **indicative of** a **poisonous mushroom**. \n", "

\n", " - **In** the **context** of a **categorical feature**, **odor** could **take** many **other values**. \n", "

\n", " - Since we **perturb** each **categorical feature drawing samples** according to the **original training distribution**, the way to interpret this is: **if odor was not foul**, on **average**, this **prediction** would be **0.27 less 'poisonous'**. \n", " \n", "
\n", "- Let's **check** if **this** is the **case**:" ] }, { "cell_type": "code", "execution_count": 270, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['almond', 'anise', 'creosote', 'fishy', 'foul', 'musty', 'none',\n", " 'pungent', 'spicy'], dtype=object)" ] }, "execution_count": 270, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Checking the different categories in the odor feature.\n", "\n", "odor_idx = list(feature_names).index('odor')\n", "explainer.categorical_names[odor_idx]" ] }, { "cell_type": "code", "execution_count": 271, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.0492, 0.0472, 0.0238, 0.0697, 0.2662, 0.0048, 0.4359, 0.0306,\n", " 0.0725])" ] }, "execution_count": 271, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Checking the feature frequencies of different categories in the odor feature.\n", "\n", "explainer.feature_frequencies[odor_idx]" ] }, { "cell_type": "code", "execution_count": 272, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['almond', 'anise', 'creosote', 'fishy', 'musty', 'none', 'pungent',\n", " 'spicy'], dtype=object)" ] }, "execution_count": 272, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Setting foul_idx equal to the index of 'foul' category in the odor feature.\n", "# Then creating non_foul array with different categories in the odor feature except foul category.\n", "\n", "foul_idx = 4\n", "non_foul = np.delete(explainer.categorical_names[odor_idx], foul_idx)\n", "non_foul" ] }, { "cell_type": "code", "execution_count": 273, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.0671, 0.0644, 0.0325, 0.095 , 0. , 0.0065, 0.594 , 0.0417,\n", " 0.0988])" ] }, "execution_count": 273, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Creating non_foul_normalized_frequencies array with feature frequencies of different categories in the odor feature.\n", "# Setting feature frequency of foul category to 0. Then normalizing the feature frequencies to have a total sum of 1.\n", "\n", "non_foul_normalized_frequencies = explainer.feature_frequencies[odor_idx].copy()\n", "non_foul_normalized_frequencies[foul_idx] = 0\n", "non_foul_normalized_frequencies /= non_foul_normalized_frequencies.sum()\n", "non_foul_normalized_frequencies" ] }, { "cell_type": "code", "execution_count": 286, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Making odor not equal foul\n", "P(poisonous) before: 1.0 \n", "\n", "P(poisonous | odor=almond): 0.66\n", "P(poisonous | odor=anise): 0.65\n", "P(poisonous | odor=creosote): 0.73\n", "P(poisonous | odor=fishy): 0.73\n", "P(poisonous | odor=musty): 0.72\n", "P(poisonous | odor=none): 0.49\n", "P(poisonous | odor=pungent): 0.76\n", "P(poisonous | odor=spicy): 0.72\n", "\n", "P(poisonous | odor != foul) = 0.58\n" ] } ], "source": [ "# Calculating the probabilies of mushroom being poisonous for different values of odor except foul.\n", "# Finally calculating the probability of mushroom being poisonous if odor not equal to foul.\n", "\n", "print('Making odor not equal foul')\n", "\n", "temp = X_test.values[i].copy()\n", "print('P(poisonous) before:', predict_fn(temp.reshape(1,-1))[0,1], '\\n')\n", "\n", "average_poisonous = 0\n", "\n", "for idx, (name, frequency) in enumerate(zip(explainer.categorical_names[odor_idx], non_foul_normalized_frequencies)):\n", " if name == 'foul':\n", " continue\n", " temp[odor_idx] = idx\n", " p_poisonous = predict_fn(temp.reshape(1,-1))[0,1]\n", " average_poisonous += p_poisonous * frequency\n", " print('P(poisonous | odor=%s): %.2f' % (name, p_poisonous))\n", "\n", "print ('\\nP(poisonous | odor != foul) = %.2f' % average_poisonous)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **Probability of poisonous when odor equals foul** = **1 - P(poisonous | odor != foul)** = **1 - 0.58** = **0.42**\n", "\n", "\n", "- We see that **in this** particular **case**, the **linear model** is **pretty close**: it **predicted** that **on average odor = foul increases** the **probability of poisonous by 0.27**, when **in fact it is by 0.42**. \n", "\n", "\n", "- Notice though that **we only changed one feature (odor)**, when the **linear model takes into account perturbations of all** the **features at once**." ] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "307.2px" }, "toc_section_display": true, "toc_window_display": false }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }