{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Risk Models Using Tree-based Models\n", "\n", "Welcome to the second assignment of Course 2!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Outline\n", "\n", "- [1. Import Packages](#1)\n", "- [2. Load the Dataset](#2)\n", "- [3. Explore the Dataset](#3)\n", "- [4. Dealing with Missing Data](#4)\n", " - [Exercise 1](#Ex-1)\n", "- [5. Decision Trees](#5)\n", " - [Exercise 2](#Ex-2)\n", "- [6. Random Forests](#6)\n", " - [Exercise 3](#Ex-3)\n", "- [7. Imputation](#7)\n", "- [8. Error Analysis](#8)\n", " - [Exercise 4](#Ex-4)\n", "- [9. Imputation Approaches](#Ex-9)\n", " - [Exercise 5](#Ex-5)\n", " - [Exercise 6](#Ex-6)\n", "- [10. Comparison](#10)\n", "- [11. Explanations: SHAP](#)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "1id9x6FmKclN" }, "source": [ "In this assignment, you'll gain experience with tree based models by predicting the 10-year risk of death of individuals from the NHANES I epidemiology dataset (for a detailed description of this dataset you can check the [CDC Website](https://wwwn.cdc.gov/nchs/nhanes/nhefs/default.aspx/)). This is a challenging task and a great test bed for the machine learning methods we learned this week.\n", "\n", "As you go through the assignment, you'll learn about: \n", "\n", "- Dealing with Missing Data\n", " - Complete Case Analysis.\n", " - Imputation\n", "- Decision Trees\n", " - Evaluation.\n", " - Regularization.\n", "- Random Forests \n", " - Hyperparameter Tuning." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "F6k2pItifWeK" }, "source": [ "\n", "## 1. Import Packages\n", "\n", "We'll first import all the common packages that we need for this assignment. \n", "\n", "- `shap` is a library that explains predictions made by machine learning models.\n", "- `sklearn` is one of the most popular machine learning libraries.\n", "- `itertools` allows us to conveniently manipulate iterable objects such as lists.\n", "- `pydotplus` is used together with `IPython.display.Image` to visualize graph structures such as decision trees.\n", "- `numpy` is a fundamental package for scientific computing in Python.\n", "- `pandas` is what we'll use to manipulate our data.\n", "- `seaborn` is a plotting library which has some convenient functions for visualizing missing data.\n", "- `matplotlib` is a plotting library." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "V5s0iQ82okBv" }, "outputs": [], "source": [ "import shap\n", "import sklearn\n", "import itertools\n", "import pydotplus\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "from IPython.display import Image \n", "\n", "from sklearn.tree import export_graphviz\n", "from sklearn.externals.six import StringIO\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.experimental import enable_iterative_imputer\n", "from sklearn.impute import IterativeImputer, SimpleImputer\n", "\n", "# We'll also import some helper functions that will be useful later on.\n", "from util import load_data, cindex" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "YckMl2bwg5Hb" }, "source": [ "\n", "## 2. Load the Dataset" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "xrqlr_ZQhnr4" }, "source": [ "Run the next cell to load in the NHANES I epidemiology dataset. This dataset contains various features of hospital patients as well as their outcomes, i.e. whether or not they died within 10 years." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 122 }, "colab_type": "code", "id": "iM2qfgvUs9c_", "outputId": "53895f4d-48f8-429f-b447-e175a80472d9" }, "outputs": [], "source": [ "X_dev, X_test, y_dev, y_test = load_data(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dataset has been split into a development set (or dev set), which we will use to develop our risk models, and a test set, which we will use to test our models.\n", "\n", "We further split the dev set into a training and validation set, respectively to train and tune our models, using a 75/25 split (note that we set a random state to make this split repeatable)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "X_train, X_val, y_train, y_val = train_test_split(X_dev, y_dev, test_size=0.25, random_state=10)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "L6ijpFDIx_I6" }, "source": [ "\n", "## 3. Explore the Dataset\n", "\n", "The first step is to familiarize yourself with the data. Run the next cell to get the size of your training set and look at a small sample. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 275 }, "colab_type": "code", "id": "V4gvn20Gx-pF", "outputId": "b8e98069-70a6-425c-b26e-fc18571c2233" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X_train shape: (5147, 18)\n" ] }, { "data": { "text/html": [ "
\n", " | Age | \n", "Diastolic BP | \n", "Poverty index | \n", "Race | \n", "Red blood cells | \n", "Sedimentation rate | \n", "Serum Albumin | \n", "Serum Cholesterol | \n", "Serum Iron | \n", "Serum Magnesium | \n", "Serum Protein | \n", "Sex | \n", "Systolic BP | \n", "TIBC | \n", "TS | \n", "White blood cells | \n", "BMI | \n", "Pulse pressure | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1599 | \n", "43.0 | \n", "84.0 | \n", "637.0 | \n", "1.0 | \n", "49.3 | \n", "10.0 | \n", "5.0 | \n", "253.0 | \n", "134.0 | \n", "1.59 | \n", "7.7 | \n", "1.0 | \n", "NaN | \n", "490.0 | \n", "27.3 | \n", "9.1 | \n", "25.803007 | \n", "34.0 | \n", "
2794 | \n", "72.0 | \n", "96.0 | \n", "154.0 | \n", "2.0 | \n", "43.4 | \n", "23.0 | \n", "4.3 | \n", "265.0 | \n", "106.0 | \n", "1.66 | \n", "6.8 | \n", "2.0 | \n", "208.0 | \n", "301.0 | \n", "35.2 | \n", "6.0 | \n", "33.394319 | \n", "112.0 | \n", "
1182 | \n", "54.0 | \n", "78.0 | \n", "205.0 | \n", "1.0 | \n", "43.8 | \n", "12.0 | \n", "4.2 | \n", "206.0 | \n", "180.0 | \n", "1.67 | \n", "6.6 | \n", "2.0 | \n", "NaN | \n", "363.0 | \n", "49.6 | \n", "5.9 | \n", "20.278410 | \n", "34.0 | \n", "
6915 | \n", "59.0 | \n", "90.0 | \n", "417.0 | \n", "1.0 | \n", "43.4 | \n", "9.0 | \n", "4.5 | \n", "327.0 | \n", "114.0 | \n", "1.65 | \n", "7.6 | \n", "2.0 | \n", "NaN | \n", "347.0 | \n", "32.9 | \n", "6.1 | \n", "32.917744 | \n", "78.0 | \n", "
500 | \n", "34.0 | \n", "80.0 | \n", "385.0 | \n", "1.0 | \n", "77.7 | \n", "9.0 | \n", "4.1 | \n", "197.0 | \n", "64.0 | \n", "1.74 | \n", "7.3 | \n", "2.0 | \n", "NaN | \n", "376.0 | \n", "17.0 | \n", "8.2 | \n", "30.743489 | \n", "30.0 | \n", "
\n", "
pandas.DataFrame.isnull()
method is helpful in this case.pandas.DataFrame.any()
method and set the axis
parameter.True
values are equal to 1.\n", "
'max_depth'
).\n", "
\n", "
mask = X_test['BMI'] < 20
. \n", "
\n", "
\n", " | Age | \n", "Diastolic BP | \n", "Poverty index | \n", "Race | \n", "Red blood cells | \n", "Sedimentation rate | \n", "Serum Albumin | \n", "Serum Cholesterol | \n", "Serum Iron | \n", "Serum Magnesium | \n", "Serum Protein | \n", "Sex | \n", "Systolic BP | \n", "TIBC | \n", "TS | \n", "White blood cells | \n", "BMI | \n", "Pulse pressure | \n", "risk | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
5493 | \n", "67.0 | \n", "80.0 | \n", "30.0 | \n", "1.0 | \n", "77.7 | \n", "59.0 | \n", "3.4 | \n", "231.0 | \n", "36.0 | \n", "1.40 | \n", "6.3 | \n", "1.0 | \n", "170.0 | \n", "202.0 | \n", "17.8 | \n", "8.4 | \n", "17.029470 | \n", "90.0 | \n", "0.619022 | \n", "
1017 | \n", "65.0 | \n", "98.0 | \n", "16.0 | \n", "1.0 | \n", "49.4 | \n", "30.0 | \n", "3.4 | \n", "124.0 | \n", "129.0 | \n", "1.59 | \n", "7.7 | \n", "1.0 | \n", "184.0 | \n", "293.0 | \n", "44.0 | \n", "5.9 | \n", "30.858853 | \n", "86.0 | \n", "0.545443 | \n", "
2050 | \n", "66.0 | \n", "100.0 | \n", "69.0 | \n", "2.0 | \n", "42.9 | \n", "47.0 | \n", "3.8 | \n", "233.0 | \n", "170.0 | \n", "1.42 | \n", "8.6 | \n", "1.0 | \n", "180.0 | \n", "411.0 | \n", "41.4 | \n", "7.2 | \n", "22.129498 | \n", "80.0 | \n", "0.527768 | \n", "
6337 | \n", "69.0 | \n", "80.0 | \n", "233.0 | \n", "1.0 | \n", "77.7 | \n", "48.0 | \n", "4.2 | \n", "159.0 | \n", "87.0 | \n", "1.81 | \n", "6.9 | \n", "1.0 | \n", "146.0 | \n", "291.0 | \n", "29.9 | \n", "15.2 | \n", "17.931276 | \n", "66.0 | \n", "0.526019 | \n", "
2608 | \n", "71.0 | \n", "80.0 | \n", "104.0 | \n", "1.0 | \n", "43.8 | \n", "23.0 | \n", "4.0 | \n", "201.0 | \n", "119.0 | \n", "1.60 | \n", "7.0 | \n", "1.0 | \n", "166.0 | \n", "311.0 | \n", "38.3 | \n", "6.3 | \n", "17.760766 | \n", "86.0 | \n", "0.525624 | \n", "