{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "azrczYYVvEQb" }, "source": [ "# Survival Analysis for Deep Learning\n", "\n", "Most machine learning algorithms have been developed to perform classification or regression. However, in clinical research we often want to estimate the time to and event, such as death or recurrence of cancer, which leads to a special type of learning task that is distinct from classification and regression. This task is termed *survival analysis*, but is also referred to as time-to-event analysis or reliability analysis.\n", "Many machine learning algorithms have been adopted to perform survival analysis:\n", "[Support Vector Machines](https://scholar.google.com/scholar?oi=bibs&cluster=18092275419152143443),\n", "[Random Forest](https://scholar.google.com/scholar?cluster=16319510831191377024),\n", "or [Boosting](https://scholar.google.com/scholar?cluster=14069073471114367075).\n", "It has only been recently that survival analysis entered the era of deep learning, which is the focus of this post.\n", "\n", "You will learn how to train a convolutional neural network to predict time to a (generated) event from MNIST images, using a loss function specific to survival analysis. The [first part](#Primer-on-Survival-Analysis), will cover some basic terms and quantities used in survival analysis (feel free to skip this part if you are already familiar). In the [second part](#Generating-Synthetic-Survival-Data-from-MNIST), we will generate synthetic survival data from MNIST images and visualize it. In the [third part](#Cox's-Proportional-Hazards-Model), we will briefly revisit the most popular survival model of them all and learn how it can be used as a loss function for training a neural network.\n", "[Finally](#Creating-a-Convolutional-Neural-Network-for-Survival-Analysis-on-MNIST), we put all the pieces together and train a convolutional neural network on MNIST and predict survival functions on the test data.\n", "\n", "\n", "## Requirements:\n", "\n", "Please make sure you have the following packages installed. All are available via [PyPI](https://pypi.org) or [Anaconda](https://www.anaconda.com/distribution/).\n", "\n", "- [numpy](https://www.numpy.org/)\n", "- [matplotlib](https://matplotlib.org/)\n", "- [pandas](https://pandas.pydata.org/)\n", "- [scikit-survival](https://github.com/sebp/scikit-survival/)\n", "- [tensorflow](https://www.tensorflow.org/) >= 2.0.0\n", "\n", "You can also run this notebook in [Google Colaboratory](https://colab.research.google.com/github/sebp/survival-cnn-estimator/blob/master/tutorial_tf2.ipynb) and install scikit-survival using the command below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "colab_type": "code", "executionInfo": { "elapsed": 12742, "status": "ok", "timestamp": 1589637598598, "user": { "displayName": "Sebastian Pölsterl", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GibzrfdaHThaPgjaoGC9Dfb7YXvuTd-tFLbzoO2Gb6WEwyKUsHIqQpwFQAnUAKIewfdDQm7LzvGMH1MzU0PGgU9JwdQ2_9F-5kiQH_DlB1ZaFKpkST5Oha3_n4379GpkI6TgsLF0WZU_7qikJ61kKM2ytdtJeEz5VVwoz3XdhEPaqbu57hGpX4JZ2aGKRbmVu9JQOU9u8Ym0_w4HOaywrK2s5F1H700i1y89hljff2afH6WLPCP2XSIW2-eK7Mkk1rWCYHvdKt2Q1F2cjNOVoPO3C_LDkAfl1U33HWfwTKRKrlf_fsw5BrBVeV65FDP2xxtFj47t1uNTni3fq9DSzMb30dX4v0k0zjKVI_PtxFOmm0VAhr1NYrNh5PgBfbgxjcCooOJbNg21wsosLvYazfQdbLZfeCNq79hK6ljJblvcDUdu9l8oV5WftCmYipe-pWi5_hd3RSeiJoHg1bRQctViY6KvOx8taENqNS6P3IY1zYVTlNYgews5dtAVR11ei3ofgB5vcBa-bfqgal4ZlJNcsCSwNzUaKMiQ3twG19ESCSnbgJTbLEb6hHeCyhGKoyRwFjCgvEixoU04BnxGH5SEh_qiXf4euMiEaALYK7SrH35KWoZTkW9wXShGv3CmgCdqyOloiG3QsusKVmB9PPCuLjw0A9ixzd3ktRotErkEH2N1_EAdQqti9CK9A3yirLJSyk7Vs6Uem3Jv1Jr21mHsFocw53FciKfwUXm-LydQGUQ9TvgiZepRPHJCypj3l-6Dg=s64", "userId": "18353690321324822306" }, "user_tz": -120 }, "id": "kb7TWFXivEQc", "jupyter": { "outputs_hidden": true }, "outputId": "17fa15cf-e4dd-441e-e667-31e39722f5e8" }, "outputs": [], "source": [ "!pip uninstall --yes --quiet osqp\n", "!pip install scikit-survival" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "executionInfo": { "elapsed": 2942, "status": "ok", "timestamp": 1589637614533, "user": { "displayName": "Sebastian Pölsterl", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GibzrfdaHThaPgjaoGC9Dfb7YXvuTd-tFLbzoO2Gb6WEwyKUsHIqQpwFQAnUAKIewfdDQm7LzvGMH1MzU0PGgU9JwdQ2_9F-5kiQH_DlB1ZaFKpkST5Oha3_n4379GpkI6TgsLF0WZU_7qikJ61kKM2ytdtJeEz5VVwoz3XdhEPaqbu57hGpX4JZ2aGKRbmVu9JQOU9u8Ym0_w4HOaywrK2s5F1H700i1y89hljff2afH6WLPCP2XSIW2-eK7Mkk1rWCYHvdKt2Q1F2cjNOVoPO3C_LDkAfl1U33HWfwTKRKrlf_fsw5BrBVeV65FDP2xxtFj47t1uNTni3fq9DSzMb30dX4v0k0zjKVI_PtxFOmm0VAhr1NYrNh5PgBfbgxjcCooOJbNg21wsosLvYazfQdbLZfeCNq79hK6ljJblvcDUdu9l8oV5WftCmYipe-pWi5_hd3RSeiJoHg1bRQctViY6KvOx8taENqNS6P3IY1zYVTlNYgews5dtAVR11ei3ofgB5vcBa-bfqgal4ZlJNcsCSwNzUaKMiQ3twG19ESCSnbgJTbLEb6hHeCyhGKoyRwFjCgvEixoU04BnxGH5SEh_qiXf4euMiEaALYK7SrH35KWoZTkW9wXShGv3CmgCdqyOloiG3QsusKVmB9PPCuLjw0A9ixzd3ktRotErkEH2N1_EAdQqti9CK9A3yirLJSyk7Vs6Uem3Jv1Jr21mHsFocw53FciKfwUXm-LydQGUQ9TvgiZepRPHJCypj3l-6Dg=s64", "userId": "18353690321324822306" }, "user_tz": -120 }, "id": "ThvRyUyVvEQf", "jupyter": { "outputs_hidden": true }, "outputId": "872381c5-d695-4bed-a4f8-bcc6d933aa34" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using Tensorflow: 2.2.0\n" ] } ], "source": [ "from typing import Any, Dict, Iterable, Sequence, Tuple, Optional, Union\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "from pathlib import Path\n", "from sksurv.nonparametric import kaplan_meier_estimator\n", "from sksurv.metrics import concordance_index_censored\n", "import tensorflow as tf\n", "from tensorflow.keras.datasets import mnist\n", "\n", "print(\"Using Tensorflow:\", tf.__version__)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from distutils.version import LooseVersion\n", "\n", "assert LooseVersion(tf.__version__) >= LooseVersion(\"2.0.0\"), \\\n", " \"This notebook requires TensorFlow 2.0 or above.\"" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "cdL25jjJvEQm" }, "source": [ "## Primer on Survival Analysis\n", "\n", "The objective in survival analysis is to establish a connection between covariates and the time of an event. The name *survival analysis* originates from clinical research, where predicting the time to death, i.e., survival, is often the main objective. Survival analysis is a type of regression problem (one wants to predict a continuous value), but with a twist. It differs from traditional regression by the fact that parts of the training data can only be partially observed – they are *censored*.\n", "\n", "As an example, consider a clinical study that has been carried out over a 1 year period as in the figure below." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "7UZNLT1GvEQm" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "QqEgCNOxvEQn" }, "source": [ "Patient A was lost to follow-up after three months with no recorded event, patient B experienced an event four and a half months after enrollment, patient C withdrew from the study two months after enrollment, and patient E did not experience any event before the study ended. Consequently, the *exact time* of an event could only be recorded for patients B and D; their records are *uncensored*. For the remaining patients it is unknown whether they did or did not experience an event after termination of the study. The only valid information that is available for patients A, C, and E is that they were event-free up to their last follow-up. Therefore, their records are *censored*.\n", "\n", "Formally, each patient record consists of the time $t>0$ when an event occurred or the time $c>0$ of censoring. Since censoring and experiencing and event are mutually exclusive, it is common to define an event indicator $\\delta \\in \\{0;1\\}$ and the observable survival time $y>0$. The observable time $y$ of a right censored time of event is defined as\n", "\n", "$$\n", "y = \\min(t, c) = \n", "\\begin{cases} \n", "t & \\text{if } \\delta = 1 , \\\\\n", "c & \\text{if } \\delta = 0 .\n", "\\end{cases}\n", "$$\n", "\n", "Consequently, survival analysis demands for models that take partially observed, i.e., censored, event times into account.\n", "\n", "\n", "## Basic Quantities\n", "\n", "Typically, the survival time is modelled as a continuous non-negative random variable $T$, from which basic quantities for time-to-event analysis can be derived, most importantly, the *survival function* and the *hazard function*.\n", "\n", "- The **survival function** $S(t)$ returns the probability of survival beyond time $t$ and is defined as $S(t) = P(T > t)$. It is non-increasing with $S(0) = 1$, and $S(\\infty) = 0$.\n", "- The **hazard function** $h(t)$ denotes an approximate probability (it is not bounded from above) that an event occurs in the small time interval $[t; t + \\Delta[$, under the condition that an individual would remain event-free up to time $t$:\n", "$$\n", "h(t) = \\lim_{\\Delta t \\rightarrow 0} \\frac{P(t \\leq T < t + \\Delta t \\mid T \\geq t)}{\\Delta t} \\geq 0\n", "$$\n", "Alternative names for the hazard function are conditional failure rate, conditional mortality rate, or instantaneous failure rate. In contrast to the survival function, which describes the absence of an event, the hazard function provides information about the occurrence of an event." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Pb3xh-qGvEQn" }, "source": [ "## Generating Synthetic Survival Data from MNIST\n", "\n", "To start off, we are using images from the MNIST dataset and will synthetically generate\n", "survival times based on the digit each image represents.\n", "We associate a survival time (or risk score) with each class of the ten digits in MNIST. First, we randomly assign each class label to one of four overall risk groups, such that some digits will correspond to better and others to worse survival. Next, we generate risk scores that indicate how big the risk of experiencing an event is, relative to each other." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "5Wukwbc5vEQo", "jupyter": { "outputs_hidden": true } }, "outputs": [], "source": [ "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", "\n", "x_train = x_train.astype(np.float32) / 255.\n", "x_test = x_test.astype(np.float32) / 255.\n", "\n", "y_train = y_train.astype(np.int32)\n", "y_test = y_test.astype(np.int32)\n", "\n", "y = np.concatenate((y_train, y_test))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 390 }, "colab_type": "code", "executionInfo": { "elapsed": 1008, "status": "ok", "timestamp": 1589637619418, "user": { "displayName": "Sebastian Pölsterl", "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GibzrfdaHThaPgjaoGC9Dfb7YXvuTd-tFLbzoO2Gb6WEwyKUsHIqQpwFQAnUAKIewfdDQm7LzvGMH1MzU0PGgU9JwdQ2_9F-5kiQH_DlB1ZaFKpkST5Oha3_n4379GpkI6TgsLF0WZU_7qikJ61kKM2ytdtJeEz5VVwoz3XdhEPaqbu57hGpX4JZ2aGKRbmVu9JQOU9u8Ym0_w4HOaywrK2s5F1H700i1y89hljff2afH6WLPCP2XSIW2-eK7Mkk1rWCYHvdKt2Q1F2cjNOVoPO3C_LDkAfl1U33HWfwTKRKrlf_fsw5BrBVeV65FDP2xxtFj47t1uNTni3fq9DSzMb30dX4v0k0zjKVI_PtxFOmm0VAhr1NYrNh5PgBfbgxjcCooOJbNg21wsosLvYazfQdbLZfeCNq79hK6ljJblvcDUdu9l8oV5WftCmYipe-pWi5_hd3RSeiJoHg1bRQctViY6KvOx8taENqNS6P3IY1zYVTlNYgews5dtAVR11ei3ofgB5vcBa-bfqgal4ZlJNcsCSwNzUaKMiQ3twG19ESCSnbgJTbLEb6hHeCyhGKoyRwFjCgvEixoU04BnxGH5SEh_qiXf4euMiEaALYK7SrH35KWoZTkW9wXShGv3CmgCdqyOloiG3QsusKVmB9PPCuLjw0A9ixzd3ktRotErkEH2N1_EAdQqti9CK9A3yirLJSyk7Vs6Uem3Jv1Jr21mHsFocw53FciKfwUXm-LydQGUQ9TvgiZepRPHJCypj3l-6Dg=s64", "userId": "18353690321324822306" }, "user_tz": -120 }, "id": "XdgLp1FkvEQq", "jupyter": { "outputs_hidden": true }, "outputId": "463ac931-4868-443e-a32f-de3a8de7a283" }, "outputs": [ { "data": { "text/html": [ "
| \n", " | risk_score | \n", "risk_group | \n", "
|---|---|---|
| class_label | \n", "\n", " | \n", " |
| 0 | \n", "3.071 | \n", "3 | \n", "
| 1 | \n", "2.555 | \n", "2 | \n", "
| 2 | \n", "0.058 | \n", "0 | \n", "
| 3 | \n", "1.790 | \n", "1 | \n", "
| 4 | \n", "2.515 | \n", "2 | \n", "
| 5 | \n", "3.031 | \n", "3 | \n", "
| 6 | \n", "1.750 | \n", "1 | \n", "
| 7 | \n", "2.475 | \n", "2 | \n", "
| 8 | \n", "0.018 | \n", "0 | \n", "
| 9 | \n", "2.435 | \n", "2 | \n", "