{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to Survival Analysis with scikit-survival" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**scikit-survival** is a Python module for [survival analysis](https://en.wikipedia.org/wiki/Survival_analysis) built on top of [scikit-learn](http://scikit-learn.org/). It allows doing survival analysis while utilizing the power of scikit-learn, e.g., for pre-processing or doing cross-validation.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Table of Contents\n", "\n", "1. [What is Survival Analysis?](#What-is-Survival-Analysis?)\n", "2. [The Veterans' Administration Lung Cancer Trial](#The-Veterans%27-Administration-Lung-Cancer-Trial)\n", "3. [Survival Data](#Survival-Data)\n", "4. [The Survival Function](#The-Survival-Function)\n", "5. [Considering other variables by stratification](#Considering-other-variables-by-stratification)\n", "6. [Multivariate Survival Models](#Multivariate-Survival-Models)\n", "7. [Measuring the Performance of Survival Models](#Measuring-the-Performance-of-Survival-Models)\n", "8. [Feature Selection: Which Variable is Most Predictive?](#Feature-Selection:-Which-Variable-is-Most-Predictive?)\n", "9. [What's next?](#What%27s-next?)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What is Survival Analysis?\n", "\n", "The objective in survival analysis — also referred to as reliability analysis in engineering — is to establish a connection between covariates and the time of an event. The name *survival analysis* originates from clinical research, where predicting the time to death, i.e., survival, is often the main objective. Survival analysis is a type of regression problem (one wants to predict a continuous value), but with a twist. It differs from traditional regression by the fact that parts of the training data can only be partially observed – they are *censored*.\n", "\n", "As an example, consider a clinical study, which investigates coronary heart disease and has been carried out over a 1 year period as in the figure below.\n", "\n", "![image censoring](https://k-d-w.org/clipboard/censoring.png)\n", "\n", "Patient A was lost to follow-up after three months with no recorded cardiovascular event, patient B experienced an event four and a half months after enrollment, patient D withdrew from the study two months after enrollment, and patient E did not experience any event before the study ended. Consequently, the exact time of a cardiovascular event could only be recorded for patients B and C; their records are *uncensored*. For the remaining patients it is unknown whether they did or did not experience an event after termination of the study. The only valid information that is available for patients A, D, and E is that they were event-free up to their last follow-up. Therefore, their records are *censored*.\n", "\n", "Formally, each patient record consists of a set of covariates $x \\in \\mathbb{R}^d$ , and the time $t>0$ when an event occurred or the time $c>0$ of censoring. Since censoring and experiencing and event are mutually exclusive, it is common to define an event indicator $\\delta \\in \\{0;1\\}$ and the observable survival time $y>0$. The observable time $y$ of a right censored sample is defined as\n", "\n", "$$\n", "y = \\min(t, c) = \n", "\\begin{cases} \n", "t & \\text{if } \\delta = 1 , \\\\ \n", "c & \\text{if } \\delta = 0 .\n", "\\end{cases}\n", "$$\n", "\n", "Consequently, survival analysis demands for models that take this unique characteristic of such a dataset into account, some of which are showcased below." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Veterans' Administration Lung Cancer Trial\n", "\n", "The Veterans' Administration Lung Cancer Trial is a randomized trial of two treatment regimens for lung cancer. The [data set](http://lib.stat.cmu.edu/datasets/veteran) (Kalbfleisch J. and Prentice R, (1980) The Statistical Analysis of Failure Time Data. New York: Wiley) consists of 137 patients and 8 variables, which are discribed below:\n", "\n", "- `Treatment`: denotes the type of lung cancer treatment; `standard` and `test` drug.\n", "- `Celltype`: denotes the type of cell involved; `squamous`, `small cell`, `adeno`, `large`.\n", "- `Karnofsky_score`: is the Karnofsky score.\n", "- `Diag`: is the time since diagnosis in months.\n", "- `Age`: is the age in years.\n", "- `Prior_Therapy`: denotes any prior therapy; `none` or `yes`.\n", "- `Status`: denotes the status of the patient as dead or alive; `dead` or `alive`.\n", "- `Survival_in_days`: is the survival time in days since the treatment.\n", "\n", "Our primary interest is studying whether there a subgroups that differ in survival and whether we can predict survival times." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Survival Data\n", "\n", "As described in the section *What is Survival Analysis?* above, survival times are subject to right-censoring, therefore, we need to consider an individual's status in addition to survival time. To be fully compatible with scikit-learn, `Status` and `Survival_in_days` need to be stored as a [structured array](https://docs.scipy.org/doc/numpy/user/basics.rec.html) with the first field indicating whether the actual survival time was observed or if was censored, and the second field denoting the observerd survival time, which corresponds to the time of death (if `Status == 'dead'`, $\\delta = 1$) or the last time that person was contacted (if `Status == 'alive'`, $\\delta = 0$)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([( True, 72.), ( True, 411.), ( True, 228.), ( True, 126.),\n", " ( True, 118.), ( True, 10.), ( True, 82.), ( True, 110.),\n", " ( True, 314.), (False, 100.), ( True, 42.), ( True, 8.),\n", " ( True, 144.), (False, 25.), ( True, 11.), ( True, 30.),\n", " ( True, 384.), ( True, 4.), ( True, 54.), ( True, 13.),\n", " (False, 123.), (False, 97.), ( True, 153.), ( True, 59.),\n", " ( True, 117.), ( True, 16.), ( True, 151.), ( True, 22.),\n", " ( True, 56.), ( True, 21.), ( True, 18.), ( True, 139.),\n", " ( True, 20.), ( True, 31.), ( True, 52.), ( True, 287.),\n", " ( True, 18.), ( True, 51.), ( True, 122.), ( True, 27.),\n", " ( True, 54.), ( True, 7.), ( True, 63.), ( True, 392.),\n", " ( True, 10.), ( True, 8.), ( True, 92.), ( True, 35.),\n", " ( True, 117.), ( True, 132.), ( True, 12.), ( True, 162.),\n", " ( True, 3.), ( True, 95.), ( True, 177.), ( True, 162.),\n", " ( True, 216.), ( True, 553.), ( True, 278.), ( True, 12.),\n", " ( True, 260.), ( True, 200.), ( True, 156.), (False, 182.),\n", " ( True, 143.), ( True, 105.), ( True, 103.), ( True, 250.),\n", " ( True, 100.), ( True, 999.), ( True, 112.), (False, 87.),\n", " (False, 231.), ( True, 242.), ( True, 991.), ( True, 111.),\n", " ( True, 1.), ( True, 587.), ( True, 389.), ( True, 33.),\n", " ( True, 25.), ( True, 357.), ( True, 467.), ( True, 201.),\n", " ( True, 1.), ( True, 30.), ( True, 44.), ( True, 283.),\n", " ( True, 15.), ( True, 25.), (False, 103.), ( True, 21.),\n", " ( True, 13.), ( True, 87.), ( True, 2.), ( True, 20.),\n", " ( True, 7.), ( True, 24.), ( True, 99.), ( True, 8.),\n", " ( True, 99.), ( True, 61.), ( True, 25.), ( True, 95.),\n", " ( True, 80.), ( True, 51.), ( True, 29.), ( True, 24.),\n", " ( True, 18.), (False, 83.), ( True, 31.), ( True, 51.),\n", " ( True, 90.), ( True, 52.), ( True, 73.), ( True, 8.),\n", " ( True, 36.), ( True, 48.), ( True, 7.), ( True, 140.),\n", " ( True, 186.), ( True, 84.), ( True, 19.), ( True, 45.),\n", " ( True, 80.), ( True, 52.), ( True, 164.), ( True, 19.),\n", " ( True, 53.), ( True, 15.), ( True, 43.), ( True, 340.),\n", " ( True, 133.), ( True, 111.), ( True, 231.), ( True, 378.),\n", " ( True, 49.)],\n", " dtype=[('Status', '?'), ('Survival_in_days', ' Let $T$ denote a continuous non-negative random variable corresponding to a patient’s survival time. The survival function $S(t)$ returns the probability of survival beyond time $t$ and is defined as\n", "> $$ S(t) = P (T > t). $$\n", "\n", "If we observed the exact survival time of all subjects, i.e., everyone died before the study ended, the survival function at time $t$ can simply be estimated by the ratio of patients surviving beyond time $t$ and the total number of patients:\n", "\n", "$$\n", "\\hat{S}(t) = \\frac{ \\text{number of patients surviving beyond $t$} }{ \\text{total number of patients} }\n", "$$\n", "\n", "In the presence of censoring, this estimator cannot be used, because the numerator is not always defined. For instance, consider the following set of patients:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StatusSurvival_in_days
1True8.0
2True10.0
3True20.0
4False25.0
5True59.0
\n", "
" ], "text/plain": [ " Status Survival_in_days\n", "1 True 8.0\n", "2 True 10.0\n", "3 True 20.0\n", "4 False 25.0\n", "5 True 59.0" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "pd.DataFrame.from_records(data_y[[11, 5, 32, 13, 23]], index=range(1, 6))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using the formula from above, we can compute $\\hat{S}(t=11) = \\frac{3}{5}$, but not $\\hat{S}(t=30)$, because we don't know whether the 4th patient is still alive at $t = 30$, all we know is that when we last checked at $t = 25$, the patient was still alive.\n", "\n", "An estimator, similar to the one above, that *is* valid if survival times are right-censored is the [Kaplan-Meier estimator](https://en.wikipedia.org/wiki/Kaplan%E2%80%93Meier_estimator)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 0, 'time $t$')" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "from sksurv.nonparametric import kaplan_meier_estimator\n", "\n", "time, survival_prob = kaplan_meier_estimator(data_y[\"Status\"], data_y[\"Survival_in_days\"])\n", "plt.step(time, survival_prob, where=\"post\")\n", "plt.ylabel(\"est. probability of survival $\\hat{S}(t)$\")\n", "plt.xlabel(\"time $t$\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The estimated curve is a step function, with steps occuring at time points where one or more patients died. From the plot we can see that most patients died in the first 200 days, as indicated by the steep slope of the estimated survival function in the first 200 days." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Considering other variables by stratification\n", "\n", "### Survival functions by treatment\n", "\n", "Patients enrolled in the Veterans' Administration Lung Cancer Trial were randomized to one of two treatments: `standard` and a new `test` drug. Next, let's have a look at how many patients underwent the standard treatment and how many received the new drug." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "standard 69\n", "test 68\n", "Name: Treatment, dtype: int64" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data_x[\"Treatment\"].value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Roughly half the patients received the alternative treatment.\n", "\n", "The obvious questions to ask is:\n", "> *Is there any difference in survival between the two treatment groups?*\n", "\n", "As a first attempt, we can estimate the survival function in both treatment groups separately." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for treatment_type in (\"standard\", \"test\"):\n", " mask_treat = data_x[\"Treatment\"] == treatment_type\n", " time_treatment, survival_prob_treatment = kaplan_meier_estimator(\n", " data_y[\"Status\"][mask_treat],\n", " data_y[\"Survival_in_days\"][mask_treat])\n", " \n", " plt.step(time_treatment, survival_prob_treatment, where=\"post\",\n", " label=\"Treatment = %s\" % treatment_type)\n", "\n", "plt.ylabel(\"est. probability of survival $\\hat{S}(t)$\")\n", "plt.xlabel(\"time $t$\")\n", "plt.legend(loc=\"best\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Unfortunately, the results are inconclusive, because the difference between the two estimated survival functions is too small to confidently argue that the drug affects survival or not.\n", "\n", "*Sidenote: Visually comparing estimated survival curves in order to assess whether there is a difference in survival between groups is usually not recommended, because it is highly subjective. Statistical tests such as the [log-rank test](https://en.wikipedia.org/wiki/Log-rank_test) are usually more appropriate.*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Survival functions by cell type\n", "\n", "Next, let's have a look at the cell type, which has been recorded as well, and repeat the analysis from above." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for value in data_x[\"Celltype\"].unique():\n", " mask = data_x[\"Celltype\"] == value\n", " time_cell, survival_prob_cell = kaplan_meier_estimator(data_y[\"Status\"][mask],\n", " data_y[\"Survival_in_days\"][mask])\n", " plt.step(time_cell, survival_prob_cell, where=\"post\",\n", " label=\"%s (n = %d)\" % (value, mask.sum()))\n", "\n", "plt.ylabel(\"est. probability of survival $\\hat{S}(t)$\")\n", "plt.xlabel(\"time $t$\")\n", "plt.legend(loc=\"best\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this case, we observe a pronounced difference between two groups. Patients with *squamous* or *large* cells seem to have a better prognosis compared to patients with *small* or *adeno* cells." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multivariate Survival Models\n", "\n", "In the Kaplan-Meier approach used above, we estimated multiple survival curves by dividing the dataset into smaller sub-groups according to a variable. If we want to consider more than 1 or 2 variables, this approach quickly becomes infeasible, because subgroups will get very small. Instead, we can use a linear model, [Cox's proportional hazard's model](https://en.wikipedia.org/wiki/Proportional_hazards_model), to estimate the impact each variable has on survival.\n", "\n", "First however, we need to convert the categorical variables in the data set into numeric values." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Age_in_yearsCelltype=largeCelltype=smallcellCelltype=squamousKarnofsky_scoreMonths_from_DiagnosisPrior_therapy=yesTreatment=test
069.00.00.01.060.07.00.00.0
164.00.00.01.070.05.01.00.0
238.00.00.01.060.03.00.00.0
363.00.00.01.060.09.01.00.0
465.00.00.01.070.011.01.00.0
\n", "
" ], "text/plain": [ " Age_in_years Celltype=large Celltype=smallcell Celltype=squamous \\\n", "0 69.0 0.0 0.0 1.0 \n", "1 64.0 0.0 0.0 1.0 \n", "2 38.0 0.0 0.0 1.0 \n", "3 63.0 0.0 0.0 1.0 \n", "4 65.0 0.0 0.0 1.0 \n", "\n", " Karnofsky_score Months_from_Diagnosis Prior_therapy=yes Treatment=test \n", "0 60.0 7.0 0.0 0.0 \n", "1 70.0 5.0 1.0 0.0 \n", "2 60.0 3.0 0.0 0.0 \n", "3 60.0 9.0 1.0 0.0 \n", "4 70.0 11.0 1.0 0.0 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sksurv.preprocessing import OneHotEncoder\n", "\n", "data_x_numeric = OneHotEncoder().fit_transform(data_x)\n", "data_x_numeric.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Survival models in **scikit-survival** follow the same rules as estimators in scikit-learn, i.e., they have a `fit` method, which expects a data matrix and a structered array of survival times and binary event indicators." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CoxPHSurvivalAnalysis(alpha=0, n_iter=100, tol=1e-09, verbose=0)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sksurv.linear_model import CoxPHSurvivalAnalysis\n", "\n", "estimator = CoxPHSurvivalAnalysis()\n", "estimator.fit(data_x_numeric, data_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result is a vector of coefficients, one for each variable, where each value corresponds to the [log hazard ratio](https://en.wikipedia.org/wiki/Hazard_ratio)." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Age_in_years -0.008549\n", "Celltype=large -0.788672\n", "Celltype=smallcell -0.331813\n", "Celltype=squamous -1.188299\n", "Karnofsky_score -0.032622\n", "Months_from_Diagnosis -0.000092\n", "Prior_therapy=yes 0.072327\n", "Treatment=test 0.289936\n", "dtype: float64" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.Series(estimator.coef_, index=data_x_numeric.columns)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using the fitted model, we can predict a patient-specific survival function, by passing an appropriate data matrix to the estimator's `predict_survival_function` method .\n", "\n", "First, let's create a set of four synthetic patients." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Age_in_yearsCelltype=largeCelltype=smallcellCelltype=squamousKarnofsky_scoreMonths_from_DiagnosisPrior_therapy=yesTreatment=test
16500160101
26500160100
36501060100
46501060101
\n", "
" ], "text/plain": [ " Age_in_years Celltype=large Celltype=smallcell Celltype=squamous \\\n", "1 65 0 0 1 \n", "2 65 0 0 1 \n", "3 65 0 1 0 \n", "4 65 0 1 0 \n", "\n", " Karnofsky_score Months_from_Diagnosis Prior_therapy=yes Treatment=test \n", "1 60 1 0 1 \n", "2 60 1 0 0 \n", "3 60 1 0 0 \n", "4 60 1 0 1 " ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_new = pd.DataFrame.from_dict({\n", " 1: [65, 0, 0, 1, 60, 1, 0, 1],\n", " 2: [65, 0, 0, 1, 60, 1, 0, 0],\n", " 3: [65, 0, 1, 0, 60, 1, 0, 0],\n", " 4: [65, 0, 1, 0, 60, 1, 0, 1]},\n", " columns=data_x_numeric.columns, orient='index')\n", "x_new" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Similar to `kaplan_meier_estimator`, the `predict_survival_function` method returns a sequence of step functions, which we can plot." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "pred_surv = estimator.predict_survival_function(x_new)\n", "for i, c in enumerate(pred_surv):\n", " plt.step(c.x, c.y, where=\"post\", label=\"Sample %d\" % (i + 1))\n", "plt.ylabel(\"est. probability of survival $\\hat{S}(t)$\")\n", "plt.xlabel(\"time $t$\")\n", "plt.legend(loc=\"best\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Measuring the Performance of Survival Models\n", "\n", "Once we fit a survival model, we usually want to assess how well a model can actually predict survival. Our test data is usually subject to censoring too, therefore metrics like root mean squared error or correlation are unsuitable. Instead, we use generalization of the area under the receiver operating characteristic (ROC) curve called [Harrell's concordance index](https://pdfs.semanticscholar.org/7705/392f1068c76669de750c6d0da8144da3304d.pdf) or c-index.\n", "\n", "The interpretation is identical to the traditional area under the [ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) metric for binary classification:\n", "- a value of 0.5 denotes a random model,\n", "- a value of 1.0 denotes a perfect model,\n", "- a value of 0.0 denotes a perfectly wrong model." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7362562471603816" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sksurv.metrics import concordance_index_censored\n", "\n", "prediction = estimator.predict(data_x_numeric)\n", "result = concordance_index_censored(data_y[\"Status\"], data_y[\"Survival_in_days\"], prediction)\n", "result[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "or alternatively" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7362562471603816" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "estimator.score(data_x_numeric, data_y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our model's c-index indicates that the model clearly performs better than random, but is also far from perfect." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature Selection: Which Variable is Most Predictive?\n", "\n", "The model above considered all available variables for prediction. Next, we want to investigate which single variable is the best risk predictor. Therefore, we fit a Cox model to each variable individually and record the c-index on the training set." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Karnofsky_score 0.709280\n", "Celltype=smallcell 0.572581\n", "Celltype=large 0.561620\n", "Celltype=squamous 0.550545\n", "Treatment=test 0.525386\n", "Age_in_years 0.515107\n", "Months_from_Diagnosis 0.509030\n", "Prior_therapy=yes 0.494434\n", "dtype: float64" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "\n", "def fit_and_score_features(X, y):\n", " n_features = X.shape[1]\n", " scores = np.empty(n_features)\n", " m = CoxPHSurvivalAnalysis()\n", " for j in range(n_features):\n", " Xj = X[:, j:j+1]\n", " m.fit(Xj, y)\n", " scores[j] = m.score(Xj, y)\n", " return scores\n", "\n", "scores = fit_and_score_features(data_x_numeric.values, data_y)\n", "pd.Series(scores, index=data_x_numeric.columns).sort_values(ascending=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Karnofsky_score` is the best variable, whereas `Months_from_Diagnosis` and `Prior_therapy='yes'` have almost no predictive power on their own.\n", "\n", "Next, we want to build a parsimonious model by excluding irrelevant features. We could use the ranking from above, but would need to determine what the optimal cut-off should be. Luckily, scikit-learn has built-in support for performing grid search.\n", "\n", "First, we create a pipeline that puts all the parts together." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "from sklearn.feature_selection import SelectKBest\n", "from sklearn.pipeline import Pipeline\n", "\n", "pipe = Pipeline([('encode', OneHotEncoder()),\n", " ('select', SelectKBest(fit_and_score_features, k=3)),\n", " ('model', CoxPHSurvivalAnalysis())])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we need to define the range of parameters we want to explore during grid search. Here, we want to optimize the parameter `k` of the `SelectKBest` class and allow `k` to vary from 1 feature to all 8 features." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
mean_fit_timestd_fit_timemean_score_timestd_score_timeparam_select__kparamssplit0_test_scoresplit1_test_scoresplit2_test_scoremean_test_scorestd_test_scorerank_test_scoresplit0_train_scoresplit1_train_scoresplit2_train_scoremean_train_scorestd_train_score
20.1268780.0088250.0054260.0000593{'select__k': 3}0.6506280.7182200.7546490.7074910.04306710.7655690.6956470.7140820.7250990.029590
40.1257020.0087850.0052930.0001015{'select__k': 5}0.6448740.7383470.7288220.7038340.04209820.7839460.6982070.7184980.7335510.036585
30.1255280.0080090.0052790.0000404{'select__k': 4}0.6506280.7198090.7257230.6985230.03413830.7681210.6910370.7073270.7221620.033172
10.1290950.0119510.0053580.0000542{'select__k': 2}0.6307530.7171610.7479340.6982560.04960440.7585500.6836110.7056380.7159330.031448
00.1322640.0161410.0053950.0001101{'select__k': 1}0.6307530.7150420.7370870.6939820.04584350.7446400.6766970.6952460.7055270.028675
50.1250610.0087720.0052720.0000496{'select__k': 6}0.6574270.6694920.7246900.6835720.02917960.7839460.6988480.7161600.7329850.036722
60.1253440.0090780.0052250.0000247{'select__k': 7}0.6548120.6599580.7148760.6762690.02708370.7884120.6955190.7121330.7320210.040447
70.1249110.0090410.0052940.0000548{'select__k': 8}0.6569040.6536020.7169420.6755160.02900480.7863710.6950060.7136920.7316900.039411
\n", "
" ], "text/plain": [ " mean_fit_time std_fit_time mean_score_time std_score_time \\\n", "2 0.126878 0.008825 0.005426 0.000059 \n", "4 0.125702 0.008785 0.005293 0.000101 \n", "3 0.125528 0.008009 0.005279 0.000040 \n", "1 0.129095 0.011951 0.005358 0.000054 \n", "0 0.132264 0.016141 0.005395 0.000110 \n", "5 0.125061 0.008772 0.005272 0.000049 \n", "6 0.125344 0.009078 0.005225 0.000024 \n", "7 0.124911 0.009041 0.005294 0.000054 \n", "\n", " param_select__k params split0_test_score split1_test_score \\\n", "2 3 {'select__k': 3} 0.650628 0.718220 \n", "4 5 {'select__k': 5} 0.644874 0.738347 \n", "3 4 {'select__k': 4} 0.650628 0.719809 \n", "1 2 {'select__k': 2} 0.630753 0.717161 \n", "0 1 {'select__k': 1} 0.630753 0.715042 \n", "5 6 {'select__k': 6} 0.657427 0.669492 \n", "6 7 {'select__k': 7} 0.654812 0.659958 \n", "7 8 {'select__k': 8} 0.656904 0.653602 \n", "\n", " split2_test_score mean_test_score std_test_score rank_test_score \\\n", "2 0.754649 0.707491 0.043067 1 \n", "4 0.728822 0.703834 0.042098 2 \n", "3 0.725723 0.698523 0.034138 3 \n", "1 0.747934 0.698256 0.049604 4 \n", "0 0.737087 0.693982 0.045843 5 \n", "5 0.724690 0.683572 0.029179 6 \n", "6 0.714876 0.676269 0.027083 7 \n", "7 0.716942 0.675516 0.029004 8 \n", "\n", " split0_train_score split1_train_score split2_train_score \\\n", "2 0.765569 0.695647 0.714082 \n", "4 0.783946 0.698207 0.718498 \n", "3 0.768121 0.691037 0.707327 \n", "1 0.758550 0.683611 0.705638 \n", "0 0.744640 0.676697 0.695246 \n", "5 0.783946 0.698848 0.716160 \n", "6 0.788412 0.695519 0.712133 \n", "7 0.786371 0.695006 0.713692 \n", "\n", " mean_train_score std_train_score \n", "2 0.725099 0.029590 \n", "4 0.733551 0.036585 \n", "3 0.722162 0.033172 \n", "1 0.715933 0.031448 \n", "0 0.705527 0.028675 \n", "5 0.732985 0.036722 \n", "6 0.732021 0.040447 \n", "7 0.731690 0.039411 " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.model_selection import GridSearchCV\n", "\n", "param_grid = {'select__k': np.arange(1, data_x_numeric.shape[1] + 1)}\n", "gcv = GridSearchCV(pipe, param_grid, return_train_score=True, cv=3, iid=True)\n", "gcv.fit(data_x, data_y)\n", "\n", "pd.DataFrame(gcv.cv_results_).sort_values(by='mean_test_score', ascending=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The results show that it is sufficient to select the 3 most predictive features." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Celltype=large -0.067277\n", "Celltype=smallcell 0.271007\n", "Karnofsky_score -0.031285\n", "dtype: float64" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pipe.set_params(**gcv.best_params_)\n", "pipe.fit(data_x, data_y)\n", "\n", "encoder, transformer, final_estimator = [s[1] for s in pipe.steps]\n", "pd.Series(final_estimator.coef_, index=encoder.encoded_columns_[transformer.get_support()])" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## What's next?\n", "\n", "Cox's proportional hazards model is by far the most popular survival model, because once trained, it is easy to interpret. However, if prediction performance is the main objective, more sophisticated, non-linear or ensemble models might lead to better results. Check-out\n", "[this notebook](https://nbviewer.jupyter.org/github/sebp/scikit-survival/blob/master/examples/evaluating-survival-models.ipynb) for getting a better understanding on how to evaluate survival models,\n", "and [this notebook](https://nbviewer.jupyter.org/github/sebp/scikit-survival/blob/master/examples/survival-svm.ipynb) to learn more about Kernel Survival Support Vector Machines. The [API reference](https://scikit-survival.readthedocs.io/en/latest/api.html) contains a full list of models that are available within **scikit-survival**. In addition, you can use any unsupervised pre-processing method available with scikit-learn, for instance, you could perform dimensionality reduction using [Non-Negative Matrix Factorization (NMF)](http://scikit-learn.org/dev/modules/generated/sklearn.decomposition.NMF.html#sklearn.decomposition.NMF), before training a Cox model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }