{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# DeepHit with Competing Risks\n", "\n", "In this notebook we give an example of how to apply the [DeepHit](http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit) method for competing risks.\n", "\n", "The `pycox` package has (so far) limited support for competing-risk data, so the evaluation procedure at the end is somewhat limited." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torchtuples as tt\n", "\n", "from pycox.preprocessing.label_transforms import LabTransDiscreteTime\n", "from pycox.models import DeepHit\n", "from pycox.evaluation import EvalSurv" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "np.random.seed(1234)\n", "_ = torch.manual_seed(1234)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset from the DeepHit repo\n", "\n", "We download a competing risk data set from the DeepHit authors repo.\n", "The dataset is from a simulation study with two event types and censored observations.\n", "\n", "We split in train, val and test." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "url = 'https://raw.githubusercontent.com/chl8856/DeepHit/master/sample%20data/SYNTHETIC/synthetic_comprisk.csv'\n", "df_train = pd.read_csv(url)\n", "df_test = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_test.index)\n", "df_val = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_val.index)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | time | \n", "label | \n", "true_time | \n", "true_label | \n", "feature1 | \n", "feature2 | \n", "feature3 | \n", "feature4 | \n", "feature5 | \n", "feature6 | \n", "feature7 | \n", "feature8 | \n", "feature9 | \n", "feature10 | \n", "feature11 | \n", "feature12 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | \n", "1 | \n", "0 | \n", "1 | \n", "1 | \n", "0.015579 | \n", "-0.84608 | \n", "0.48753 | \n", "0.65193 | \n", "0.20099 | \n", "-0.11238 | \n", "-1.39630 | \n", "-0.188740 | \n", "-0.30001 | \n", "-0.24032 | \n", "-0.38533 | \n", "-1.02450 | \n", "
2 | \n", "34 | \n", "2 | \n", "34 | \n", "2 | \n", "0.446490 | \n", "1.64100 | \n", "-1.74500 | \n", "0.31795 | \n", "-1.14060 | \n", "0.36560 | \n", "0.28110 | \n", "-0.582530 | \n", "-1.69070 | \n", "1.20220 | \n", "-0.51920 | \n", "1.78400 | \n", "
3 | \n", "9 | \n", "0 | \n", "9 | \n", "2 | \n", "0.629460 | \n", "-0.61575 | \n", "-0.32345 | \n", "-0.90020 | \n", "0.45360 | \n", "-0.61992 | \n", "2.16240 | \n", "0.198750 | \n", "-1.11960 | \n", "-2.73210 | \n", "-0.25673 | \n", "-0.81836 | \n", "
5 | \n", "11 | \n", "2 | \n", "11 | \n", "2 | \n", "0.487010 | \n", "0.52086 | \n", "1.99370 | \n", "-0.94736 | \n", "0.24371 | \n", "1.06550 | \n", "0.57686 | \n", "0.019192 | \n", "0.23212 | \n", "0.48023 | \n", "-0.73096 | \n", "1.43960 | \n", "
6 | \n", "37 | \n", "0 | \n", "40 | \n", "2 | \n", "-1.183700 | \n", "-0.31602 | \n", "-0.58640 | \n", "-0.53890 | \n", "-1.15830 | \n", "1.04010 | \n", "0.61938 | \n", "-0.415420 | \n", "-0.50700 | \n", "-2.18300 | \n", "0.97320 | \n", "0.97753 | \n", "