{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Lesson 7 - Classification\n",
"\n",
"> In this lesson we bring together all the knowledge we have gained about Random Forests and apply it to a new type of supervised learning task: binary classification"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/lewtun/dslectures/master?urlpath=lab/tree/notebooks%2Flesson07_classification.ipynb) \n",
"[![slides](https://img.shields.io/static/v1?label=slides&message=lesson07_classification.pdf&color=blue&logo=Google-drive)](https://drive.google.com/open?id=12gVqkVJ9KIaBGPCq0auFxXOPlEdo9Dik)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learning objectives"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"* Know how to apply Random Forests to classification tasks\n",
"* Understand the performance metrics associated with binary classification\n",
"* Gain an introduction to fast.ai's data preprocessing functions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## References"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This lesson is inspired by the following textbooks and online courses:\n",
"\n",
"* Chapter 3 of _Hands-On Machine Learning with Scikit-Learn and TensorFlow_ by Aurèlien Geron\n",
"* Chapter 7 of _Data Science for Business_ by Provost and Fawcett\n",
"* Lessons 1 - 4 of Jeremy Howard's fantastic online course [_Introduction to Machine Learning for Coders_](https://course18.fast.ai/ml)\n",
"\n",
"You may also find the following blog post useful:\n",
"\n",
"* [Grumpy, euphoric, and smart classifiers (interactive)](https://christian.bock.ml/posts/metrics/)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Homework\n",
"\n",
"* Solve the exercises included in this notebook\n",
"* Read chapter 3 of _Hands-On Machine Learning with Scikit-Learn and TensorFlow_ by Aurèlien Geron\n",
"* Read chapter 7 of _Data Science for Business_ by Provost and Fawcett\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## What is customer churn?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
\n",
"\n",
"We will explore [IBM's telecommunications dataset](https://www.kaggle.com/blastchar/telco-customer-churn) and determine which attributes are most informative for predicting customer retention (also known as customer churn). As described by IBM, the problem setting is as follows:\n",
"\n",
"> A telecommunications company is concerned about the number of customers leaving their landline business for cable competitors. They need to understand who is leaving. Imagine that you’re an analyst at this company and you have to find out who is leaving and why.\n",
"\n",
"The kind of questions we'd like to find answers to are:\n",
"\n",
"* Which customers are likely to leave?\n",
"* Which attributes influence customers who leave?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As noted above, in this lesson we will analyse IBM's customer churn dataset:\n",
"\n",
"* `churn.csv`\n",
"\n",
"The dataset includes information about:\n",
"\n",
"* Customers who left within the last month – the column is called `Churn`\n",
"* Services that each customer has signed up for – phone, multiple lines, internet, online security, online backup, device protection, tech support, and streaming TV and movies\n",
"* Customer account information – how long they’ve been a customer (tenure), contract, payment method, paperless billing, monthly charges, and total charges\n",
"* Demographic info about customers – gender, whether they're a senior citizen or not, and if they have partners and dependents\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# reload modules before executing user code\n",
"%load_ext autoreload\n",
"# reload all modules every time before executing Python code\n",
"%autoreload 2\n",
"# render plots in notebook\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# uncomment to update the library if working locally\n",
"# !pip install dslectures --upgrade"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# data wrangling\n",
"import pandas as pd\n",
"import numpy as np\n",
"from dslectures.core import (\n",
" get_dataset,\n",
" display_large,\n",
" convert_strings_to_categories,\n",
" rf_feature_importance,\n",
" plot_feature_importance,\n",
" plot_dendogram,\n",
")\n",
"from dslectures.structured import proc_df\n",
"from pathlib import Path\n",
"\n",
"# data viz\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.metrics import plot_confusion_matrix, plot_roc_curve\n",
"from sklearn.tree import plot_tree\n",
"\n",
"sns.set(color_codes=True)\n",
"sns.set_palette(sns.color_palette(\"muted\"))\n",
"\n",
"# ml magic\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"import scipy\n",
"from scipy.cluster import hierarchy as hc"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Download of churn.csv dataset complete.\n"
]
}
],
"source": [
"get_dataset(\"churn.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We also make use of the `pathlib` library to handle our filepaths:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"churn.csv housing_merged.csv\n",
"churn_processed.csv housing_processed.csv\n",
"housing.csv submission.csv\n",
"housing_addresses.csv test.csv\n",
"housing_gmaps_data_raw.csv train.csv\n"
]
}
],
"source": [
"DATA = Path('../data/')\n",
"!ls {DATA}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"churn_data = pd.read_csv(DATA / \"churn.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inspect the data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Preview the data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Sometimes you will find that the dataset has too many columns to be displayed with the standard `DataFrame.head()` method and just shows `...` for intermediate columns:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
customerID
\n",
"
gender
\n",
"
SeniorCitizen
\n",
"
Partner
\n",
"
Dependents
\n",
"
tenure
\n",
"
PhoneService
\n",
"
MultipleLines
\n",
"
InternetService
\n",
"
OnlineSecurity
\n",
"
...
\n",
"
DeviceProtection
\n",
"
TechSupport
\n",
"
StreamingTV
\n",
"
StreamingMovies
\n",
"
Contract
\n",
"
PaperlessBilling
\n",
"
PaymentMethod
\n",
"
MonthlyCharges
\n",
"
TotalCharges
\n",
"
Churn
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
7590-VHVEG
\n",
"
Female
\n",
"
0
\n",
"
Yes
\n",
"
No
\n",
"
1
\n",
"
No
\n",
"
No phone service
\n",
"
DSL
\n",
"
No
\n",
"
...
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
Month-to-month
\n",
"
Yes
\n",
"
Electronic check
\n",
"
29.85
\n",
"
29.85
\n",
"
No
\n",
"
\n",
"
\n",
"
1
\n",
"
5575-GNVDE
\n",
"
Male
\n",
"
0
\n",
"
No
\n",
"
No
\n",
"
34
\n",
"
Yes
\n",
"
No
\n",
"
DSL
\n",
"
Yes
\n",
"
...
\n",
"
Yes
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
One year
\n",
"
No
\n",
"
Mailed check
\n",
"
56.95
\n",
"
1889.5
\n",
"
No
\n",
"
\n",
"
\n",
"
2
\n",
"
3668-QPYBK
\n",
"
Male
\n",
"
0
\n",
"
No
\n",
"
No
\n",
"
2
\n",
"
Yes
\n",
"
No
\n",
"
DSL
\n",
"
Yes
\n",
"
...
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
Month-to-month
\n",
"
Yes
\n",
"
Mailed check
\n",
"
53.85
\n",
"
108.15
\n",
"
Yes
\n",
"
\n",
"
\n",
"
3
\n",
"
7795-CFOCW
\n",
"
Male
\n",
"
0
\n",
"
No
\n",
"
No
\n",
"
45
\n",
"
No
\n",
"
No phone service
\n",
"
DSL
\n",
"
Yes
\n",
"
...
\n",
"
Yes
\n",
"
Yes
\n",
"
No
\n",
"
No
\n",
"
One year
\n",
"
No
\n",
"
Bank transfer (automatic)
\n",
"
42.30
\n",
"
1840.75
\n",
"
No
\n",
"
\n",
"
\n",
"
4
\n",
"
9237-HQITU
\n",
"
Female
\n",
"
0
\n",
"
No
\n",
"
No
\n",
"
2
\n",
"
Yes
\n",
"
No
\n",
"
Fiber optic
\n",
"
No
\n",
"
...
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
Month-to-month
\n",
"
Yes
\n",
"
Electronic check
\n",
"
70.70
\n",
"
151.65
\n",
"
Yes
\n",
"
\n",
" \n",
"
\n",
"
5 rows × 21 columns
\n",
"
"
],
"text/plain": [
" customerID gender SeniorCitizen Partner Dependents tenure PhoneService \\\n",
"0 7590-VHVEG Female 0 Yes No 1 No \n",
"1 5575-GNVDE Male 0 No No 34 Yes \n",
"2 3668-QPYBK Male 0 No No 2 Yes \n",
"3 7795-CFOCW Male 0 No No 45 No \n",
"4 9237-HQITU Female 0 No No 2 Yes \n",
"\n",
" MultipleLines InternetService OnlineSecurity ... DeviceProtection \\\n",
"0 No phone service DSL No ... No \n",
"1 No DSL Yes ... Yes \n",
"2 No DSL Yes ... No \n",
"3 No phone service DSL Yes ... Yes \n",
"4 No Fiber optic No ... No \n",
"\n",
" TechSupport StreamingTV StreamingMovies Contract PaperlessBilling \\\n",
"0 No No No Month-to-month Yes \n",
"1 No No No One year No \n",
"2 No No No Month-to-month Yes \n",
"3 Yes No No One year No \n",
"4 No No No Month-to-month Yes \n",
"\n",
" PaymentMethod MonthlyCharges TotalCharges Churn \n",
"0 Electronic check 29.85 29.85 No \n",
"1 Mailed check 56.95 1889.5 No \n",
"2 Mailed check 53.85 108.15 Yes \n",
"3 Bank transfer (automatic) 42.30 1840.75 No \n",
"4 Electronic check 70.70 151.65 Yes \n",
"\n",
"[5 rows x 21 columns]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"churn_data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To fix that we can configure the [options in pandas](https://pandas.pydata.org/pandas-docs/version/0.15/options.html) which we can wrap inside a simple function:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
customerID
\n",
"
gender
\n",
"
SeniorCitizen
\n",
"
Partner
\n",
"
Dependents
\n",
"
tenure
\n",
"
PhoneService
\n",
"
MultipleLines
\n",
"
InternetService
\n",
"
OnlineSecurity
\n",
"
OnlineBackup
\n",
"
DeviceProtection
\n",
"
TechSupport
\n",
"
StreamingTV
\n",
"
StreamingMovies
\n",
"
Contract
\n",
"
PaperlessBilling
\n",
"
PaymentMethod
\n",
"
MonthlyCharges
\n",
"
TotalCharges
\n",
"
Churn
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
7590-VHVEG
\n",
"
Female
\n",
"
0
\n",
"
Yes
\n",
"
No
\n",
"
1
\n",
"
No
\n",
"
No phone service
\n",
"
DSL
\n",
"
No
\n",
"
Yes
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
Month-to-month
\n",
"
Yes
\n",
"
Electronic check
\n",
"
29.85
\n",
"
29.85
\n",
"
No
\n",
"
\n",
"
\n",
"
1
\n",
"
5575-GNVDE
\n",
"
Male
\n",
"
0
\n",
"
No
\n",
"
No
\n",
"
34
\n",
"
Yes
\n",
"
No
\n",
"
DSL
\n",
"
Yes
\n",
"
No
\n",
"
Yes
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
One year
\n",
"
No
\n",
"
Mailed check
\n",
"
56.95
\n",
"
1889.5
\n",
"
No
\n",
"
\n",
"
\n",
"
2
\n",
"
3668-QPYBK
\n",
"
Male
\n",
"
0
\n",
"
No
\n",
"
No
\n",
"
2
\n",
"
Yes
\n",
"
No
\n",
"
DSL
\n",
"
Yes
\n",
"
Yes
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
Month-to-month
\n",
"
Yes
\n",
"
Mailed check
\n",
"
53.85
\n",
"
108.15
\n",
"
Yes
\n",
"
\n",
"
\n",
"
3
\n",
"
7795-CFOCW
\n",
"
Male
\n",
"
0
\n",
"
No
\n",
"
No
\n",
"
45
\n",
"
No
\n",
"
No phone service
\n",
"
DSL
\n",
"
Yes
\n",
"
No
\n",
"
Yes
\n",
"
Yes
\n",
"
No
\n",
"
No
\n",
"
One year
\n",
"
No
\n",
"
Bank transfer (automatic)
\n",
"
42.30
\n",
"
1840.75
\n",
"
No
\n",
"
\n",
"
\n",
"
4
\n",
"
9237-HQITU
\n",
"
Female
\n",
"
0
\n",
"
No
\n",
"
No
\n",
"
2
\n",
"
Yes
\n",
"
No
\n",
"
Fiber optic
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
Month-to-month
\n",
"
Yes
\n",
"
Electronic check
\n",
"
70.70
\n",
"
151.65
\n",
"
Yes
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" customerID gender SeniorCitizen Partner Dependents tenure PhoneService \\\n",
"0 7590-VHVEG Female 0 Yes No 1 No \n",
"1 5575-GNVDE Male 0 No No 34 Yes \n",
"2 3668-QPYBK Male 0 No No 2 Yes \n",
"3 7795-CFOCW Male 0 No No 45 No \n",
"4 9237-HQITU Female 0 No No 2 Yes \n",
"\n",
" MultipleLines InternetService OnlineSecurity OnlineBackup \\\n",
"0 No phone service DSL No Yes \n",
"1 No DSL Yes No \n",
"2 No DSL Yes Yes \n",
"3 No phone service DSL Yes No \n",
"4 No Fiber optic No No \n",
"\n",
" DeviceProtection TechSupport StreamingTV StreamingMovies Contract \\\n",
"0 No No No No Month-to-month \n",
"1 Yes No No No One year \n",
"2 No No No No Month-to-month \n",
"3 Yes Yes No No One year \n",
"4 No No No No Month-to-month \n",
"\n",
" PaperlessBilling PaymentMethod MonthlyCharges TotalCharges \\\n",
"0 Yes Electronic check 29.85 29.85 \n",
"1 No Mailed check 56.95 1889.5 \n",
"2 Yes Mailed check 53.85 108.15 \n",
"3 No Bank transfer (automatic) 42.30 1840.75 \n",
"4 Yes Electronic check 70.70 151.65 \n",
"\n",
" Churn \n",
"0 No \n",
"1 No \n",
"2 Yes \n",
"3 No \n",
"4 Yes "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display_large(churn_data.head())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, you can take the transpose to see all the columns more easily:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
0
\n",
"
1
\n",
"
2
\n",
"
3
\n",
"
4
\n",
"
\n",
" \n",
" \n",
"
\n",
"
customerID
\n",
"
7590-VHVEG
\n",
"
5575-GNVDE
\n",
"
3668-QPYBK
\n",
"
7795-CFOCW
\n",
"
9237-HQITU
\n",
"
\n",
"
\n",
"
gender
\n",
"
Female
\n",
"
Male
\n",
"
Male
\n",
"
Male
\n",
"
Female
\n",
"
\n",
"
\n",
"
SeniorCitizen
\n",
"
0
\n",
"
0
\n",
"
0
\n",
"
0
\n",
"
0
\n",
"
\n",
"
\n",
"
Partner
\n",
"
Yes
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
\n",
"
\n",
"
Dependents
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
\n",
"
\n",
"
tenure
\n",
"
1
\n",
"
34
\n",
"
2
\n",
"
45
\n",
"
2
\n",
"
\n",
"
\n",
"
PhoneService
\n",
"
No
\n",
"
Yes
\n",
"
Yes
\n",
"
No
\n",
"
Yes
\n",
"
\n",
"
\n",
"
MultipleLines
\n",
"
No phone service
\n",
"
No
\n",
"
No
\n",
"
No phone service
\n",
"
No
\n",
"
\n",
"
\n",
"
InternetService
\n",
"
DSL
\n",
"
DSL
\n",
"
DSL
\n",
"
DSL
\n",
"
Fiber optic
\n",
"
\n",
"
\n",
"
OnlineSecurity
\n",
"
No
\n",
"
Yes
\n",
"
Yes
\n",
"
Yes
\n",
"
No
\n",
"
\n",
"
\n",
"
OnlineBackup
\n",
"
Yes
\n",
"
No
\n",
"
Yes
\n",
"
No
\n",
"
No
\n",
"
\n",
"
\n",
"
DeviceProtection
\n",
"
No
\n",
"
Yes
\n",
"
No
\n",
"
Yes
\n",
"
No
\n",
"
\n",
"
\n",
"
TechSupport
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
Yes
\n",
"
No
\n",
"
\n",
"
\n",
"
StreamingTV
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
\n",
"
\n",
"
StreamingMovies
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
No
\n",
"
\n",
"
\n",
"
Contract
\n",
"
Month-to-month
\n",
"
One year
\n",
"
Month-to-month
\n",
"
One year
\n",
"
Month-to-month
\n",
"
\n",
"
\n",
"
PaperlessBilling
\n",
"
Yes
\n",
"
No
\n",
"
Yes
\n",
"
No
\n",
"
Yes
\n",
"
\n",
"
\n",
"
PaymentMethod
\n",
"
Electronic check
\n",
"
Mailed check
\n",
"
Mailed check
\n",
"
Bank transfer (automatic)
\n",
"
Electronic check
\n",
"
\n",
"
\n",
"
MonthlyCharges
\n",
"
29.85
\n",
"
56.95
\n",
"
53.85
\n",
"
42.3
\n",
"
70.7
\n",
"
\n",
"
\n",
"
TotalCharges
\n",
"
29.85
\n",
"
1889.5
\n",
"
108.15
\n",
"
1840.75
\n",
"
151.65
\n",
"
\n",
"
\n",
"
Churn
\n",
"
No
\n",
"
No
\n",
"
Yes
\n",
"
No
\n",
"
Yes
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" 0 1 2 \\\n",
"customerID 7590-VHVEG 5575-GNVDE 3668-QPYBK \n",
"gender Female Male Male \n",
"SeniorCitizen 0 0 0 \n",
"Partner Yes No No \n",
"Dependents No No No \n",
"tenure 1 34 2 \n",
"PhoneService No Yes Yes \n",
"MultipleLines No phone service No No \n",
"InternetService DSL DSL DSL \n",
"OnlineSecurity No Yes Yes \n",
"OnlineBackup Yes No Yes \n",
"DeviceProtection No Yes No \n",
"TechSupport No No No \n",
"StreamingTV No No No \n",
"StreamingMovies No No No \n",
"Contract Month-to-month One year Month-to-month \n",
"PaperlessBilling Yes No Yes \n",
"PaymentMethod Electronic check Mailed check Mailed check \n",
"MonthlyCharges 29.85 56.95 53.85 \n",
"TotalCharges 29.85 1889.5 108.15 \n",
"Churn No No Yes \n",
"\n",
" 3 4 \n",
"customerID 7795-CFOCW 9237-HQITU \n",
"gender Male Female \n",
"SeniorCitizen 0 0 \n",
"Partner No No \n",
"Dependents No No \n",
"tenure 45 2 \n",
"PhoneService No Yes \n",
"MultipleLines No phone service No \n",
"InternetService DSL Fiber optic \n",
"OnlineSecurity Yes No \n",
"OnlineBackup No No \n",
"DeviceProtection Yes No \n",
"TechSupport Yes No \n",
"StreamingTV No No \n",
"StreamingMovies No No \n",
"Contract One year Month-to-month \n",
"PaperlessBilling No Yes \n",
"PaymentMethod Bank transfer (automatic) Electronic check \n",
"MonthlyCharges 42.3 70.7 \n",
"TotalCharges 1840.75 151.65 \n",
"Churn No Yes "
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"churn_data.head().T"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### The shape of data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7043"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# get number of rows\n",
"len(churn_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(7043, 21)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# get tuples of (n_rows, n_columns)\n",
"churn_data.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this case, we see that we have 7043 customers and 21 variables or attributes that describe their telecom subscription. Let's have a look at the columns:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Index(['customerID', 'gender', 'SeniorCitizen', 'Partner', 'Dependents',\n",
" 'tenure', 'PhoneService', 'MultipleLines', 'InternetService',\n",
" 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport',\n",
" 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling',\n",
" 'PaymentMethod', 'MonthlyCharges', 'TotalCharges', 'Churn'],\n",
" dtype='object')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"churn_data.columns"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Note: As explained in the summary, the _**target attribute**_ is `Churn` and thus we have a _**classification problem**_ (rather than regression) because the target is a _**category**_ (Yes or No) rather than a coninuous number."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Unique values"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Whenever we see an ID column like `Id`, it is useful to perform a sanity check that each value is unique. Otherwise it may be possible that you have duplicates in your data that can bias your models and hence conclusions. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"7043"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"churn_data[\"customerID\"].nunique()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Good! The number of unique IDs matches the number of rows in our DataFrame."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data types"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"customerID object\n",
"gender object\n",
"SeniorCitizen int64\n",
"Partner object\n",
"Dependents object\n",
"tenure int64\n",
"PhoneService object\n",
"MultipleLines object\n",
"InternetService object\n",
"OnlineSecurity object\n",
"OnlineBackup object\n",
"DeviceProtection object\n",
"TechSupport object\n",
"StreamingTV object\n",
"StreamingMovies object\n",
"Contract object\n",
"PaperlessBilling object\n",
"PaymentMethod object\n",
"MonthlyCharges float64\n",
"TotalCharges object\n",
"Churn object\n",
"dtype: object"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"churn_data.dtypes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Hmm, `TotalCharges` is of type **object** (i.e. string) even though it is clearly a float. Since null values or NaNs don't produce this behaviour, there are presumably empty strings lurking in this column. Let's test this hypothesis using `DataFrame.value_counts()`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"20.2 11\n",
" 11\n",
"19.75 9\n",
"20.05 8\n",
"19.9 8\n",
" ..\n",
"514.75 1\n",
"676.35 1\n",
"6510.45 1\n",
"428.45 1\n",
"6004.85 1\n",
"Name: TotalCharges, Length: 6531, dtype: int64"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"churn_data[\"TotalCharges\"].value_counts()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will deal with this empty strings in the preprocessing steps below."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data preprocessing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Recall in our housing analysis, we needed to perform three main steps to bring out DataFrame to a form suitable for training a Random Forest on:\n",
"\n",
"* Convert strings to categorical data type\n",
"* Fill missing values\n",
"* Numericalise the DataFrame and create a features matrix $X$ and target vector $y$\n",
"* Create train and validation sets\n",
"\n",
"Let's perform each of those steps below."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Convert strings to categories"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we convert all the string columns to pandas' categorical data type:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"convert_strings_to_categories(churn_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"customerID category\n",
"gender category\n",
"SeniorCitizen int64\n",
"Partner category\n",
"Dependents category\n",
"tenure int64\n",
"PhoneService category\n",
"MultipleLines category\n",
"InternetService category\n",
"OnlineSecurity category\n",
"OnlineBackup category\n",
"DeviceProtection category\n",
"TechSupport category\n",
"StreamingTV category\n",
"StreamingMovies category\n",
"Contract category\n",
"PaperlessBilling category\n",
"PaymentMethod category\n",
"MonthlyCharges float64\n",
"TotalCharges category\n",
"Churn category\n",
"dtype: object"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"churn_data.dtypes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is almost correct, although a closer look at `SeniorCitizen` reveals that it refers to a binary feature and thus should also be categorical:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0, 1])"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"churn_data[\"SeniorCitizen\"].unique()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can fix this easily by simply changing the data type:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"churn_data[\"SeniorCitizen\"] = churn_data[\"SeniorCitizen\"].astype(\"category\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"customerID category\n",
"gender category\n",
"SeniorCitizen category\n",
"Partner category\n",
"Dependents category\n",
"tenure int64\n",
"PhoneService category\n",
"MultipleLines category\n",
"InternetService category\n",
"OnlineSecurity category\n",
"OnlineBackup category\n",
"DeviceProtection category\n",
"TechSupport category\n",
"StreamingTV category\n",
"StreamingMovies category\n",
"Contract category\n",
"PaperlessBilling category\n",
"PaymentMethod category\n",
"MonthlyCharges float64\n",
"TotalCharges category\n",
"Churn category\n",
"dtype: object"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# sanity check on the transformation\n",
"churn_data.dtypes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fill missing values"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A quick way to test for missing values is to apply the `isna` method from pandas and calculate the sum of missing values in our DataFrame:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Churn 0.0\n",
"OnlineSecurity 0.0\n",
"gender 0.0\n",
"SeniorCitizen 0.0\n",
"Partner 0.0\n",
"Dependents 0.0\n",
"tenure 0.0\n",
"PhoneService 0.0\n",
"MultipleLines 0.0\n",
"InternetService 0.0\n",
"OnlineBackup 0.0\n",
"TotalCharges 0.0\n",
"DeviceProtection 0.0\n",
"TechSupport 0.0\n",
"StreamingTV 0.0\n",
"StreamingMovies 0.0\n",
"Contract 0.0\n",
"PaperlessBilling 0.0\n",
"PaymentMethod 0.0\n",
"MonthlyCharges 0.0\n",
"customerID 0.0\n",
"dtype: float64"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(churn_data.isna().sum() / len(churn_data)).sort_values(ascending=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this case, it looks like we're lucky and have a pre-cleaned dataset!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create feature matrix and target vector"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have done some basic preprocessing, the final step is to numericalise the `pandas.DataFrame` and create the feature matrix $X$ and target vector $y$. In previous lessons we created some functions to automate these steps. Below we use fast.ai's utility function `proc_df` to wrap all these steps into a single step:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X, y, nas = proc_df(churn_data, \"Churn\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
customerID
\n",
"
gender
\n",
"
SeniorCitizen
\n",
"
Partner
\n",
"
Dependents
\n",
"
tenure
\n",
"
PhoneService
\n",
"
MultipleLines
\n",
"
InternetService
\n",
"
OnlineSecurity
\n",
"
OnlineBackup
\n",
"
DeviceProtection
\n",
"
TechSupport
\n",
"
StreamingTV
\n",
"
StreamingMovies
\n",
"
Contract
\n",
"
PaperlessBilling
\n",
"
PaymentMethod
\n",
"
MonthlyCharges
\n",
"
TotalCharges
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
5376
\n",
"
1
\n",
"
1
\n",
"
2
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
2
\n",
"
1
\n",
"
1
\n",
"
3
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
2
\n",
"
3
\n",
"
29.85
\n",
"
2506
\n",
"
\n",
"
\n",
"
1
\n",
"
3963
\n",
"
2
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
34
\n",
"
2
\n",
"
1
\n",
"
1
\n",
"
3
\n",
"
1
\n",
"
3
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
2
\n",
"
1
\n",
"
4
\n",
"
56.95
\n",
"
1467
\n",
"
\n",
"
\n",
"
2
\n",
"
2565
\n",
"
2
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
2
\n",
"
2
\n",
"
1
\n",
"
1
\n",
"
3
\n",
"
3
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
2
\n",
"
4
\n",
"
53.85
\n",
"
158
\n",
"
\n",
"
\n",
"
3
\n",
"
5536
\n",
"
2
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
45
\n",
"
1
\n",
"
2
\n",
"
1
\n",
"
3
\n",
"
1
\n",
"
3
\n",
"
3
\n",
"
1
\n",
"
1
\n",
"
2
\n",
"
1
\n",
"
1
\n",
"
42.30
\n",
"
1401
\n",
"
\n",
"
\n",
"
4
\n",
"
6512
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
2
\n",
"
2
\n",
"
1
\n",
"
2
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
1
\n",
"
2
\n",
"
3
\n",
"
70.70
\n",
"
926
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" customerID gender SeniorCitizen Partner Dependents tenure \\\n",
"0 5376 1 1 2 1 1 \n",
"1 3963 2 1 1 1 34 \n",
"2 2565 2 1 1 1 2 \n",
"3 5536 2 1 1 1 45 \n",
"4 6512 1 1 1 1 2 \n",
"\n",
" PhoneService MultipleLines InternetService OnlineSecurity OnlineBackup \\\n",
"0 1 2 1 1 3 \n",
"1 2 1 1 3 1 \n",
"2 2 1 1 3 3 \n",
"3 1 2 1 3 1 \n",
"4 2 1 2 1 1 \n",
"\n",
" DeviceProtection TechSupport StreamingTV StreamingMovies Contract \\\n",
"0 1 1 1 1 1 \n",
"1 3 1 1 1 2 \n",
"2 1 1 1 1 1 \n",
"3 3 3 1 1 2 \n",
"4 1 1 1 1 1 \n",
"\n",
" PaperlessBilling PaymentMethod MonthlyCharges TotalCharges \n",
"0 2 3 29.85 2506 \n",
"1 1 4 56.95 1467 \n",
"2 2 4 53.85 158 \n",
"3 1 1 42.30 1401 \n",
"4 2 3 70.70 926 "
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"display_large(X.head())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((7043, 20), (7043,))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape, y.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For future use we can save our processed quantities:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"churn_processed = X.join(pd.Series(y, name=\"Churn\"))\n",
"\n",
"churn_processed.to_csv(DATA / \"churn_processed.csv\", index=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create train and validation sets"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5634 train rows + 1409 valid rows\n"
]
}
],
"source": [
"X_train, X_valid, y_train, y_valid = train_test_split(\n",
" X, y, test_size=0.2, random_state=42\n",
")\n",
"print(f\"{len(X_train)} train rows + {len(X_valid)} valid rows\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Select a performance measure"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Evaluating classifiers is often significantly trickier than evaluating a regressor. One way to do this is to compare the accuracy of each classifier, where \n",
"\n",
"$$ \\mbox{accuracy} = \\frac{\\mbox{Number of correct decisions made}}{\\mbox{Total number of decisions made}} $$\n",
"\n",
"In general, however, accuracy is _**not**_ the preferred performance measures for classifiers, especially when you are dealing with _**skewed datasets**_ (i.e. when some classes are much more frequent than others). For our churn example, supose we build a model that generates 75% accuracy. Is this any good? Let's have a look at the distribution of churn in the data:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sns.countplot(x=\"Churn\", data=churn_data)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"No 0.73463\n",
"Yes 0.26537\n",
"Name: Churn, dtype: float64"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"churn_data[\"Churn\"].value_counts(normalize=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the plot and numbers we see that the \"No Churn\" and \"Churn\" classes appear in approximately a 3:1 ratio. If we built a dumb classifier that just classifies every single customer as \"No Churn\", then we would be right about 73.5% of the time! In practice skews of 99:1 are common, for which a report of 99% accuracy is somewhat meaningless."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Confusion matrix"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A much better way to evaluate the performance of a classifier is to look at the _**confusion matrix.**_ Recall that a confusion matrix for a problem involving $n$ classes is an $n\\times n$ matrix with the rows labelled by the _**actual**_ classes and the columns with the _**predicted**_ classes. Our churn example is a two-class problem (\"Churn\" vs \"No Churn\"), so the confusion matrix is $2\\times 2$.\n",
"\n",
"If we denote the true classes as $\\mathbf{p}$(positive) and $\\mathbf{n}$(egative), and the classes predicted by the model as $\\mathbf{Y}$(es) and $\\mathbf{N}$(o) then the confusion matrix has the form:\n",
"\n",
"| | **N** | **Y** | \n",
"| :---: | :---: | :---: |\n",
"| **n** | True negatives | False positives | \n",
"| **p** | False negatives | True positives |\n",
"\n",
"The main diagonal contains the counts of correct decisions. The errors of the classifier are the _**false negatives**_ (positives classified as negative) and **false positives** (negatives classified as positive)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"**You should know**\n",
"\n",
"In the _Data Science for Business_ textbook, the confusion matrix is given a different layout, namely the columns a relabelled by the actual classes and the rows by the predicted classes: \n",
"\n",
"| | **p** | **n** | \n",
"| :---: | :---: | :---: |\n",
"| **Y** | True positives | False positives | \n",
"| **P** | False negatives | True negatives |\n",
"\n",
"The layout adopted above and in this notebook is the one produced by scikit-learn, since we'd like to make use of the in-built functions included in that library.\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Baseline model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's begin by creating a baseline Random Forest model to build upon. First we need a scoring function for classifiers, similar to our $R^2$ and RMSE function for regression:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def print_scores(fitted_model):\n",
" res = {\n",
" \"Accuracy on train:\": accuracy_score(fitted_model.predict(X_train), y_train),\n",
" \"ROC AUC on train:\": roc_auc_score(\n",
" y_train, fitted_model.predict_proba(X_train)[:, 1]\n",
" ),\n",
" \"Accuracy on valid:\": accuracy_score(fitted_model.predict(X_valid), y_valid),\n",
" \"ROC AUC on valid:\": roc_auc_score(\n",
" y_valid, fitted_model.predict_proba(X_valid)[:, 1]\n",
" ),\n",
" }\n",
" if hasattr(fitted_model, \"oob_score_\"):\n",
" res[\"OOB accuracy:\"] = fitted_model.oob_score_\n",
"\n",
" for k, v in res.items():\n",
" print(k, round(v, 3))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,\n",
" criterion='gini', max_depth=None, max_features='auto',\n",
" max_leaf_nodes=None, max_samples=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=1, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=-1,\n",
" oob_score=False, random_state=42, verbose=0,\n",
" warm_start=False)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = RandomForestClassifier(n_estimators=10, n_jobs=-1, random_state=42)\n",
"model.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy on train: 0.984\n",
"ROC AUC on train: 0.999\n",
"Accuracy on valid: 0.771\n",
"ROC AUC on valid: 0.801\n"
]
}
],
"source": [
"print_scores(model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ROC Curves and AUC\n",
"Note that in addition to accuracy, we also show a second value, the _**Area Under the ROC Curve**_ (AUC). The AUC is a good summary statistic of the predictiveness of a binary classifier. It varies from zero to one. A value of 0.5 corresponds to randomness (the classifier cannot distinguish at all between \"churn\" and \"no churn\") and a value of 1.0 means that it is perfect.\n",
"\n",
"The \"ROC\" refers to the Receiver Operating Characteristic (ROC) curve which plots the _true positive rate_ \n",
"\n",
"$$ \\mbox{TPR} = \\frac{\\mbox{TP}}{\\mbox{TP} + \\mbox{FP}} \\,, \\qquad \\mbox{TP (FP)} = \\mbox{number of true (false) positives}\\,,$$\n",
"\n",
"against the _false positive rate_ FPR, where the FPR is the ratio of negative instances that are incorrectly classified as positive. In general there is a tradeoff between these two quantities: the higher the TPR, the more false positives (FPR) the classifier produces. A good classifiers stays as close to the top-left corner of a ROC curve plot as possible.\n",
"\n",
"In scikit-learn we can visualise the ROC curve of an estimator using the plotting API:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# extract labels for classes\n",
"class_names = churn_data[\"Churn\"].cat.categories\n",
"\n",
"plot_confusion_matrix(\n",
" model,\n",
" X_valid,\n",
" y_valid,\n",
" display_labels=class_names,\n",
" cmap=plt.cm.Blues,\n",
" normalize=\"true\",\n",
")\n",
"plt.grid(None)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"From the confusion matrix we see that our baseline model is able to identify churners only 42% of the time and incorrectly classifies people who churned 58% of the time. We can do better, but first let's inspect how a single decision tree is making decisions on this data."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Single tree"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have a baseline model, let's make a single tree so we can gain some insight into how the decisions are being made:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RandomForestClassifier(bootstrap=False, ccp_alpha=0.0, class_weight=None,\n",
" criterion='gini', max_depth=3, max_features='auto',\n",
" max_leaf_nodes=None, max_samples=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=1, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, n_estimators=1, n_jobs=-1,\n",
" oob_score=False, random_state=42, verbose=0,\n",
" warm_start=False)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = RandomForestClassifier(\n",
" n_estimators=1, max_depth=3, bootstrap=False, n_jobs=-1, random_state=42\n",
")\n",
"model.fit(X_train, y_train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# get column names\n",
"feature_names = X_train.columns\n",
"# we need to specify the background color because of a bug in sklearn\n",
"fig, ax = plt.subplots(figsize=(30, 10), facecolor=\"k\")\n",
"# generate tree plot\n",
"plot_tree(model.estimators_[0], filled=True, feature_names=feature_names, ax=ax)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The main difference here compared to our housing example is that the splitting criterion is no longer the mean squared error, but instead is something known as the **_Gini index_**:\n",
"\n",
"$$ G = 1 - \\sum_{i=1}^n p_i^2 $$\n",
"\n",
"where $p_i$ is the probability of an object being classified to a particular class (in our case \"Yes\" or \"No\"). For classification tasks, the goal is to _minimise_ the Gini index across each split, which amounts to finding which segments are most \"pure\".\n",
"\n",
"From the figure, we can already start seeing some features that might be interesting for predicting churn, e.g. `TotalCharges`, `tenure`, and `TechSupport` seem like good indicators."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Hyperparameter tuning"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our baseline model has an accuracy of 0.771 and ROC AUC score of 0.801 on the validation set. Let's now examine whether we can improve this by tuning the:\n",
"\n",
"* number of trees in the forest\n",
"* minimum number of samples per leaf\n",
"* maximum number of features per split\n",
"\n",
"In our previous lessons we manually inspected how the performance evolved when we changed these hyperparameters one at a time. Instead we can automate this process using scikit-learn's `GridSearchCV` to search for the best combination of hyperparameter values:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import GridSearchCV\n",
"\n",
"# define range of values for each hyperparameter\n",
"param_grid = [\n",
" {\n",
" \"n_estimators\": [10, 20, 40, 80, 100],\n",
" \"max_features\": [0.5, 1.0, \"sqrt\", \"log2\"],\n",
" \"min_samples_leaf\": [1, 3, 5, 10, 25],\n",
" }\n",
"]\n",
"\n",
"# instantiate baseline model\n",
"model = RandomForestClassifier(n_estimators=10, n_jobs=-1, random_state=42)\n",
"\n",
"# initialise grid search with cross-validation\n",
"grid_search = GridSearchCV(\n",
" model, param_grid=param_grid, cv=3, scoring=\"roc_auc\", n_jobs=-1\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.15 s, sys: 285 ms, total: 2.44 s\n",
"Wall time: 17.3 s\n"
]
},
{
"data": {
"text/plain": [
"GridSearchCV(cv=3, error_score=nan,\n",
" estimator=RandomForestClassifier(bootstrap=True, ccp_alpha=0.0,\n",
" class_weight=None,\n",
" criterion='gini', max_depth=None,\n",
" max_features='auto',\n",
" max_leaf_nodes=None,\n",
" max_samples=None,\n",
" min_impurity_decrease=0.0,\n",
" min_impurity_split=None,\n",
" min_samples_leaf=1,\n",
" min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0,\n",
" n_estimators=10, n_jobs=-1,\n",
" oob_score=False, random_state=42,\n",
" verbose=0, warm_start=False),\n",
" iid='deprecated', n_jobs=-1,\n",
" param_grid=[{'max_features': [0.5, 1.0, 'sqrt', 'log2'],\n",
" 'min_samples_leaf': [1, 3, 5, 10, 25],\n",
" 'n_estimators': [10, 20, 40, 80, 100]}],\n",
" pre_dispatch='2*n_jobs', refit=True, return_train_score=False,\n",
" scoring='roc_auc', verbose=0)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time grid_search.fit(X, y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Once the search is finished, we can get the best combination of parameters as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'max_features': 'sqrt', 'min_samples_leaf': 25, 'n_estimators': 80}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best_params = grid_search.best_params_\n",
"best_params"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similarly we can get the best model in the search:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,\n",
" criterion='gini', max_depth=None, max_features='sqrt',\n",
" max_leaf_nodes=None, max_samples=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=25, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, n_estimators=80, n_jobs=-1,\n",
" oob_score=False, random_state=42, verbose=0,\n",
" warm_start=False)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best_model = grid_search.best_estimator_\n",
"best_model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's see how this model performs on our validation set in terms of metrics and the confusion matrix:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy on train: 0.819\n",
"ROC AUC on train: 0.876\n",
"Accuracy on valid: 0.828\n",
"ROC AUC on valid: 0.891\n"
]
}
],
"source": [
"print_scores(best_model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plot_confusion_matrix(\n",
" best_model,\n",
" X_valid,\n",
" y_valid,\n",
" display_labels=class_names,\n",
" cmap=plt.cm.Blues,\n",
" normalize=\"true\",\n",
")\n",
"plt.grid(None)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In terms of the AUC score, we see about a 10% boost over our baseline model - not bad! Our confusion matrix has also visibly improved."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model interpretability"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we did in the housing example, we now examine which features were deemed to be important for our Random Forest model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"