{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "#
Model Interpretability on Random Forest using LIME
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Table of Contents\n", "\n", "1. [Problem Statement](#section1)

\n", "2. [Importing Packages](#section2)

\n", "3. [Loading Data](#section3)\n", " - 3.1 [Description of the Dataset](#section301)

\n", "4. [Data train/test split](#section4)

\n", "5. [Random Forest Model](#section5)\n", " - 5.1 [Random Forest in scikit-learn](#section501)

\n", " - 5.2 [Feature Importances](#section502)

\n", " - 5.3 [Using the Model for Prediction](#section503)

\n", "6. [Model Evaluation](#section6)\n", " - 6.1 [R-Squared Value](#section601)

\n", "7. [Model Interpretability using LIME](#section7) \n", " - 7.1 [Setup LIME Algorithm](#section701)

\n", " - 7.2 [Explore Key Features in Instance-by-Instance Predictions](#section702)
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 1. Problem Statement\n", "\n", "- We have often found that **Machine Learning (ML)** algorithms capable of capturing **structural non-linearities** in training data - models that are sometimes referred to as **'black box' (e.g. Random Forests, Deep Neural Networks, etc.)** - perform far **better at prediction** than their **linear counterparts (e.g. Generalised Linear Models)**. \n", "\n", "\n", "- They are, however, much **harder to interpret** - in fact, quite often it is **not possible to gain any insight into why a particular prediction has been produced**, when given an **instance of input data (i.e. the model features)**. \n", "\n", "\n", "- Consequently, it has **not been possible to use 'black box' ML algorithms** in situations where clients have sought **cause-and-effect explanations for model predictions**, with end-results being that sub-optimal predictive models have been used in their place, as their explanatory power has been more valuable, in relative terms.\n", "\n", "\n", "- The **problem with model explainability** is that it’s **very hard to define a model’s decision boundary in human understandable manner**. \n", "\n", "\n", "- **LIME** is a **python library** which tries to **solve for model interpretability by producing locally faithful explanations**. \n", "\n", "
\n", "

\n", "\n", "\n", "- We will use **LIME** to **interpret** our **RandomForest model**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 2. Importing Packages" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install LIME using the following command.\n", "\n", "!pip install lime" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "np.set_printoptions(precision=4) # To display values only upto four decimal places. \n", "\n", "import pandas as pd\n", "pd.set_option('mode.chained_assignment', None) # To suppress pandas warnings.\n", "pd.set_option('display.max_colwidth', -1) # To display all the data in the columns.\n", "pd.options.display.max_columns = 40 # To display all the columns. (Set the value to a high number)\n", "\n", "import matplotlib.pyplot as plt\n", "plt.style.use('seaborn-whitegrid') # To apply seaborn whitegrid style to the plots.\n", "plt.rc('figure', figsize=(10, 8)) # Set the default figure size of plots.\n", "%matplotlib inline\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore') # To suppress all the warnings in the notebook.\n", "\n", "from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import r2_score" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 3. Loading Data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
crimzninduschasnoxrmagedisradtaxptratioblacklstatmedv
00.0063218.02.3100.5386.57565.24.09001296.015.3396.904.9824.0
10.027310.07.0700.4696.42178.94.96712242.017.8396.909.1421.6
20.027290.07.0700.4697.18561.14.96712242.017.8392.834.0334.7
30.032370.02.1800.4586.99845.86.06223222.018.7394.632.9433.4
40.069050.02.1800.4587.14754.26.06223222.018.7396.905.3336.2
\n", "
" ], "text/plain": [ " crim zn indus chas nox rm age dis rad tax \\\n", "0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296.0 \n", "1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242.0 \n", "2 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242.0 \n", "3 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222.0 \n", "4 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222.0 \n", "\n", " ptratio black lstat medv \n", "0 15.3 396.90 4.98 24.0 \n", "1 17.8 396.90 9.14 21.6 \n", "2 17.8 392.83 4.03 34.7 \n", "3 18.7 394.63 2.94 33.4 \n", "4 18.7 396.90 5.33 36.2 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv('../../data/Boston.csv')\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 3.1 Description of the Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- This dataset contains information on **Housing Values in Suburbs of Boston**.\n", "\n", "\n", "- The column **medv** is the **target variable**. It is the **median** value of **owner-occupied homes in $1000s**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "| Column Name | Description |\n", "| ---------------------------------|:----------------------------------------------------------------------------------------:| \n", "| crim | Per capita crime rate by town. |\n", "| zn | Proportion of residential land zoned for lots over 25,000 sq.ft. |\n", "| indus | Proportion of non-retail business acres per town. |\n", "| chas | Charles River dummy variable (= 1 if tract bounds river; 0 otherwise). |\n", "| nox | Nitrogen oxides concentration (parts per 10 million). |\n", "| rm | Average number of rooms per dwelling. |\n", "| age | Proportion of owner-occupied units built prior to 1940. |\n", "| dis | Weighted mean of distances to five Boston employment centres. |\n", "| rad | Index of accessibility to radial highways. |\n", "| tax | Full-value property-tax rate per 10,000 dollars. |\n", "| ptratio | Pupil-teacher ratio by town. |\n", "| black | 1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town. |\n", "| lstat | Lower status of the population (percent). |\n", "| medv | Target, median value of owner-occupied homes in $1000s. |" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "RangeIndex: 506 entries, 0 to 505\n", "Data columns (total 14 columns):\n", "crim 506 non-null float64\n", "zn 506 non-null float64\n", "indus 506 non-null float64\n", "chas 506 non-null int64\n", "nox 506 non-null float64\n", "rm 506 non-null float64\n", "age 506 non-null float64\n", "dis 506 non-null float64\n", "rad 506 non-null int64\n", "tax 506 non-null float64\n", "ptratio 506 non-null float64\n", "black 506 non-null float64\n", "lstat 506 non-null float64\n", "medv 506 non-null float64\n", "dtypes: float64(12), int64(2)\n", "memory usage: 55.5 KB\n" ] } ], "source": [ "df.info()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
crimzninduschasnoxrmagedisradtaxptratioblacklstatmedv
count506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000506.000000
mean3.61352411.36363611.1367790.0691700.5546956.28463468.5749013.7950439.549407408.23715418.455534356.67403212.65306322.532806
std8.60154523.3224536.8603530.2539940.1158780.70261728.1488612.1057108.707259168.5371162.16494691.2948647.1410629.197104
min0.0063200.0000000.4600000.0000000.3850003.5610002.9000001.1296001.000000187.00000012.6000000.3200001.7300005.000000
25%0.0820450.0000005.1900000.0000000.4490005.88550045.0250002.1001754.000000279.00000017.400000375.3775006.95000017.025000
50%0.2565100.0000009.6900000.0000000.5380006.20850077.5000003.2074505.000000330.00000019.050000391.44000011.36000021.200000
75%3.67708212.50000018.1000000.0000000.6240006.62350094.0750005.18842524.000000666.00000020.200000396.22500016.95500025.000000
max88.976200100.00000027.7400001.0000000.8710008.780000100.00000012.12650024.000000711.00000022.000000396.90000037.97000050.000000
\n", "
" ], "text/plain": [ " crim zn indus chas nox rm \\\n", "count 506.000000 506.000000 506.000000 506.000000 506.000000 506.000000 \n", "mean 3.613524 11.363636 11.136779 0.069170 0.554695 6.284634 \n", "std 8.601545 23.322453 6.860353 0.253994 0.115878 0.702617 \n", "min 0.006320 0.000000 0.460000 0.000000 0.385000 3.561000 \n", "25% 0.082045 0.000000 5.190000 0.000000 0.449000 5.885500 \n", "50% 0.256510 0.000000 9.690000 0.000000 0.538000 6.208500 \n", "75% 3.677082 12.500000 18.100000 0.000000 0.624000 6.623500 \n", "max 88.976200 100.000000 27.740000 1.000000 0.871000 8.780000 \n", "\n", " age dis rad tax ptratio black \\\n", "count 506.000000 506.000000 506.000000 506.000000 506.000000 506.000000 \n", "mean 68.574901 3.795043 9.549407 408.237154 18.455534 356.674032 \n", "std 28.148861 2.105710 8.707259 168.537116 2.164946 91.294864 \n", "min 2.900000 1.129600 1.000000 187.000000 12.600000 0.320000 \n", "25% 45.025000 2.100175 4.000000 279.000000 17.400000 375.377500 \n", "50% 77.500000 3.207450 5.000000 330.000000 19.050000 391.440000 \n", "75% 94.075000 5.188425 24.000000 666.000000 20.200000 396.225000 \n", "max 100.000000 12.126500 24.000000 711.000000 22.000000 396.900000 \n", "\n", " lstat medv \n", "count 506.000000 506.000000 \n", "mean 12.653063 22.532806 \n", "std 7.141062 9.197104 \n", "min 1.730000 5.000000 \n", "25% 6.950000 17.025000 \n", "50% 11.360000 21.200000 \n", "75% 16.955000 25.000000 \n", "max 37.970000 50.000000 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.describe()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 4. Data train/test split\n", "\n", "- Now that the entire **data** is of **numeric datatype**, lets begin our modelling process.\n", "\n", "\n", "- Firstly, **splitting** the complete **dataset** into **training** and **testing** datasets." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
crimzninduschasnoxrmagedisradtaxptratioblacklstatmedv
00.0063218.02.3100.5386.57565.24.09001296.015.3396.904.9824.0
10.027310.07.0700.4696.42178.94.96712242.017.8396.909.1421.6
20.027290.07.0700.4697.18561.14.96712242.017.8392.834.0334.7
30.032370.02.1800.4586.99845.86.06223222.018.7394.632.9433.4
40.069050.02.1800.4587.14754.26.06223222.018.7396.905.3336.2
\n", "
" ], "text/plain": [ " crim zn indus chas nox rm age dis rad tax \\\n", "0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296.0 \n", "1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242.0 \n", "2 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242.0 \n", "3 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222.0 \n", "4 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222.0 \n", "\n", " ptratio black lstat medv \n", "0 15.3 396.90 4.98 24.0 \n", "1 17.8 396.90 9.14 21.6 \n", "2 17.8 392.83 4.03 34.7 \n", "3 18.7 394.63 2.94 33.4 \n", "4 18.7 396.90 5.33 36.2 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df.head()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
crimzninduschasnoxrmagedisradtaxptratioblacklstat
00.0063218.02.3100.5386.57565.24.09001296.015.3396.904.98
10.027310.07.0700.4696.42178.94.96712242.017.8396.909.14
20.027290.07.0700.4697.18561.14.96712242.017.8392.834.03
30.032370.02.1800.4586.99845.86.06223222.018.7394.632.94
40.069050.02.1800.4587.14754.26.06223222.018.7396.905.33
\n", "
" ], "text/plain": [ " crim zn indus chas nox rm age dis rad tax \\\n", "0 0.00632 18.0 2.31 0 0.538 6.575 65.2 4.0900 1 296.0 \n", "1 0.02731 0.0 7.07 0 0.469 6.421 78.9 4.9671 2 242.0 \n", "2 0.02729 0.0 7.07 0 0.469 7.185 61.1 4.9671 2 242.0 \n", "3 0.03237 0.0 2.18 0 0.458 6.998 45.8 6.0622 3 222.0 \n", "4 0.06905 0.0 2.18 0 0.458 7.147 54.2 6.0622 3 222.0 \n", "\n", " ptratio black lstat \n", "0 15.3 396.90 4.98 \n", "1 17.8 396.90 9.14 \n", "2 17.8 392.83 4.03 \n", "3 18.7 394.63 2.94 \n", "4 18.7 396.90 5.33 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = df.iloc[:, :-1]\n", "X.head()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 24.0\n", "1 21.6\n", "2 34.7\n", "3 33.4\n", "4 36.2\n", "Name: medv, dtype: float64" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = df.iloc[:, -1]\n", "y.head()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Using scikit-learn's train_test_split function to split the dataset into train and test sets.\n", "# 80% of the data will be in the train set and 20% in the test set, as specified by test_size=0.2\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(404, 13)\n", "(404,)\n", "(102, 13)\n", "(102,)\n" ] } ], "source": [ "# Checking the shapes of all the training and test sets for the dependent and independent features.\n", "\n", "print(X_train.shape)\n", "print(y_train.shape)\n", "print(X_test.shape)\n", "print(y_test.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 5. Random Forest Model\n", "\n", "\n", "### 5.1 Random Forest with Scikit-Learn" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Creating a Random Forest Regressor.\n", "\n", "regressor_rf = RandomForestRegressor(n_estimators=200, random_state=0, oob_score=True, n_jobs=-1)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=None,\n", " max_features='auto', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=200, n_jobs=-1,\n", " oob_score=True, random_state=0, verbose=0,\n", " warm_start=False)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Fitting the model on the dataset.\n", "\n", "regressor_rf.fit(X_train, y_train)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.8375610635726134" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "regressor_rf.oob_score_" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- From the **OOB score** we can see how our model's gonna perform against the **test set or new** samples." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 5.2 Feature Importances" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Index(['crim', 'zn', 'indus', 'chas', 'nox', 'rm', 'age', 'dis', 'rad', 'tax',\n", " 'ptratio', 'black', 'lstat'],\n", " dtype='object')" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_train.columns" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Feature importance of rm : 48.652000069343806 %\n", "Feature importance of lstat : 32.62519467239989 %\n", "Feature importance of dis : 5.723002765414628 %\n", "Feature importance of crim : 3.688078587412392 %\n", "Feature importance of nox : 1.7565510857613313 %\n", "Feature importance of ptratio : 1.7390709979839825 %\n", "Feature importance of tax : 1.662967097458929 %\n", "Feature importance of age : 1.470551943168394 %\n", "Feature importance of black : 1.3621235473729538 %\n", "Feature importance of indus : 0.6664660162044432 %\n", "Feature importance of rad : 0.3849848939558503 %\n", "Feature importance of zn : 0.1498360991801689 %\n", "Feature importance of chas : 0.11917222434323611 %\n" ] } ], "source": [ "# Checking the feature importances of various features.\n", "# Sorting the importances by descending order (lowest importance at the bottom).\n", "\n", "for score, name in sorted(zip(regressor_rf.feature_importances_, X_train.columns), reverse=True):\n", " print('Feature importance of', name, ':', score*100, '%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 5.3 Using the Model for Prediction" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([12.511 , 19.9775, 19.77 , 13.258 , 18.5805, 24.453 , 20.89 ,\n", " 23.869 , 8.595 , 23.5375])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Making predictions on the training set.\n", "\n", "y_pred_train = regressor_rf.predict(X_train)\n", "y_pred_train[:10]" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([22.57 , 31.5625, 17.109 , 23.243 , 16.7635, 21.303 , 19.24 ,\n", " 15.6195, 21.429 , 20.964 ])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Making predictions on test set.\n", "\n", "y_pred_test = regressor_rf.predict(X_test)\n", "y_pred_test[:10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 6. Model Evaluation\n", "\n", "**Error** is the deviation of the values predicted by the model with the true values." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 6.1 R-Squared Value" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "R-Squared Value for train data is: 0.9774359941752926\n" ] } ], "source": [ "# R-Squared Value on the training set.\n", "\n", "print('R-Squared Value for train data is:', r2_score(y_train, y_pred_train))" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "R-Squared Value for test data is: 0.8798665934496468\n" ] } ], "source": [ "# R-Squared Value on the test set.\n", "\n", "print('R-Squared Value for test data is:', r2_score(y_test, y_pred_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- We get an **R-Squared Value** of **97.74%** on our train set and an **R-Squared Value** of **87.98%** on our test set." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 7. Model Interpretability using LIME\n", "\n", "\n", "- **LIME** stands for **Local Interpretable Model-Agnostic Explanations** is a technique to **explain the predictions of any machine learning classifier**, and **evaluate its usefulness** in various **tasks** related to **trust**. " ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([3, 8], dtype=int64)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Selecting the indexes of the categorical features in the dataset.\n", "\n", "categorical_features = np.argwhere(np.array([len(set(X_train.values[:, x])) for x in range(X_train.shape[1])]) <= 10).flatten()\n", "categorical_features" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 7.1 Setup LIME Algorithm" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "from lime.lime_tabular import LimeTabularExplainer" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# Creating the LIME explainer object.\n", "\n", "explainer = LimeTabularExplainer(X_train.values, mode='regression', feature_names=X_train.columns, \n", " categorical_features=categorical_features, verbose=True, random_state=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 7.2 Explore Key Features in Instance-by-Instance Predictions\n", "\n", "\n", "- **Start by choosing an instance** from the **test dataset**.\n", "\n", "\n", "- Use **LIME** to **estimate a local model** to use for **explaining our model's predictions**. The **outputs** will be:\n", "\n", " 1. The **intercept** estimated for the local model.\n", " 2. The **local model's estimate** for the **Regression Forest's prediction**.\n", " 3. The **Regression Forest's actual prediction**.\n", "\n", "\n", "- Note, that the **actual value from the data does not enter into this** - the **idea of LIME** is to **gain insight** into **why the chosen model** - in our case the Random Forest regressor - **is predicting whatever it has been asked to predict**. Whether or not this prediction is actually any good, is a separate issue." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "i = 34\n" ] } ], "source": [ "# Selecting a random instance from the test dataset.\n", "\n", "i = np.random.randint(0, X_test.shape[0])\n", "print('i =', i)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Intercept 26.667118750528495\n", "Prediction_local [16.6593]\n", "Right: 14.596999999999998\n" ] } ], "source": [ "# Using LIME to estimate a local model. Using only 6 features to explain our model's predictions.\n", "\n", "exp = explainer.explain_instance(X_test.values[i], regressor_rf.predict, num_features=6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **Printing** the **DataFrame row** for the **chosen test instance**." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
indexcrimzninduschasnoxrmagedisradtaxptratioblacklstat
344674.422280.018.100.5846.00394.52.540324666.020.2331.2921.32
\n", "
" ], "text/plain": [ " index crim zn indus chas nox rm age dis rad tax \\\n", "34 467 4.42228 0.0 18.1 0 0.584 6.003 94.5 2.5403 24 666.0 \n", "\n", " ptratio black lstat \n", "34 20.2 331.29 21.32 " ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Here the index column is the original index as per the df dataframe and the number at the beginning the index after reset.\n", "\n", "X_test.reset_index().loc[[i]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **LIME's interpretation** of our **Random Forest's prediction**." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "exp.show_in_notebook(show_table=True, show_all=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- First, note that the **row** we **explained** is **displayed** on the **right side**, in **table** format. Since we had the **show_all parameter** set to **false**, only the **features used in the explanation are displayed**.\n", "\n", "\n", "- The **value column** displays the **original value for each feature**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- To get the **output generated above** in the **form of a list**." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('lstat > 16.37', -5.335049207889584),\n", " ('5.89 < rm <= 6.21', -3.3708905664632702),\n", " ('crim > 3.20', -0.9545632488848613),\n", " ('330.00 < tax <= 666.00', -0.4234910180513856),\n", " ('rad=24', 0.38388434357383366),\n", " ('18.70 < ptratio <= 20.20', -0.30774577277796716)]" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "exp.as_list()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Obesrvations obtained from LIME's interpretation of our Random Forest's prediction**:\n", "\n", "- The **values** shown after the condition is the **amount** by which the value is **shifted** from the **intercept** estimated for the local model. \n", " \n", " \n", "- When all these values are **added** to the **intercept**, it gives us the **Prediction_local** (local model's estimate for the Regression Forest's prediction) calculated by **LIME**." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Intercept = 26.667118750528495\n", "Prediction_local = 16.65926328003526\n" ] } ], "source": [ "print('Intercept =', exp.intercept[0])\n", "print('Prediction_local =', exp.local_pred[0])" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction_local = 16.659263280035262\n" ] } ], "source": [ "# Calculating the Prediction_local by adding all the values obtained above for each condition into the intercept.\n", "# The intercept can be obtained from the exp.intercept using the index 0.\n", "\n", "intercept = exp.intercept[0]\n", "prediction_local = intercept\n", "\n", "for i in range(len(exp.as_list())):\n", " prediction_local += exp.as_list()[i][1]\n", "\n", "print('Prediction_local =', prediction_local)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " \n", " \n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Choosing **another instance** from the **test dataset**." ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "i = 93\n", "Intercept 25.658304604538877\n", "Prediction_local [19.7191]\n", "Right: 18.784999999999997\n" ] } ], "source": [ "i = np.random.randint(0, X_test.shape[0])\n", "print('i =', i)\n", "\n", "exp = explainer.explain_instance(X_test.values[i], regressor_rf.predict, num_features=6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **Printing** the **DataFrame row** for the **chosen test instance**." ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
indexcrimzninduschasnoxrmagedisradtaxptratioblacklstat
933460.061620.04.3900.4425.89852.38.01363352.018.8364.6112.67
\n", "
" ], "text/plain": [ " index crim zn indus chas nox rm age dis rad tax \\\n", "93 346 0.06162 0.0 4.39 0 0.442 5.898 52.3 8.0136 3 352.0 \n", "\n", " ptratio black lstat \n", "93 18.8 364.61 12.67 " ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X_test.reset_index().loc[[i]]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **LIME's interpretation** of our **Random Forest's prediction**." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", "
\n", " \n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "exp.show_in_notebook(show_table=True, show_all=False)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[('5.89 < rm <= 6.21', -3.3739831077720055),\n", " ('10.93 < lstat <= 16.37', -1.2396077198032527),\n", " ('18.70 < ptratio <= 20.20', -0.8122143970518639),\n", " ('black <= 375.47', -0.48573831431424713),\n", " ('330.00 < tax <= 666.00', -0.4728522138910121),\n", " ('zn <= 0.00', 0.44516993577257424)]" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "exp.as_list()" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Intercept = 25.658304604538877\n", "Prediction_local = 19.71907878747907\n" ] } ], "source": [ "print('Intercept =', exp.intercept[0])\n", "print('Prediction_local =', exp.local_pred[0])" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Prediction_local = 19.71907878747907\n" ] } ], "source": [ "intercept = exp.intercept[0]\n", "prediction_local = intercept\n", "\n", "for i in range(len(exp.as_list())):\n", " prediction_local += exp.as_list()[i][1]\n", "\n", "print('Prediction_local =', prediction_local)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- By **changing** the chosen **i**, we observe that the **narrative provided by LIME** also **changes, in response to changes in the model** in the **local region** of the **feature space** in which it is working to **generate a given prediction**. \n", "\n", "\n", "- This is clearly an **improvement on relying purely** on the **Regression Forest's (static) expected relative feature importance** and of **great benefit to models that provice no insight whatsoever**." ] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "307.2px" }, "toc_section_display": true, "toc_window_display": false }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 1 }