{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# DeepHit with Competing Risks\n", "\n", "In this notebook we give an example of how to apply the [DeepHit](http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit) method for competing risks.\n", "\n", "The `pycox` package has (so far) limited support for competing-risk data, so the evaluation procedure at the end is somewhat limited." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torchtuples as tt\n", "\n", "from pycox.preprocessing.label_transforms import LabTransDiscreteTime\n", "from pycox.models import DeepHit\n", "from pycox.evaluation import EvalSurv" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "np.random.seed(1234)\n", "_ = torch.manual_seed(1234)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset from the DeepHit repo\n", "\n", "We download a competing risk data set from the DeepHit authors repo.\n", "The dataset is from a simulation study with two event types and censored observations.\n", "\n", "We split in train, val and test." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "url = 'https://raw.githubusercontent.com/chl8856/DeepHit/master/sample%20data/SYNTHETIC/synthetic_comprisk.csv'\n", "df_train = pd.read_csv(url)\n", "df_test = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_test.index)\n", "df_val = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_val.index)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
timelabeltrue_timetrue_labelfeature1feature2feature3feature4feature5feature6feature7feature8feature9feature10feature11feature12
110110.015579-0.846080.487530.651930.20099-0.11238-1.39630-0.188740-0.30001-0.24032-0.38533-1.02450
23423420.4464901.64100-1.745000.31795-1.140600.365600.28110-0.582530-1.690701.20220-0.519201.78400
390920.629460-0.61575-0.32345-0.900200.45360-0.619922.162400.198750-1.11960-2.73210-0.25673-0.81836
51121120.4870100.520861.99370-0.947360.243711.065500.576860.0191920.232120.48023-0.730961.43960
6370402-1.183700-0.31602-0.58640-0.53890-1.158301.040100.61938-0.415420-0.50700-2.183000.973200.97753
\n", "
" ], "text/plain": [ " time label true_time true_label feature1 feature2 feature3 feature4 \\\n", "1 1 0 1 1 0.015579 -0.84608 0.48753 0.65193 \n", "2 34 2 34 2 0.446490 1.64100 -1.74500 0.31795 \n", "3 9 0 9 2 0.629460 -0.61575 -0.32345 -0.90020 \n", "5 11 2 11 2 0.487010 0.52086 1.99370 -0.94736 \n", "6 37 0 40 2 -1.183700 -0.31602 -0.58640 -0.53890 \n", "\n", " feature5 feature6 feature7 feature8 feature9 feature10 feature11 \\\n", "1 0.20099 -0.11238 -1.39630 -0.188740 -0.30001 -0.24032 -0.38533 \n", "2 -1.14060 0.36560 0.28110 -0.582530 -1.69070 1.20220 -0.51920 \n", "3 0.45360 -0.61992 2.16240 0.198750 -1.11960 -2.73210 -0.25673 \n", "5 0.24371 1.06550 0.57686 0.019192 0.23212 0.48023 -0.73096 \n", "6 -1.15830 1.04010 0.61938 -0.415420 -0.50700 -2.18300 0.97320 \n", "\n", " feature12 \n", "1 -1.02450 \n", "2 1.78400 \n", "3 -0.81836 \n", "5 1.43960 \n", "6 0.97753 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature transforms\n", "\n", "The covariates are standardized, so we don't need to any preprocessing." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "get_x = lambda df: (df\n", " .drop(columns=['time', 'label', 'true_time', 'true_label'])\n", " .values.astype('float32'))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "x_train = get_x(df_train)\n", "x_val = get_x(df_val)\n", "x_test = get_x(df_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Label transforms\n", "\n", "Currently, the `pycox` package is mainly focused on single-event data, so there is no dedicated label transformer for competing risks. \n", "So, we make a simple one on our own, based on the the transform for the discrete methods `LabTransDiscreteSurv`.\n", "\n", "The class returns durations (as integers) and event types (as integers) where 0 is reserved for censored observations.\n", "\n", "We discretize the data to `num_durations` time points, which can be found with `labtrans.cuts`." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class LabTransform(LabTransDiscreteTime):\n", " def transform(self, durations, events):\n", " durations, is_event = super().transform(durations, events > 0)\n", " events[is_event == 0] = 0\n", " return durations, events.astype('int64')" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "num_durations = 10\n", "labtrans = LabTransform(num_durations)\n", "get_target = lambda df: (df['time'].values, df['label'].values)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "y_train = labtrans.fit_transform(*get_target(df_train))\n", "y_val = labtrans.transform(*get_target(df_val))\n", "durations_test, events_test = get_target(df_test)\n", "val = (x_val, y_val)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([0, 2, 0, 1, 1, 1]), array([0, 2, 0, 2, 0, 1]))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train[0][:6], y_train[1][:6]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0. , 21.33333333, 42.66666667, 64. ,\n", " 85.33333333, 106.66666667, 128. , 149.33333333,\n", " 170.66666667, 192. ])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labtrans.cuts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Network architecture\n", "\n", "Under, we describe two networks that can be for competing risks:\n", "\n", "The first, `SimpleMLP`, is a regular MLP which outputs a `[batch_size x num_risks x num_durations]`, e.g., `[64, 2, 10]` tensor.\n", "\n", "The other, `CauseSpecificNet` is similar to the cause-specific network described in the [DeepHit paper](http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit).\n", "It has the same output shape as the `SimpleMLP`, but is a little more complex to make." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class SimpleMLP(torch.nn.Module):\n", " \"\"\"Simple network structure for competing risks.\n", " \"\"\"\n", " def __init__(self, in_features, num_nodes, num_risks, out_features, batch_norm=True,\n", " dropout=None):\n", " super().__init__()\n", " self.num_risks = num_risks\n", " self.mlp = tt.practical.MLPVanilla(\n", " in_features, num_nodes, num_risks * out_features,\n", " batch_norm, dropout,\n", " )\n", " \n", " def forward(self, input):\n", " out = self.mlp(input)\n", " return out.view(out.size(0), self.num_risks, -1)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "class CauseSpecificNet(torch.nn.Module):\n", " \"\"\"Network structure similar to the DeepHit paper, but without the residual\n", " connections (for simplicity).\n", " \"\"\"\n", " def __init__(self, in_features, num_nodes_shared, num_nodes_indiv, num_risks,\n", " out_features, batch_norm=True, dropout=None):\n", " super().__init__()\n", " self.shared_net = tt.practical.MLPVanilla(\n", " in_features, num_nodes_shared[:-1], num_nodes_shared[-1],\n", " batch_norm, dropout,\n", " )\n", " self.risk_nets = torch.nn.ModuleList()\n", " for _ in range(num_risks):\n", " net = tt.practical.MLPVanilla(\n", " num_nodes_shared[-1], num_nodes_indiv, out_features,\n", " batch_norm, dropout,\n", " )\n", " self.risk_nets.append(net)\n", "\n", " def forward(self, input):\n", " out = self.shared_net(input)\n", " out = [net(out) for net in self.risk_nets]\n", " out = torch.stack(out, dim=1)\n", " return out" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "in_features = x_train.shape[1]\n", "num_nodes_shared = [64, 64]\n", "num_nodes_indiv = [32]\n", "num_risks = y_train[1].max()\n", "out_features = len(labtrans.cuts)\n", "batch_norm = True\n", "dropout = 0.1\n", "\n", "# net = SimpleMLP(in_features, num_nodes_shared, num_risks, out_features)\n", "net = CauseSpecificNet(in_features, num_nodes_shared, num_nodes_indiv, num_risks,\n", " out_features, batch_norm, dropout)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training\n", "\n", "We fit the net with the [AdamWR](https://arxiv.org/abs/1711.05101) cyclic optimizer with initial learning rate of 0.01, decoupled_weight_decay of 0.01, initial cycle length of 1 epoch, and multiply the learning rate with 0.8 and cycle length with 2 at each new cycle.\n", "\n", "The hyperparameters in the DeepHit loss-function, `alpha` and `sigma`, are set to 0.2 and 0.1, respectively. Note that `alpha` here controls the convex combination of the two losses,\n", "$$\\text{loss} = \\alpha \\text{loss}_\\text{NLL} + (1 - \\alpha) \\text{loss}_\\text{rank},$$\n", "and therefore has a different interpretation than in the [DeepHit paper](http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit)." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "optimizer = tt.optim.AdamWR(lr=0.01, decoupled_weight_decay=0.01,\n", " cycle_eta_multiplier=0.8)\n", "model = DeepHit(net, optimizer, alpha=0.2, sigma=0.1,\n", " duration_index=labtrans.cuts)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "epochs = 512\n", "batch_size = 256\n", "callbacks = [tt.callbacks.EarlyStoppingCycle()]\n", "verbose = False # set to True if you want printout" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 3min 53s, sys: 19.9 s, total: 4min 13s\n", "Wall time: 2min 2s\n" ] } ], "source": [ "%%time\n", "log = model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose, val_data=val)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_ = log.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation\n", "\n", "There is, currently, limited support for competing risks in `pycox`, so the evaluation here is just illustrative and should not be considered *best practice*.\n", "\n", "The survival function obtained with `predict_surv_df` is the probability of surviving any of the events, and does, therefore, not distinguish between the event types.\n", "This means that we evaluate this \"single-event case\" as before." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "surv = model.predict_surv_df(x_test)\n", "ev = EvalSurv(surv, durations_test, events_test != 0, censor_surv='km')" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7202672608491353" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev.concordance_td()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.10413137528575585" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev.integrated_brier_score(np.linspace(0, durations_test.max(), 100))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The cumulative incidence function\n", "\n", "The cumulative incidence function, or CIF, is commonly used in settings with competing risks.\n", "We can evaluate the case-specific concordance using these CIF's.\n", "\n", "We use `1 - cif` here because the CIF increases with with risk (while the survival function decreases with higher risk)." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "cif = model.predict_cif(x_test)\n", "cif1 = pd.DataFrame(cif[0], model.duration_index)\n", "cif2 = pd.DataFrame(cif[1], model.duration_index)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "ev1 = EvalSurv(1-cif1, durations_test, events_test == 1, censor_surv='km')\n", "ev2 = EvalSurv(1-cif2, durations_test, events_test == 2, censor_surv='km')" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.7150254240668652" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev1.concordance_td()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.712722920851343" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev2.concordance_td()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot CIF\n", "\n", "Finally, we plot six random individuals and their cumulative incidence functions" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sample = np.random.choice(len(durations_test), 6)\n", "fig, axs = plt.subplots(2, 3, True, True, figsize=(10, 5))\n", "for ax, idx in zip(axs.flat, sample):\n", " pd.DataFrame(cif.transpose()[idx], index=labtrans.cuts).plot(ax=ax)\n", " ax.set_ylabel('CIF')\n", " ax.set_xlabel('Time')\n", " ax.grid(linestyle='--')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }