{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "EZqu9a_ChWYv" }, "source": [ "# Cox Proportional Hazards and Random Survival Forests\n", "\n", "Welcome to the final assignment in Course 2! In this assignment you'll develop risk models using survival data and a combination of linear and non-linear techniques. We'll be using a dataset with survival data of patients with Primary Biliary Cirrhosis (pbc). PBC is a progressive disease of the liver caused by a buildup of bile within the liver (cholestasis) that results in damage to the small bile ducts that drain bile from the liver. Our goal will be to understand the effects of different factors on the survival times of the patients. Along the way you'll learn about the following topics: \n", "\n", "- Cox Proportional Hazards\n", " - Data Preprocessing for Cox Models.\n", "- Random Survival Forests\n", " - Permutation Methods for Interpretation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Outline\n", "\n", "- [1. Import Packages](#1)\n", "- [2. Load the Dataset](#2)\n", "- [3. Explore the Dataset](#3)\n", "- [4. Cox Proportional Hazards](#4)\n", " - [Exercise 1](#Ex-1)\n", "- [5. Fitting and Interpreting a Cox Model](#5)\n", "- [6. Hazard ratio](#3)\n", " - [Exercise 2](#Ex-2)\n", "- [7. Harrell's C-Index](#7)\n", " - [Exercise 3](#Ex-3)\n", "- [8. Random Survival Forests](#8)\n", "- [9. Permutation Method for Interpretation](#9)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "IH0ukiNS3zG-" }, "source": [ "<a name='1'></a>\n", "## 1. Import Packages\n", "\n", "We'll first import all the packages that we need for this assignment. \n", "\n", "- `sklearn` is one of the most popular machine learning libraries.\n", "- `numpy` is the fundamental package for scientific computing in python.\n", "- `pandas` is what we'll use to manipulate our data.\n", "- `matplotlib` is a plotting library.\n", "- `lifelines` is an open-source survival analysis library." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "0JHzRJaQi_nU" }, "outputs": [], "source": [ "import sklearn\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "from lifelines import CoxPHFitter\n", "from lifelines.utils import concordance_index as cindex\n", "from sklearn.model_selection import train_test_split\n", "\n", "from util import load_data" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "vZMwq0VfW5TW" }, "source": [ "<a name='2'></a>\n", "## 2. Load the Dataset\n", "\n", "Run the next cell to load the data." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "df = load_data()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<a name='3'></a>\n", "## 3. Explore the Dataset\n", "\n", "In the lecture videos `time` was in months, however in this assignment, `time` will be converted into years. Also notice that we have assigned a numeric value to `sex`, where `female = 0` and `male = 1`.\n", "\n", "Next, familiarize yourself with the data and the shape of it. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 241 }, "colab_type": "code", "id": "T1a_aHGmXT_C", "outputId": "1bbcf6d9-f293-49f4-963a-827c8e79813b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(258, 19)\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>time</th>\n", " <th>status</th>\n", " <th>trt</th>\n", " <th>age</th>\n", " <th>sex</th>\n", " <th>ascites</th>\n", " <th>hepato</th>\n", " <th>spiders</th>\n", " <th>edema</th>\n", " <th>bili</th>\n", " <th>chol</th>\n", " <th>albumin</th>\n", " <th>copper</th>\n", " <th>alk.phos</th>\n", " <th>ast</th>\n", " <th>trig</th>\n", " <th>platelet</th>\n", " <th>protime</th>\n", " <th>stage</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>1.095890</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>58.765229</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>14.5</td>\n", " <td>261.0</td>\n", " <td>2.60</td>\n", " <td>156.0</td>\n", " <td>1718.0</td>\n", " <td>137.95</td>\n", " <td>172.0</td>\n", " <td>190.0</td>\n", " <td>12.2</td>\n", " <td>4.0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>12.328767</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>56.446270</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>1.1</td>\n", " <td>302.0</td>\n", " <td>4.14</td>\n", " <td>54.0</td>\n", " <td>7394.8</td>\n", " <td>113.52</td>\n", " <td>88.0</td>\n", " <td>221.0</td>\n", " <td>10.6</td>\n", " <td>3.0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>2.772603</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>70.072553</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.5</td>\n", " <td>1.4</td>\n", " <td>176.0</td>\n", " <td>3.48</td>\n", " <td>210.0</td>\n", " <td>516.0</td>\n", " <td>96.10</td>\n", " <td>55.0</td>\n", " <td>151.0</td>\n", " <td>12.0</td>\n", " <td>4.0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>5.273973</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>54.740589</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.5</td>\n", " <td>1.8</td>\n", " <td>244.0</td>\n", " <td>2.54</td>\n", " <td>64.0</td>\n", " <td>6121.8</td>\n", " <td>60.63</td>\n", " <td>92.0</td>\n", " <td>183.0</td>\n", " <td>10.3</td>\n", " <td>4.0</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>5.019178</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>55.534565</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>322.0</td>\n", " <td>4.09</td>\n", " <td>52.0</td>\n", " <td>824.0</td>\n", " <td>60.45</td>\n", " <td>213.0</td>\n", " <td>204.0</td>\n", " <td>9.7</td>\n", " <td>3.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " time status trt age sex ascites hepato spiders edema \\\n", "0 1.095890 1.0 0.0 58.765229 0.0 1.0 1.0 1.0 1.0 \n", "1 12.328767 0.0 0.0 56.446270 0.0 0.0 1.0 1.0 0.0 \n", "2 2.772603 1.0 0.0 70.072553 1.0 0.0 0.0 0.0 0.5 \n", "3 5.273973 1.0 0.0 54.740589 0.0 0.0 1.0 1.0 0.5 \n", "6 5.019178 0.0 1.0 55.534565 0.0 0.0 1.0 0.0 0.0 \n", "\n", " bili chol albumin copper alk.phos ast trig platelet protime \\\n", "0 14.5 261.0 2.60 156.0 1718.0 137.95 172.0 190.0 12.2 \n", "1 1.1 302.0 4.14 54.0 7394.8 113.52 88.0 221.0 10.6 \n", "2 1.4 176.0 3.48 210.0 516.0 96.10 55.0 151.0 12.0 \n", "3 1.8 244.0 2.54 64.0 6121.8 60.63 92.0 183.0 10.3 \n", "6 1.0 322.0 4.09 52.0 824.0 60.45 213.0 204.0 9.7 \n", "\n", " stage \n", "0 4.0 \n", "1 3.0 \n", "2 4.0 \n", "3 4.0 \n", "6 3.0 " ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(df.shape)\n", "\n", "# df.head() only outputs the top few rows\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Zy5BmjCV-Uo2" }, "source": [ "Take a minute to examine particular cases." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 357 }, "colab_type": "code", "id": "01I3ChzL-T-f", "outputId": "68e209dc-7a44-434b-d44c-4a1e817ee6ca" }, "outputs": [ { "data": { "text/plain": [ "time 11.175342\n", "status 1.000000\n", "trt 0.000000\n", "age 44.520192\n", "sex 1.000000\n", "ascites 0.000000\n", "hepato 1.000000\n", "spiders 0.000000\n", "edema 0.000000\n", "bili 2.100000\n", "chol 456.000000\n", "albumin 4.000000\n", "copper 124.000000\n", "alk.phos 5719.000000\n", "ast 221.880000\n", "trig 230.000000\n", "platelet 70.000000\n", "protime 9.900000\n", "stage 2.000000\n", "Name: 23, dtype: float64" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "i = 20\n", "df.iloc[i, :]" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "pYZKl_9Tk2vS" }, "source": [ "Now, split your dataset into train, validation and test set using 60/20/20 split. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "id": "V4HJSZaMk1xG" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of patients: 258\n", "Total number of patients in training set: 154\n", "Total number of patients in validation set: 52\n", "Total number of patients in test set: 52\n" ] } ], "source": [ "np.random.seed(0)\n", "df_dev, df_test = train_test_split(df, test_size = 0.2)\n", "df_train, df_val = train_test_split(df_dev, test_size = 0.25)\n", "\n", "print(\"Total number of patients:\", df.shape[0])\n", "print(\"Total number of patients in training set:\", df_train.shape[0])\n", "print(\"Total number of patients in validation set:\", df_val.shape[0])\n", "print(\"Total number of patients in test set:\", df_test.shape[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before proceeding to modeling, let's normalize the continuous covariates to make sure they're on the same scale. Again, we should normalize the test data using statistics from the train data." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "continuous_columns = ['age', 'bili', 'chol', 'albumin', 'copper', 'alk.phos', 'ast', 'trig', 'platelet', 'protime']\n", "mean = df_train.loc[:, continuous_columns].mean()\n", "std = df_train.loc[:, continuous_columns].std()\n", "df_train.loc[:, continuous_columns] = (df_train.loc[:, continuous_columns] - mean) / std\n", "df_val.loc[:, continuous_columns] = (df_val.loc[:, continuous_columns] - mean) / std\n", "df_test.loc[:, continuous_columns] = (df_test.loc[:, continuous_columns] - mean) / std" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check the summary statistics on our training dataset to make sure it's standardized." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>age</th>\n", " <th>bili</th>\n", " <th>chol</th>\n", " <th>albumin</th>\n", " <th>copper</th>\n", " <th>alk.phos</th>\n", " <th>ast</th>\n", " <th>trig</th>\n", " <th>platelet</th>\n", " <th>protime</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>count</th>\n", " <td>1.540000e+02</td>\n", " <td>1.540000e+02</td>\n", " <td>1.540000e+02</td>\n", " <td>1.540000e+02</td>\n", " <td>1.540000e+02</td>\n", " <td>1.540000e+02</td>\n", " <td>1.540000e+02</td>\n", " <td>1.540000e+02</td>\n", " <td>1.540000e+02</td>\n", " <td>1.540000e+02</td>\n", " </tr>\n", " <tr>\n", " <th>mean</th>\n", " <td>9.833404e-16</td>\n", " <td>-3.258577e-16</td>\n", " <td>1.153478e-16</td>\n", " <td>1.153478e-16</td>\n", " <td>5.767392e-18</td>\n", " <td>1.326500e-16</td>\n", " <td>-1.263059e-15</td>\n", " <td>8.074349e-17</td>\n", " <td>2.018587e-17</td>\n", " <td>1.291896e-14</td>\n", " </tr>\n", " <tr>\n", " <th>std</th>\n", " <td>1.000000e+00</td>\n", " <td>1.000000e+00</td>\n", " <td>1.000000e+00</td>\n", " <td>1.000000e+00</td>\n", " <td>1.000000e+00</td>\n", " <td>1.000000e+00</td>\n", " <td>1.000000e+00</td>\n", " <td>1.000000e+00</td>\n", " <td>1.000000e+00</td>\n", " <td>1.000000e+00</td>\n", " </tr>\n", " <tr>\n", " <th>min</th>\n", " <td>-2.304107e+00</td>\n", " <td>-5.735172e-01</td>\n", " <td>-1.115330e+00</td>\n", " <td>-3.738104e+00</td>\n", " <td>-9.856552e-01</td>\n", " <td>-7.882167e-01</td>\n", " <td>-1.489281e+00</td>\n", " <td>-1.226674e+00</td>\n", " <td>-2.058899e+00</td>\n", " <td>-1.735556e+00</td>\n", " </tr>\n", " <tr>\n", " <th>25%</th>\n", " <td>-6.535035e-01</td>\n", " <td>-4.895812e-01</td>\n", " <td>-5.186963e-01</td>\n", " <td>-5.697976e-01</td>\n", " <td>-6.470611e-01</td>\n", " <td>-5.186471e-01</td>\n", " <td>-8.353982e-01</td>\n", " <td>-6.884514e-01</td>\n", " <td>-6.399831e-01</td>\n", " <td>-7.382590e-01</td>\n", " </tr>\n", " <tr>\n", " <th>50%</th>\n", " <td>-6.443852e-03</td>\n", " <td>-3.846612e-01</td>\n", " <td>-2.576693e-01</td>\n", " <td>5.663556e-02</td>\n", " <td>-3.140636e-01</td>\n", " <td>-3.416086e-01</td>\n", " <td>-2.260984e-01</td>\n", " <td>-2.495932e-01</td>\n", " <td>-4.100373e-02</td>\n", " <td>-1.398807e-01</td>\n", " </tr>\n", " <tr>\n", " <th>75%</th>\n", " <td>5.724289e-01</td>\n", " <td>2.977275e-02</td>\n", " <td>1.798617e-01</td>\n", " <td>6.890921e-01</td>\n", " <td>3.435366e-01</td>\n", " <td>-4.620597e-03</td>\n", " <td>6.061159e-01</td>\n", " <td>3.755727e-01</td>\n", " <td>6.617988e-01</td>\n", " <td>3.587680e-01</td>\n", " </tr>\n", " <tr>\n", " <th>max</th>\n", " <td>2.654276e+00</td>\n", " <td>5.239050e+00</td>\n", " <td>6.243146e+00</td>\n", " <td>2.140730e+00</td>\n", " <td>5.495204e+00</td>\n", " <td>4.869263e+00</td>\n", " <td>3.058176e+00</td>\n", " <td>5.165751e+00</td>\n", " <td>3.190823e+00</td>\n", " <td>4.447687e+00</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " age bili chol albumin copper \\\n", "count 1.540000e+02 1.540000e+02 1.540000e+02 1.540000e+02 1.540000e+02 \n", "mean 9.833404e-16 -3.258577e-16 1.153478e-16 1.153478e-16 5.767392e-18 \n", "std 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00 \n", "min -2.304107e+00 -5.735172e-01 -1.115330e+00 -3.738104e+00 -9.856552e-01 \n", "25% -6.535035e-01 -4.895812e-01 -5.186963e-01 -5.697976e-01 -6.470611e-01 \n", "50% -6.443852e-03 -3.846612e-01 -2.576693e-01 5.663556e-02 -3.140636e-01 \n", "75% 5.724289e-01 2.977275e-02 1.798617e-01 6.890921e-01 3.435366e-01 \n", "max 2.654276e+00 5.239050e+00 6.243146e+00 2.140730e+00 5.495204e+00 \n", "\n", " alk.phos ast trig platelet protime \n", "count 1.540000e+02 1.540000e+02 1.540000e+02 1.540000e+02 1.540000e+02 \n", "mean 1.326500e-16 -1.263059e-15 8.074349e-17 2.018587e-17 1.291896e-14 \n", "std 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00 1.000000e+00 \n", "min -7.882167e-01 -1.489281e+00 -1.226674e+00 -2.058899e+00 -1.735556e+00 \n", "25% -5.186471e-01 -8.353982e-01 -6.884514e-01 -6.399831e-01 -7.382590e-01 \n", "50% -3.416086e-01 -2.260984e-01 -2.495932e-01 -4.100373e-02 -1.398807e-01 \n", "75% -4.620597e-03 6.061159e-01 3.755727e-01 6.617988e-01 3.587680e-01 \n", "max 4.869263e+00 3.058176e+00 5.165751e+00 3.190823e+00 4.447687e+00 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.loc[:, continuous_columns].describe()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "BX3woHz-jit1" }, "source": [ "<a name='4'></a>\n", "## 4. Cox Proportional Hazards\n", "\n", "Our goal is to build a risk score using the survival data that we have. We'll begin by fitting a Cox Proportional Hazards model to your data.\n", "\n", "Recall that the Cox Proportional Hazards model describes the hazard for an individual $i$ at time $t$ as \n", "\n", "$$\n", "\\lambda(t, x) = \\lambda_0(t)e^{\\theta^T X_i}\n", "$$\n", "\n", "The $\\lambda_0$ term is a baseline hazard and incorporates the risk over time, and the other term incorporates the risk due to the individual's covariates. After fitting the model, we can rank individuals using the person-dependent risk term $e^{\\theta^T X_i}$. \n", "\n", "Categorical variables cannot be used in a regression model as they are. In order to use them, conversion to a series of variables is required.\n", "\n", "Since our data has a mix of categorical (`stage`) and continuous (`wblc`) variables, before we proceed further we need to do some data engineering. To tackle the issue at hand we'll be using the `Dummy Coding` technique. In order to use Cox Proportional Hazards, we will have to turn the categorical data into one hot features so that we can fit our Cox model. Luckily, Pandas has a built-in function called `get_dummies` that will make it easier for us to implement our function. It turns categorical features into multiple binary features.\n", "\n", "<img src=\"1-hot-encode.png\" style=\"padding-top: 5px;width: 60%;left: 0px;margin-left: 150px;margin-right: 0px;\">\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<a name='Ex-1'></a>\n", "### Exercise 1\n", "In the cell below, implement the `to_one_hot(...)` function." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<details> \n", "<summary>\n", " <font size=\"3\" color=\"darkgreen\"><b>Hints</b></font>\n", "</summary>\n", "<p>\n", "<ul>\n", " <li>Remember to drop the first dummy for each each category to avoid convergence issues when fitting the proportional hazards model.</li>\n", " <li> Check out the <a href=\"https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.get_dummies.html\" > get_dummies() </a> documentation. </li>\n", " <li>Use <code>dtype=np.float64</code>.</li>\n", "</ul>\n", "</p>" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": {}, "colab_type": "code", "id": "VMzvx0xF_C3I" }, "outputs": [], "source": [ "# UNQ_C1 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n", "def to_one_hot(dataframe, columns):\n", " '''\n", " Convert columns in dataframe to one-hot encoding.\n", " Args:\n", " dataframe (dataframe): pandas dataframe containing covariates\n", " columns (list of strings): list categorical column names to one hot encode\n", " Returns:\n", " one_hot_df (dataframe): dataframe with categorical columns encoded\n", " as binary variables\n", " '''\n", " \n", " ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###\n", " \n", " one_hot_df = pd.get_dummies(dataframe, columns = columns, drop_first = True, dtype=np.float64)\n", " \n", " ### END CODE HERE ###\n", " \n", " return one_hot_df" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rM2tIzvG_ifc" }, "source": [ "Now we'll use the function you coded to transform the training, validation, and test sets." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": {}, "colab_type": "code", "id": "SGZfLeup_fUL" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['time', 'status', 'trt', 'age', 'sex', 'ascites', 'hepato', 'spiders', 'bili', 'chol', 'albumin', 'copper', 'alk.phos', 'ast', 'trig', 'platelet', 'protime', 'edema_0.5', 'edema_1.0', 'stage_2.0', 'stage_3.0', 'stage_4.0']\n", "There are 22 columns\n" ] } ], "source": [ "# List of categorical columns\n", "to_encode = ['edema', 'stage']\n", "\n", "one_hot_train = to_one_hot(df_train, to_encode)\n", "one_hot_val = to_one_hot(df_val, to_encode)\n", "one_hot_test = to_one_hot(df_test, to_encode)\n", "\n", "print(one_hot_val.columns.tolist())\n", "print(f\"There are {len(one_hot_val.columns)} columns\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Expected output\n", "```Python\n", "['time', 'status', 'trt', 'age', 'sex', 'ascites', 'hepato', 'spiders', 'bili', 'chol', 'albumin', 'copper', 'alk.phos', 'ast', 'trig', 'platelet', 'protime', 'edema_0.5', 'edema_1.0', 'stage_2.0', 'stage_3.0', 'stage_4.0']\n", "There are 22 columns\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Look for new features\n", "Now, let's take a peek at one of the transformed data sets. Do you notice any new features?" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 241 }, "colab_type": "code", "id": "w8EG8A9gXcpu", "outputId": "384d9ade-2c96-4979-d3b7-da2b8e50f2e0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(154, 22)\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>time</th>\n", " <th>status</th>\n", " <th>trt</th>\n", " <th>age</th>\n", " <th>sex</th>\n", " <th>ascites</th>\n", " <th>hepato</th>\n", " <th>spiders</th>\n", " <th>bili</th>\n", " <th>chol</th>\n", " <th>...</th>\n", " <th>alk.phos</th>\n", " <th>ast</th>\n", " <th>trig</th>\n", " <th>platelet</th>\n", " <th>protime</th>\n", " <th>edema_0.5</th>\n", " <th>edema_1.0</th>\n", " <th>stage_2.0</th>\n", " <th>stage_3.0</th>\n", " <th>stage_4.0</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>279</th>\n", " <td>3.868493</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>-0.414654</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>-0.300725</td>\n", " <td>-0.096081</td>\n", " <td>...</td>\n", " <td>0.167937</td>\n", " <td>0.401418</td>\n", " <td>0.330031</td>\n", " <td>0.219885</td>\n", " <td>-1.137178</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>137</th>\n", " <td>3.553425</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.069681</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.895363</td>\n", " <td>0.406085</td>\n", " <td>...</td>\n", " <td>0.101665</td>\n", " <td>0.472367</td>\n", " <td>1.621764</td>\n", " <td>-0.120868</td>\n", " <td>-0.239610</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " <tr>\n", " <th>249</th>\n", " <td>4.846575</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>-0.924494</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>-0.510565</td>\n", " <td>-0.225352</td>\n", " <td>...</td>\n", " <td>0.245463</td>\n", " <td>1.899020</td>\n", " <td>-0.580807</td>\n", " <td>0.422207</td>\n", " <td>0.159309</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>266</th>\n", " <td>0.490411</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>1.938314</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>0.748475</td>\n", " <td>-0.608191</td>\n", " <td>...</td>\n", " <td>-0.650254</td>\n", " <td>-0.288898</td>\n", " <td>-0.481443</td>\n", " <td>-0.727833</td>\n", " <td>1.356065</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>12.328767</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.563645</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>1.0</td>\n", " <td>-0.405645</td>\n", " <td>-0.210436</td>\n", " <td>...</td>\n", " <td>2.173526</td>\n", " <td>-0.144699</td>\n", " <td>-0.531125</td>\n", " <td>-0.450972</td>\n", " <td>-0.139881</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>1.0</td>\n", " <td>0.0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>5 rows × 22 columns</p>\n", "</div>" ], "text/plain": [ " time status trt age sex ascites hepato spiders \\\n", "279 3.868493 0.0 0.0 -0.414654 0.0 0.0 0.0 0.0 \n", "137 3.553425 1.0 0.0 0.069681 1.0 0.0 1.0 0.0 \n", "249 4.846575 0.0 1.0 -0.924494 0.0 0.0 1.0 0.0 \n", "266 0.490411 1.0 0.0 1.938314 0.0 1.0 1.0 1.0 \n", "1 12.328767 0.0 0.0 0.563645 0.0 0.0 1.0 1.0 \n", "\n", " bili chol ... alk.phos ast trig platelet \\\n", "279 -0.300725 -0.096081 ... 0.167937 0.401418 0.330031 0.219885 \n", "137 0.895363 0.406085 ... 0.101665 0.472367 1.621764 -0.120868 \n", "249 -0.510565 -0.225352 ... 0.245463 1.899020 -0.580807 0.422207 \n", "266 0.748475 -0.608191 ... -0.650254 -0.288898 -0.481443 -0.727833 \n", "1 -0.405645 -0.210436 ... 2.173526 -0.144699 -0.531125 -0.450972 \n", "\n", " protime edema_0.5 edema_1.0 stage_2.0 stage_3.0 stage_4.0 \n", "279 -1.137178 0.0 0.0 0.0 1.0 0.0 \n", "137 -0.239610 0.0 0.0 0.0 1.0 0.0 \n", "249 0.159309 0.0 0.0 0.0 0.0 1.0 \n", "266 1.356065 0.0 1.0 0.0 0.0 1.0 \n", "1 -0.139881 0.0 0.0 0.0 1.0 0.0 \n", "\n", "[5 rows x 22 columns]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(one_hot_train.shape)\n", "one_hot_train.head()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "hNxuymLwyjqM" }, "source": [ "<a name='5'></a>\n", "## 5. Fitting and Interpreting a Cox Model" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ygiFcUKcAFQk" }, "source": [ "Run the following cell to fit your Cox Proportional Hazards model using the `lifelines` package." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "dDCS7p3xjbXB", "outputId": "41b12f82-8b35-43e1-d2a9-05258ac50b20" }, "outputs": [ { "data": { "text/plain": [ "<lifelines.CoxPHFitter: fitted with 154 total observations, 90 right-censored observations>" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cph = CoxPHFitter()\n", "cph.fit(one_hot_train, duration_col = 'time', event_col = 'status', step_size=0.1)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "5MUITR0QANDH" }, "source": [ "You can use `cph.print_summary()` to view the coefficients associated with each covariate as well as confidence intervals. " ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 903 }, "colab_type": "code", "id": "fH5AZs8vjcEv", "outputId": "5429f7d5-5669-431f-a014-cf609c90997f" }, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <tbody>\n", " <tr>\n", " <th>model</th>\n", " <td>lifelines.CoxPHFitter</td>\n", " </tr>\n", " <tr>\n", " <th>duration col</th>\n", " <td>'time'</td>\n", " </tr>\n", " <tr>\n", " <th>event col</th>\n", " <td>'status'</td>\n", " </tr>\n", " <tr>\n", " <th>number of observations</th>\n", " <td>154</td>\n", " </tr>\n", " <tr>\n", " <th>number of events observed</th>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>partial log-likelihood</th>\n", " <td>-230.82</td>\n", " </tr>\n", " <tr>\n", " <th>time fit was run</th>\n", " <td>2020-06-26 05:28:02 UTC</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div><table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>coef</th>\n", " <th>exp(coef)</th>\n", " <th>se(coef)</th>\n", " <th>coef lower 95%</th>\n", " <th>coef upper 95%</th>\n", " <th>exp(coef) lower 95%</th>\n", " <th>exp(coef) upper 95%</th>\n", " <th>z</th>\n", " <th>p</th>\n", " <th>-log2(p)</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>trt</th>\n", " <td>-0.22</td>\n", " <td>0.80</td>\n", " <td>0.30</td>\n", " <td>-0.82</td>\n", " <td>0.37</td>\n", " <td>0.44</td>\n", " <td>1.45</td>\n", " <td>-0.73</td>\n", " <td>0.46</td>\n", " <td>1.11</td>\n", " </tr>\n", " <tr>\n", " <th>age</th>\n", " <td>0.23</td>\n", " <td>1.26</td>\n", " <td>0.19</td>\n", " <td>-0.13</td>\n", " <td>0.60</td>\n", " <td>0.88</td>\n", " <td>1.82</td>\n", " <td>1.26</td>\n", " <td>0.21</td>\n", " <td>2.27</td>\n", " </tr>\n", " <tr>\n", " <th>sex</th>\n", " <td>0.34</td>\n", " <td>1.41</td>\n", " <td>0.40</td>\n", " <td>-0.45</td>\n", " <td>1.14</td>\n", " <td>0.64</td>\n", " <td>3.11</td>\n", " <td>0.84</td>\n", " <td>0.40</td>\n", " <td>1.33</td>\n", " </tr>\n", " <tr>\n", " <th>ascites</th>\n", " <td>-0.10</td>\n", " <td>0.91</td>\n", " <td>0.56</td>\n", " <td>-1.20</td>\n", " <td>1.01</td>\n", " <td>0.30</td>\n", " <td>2.75</td>\n", " <td>-0.17</td>\n", " <td>0.86</td>\n", " <td>0.21</td>\n", " </tr>\n", " <tr>\n", " <th>hepato</th>\n", " <td>0.31</td>\n", " <td>1.36</td>\n", " <td>0.38</td>\n", " <td>-0.44</td>\n", " <td>1.06</td>\n", " <td>0.64</td>\n", " <td>2.89</td>\n", " <td>0.81</td>\n", " <td>0.42</td>\n", " <td>1.26</td>\n", " </tr>\n", " <tr>\n", " <th>spiders</th>\n", " <td>-0.18</td>\n", " <td>0.83</td>\n", " <td>0.38</td>\n", " <td>-0.94</td>\n", " <td>0.57</td>\n", " <td>0.39</td>\n", " <td>1.77</td>\n", " <td>-0.47</td>\n", " <td>0.64</td>\n", " <td>0.66</td>\n", " </tr>\n", " <tr>\n", " <th>bili</th>\n", " <td>0.05</td>\n", " <td>1.05</td>\n", " <td>0.18</td>\n", " <td>-0.29</td>\n", " <td>0.39</td>\n", " <td>0.75</td>\n", " <td>1.48</td>\n", " <td>0.29</td>\n", " <td>0.77</td>\n", " <td>0.37</td>\n", " </tr>\n", " <tr>\n", " <th>chol</th>\n", " <td>0.19</td>\n", " <td>1.20</td>\n", " <td>0.15</td>\n", " <td>-0.10</td>\n", " <td>0.47</td>\n", " <td>0.91</td>\n", " <td>1.60</td>\n", " <td>1.28</td>\n", " <td>0.20</td>\n", " <td>2.33</td>\n", " </tr>\n", " <tr>\n", " <th>albumin</th>\n", " <td>-0.40</td>\n", " <td>0.67</td>\n", " <td>0.18</td>\n", " <td>-0.75</td>\n", " <td>-0.06</td>\n", " <td>0.47</td>\n", " <td>0.94</td>\n", " <td>-2.28</td>\n", " <td>0.02</td>\n", " <td>5.46</td>\n", " </tr>\n", " <tr>\n", " <th>copper</th>\n", " <td>0.30</td>\n", " <td>1.35</td>\n", " <td>0.16</td>\n", " <td>-0.01</td>\n", " <td>0.61</td>\n", " <td>0.99</td>\n", " <td>1.84</td>\n", " <td>1.91</td>\n", " <td>0.06</td>\n", " <td>4.14</td>\n", " </tr>\n", " <tr>\n", " <th>alk.phos</th>\n", " <td>-0.22</td>\n", " <td>0.80</td>\n", " <td>0.14</td>\n", " <td>-0.49</td>\n", " <td>0.05</td>\n", " <td>0.61</td>\n", " <td>1.05</td>\n", " <td>-1.62</td>\n", " <td>0.11</td>\n", " <td>3.24</td>\n", " </tr>\n", " <tr>\n", " <th>ast</th>\n", " <td>0.21</td>\n", " <td>1.24</td>\n", " <td>0.16</td>\n", " <td>-0.10</td>\n", " <td>0.53</td>\n", " <td>0.91</td>\n", " <td>1.69</td>\n", " <td>1.34</td>\n", " <td>0.18</td>\n", " <td>2.48</td>\n", " </tr>\n", " <tr>\n", " <th>trig</th>\n", " <td>0.20</td>\n", " <td>1.23</td>\n", " <td>0.16</td>\n", " <td>-0.11</td>\n", " <td>0.52</td>\n", " <td>0.89</td>\n", " <td>1.68</td>\n", " <td>1.27</td>\n", " <td>0.21</td>\n", " <td>2.28</td>\n", " </tr>\n", " <tr>\n", " <th>platelet</th>\n", " <td>0.14</td>\n", " <td>1.15</td>\n", " <td>0.15</td>\n", " <td>-0.16</td>\n", " <td>0.43</td>\n", " <td>0.86</td>\n", " <td>1.54</td>\n", " <td>0.92</td>\n", " <td>0.36</td>\n", " <td>1.48</td>\n", " </tr>\n", " <tr>\n", " <th>protime</th>\n", " <td>0.36</td>\n", " <td>1.43</td>\n", " <td>0.17</td>\n", " <td>0.03</td>\n", " <td>0.69</td>\n", " <td>1.03</td>\n", " <td>1.99</td>\n", " <td>2.15</td>\n", " <td>0.03</td>\n", " <td>4.97</td>\n", " </tr>\n", " <tr>\n", " <th>edema_0.5</th>\n", " <td>1.24</td>\n", " <td>3.47</td>\n", " <td>0.46</td>\n", " <td>0.35</td>\n", " <td>2.14</td>\n", " <td>1.42</td>\n", " <td>8.50</td>\n", " <td>2.72</td>\n", " <td>0.01</td>\n", " <td>7.28</td>\n", " </tr>\n", " <tr>\n", " <th>edema_1.0</th>\n", " <td>2.02</td>\n", " <td>7.51</td>\n", " <td>0.60</td>\n", " <td>0.84</td>\n", " <td>3.20</td>\n", " <td>2.31</td>\n", " <td>24.43</td>\n", " <td>3.35</td>\n", " <td><0.005</td>\n", " <td>10.28</td>\n", " </tr>\n", " <tr>\n", " <th>stage_2.0</th>\n", " <td>1.21</td>\n", " <td>3.35</td>\n", " <td>1.08</td>\n", " <td>-0.92</td>\n", " <td>3.33</td>\n", " <td>0.40</td>\n", " <td>28.06</td>\n", " <td>1.11</td>\n", " <td>0.27</td>\n", " <td>1.91</td>\n", " </tr>\n", " <tr>\n", " <th>stage_3.0</th>\n", " <td>1.18</td>\n", " <td>3.27</td>\n", " <td>1.09</td>\n", " <td>-0.96</td>\n", " <td>3.33</td>\n", " <td>0.38</td>\n", " <td>27.86</td>\n", " <td>1.08</td>\n", " <td>0.28</td>\n", " <td>1.84</td>\n", " </tr>\n", " <tr>\n", " <th>stage_4.0</th>\n", " <td>1.41</td>\n", " <td>4.10</td>\n", " <td>1.15</td>\n", " <td>-0.85</td>\n", " <td>3.67</td>\n", " <td>0.43</td>\n", " <td>39.43</td>\n", " <td>1.22</td>\n", " <td>0.22</td>\n", " <td>2.18</td>\n", " </tr>\n", " </tbody>\n", "</table><div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <tbody>\n", " <tr>\n", " <th>Concordance</th>\n", " <td>0.83</td>\n", " </tr>\n", " <tr>\n", " <th>Log-likelihood ratio test</th>\n", " <td>97.63 on 20 df, -log2(p)=38.13</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ "<IPython.core.display.HTML object>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "cph.print_summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question:**\n", "\n", "- According to the model, was treatment `trt` beneficial? \n", "- What was its associated hazard ratio? \n", " - Note that the hazard ratio is how much an incremental increase in the feature variable changes the hazard." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<details> \n", "<summary>\n", " <font size=\"3\" color=\"darkgreen\"><b>Check your answer!</b></font>\n", "</summary>\n", "<p>\n", "<ul>\n", "<ul>\n", " <li>You should see that the treatment (trt) was beneficial because it has a negative impact on the hazard (the coefficient is negative, and exp(coef) is less than 1).</li>\n", " <li>The associated hazard ratio is ~0.8, because this is the exp(coef) of treatment.</li>\n", "</ul>\n", "</p>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can compare the predicted survival curves for treatment variables. Run the next cell to plot survival curves using the `plot_covariate_groups()` function. \n", "- The y-axis is th survival rate\n", "- The x-axis is time" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 282 }, "colab_type": "code", "id": "Uxl0icyBS4Dr", "outputId": "5fa08369-e89e-424f-f9f0-60cf7a1cfbcd" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "cph.plot_covariate_groups('trt', values=[0, 1]);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice how the group without treatment has a lower survival rate at all times (the x-axis is time) compared to the treatment group." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<a name='6'></a>\n", "## 6. Hazard Ratio\n", "\n", "Recall from the lecture videos that the Hazard Ratio between two patients was the likelihood of one patient (e.g smoker) being more at risk than the other (e.g non-smoker).\n", "$$\n", "\\frac{\\lambda_{smoker}(t)}{\\lambda_{nonsmoker}(t)} = e^{\\theta (X_{smoker} - X_{nonsmoker})^T}\n", "$$\n", "\n", "Where\n", "\n", "$$\n", "\\lambda_{smoker}(t) = \\lambda_0(t)e^{\\theta X_{smoker}^T}\n", "$$\n", "and\n", "$$\n", "\\lambda_{nonsmoker}(t) = \\lambda_0(t)e^{\\theta X_{nonsmoker}^T} \\\\\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<a name='Ex-2'></a>\n", "### Exercise 2\n", "In the cell below, write a function to compute the hazard ratio between two individuals given the cox model's coefficients." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<details> \n", "<summary>\n", " <font size=\"3\" color=\"darkgreen\"><b>Hints</b></font>\n", "</summary>\n", "<p>\n", "<ul>\n", " <li>use numpy.dot</li>\n", " <li>use nump.exp</li>\n", "</ul>\n", "</p>\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": {}, "colab_type": "code", "id": "WbBmxbeDA3k1" }, "outputs": [], "source": [ "# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n", "def hazard_ratio(case_1, case_2, cox_params):\n", " '''\n", " Return the hazard ratio of case_1 : case_2 using\n", " the coefficients of the cox model.\n", " \n", " Args:\n", " case_1 (np.array): (1 x d) array of covariates\n", " case_2 (np.array): (1 x d) array of covariates\n", " model (np.array): (1 x d) array of cox model coefficients\n", " Returns:\n", " hazard_ratio (float): hazard ratio of case_1 : case_2\n", " '''\n", " \n", " ### START CODE HERE (REPLACE INSTANCES OF 'None' with your code) ###\n", " \n", " hr = np.exp(cox_params.dot((case_1 - case_2).T))\n", " \n", " ### END CODE HERE ###\n", " \n", " return hr" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "zbDQUxE6CcA3" }, "source": [ "Now, evaluate it on the following pair of indivduals: `i = 1` and `j = 5`" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": {}, "colab_type": "code", "id": "7flsvTRXCgqO" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "15.029017732492221\n" ] } ], "source": [ "i = 1\n", "case_1 = one_hot_train.iloc[i, :].drop(['time', 'status'])\n", "\n", "j = 5\n", "case_2 = one_hot_train.iloc[j, :].drop(['time', 'status'])\n", "\n", "print(hazard_ratio(case_1.values, case_2.values, cph.params_.values))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Expected Output:\n", "```CPP\n", "15.029017732492221\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Question:** \n", "\n", "Is `case_1` or `case_2` at greater risk? " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<details> \n", "<summary>\n", " <font size=\"3\" color=\"darkgreen\"><b>Check your answer!</b></font>\n", "</summary>\n", "<p>\n", "<ul>\n", "<ul>\n", " Important! The following answer only applies if you picked i = 1 and j = 5\n", " <li>You should see that `case_1` is at higher risk.</li>\n", " <li>The hazard ratio of case 1 / case 2 is greater than 1, so case 1 had a higher hazard relative to case 2</li>\n", "</ul>\n", "</p>" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Inspect different pairs, and see if you can figure out which patient is more at risk." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 374 }, "colab_type": "code", "id": "g2PZ3sGvCs0K", "outputId": "59336868-d421-4645-d88e-76a8a8cffc9f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Case 1\n", "\n", " trt 0.000000\n", "age 0.563645\n", "sex 0.000000\n", "ascites 0.000000\n", "hepato 1.000000\n", "spiders 1.000000\n", "bili -0.405645\n", "chol -0.210436\n", "albumin 1.514297\n", "copper -0.481961\n", "alk.phos 2.173526\n", "ast -0.144699\n", "trig -0.531125\n", "platelet -0.450972\n", "protime -0.139881\n", "edema_0.5 0.000000\n", "edema_1.0 0.000000\n", "stage_2.0 0.000000\n", "stage_3.0 1.000000\n", "stage_4.0 0.000000\n", "Name: 1, dtype: float64 \n", "\n", "Case 2\n", "\n", " trt 0.000000\n", "age 0.463447\n", "sex 0.000000\n", "ascites 0.000000\n", "hepato 1.000000\n", "spiders 0.000000\n", "bili -0.489581\n", "chol -0.309875\n", "albumin -1.232371\n", "copper -0.504348\n", "alk.phos 2.870427\n", "ast -0.936261\n", "trig -0.150229\n", "platelet 3.190823\n", "protime -0.139881\n", "edema_0.5 0.000000\n", "edema_1.0 0.000000\n", "stage_2.0 0.000000\n", "stage_3.0 0.000000\n", "stage_4.0 1.000000\n", "Name: 38, dtype: float64 \n", "\n", "Hazard Ratio: 0.1780450006997129\n" ] } ], "source": [ "i = 4\n", "case_1 = one_hot_train.iloc[i, :].drop(['time', 'status'])\n", "\n", "j = 7\n", "case_2 = one_hot_train.iloc[j, :].drop(['time', 'status'])\n", "\n", "print(\"Case 1\\n\\n\", case_1, \"\\n\")\n", "print(\"Case 2\\n\\n\", case_2, \"\\n\")\n", "print(\"Hazard Ratio:\", hazard_ratio(case_1.values, case_2.values, cph.params_.values))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<details> \n", "<summary>\n", " <font size=\"3\" color=\"darkgreen\"><b>Check your answer!</b></font>\n", "</summary>\n", "<p>\n", "<ul>\n", "<ul>\n", " Important! The following answer only applies if you picked i = 4 and j = 7\n", " <li>You should see that `case_2` is at higher risk.</li>\n", " <li>The hazard ratio of case 1 / case 2 is less than 1, so case 2 had a higher hazard relative to case 1</li>\n", "</ul>\n", "</p>" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "KUa6r-KOyySp" }, "source": [ "<a name='7'></a>\n", "## 7. Harrell's C-index" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "woQAtSmRXrgr" }, "source": [ "To evaluate how good our model is performing, we will write our own version of the C-index. Similar to the week 1 case, C-index in the survival context is the probability that, given a randomly selected pair of individuals, the one who died sooner has a higher risk score. \n", "\n", "However, we need to take into account censoring. Imagine a pair of patients, $A$ and $B$. \n", "\n", "#### Scenario 1\n", "- A was censored at time $t_A$ \n", "- B died at $t_B$\n", "- $t_A < t_B$. \n", "\n", "Because of censoring, we can't say whether $A$ or $B$ should have a higher risk score. \n", "\n", "#### Scenario 2\n", "Now imagine that $t_A > t_B$.\n", "\n", "- A was censored at time $t_A$ \n", "- B died at $t_B$\n", "- $t_A > t_B$\n", "\n", "Now we can definitively say that $B$ should have a higher risk score than $A$, since we know for a fact that $A$ lived longer. \n", "\n", "Therefore, when we compute our C-index\n", "- We should only consider pairs where at most one person is censored\n", "- If they are censored, then their censored time should occur *after* the other person's time of death. \n", "\n", "The metric we get if we use this rule is called **Harrel's C-index**.\n", "\n", "Note that in this case, being censored at time $t$ means that the true death time was some time AFTER time $t$ and not at $t$. \n", "- Therefore if $t_A = t_B$ and A was censored:\n", " - Then $A$ actually lived longer than $B$. \n", " - This will effect how you deal with ties in the exercise below!\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<a name='Ex-3'></a>\n", "### Exercise 3\n", "Fill in the function below to compute Harrel's C-index." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<details> \n", "<summary>\n", " <font size=\"3\" color=\"darkgreen\"><b>Hints</b></font>\n", "</summary>\n", "<p>\n", "<ul>\n", " <li>If you get a division by zero error, consider checking how you count when a pair is permissible (in the case where one patient is censored and the other is not censored).</li>\n", "</ul>\n", "</p>" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# UNQ_C3 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n", "def harrell_c(y_true, scores, event):\n", " '''\n", " Compute Harrel C-index given true event/censoring times,\n", " model output, and event indicators.\n", " \n", " Args:\n", " y_true (array): array of true event times\n", " scores (array): model risk scores\n", " event (array): indicator, 1 if event occurred at that index, 0 for censorship\n", " Returns:\n", " result (float): C-index metric\n", " '''\n", " \n", " n = len(y_true)\n", " assert (len(scores) == n and len(event) == n)\n", " \n", " concordant = 0.0\n", " permissible = 0.0\n", " ties = 0.0\n", " \n", " result = 0.0\n", " \n", " ### START CODE HERE (REPLACE INSTANCES OF 'None' and 'pass' with your code) ###\n", " \n", " # use double for loop to go through cases\n", " for i in range(n):\n", " # set lower bound on j to avoid double counting\n", " for j in range(i+1, n):\n", " \n", " # check if at most one is censored\n", " if event[i] == 1 or event[j] == 1:\n", " \n", " # check if neither are censored\n", " if event[i] == 1 and event[j] == 1:\n", " \n", " permissible += 1.0\n", " \n", " # check if scores are tied\n", " if scores[i] == scores[j]:\n", " ties += 1.0\n", " \n", " # check for concordant\n", " elif y_true[i] < y_true[j] and scores[i] > scores[j]:\n", " concordant += 1.0\n", " elif y_true[i] > y_true[j] and scores[i] < scores[j]:\n", " concordant += 1.0\n", " \n", " # check if one is censored\n", " elif event[i] != event[j]:\n", " \n", " # get censored index\n", " censored = j\n", " uncensored = i\n", " \n", " if event[i] == 0:\n", " censored = i\n", " uncensored = j\n", " \n", " # check if permissible\n", " # Note: in this case, we are assuming that censored at a time\n", " # means that you did NOT die at that time. That is, if you\n", " # live until time 30 and have event = 0, then you lived THROUGH\n", " # time 30.\n", " if y_true[uncensored] <= y_true[censored]:\n", " permissible += 1.0\n", " \n", " # check if scores are tied\n", " if scores[uncensored] == scores[censored]:\n", " # update ties \n", " ties += 1.0\n", " \n", " # check if scores are concordant \n", " if scores[uncensored] > scores[censored]:\n", " concordant += 1.0\n", " \n", " # set result to c-index computed from number of concordant pairs,\n", " # number of ties, and number of permissible pairs (REPLACE 0 with your code) \n", " result = (concordant + 0.5*ties) / permissible\n", " \n", " ### END CODE HERE ###\n", " \n", " return result " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can test your function on the following test cases:" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Case 1\n", "Expected: 1.0, Output: 1.0\n", "\n", "Case 2\n", "Expected: 0.0, Output: 0.0\n", "\n", "Case 3\n", "Expected: 1.0, Output: 1.0\n", "\n", "Case 4\n", "Expected: 0.75, Output: 0.75\n", "\n", "Case 5\n", "Expected: 0.583, Output: 0.5833333333333334\n", "\n", "Case 6\n", "Expected: 1.0 , Output:1.0000\n" ] } ], "source": [ "y_true = [30, 12, 84, 9]\n", "\n", "# Case 1\n", "event = [1, 1, 1, 1]\n", "scores = [0.5, 0.9, 0.1, 1.0]\n", "print(\"Case 1\")\n", "print(\"Expected: 1.0, Output: {}\".format(harrell_c(y_true, scores, event)))\n", "\n", "# Case 2\n", "scores = [0.9, 0.5, 1.0, 0.1]\n", "print(\"\\nCase 2\")\n", "print(\"Expected: 0.0, Output: {}\".format(harrell_c(y_true, scores, event)))\n", "\n", "# Case 3\n", "event = [1, 0, 1, 1]\n", "scores = [0.5, 0.9, 0.1, 1.0]\n", "print(\"\\nCase 3\")\n", "print(\"Expected: 1.0, Output: {}\".format(harrell_c(y_true, scores, event)))\n", "\n", "# Case 4\n", "y_true = [30, 30, 20, 20]\n", "event = [1, 0, 1, 0]\n", "scores = [10, 5, 15, 20]\n", "print(\"\\nCase 4\")\n", "print(\"Expected: 0.75, Output: {}\".format(harrell_c(y_true, scores, event)))\n", "\n", "# Case 5\n", "y_true = list(reversed([30, 30, 30, 20, 20]))\n", "event = [0, 1, 0, 1, 0]\n", "scores = list(reversed([15, 10, 5, 15, 20]))\n", "print(\"\\nCase 5\")\n", "print(\"Expected: 0.583, Output: {}\".format(harrell_c(y_true, scores, event)))\n", "\n", "# Case 6\n", "y_true = [10,10]\n", "event = [0,1]\n", "scores = [4,5]\n", "print(\"\\nCase 6\")\n", "print(f\"Expected: 1.0 , Output:{harrell_c(y_true, scores, event):.4f}\")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "CtQVe4pAn8ic" }, "source": [ "Now use the Harrell's C-index function to evaluate the cox model on our data sets." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "8nzHc_Qbn7dM", "outputId": "bc2f960d-16e5-46b2-a41f-695892c311c7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train: 0.8265139116202946\n", "Val: 0.8544776119402985\n", "Test: 0.8478543563068921\n" ] } ], "source": [ "# Train\n", "scores = cph.predict_partial_hazard(one_hot_train)\n", "cox_train_scores = harrell_c(one_hot_train['time'].values, scores.values, one_hot_train['status'].values)\n", "# Validation\n", "scores = cph.predict_partial_hazard(one_hot_val)\n", "cox_val_scores = harrell_c(one_hot_val['time'].values, scores.values, one_hot_val['status'].values)\n", "# Test\n", "scores = cph.predict_partial_hazard(one_hot_test)\n", "cox_test_scores = harrell_c(one_hot_test['time'].values, scores.values, one_hot_test['status'].values)\n", "\n", "print(\"Train:\", cox_train_scores)\n", "print(\"Val:\", cox_val_scores)\n", "print(\"Test:\", cox_test_scores)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What do these values tell us ?" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "AuNjR_wNkpWz" }, "source": [ "<a name='8'></a>\n", "## 8. Random Survival Forests\n", "\n", "This performed well, but you have a hunch you can squeeze out better performance by using a machine learning approach. You decide to use a Random Survival Forest. To do this, you can use the `RandomForestSRC` package in R. To call R function from Python, we'll use the `r2py` package. Run the following cell to import the necessary requirements. \n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 530 }, "colab_type": "code", "id": "ZgSy-Dj6kquK", "outputId": "4aa5d2fa-30f4-4328-ae29-a2ff05223e22" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "R[write to console]: Loading required package: ggplot2\n", "\n" ] } ], "source": [ "%load_ext rpy2.ipython\n", "%R require(ggplot2)\n", "\n", "from rpy2.robjects.packages import importr\n", "# import R's \"base\" package\n", "base = importr('base')\n", "\n", "# import R's \"utils\" package\n", "utils = importr('utils')\n", "\n", "# import rpy2's package module\n", "import rpy2.robjects.packages as rpackages\n", "\n", "forest = rpackages.importr('randomForestSRC', lib_loc='R')\n", "\n", "from rpy2 import robjects as ro\n", "R = ro.r\n", "\n", "from rpy2.robjects import pandas2ri\n", "pandas2ri.activate()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "LXBgqQBfuMA5" }, "source": [ "Instead of encoding our categories as binary features, we can use the original dataframe since trees deal well with raw categorical data (can you think why this might be?).\n", "\n", "Run the code cell below to build your forest." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": {}, "colab_type": "code", "id": "B-pio4o4mdVJ" }, "outputs": [], "source": [ "model = forest.rfsrc(ro.Formula('Surv(time, status) ~ .'), data=df_train, ntree=300, nodedepth=5, seed=-1)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 289 }, "colab_type": "code", "id": "zZfcUvJ3nL04", "outputId": "27d00bd8-ea33-4c1b-f5d7-ca2c73ead721" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Sample size: 154\n", " Number of deaths: 64\n", " Number of trees: 300\n", " Forest terminal node size: 15\n", " Average no. of terminal nodes: 6.54\n", "No. of variables tried at each split: 5\n", " Total no. of variables: 17\n", " Resampling used to grow trees: swor\n", " Resample size used to grow trees: 97\n", " Analysis: RSF\n", " Family: surv\n", " Splitting rule: logrank *random*\n", " Number of random split points: 10\n", " Error rate: 19.07%\n", "\n", "\n" ] } ], "source": [ "print(model)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "9Mwzm55H-QKV" }, "source": [ "Finally, let's evaluate on our validation and test sets, and compare it with our Cox model." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "vfl4LbGfpbKp", "outputId": "13f8b560-e171-41e9-f6dc-cf0468a3f786" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cox Model Validation Score: 0.8544776119402985\n", "Survival Forest Validation Score: 0.8296019900497512\n" ] } ], "source": [ "result = R.predict(model, newdata=df_val)\n", "scores = np.array(result.rx('predicted')[0])\n", "\n", "print(\"Cox Model Validation Score:\", cox_val_scores)\n", "print(\"Survival Forest Validation Score:\", harrell_c(df_val['time'].values, scores, df_val['status'].values))" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "uhqSQJhrplSG", "outputId": "752c266e-0234-45c5-d53f-554e2ff17a5a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cox Model Test Score: 0.8478543563068921\n", "Survival Forest Validation Score: 0.8621586475942783\n" ] } ], "source": [ "result = R.predict(model, newdata=df_test)\n", "scores = np.array(result.rx('predicted')[0])\n", "\n", "print(\"Cox Model Test Score:\", cox_test_scores)\n", "print(\"Survival Forest Validation Score:\", harrell_c(df_test['time'].values, scores, df_test['status'].values))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Gp_SgUXreAWn" }, "source": [ "Your random forest model should be outperforming the Cox model slightly. Let's dig deeper to see how they differ." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ZtPMPaSli8GB" }, "source": [ "<a name='9'></a>\n", "## 9. Permutation Method for Interpretation\n", "\n", "We'll dig a bit deeper into interpretation methods for forests a bit later, but for now just know that random surival forests come with their own built in variable importance feature. The method is referred to as VIMP, and for the purpose of this section you should just know that higher absolute value of the VIMP means that the variable generally has a larger effect on the model outcome.\n", "\n", "Run the next cell to compute and plot VIMP for the random survival forest." ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 281 }, "colab_type": "code", "id": "u7M4_N_d-YJu", "outputId": "7e1830cb-4b67-444f-8ba5-d49d3ff2f172" }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "vimps = np.array(forest.vimp(model).rx('importance')[0])\n", "\n", "y = np.arange(len(vimps))\n", "plt.barh(y, np.abs(vimps))\n", "plt.yticks(y, df_train.drop(['time', 'status'], axis=1).columns)\n", "plt.title(\"VIMP (absolute value)\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "2YGhK2xwjkiA" }, "source": [ "### Question:\n", "\n", "How does the variable importance compare to that of the Cox model? Which variable is important in both models? Which variable is important in the random survival forest but not in the Cox model? You should see that `edema` is important in both the random survival forest and the Cox model. You should also see that `bili` is important in the random survival forest but not the Cox model ." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Congratulations!\n", "\n", "You've finished the last assignment in course 2! Take a minute to look back at the analysis you've done over the last four assignments. You've done a great job!" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "include_colab_link": true, "name": "C2M4_Assignment.ipynb", "provenance": [] }, "coursera": { "schema_names": [ "AI4MC2-4" ] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 4 }