{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Lesson 7 - Model interpretability\n",
"\n",
"> How to interpret the predictions from Random Forest models and use these insights to prune the feature space."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/lewtun/dslectures/master?urlpath=lab/tree/notebooks%2Flesson07_model-interpretation.ipynb) \n",
"[![slides](https://img.shields.io/static/v1?label=slides&message=2021-lesson07.pdf&color=blue&logo=Google-drive)](https://drive.google.com/open?id=1Tib76a0leS8xhuAlni2WUhqlcITGEf38)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learning objectives"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* Understand how to interpret feature importance plots for Random Forest models.\n",
"* Know how to drop uninformative features to build simpler models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This lesson is adapted (with permission) from Jeremy Howard's fantastic online course [_Introduction to Machine Learning for Coders_](https://course18.fast.ai/ml), in particular:\n",
"\n",
"* [3 — Performance, validation and model interpretation](https://course18.fast.ai/lessonsml1/lesson3.html)\n",
"\n",
"Below are a few relevant articles that may be of general interest:\n",
"\n",
"* [Explaining Feature Importance by example of a Random Forest](https://towardsdatascience.com/explaining-feature-importance-by-example-of-a-random-forest-d9166011959e)\n",
"* [Beware Default Random Forest Importances](https://explained.ai/rf-importance/index.html)\n",
"* [Explainable AI won’t deliver. Here’s why.](https://hackernoon.com/explainable-ai-wont-deliver-here-s-why-6738f54216be)\n",
"* [Confidence Intervals](https://dfrieds.com/math/confidence-intervals.html)\n",
"* [Reading and Writing Files in Python (Guide)](https://realpython.com/read-write-files-python/)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Homework\n",
"\n",
"* Solve the exercises included in this notebook"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this lesson we will analyse the preprocessed table of clean housing data and their addresses that we prepared in lesson 3:\n",
"\n",
"* `housing_processed.csv`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## What is model interpretability?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
\n",
"\n",
"
Figure reference: https://bit.ly/3djjWc6
\n",
"
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A nice explanation for what it means to interpret a model's predictions is given in the _Beware Default Random Forest Importances_ article:\n",
"\n",
"> Training a model that accurately predicts outcomes is great, but most of the time you don't just need predictions, you want to be able to interpret your model. For example, if you build a model of house prices, knowing which features are most predictive of price tells us which features people are willing to pay for.\n",
"\n",
"In this lesson we will focus on one specific aspect of interpretability for Random Forests, namely _feature importance_ which is a technique that (with care) can be used to identify the most informative features in a dataset."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# reload modules before executing user code\n",
"%load_ext autoreload\n",
"# reload all modules every time before executing Python code\n",
"%autoreload 2\n",
"# render plots in notebook\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# uncomment to update the library if working locally\n",
"# !pip install dslectures --upgrade"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# data wrangling\n",
"import pandas as pd\n",
"import numpy as np\n",
"from dslectures.core import get_dataset, convert_strings_to_categories, rmse, fill_missing_values_with_median\n",
"from pathlib import Path\n",
"\n",
"# data viz\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"sns.set(color_codes=True)\n",
"sns.set_palette(sns.color_palette(\"muted\"))\n",
"\n",
"# ml magic\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.metrics import r2_score\n",
"import scipy\n",
"from scipy.cluster import hierarchy as hc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset already exists at '../data/housing_processed.csv' and is not downloaded again.\n"
]
}
],
"source": [
"get_dataset('housing_processed.csv')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We also make use of the `pathlib` library to handle our filepaths:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"autos.csv housing_gmaps_data_raw.csv\n",
"\u001b[1m\u001b[36mcar-challenge\u001b[m\u001b[m housing_merged.csv\n",
"\u001b[1m\u001b[36mcats_vs_dogs\u001b[m\u001b[m housing_processed.csv\n",
"churn.csv imdb.csv\n",
"fine_tuned.pth \u001b[1m\u001b[36mspotify\u001b[m\u001b[m\n",
"housing.csv word2vec-google-news-300.pkl\n",
"housing_addresses.csv\n"
]
}
],
"source": [
"DATA = Path('../data/')\n",
"!ls {DATA}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
longitude
\n",
"
latitude
\n",
"
housing_median_age
\n",
"
total_rooms
\n",
"
total_bedrooms
\n",
"
population
\n",
"
households
\n",
"
median_income
\n",
"
median_house_value
\n",
"
city
\n",
"
postal_code
\n",
"
rooms_per_household
\n",
"
bedrooms_per_household
\n",
"
bedrooms_per_room
\n",
"
population_per_household
\n",
"
ocean_proximity_INLAND
\n",
"
ocean_proximity_<1H OCEAN
\n",
"
ocean_proximity_NEAR BAY
\n",
"
ocean_proximity_NEAR OCEAN
\n",
"
ocean_proximity_ISLAND
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
-122.23
\n",
"
37.88
\n",
"
41.0
\n",
"
880.0
\n",
"
129.0
\n",
"
322.0
\n",
"
126.0
\n",
"
8.3252
\n",
"
452600.0
\n",
"
69
\n",
"
94705
\n",
"
6.984127
\n",
"
1.023810
\n",
"
0.146591
\n",
"
2.555556
\n",
"
0
\n",
"
0
\n",
"
1
\n",
"
0
\n",
"
0
\n",
"
\n",
"
\n",
"
1
\n",
"
-122.22
\n",
"
37.86
\n",
"
21.0
\n",
"
7099.0
\n",
"
1106.0
\n",
"
2401.0
\n",
"
1138.0
\n",
"
8.3014
\n",
"
358500.0
\n",
"
620
\n",
"
94611
\n",
"
6.238137
\n",
"
0.971880
\n",
"
0.155797
\n",
"
2.109842
\n",
"
0
\n",
"
0
\n",
"
1
\n",
"
0
\n",
"
0
\n",
"
\n",
"
\n",
"
2
\n",
"
-122.24
\n",
"
37.85
\n",
"
52.0
\n",
"
1467.0
\n",
"
190.0
\n",
"
496.0
\n",
"
177.0
\n",
"
7.2574
\n",
"
352100.0
\n",
"
620
\n",
"
94618
\n",
"
8.288136
\n",
"
1.073446
\n",
"
0.129516
\n",
"
2.802260
\n",
"
0
\n",
"
0
\n",
"
1
\n",
"
0
\n",
"
0
\n",
"
\n",
"
\n",
"
3
\n",
"
-122.25
\n",
"
37.85
\n",
"
52.0
\n",
"
1274.0
\n",
"
235.0
\n",
"
558.0
\n",
"
219.0
\n",
"
5.6431
\n",
"
341300.0
\n",
"
620
\n",
"
94618
\n",
"
5.817352
\n",
"
1.073059
\n",
"
0.184458
\n",
"
2.547945
\n",
"
0
\n",
"
0
\n",
"
1
\n",
"
0
\n",
"
0
\n",
"
\n",
"
\n",
"
4
\n",
"
-122.25
\n",
"
37.85
\n",
"
52.0
\n",
"
1627.0
\n",
"
280.0
\n",
"
565.0
\n",
"
259.0
\n",
"
3.8462
\n",
"
342200.0
\n",
"
620
\n",
"
94618
\n",
"
6.281853
\n",
"
1.081081
\n",
"
0.172096
\n",
"
2.181467
\n",
"
0
\n",
"
0
\n",
"
1
\n",
"
0
\n",
"
0
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" longitude latitude housing_median_age total_rooms total_bedrooms \\\n",
"0 -122.23 37.88 41.0 880.0 129.0 \n",
"1 -122.22 37.86 21.0 7099.0 1106.0 \n",
"2 -122.24 37.85 52.0 1467.0 190.0 \n",
"3 -122.25 37.85 52.0 1274.0 235.0 \n",
"4 -122.25 37.85 52.0 1627.0 280.0 \n",
"\n",
" population households median_income median_house_value city \\\n",
"0 322.0 126.0 8.3252 452600.0 69 \n",
"1 2401.0 1138.0 8.3014 358500.0 620 \n",
"2 496.0 177.0 7.2574 352100.0 620 \n",
"3 558.0 219.0 5.6431 341300.0 620 \n",
"4 565.0 259.0 3.8462 342200.0 620 \n",
"\n",
" postal_code rooms_per_household bedrooms_per_household \\\n",
"0 94705 6.984127 1.023810 \n",
"1 94611 6.238137 0.971880 \n",
"2 94618 8.288136 1.073446 \n",
"3 94618 5.817352 1.073059 \n",
"4 94618 6.281853 1.081081 \n",
"\n",
" bedrooms_per_room population_per_household ocean_proximity_INLAND \\\n",
"0 0.146591 2.555556 0 \n",
"1 0.155797 2.109842 0 \n",
"2 0.129516 2.802260 0 \n",
"3 0.184458 2.547945 0 \n",
"4 0.172096 2.181467 0 \n",
"\n",
" ocean_proximity_<1H OCEAN ocean_proximity_NEAR BAY \\\n",
"0 0 1 \n",
"1 0 1 \n",
"2 0 1 \n",
"3 0 1 \n",
"4 0 1 \n",
"\n",
" ocean_proximity_NEAR OCEAN ocean_proximity_ISLAND \n",
"0 0 0 \n",
"1 0 0 \n",
"2 0 0 \n",
"3 0 0 \n",
"4 0 0 "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"housing_data = pd.read_csv(DATA/'housing_processed.csv'); housing_data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With the data loaded, we can recreate our train/validation splits as before:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X = housing_data.drop('median_house_value', axis=1)\n",
"y = housing_data['median_house_value']"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"15554 train rows + 3889 valid rows\n"
]
}
],
"source": [
"X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"print(f'{len(X_train)} train rows + {len(X_valid)} valid rows')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To simplify the evaluation of our models, we'll reuse our scoring function from lesson 5:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def print_rf_scores(fitted_model):\n",
" \"\"\"Generates RMSE and R^2 scores from fitted Random Forest model.\"\"\"\n",
"\n",
" yhat_train = fitted_model.predict(X_train)\n",
" R2_train = fitted_model.score(X_train, y_train)\n",
" yhat_valid = fitted_model.predict(X_valid)\n",
" R2_valid = fitted_model.score(X_valid, y_valid)\n",
"\n",
" scores = {\n",
" \"RMSE on train:\": rmse(y_train, yhat_train),\n",
" \"R^2 on train:\": R2_train,\n",
" \"RMSE on valid:\": rmse(y_valid, yhat_valid),\n",
" \"R^2 on valid:\": R2_valid,\n",
" }\n",
" if hasattr(fitted_model, \"oob_score_\"):\n",
" scores[\"OOB R^2:\"] = fitted_model.oob_score_\n",
"\n",
" for score_name, score_value in scores.items():\n",
" print(score_name, round(score_value, 3))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Confidence intervals"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Recall that to make predictions with our Random Forest models, we take the _average value_ in each leaf node as we pass each row in the validation set through the tree. However, we would also like to estimate our _**confidence**_ in these predictions - how can we achieve this?\n",
"\n",
"One way to do this is to calculate the _**standard deviation of the predictions**_ of the trees. Conceptually, the idea is that if the standard deviation is high, each tree is generating very different predictions and may indicate the model has not learnt the most important features of the data.\n",
"\n",
"To get started, let's use our baseline model from the previous lesson:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RMSE on train: 16520.37\n",
"R^2 on train: 0.971\n",
"RMSE on valid: 42727.043\n",
"R^2 on valid: 0.81\n",
"OOB R^2: 0.791\n"
]
}
],
"source": [
"model = RandomForestRegressor(n_estimators=40, max_features='sqrt', n_jobs=-1, oob_score=True, random_state=42)\n",
"model.fit(X_train, y_train)\n",
"print_rf_scores(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As before, we concatenate all the predictions from each tree into a single array:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(137125.0, 41845.23120978064)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds = np.stack([t.predict(X_valid) for t in model.estimators_])\n",
"# calculate mean and standard deviation for single observation\n",
"np.mean(preds[:,0]), np.std(preds[:,0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(40, 3889)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, let's create a copy of the validation dataset and add the mean predictions and their standard deviation as new columns."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"#### Exercise #1\n",
"\n",
"* Combine `X_valid` and `y_valid` into a single array called `valid_copy`. You may find the `DataFrame.join(DataFrame)` method from pandas useful here.\n",
"* Create two new columns `preds_mean` and `preds_std` that are the mean and standard deviation of the predictions for each tree in `preds`\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can use these new columns to drill-down into the predictions of each individual, categorical feature. Let's examine `ocean_proximity_INLAND` which denotes whether a housing district is inland or not:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sns.countplot(y='ocean_proximity_INLAND', data=valid_copy);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can calculate the predictions and standard deviation per category by applying a _**group by**_ operation in pandas, followed by taking the mean."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
ocean_proximity_INLAND
\n",
"
median_house_value
\n",
"
preds_mean
\n",
"
preds_std
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
0
\n",
"
227863.194178
\n",
"
227319.699349
\n",
"
51913.127856
\n",
"
\n",
"
\n",
"
1
\n",
"
1
\n",
"
123654.224570
\n",
"
124241.488165
\n",
"
35858.872619
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" ocean_proximity_INLAND median_house_value preds_mean preds_std\n",
"0 0 227863.194178 227319.699349 51913.127856\n",
"1 1 123654.224570 124241.488165 35858.872619"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cols = ['ocean_proximity_INLAND', 'median_house_value', 'preds_mean', 'preds_std']\n",
"preds_quality = valid_copy[cols].groupby('ocean_proximity_INLAND', as_index=False).mean()\n",
"preds_quality"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the table, we can see that the predictions and ground truth are close to each other on average, while the standard deviation varies somewhat for each category. We can visualise this table in terms of bar plots as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# prepare figure \n",
"fig, (ax0, ax1) = plt.subplots(nrows=2, ncols=1, figsize=(12,7), sharex=True)\n",
"\n",
"# plot ground truth\n",
"preds_quality.plot('ocean_proximity_INLAND', 'median_house_value', 'barh', ax=ax0)\n",
"# put legend outside plot\n",
"ax0.legend(loc='upper left', bbox_to_anchor=(1.0, 0.5))\n",
"\n",
"# plot preds\n",
"preds_quality.plot('ocean_proximity_INLAND', 'preds_mean', 'barh', xerr='preds_std', alpha=0.6, ax=ax1)\n",
"# put legend outside plot\n",
"ax1.legend(loc='upper left', bbox_to_anchor=(1.0, 0.5))\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The error bars in the plot indicate how confident the model is at predicting each category. Alternatively, we can compare the _distribution_ of values to inspect how close the predictions match the ground truth. For example, for housing districts where `ocean_proximity_INLAND` is 0 we have:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/leandro/git/dslectures/env/lib/python3.7/site-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).\n",
" warnings.warn(msg, FutureWarning)\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sample = valid_copy.copy().loc[valid_copy[\"ocean_proximity_INLAND\"] == 0]\n",
"sample_mean = sample['preds_mean'].mean()\n",
"sample_std = sample['preds_mean'].std()\n",
"lower_bound = sample_mean - sample_std\n",
"upper_bound = sample_mean + sample_std\n",
"\n",
"sns.distplot(\n",
" sample[\"median_house_value\"], kde=False,\n",
")\n",
"sns.distplot(sample[\"preds_mean\"], kde=False)\n",
"plt.axvline(\n",
" x=sample_mean,\n",
" linestyle=\"--\",\n",
" linewidth=2.5,\n",
" label=\"Mean of predictions\",\n",
" c=\"k\",\n",
")\n",
"plt.axvline(\n",
" x=lower_bound,\n",
" linestyle=\"--\",\n",
" linewidth=2.5,\n",
" label=\"Lower bound 68% CI\",\n",
" c=\"g\",\n",
")\n",
"plt.axvline(\n",
" x=upper_bound,\n",
" linestyle=\"--\",\n",
" linewidth=2.5,\n",
" label=\"Upper bound 68% CI\",\n",
" c=\"purple\",\n",
")\n",
"plt.legend(bbox_to_anchor=(1.01, 1), loc=\"upper left\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In general, we expect our models to perform best on categories that are most frequent in the data. One way to validate this hypothesis is by calculating the ratio of the standard deviation of the predictions to the predictions themselves:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ocean_proximity_INLAND\n",
"1 0.288622\n",
"0 0.228371\n",
"dtype: float64"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds_quality = valid_copy[cols].groupby(\"ocean_proximity_INLAND\", as_index=True).mean()\n",
"(preds_quality[\"preds_std\"] / preds_quality[\"preds_mean\"]).sort_values(ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What the above tells us is that our predictions are less confident (i.e. higher variance) for housing districts that are inland - indeed looking at our bar plot we see these categories are under-represented in the data!\n",
"\n",
"In general, confidence intervals serve two main purposes:\n",
"\n",
"* We can identify which categories the model is less confident about and investigate further\n",
"* We can identify which rows in the data the model is not confident about. This is particularly important when deploying models to production, where e.g. we need to decide how to evaluate the model's predictions for a _single_ housing district."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Feature importance"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One drawback with the confidence interval analysis is that we need to drill-down into each feature to see where the model is making mistakes. In practice, we can get a global view by ranking each feature in terms of its importance to the model's predictions. In scikit-learn, the Random Forest model has an attribute called `feature_importances_` that we can use to rank each feature:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def rf_feature_importance(fitted_model, df):\n",
" return pd.DataFrame(\n",
" {\"Column\": df.columns, \"Importance\": fitted_model.feature_importances_}\n",
" ).sort_values(\"Importance\", ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's use this function to calculate the feature importance for our fitted model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
Column
\n",
"
Importance
\n",
"
\n",
" \n",
" \n",
"
\n",
"
7
\n",
"
median_income
\n",
"
0.248165
\n",
"
\n",
"
\n",
"
14
\n",
"
ocean_proximity_INLAND
\n",
"
0.135266
\n",
"
\n",
"
\n",
"
13
\n",
"
population_per_household
\n",
"
0.088856
\n",
"
\n",
"
\n",
"
9
\n",
"
postal_code
\n",
"
0.088172
\n",
"
\n",
"
\n",
"
0
\n",
"
longitude
\n",
"
0.080466
\n",
"
\n",
"
\n",
"
12
\n",
"
bedrooms_per_room
\n",
"
0.070059
\n",
"
\n",
"
\n",
"
1
\n",
"
latitude
\n",
"
0.066439
\n",
"
\n",
"
\n",
"
10
\n",
"
rooms_per_household
\n",
"
0.048305
\n",
"
\n",
"
\n",
"
2
\n",
"
housing_median_age
\n",
"
0.030146
\n",
"
\n",
"
\n",
"
8
\n",
"
city
\n",
"
0.022991
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Column Importance\n",
"7 median_income 0.248165\n",
"14 ocean_proximity_INLAND 0.135266\n",
"13 population_per_household 0.088856\n",
"9 postal_code 0.088172\n",
"0 longitude 0.080466\n",
"12 bedrooms_per_room 0.070059\n",
"1 latitude 0.066439\n",
"10 rooms_per_household 0.048305\n",
"2 housing_median_age 0.030146\n",
"8 city 0.022991"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# expected shape - (n_features, 2)\n",
"feature_importance = rf_feature_importance(model, X)\n",
"\n",
"# peek at top 10 features\n",
"feature_importance[:10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the table we see that `median_income`, `ocean_proximity_INLAND`, and `population_per_household` are the most important features - this is not entirely surprising since income and house location seem to be good indicators of house value. We can also plot the feature importance to gain a visual understanding:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_feature_importance(feature_importance):\n",
" return sns.barplot(y=\"Column\", x=\"Importance\", data=feature_importance, color='b')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_feature_importance(feature_importance);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In nearly every real-world dataset, this is what the feature importance looks like: a handful of columns are very important, while most are not. The powerful aspect of this approach is that is _focuses our attention_ on which features we should investigate further and which ones we can safely ignore."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Warning: The feature importance analysis above can be biased and has a tendency to inflate the importance of continuous features or categorical features with high cardinality (i.e. many unique categories). See the Beware Default Random Forest Importances article in the references for more information."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Drop uninformative features"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the feature importance plot above, we can see there are only a handful of informative features - let's use this insight to make a simpler model by dropping uninformative columns from our data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"feature_importance_threshold = 0.03\n",
"cols_to_keep = feature_importance[\n",
" feature_importance['Importance'] > feature_importance_threshold\n",
"]['Column']\n",
"\n",
"len(cols_to_keep)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# create a copy of the data with selected columns and create new train / test set\n",
"X_keep = X.copy()[cols_to_keep]\n",
"X_train, X_valid = train_test_split(X_keep, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"