{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# DeepHit\n", "\n", "In this notebook we show an example of how we can fit a [DeepHit](http://medianetlab.ee.ucla.edu/papers/AAAI_2018_DeepHit) model.\n", "\n", "We will: \n", "- use SUPPORT as an example dataset,\n", "- use entity embeddings for categorical variables,\n", "- use the [AdamWR optimizer](https://arxiv.org/pdf/1711.05101.pdf) with cyclical learning rates,\n", "- use the scheeme proposed by [Smith 2017](https://arxiv.org/pdf/1506.01186.pdf) to find a suitable learning rate." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import torch\n", "from pycox import datasets\n", "from pycox.models import DeepHitSingle\n", "from pycox.evaluation import EvalSurv\n", "from torchtuples import optim\n", "from torchtuples import callbacks as cb\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn_pandas import DataFrameMapper\n", "from pycox.preprocessing.feature_transforms import OrderedCategoricalLong\n", "from pycox.preprocessing.label_transforms import LabTransDiscreteSurv\n", "from torchtuples.practical import MixedInputMLP" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`sklearn_pandas` can be installed with `! pip install sklearn-pandas`" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "np.random.seed(123456)\n", "_ = torch.manual_seed(123456)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We load the SUPPORT data set and split in train, test and validation." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "df_train = datasets.support.read_df()\n", "df_test = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_test.index)\n", "df_val = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_val.index)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | x0 | \n", "x1 | \n", "x2 | \n", "x3 | \n", "x4 | \n", "x5 | \n", "x6 | \n", "x7 | \n", "x8 | \n", "x9 | \n", "x10 | \n", "x11 | \n", "x12 | \n", "x13 | \n", "duration | \n", "event | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "82.709961 | \n", "1.0 | \n", "2.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "160.0 | \n", "55.0 | \n", "16.0 | \n", "38.195309 | \n", "142.0 | \n", "19.000000 | \n", "1.099854 | \n", "30.0 | \n", "1 | \n", "
1 | \n", "79.660950 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "54.0 | \n", "67.0 | \n", "16.0 | \n", "38.000000 | \n", "142.0 | \n", "10.000000 | \n", "0.899902 | \n", "1527.0 | \n", "0 | \n", "
4 | \n", "71.794983 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "65.0 | \n", "135.0 | \n", "40.0 | \n", "38.593750 | \n", "146.0 | \n", "0.099991 | \n", "0.399963 | \n", "7.0 | \n", "1 | \n", "
5 | \n", "49.932980 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "70.0 | \n", "105.0 | \n", "33.0 | \n", "38.195309 | \n", "127.0 | \n", "5.299805 | \n", "1.199951 | \n", "50.0 | \n", "1 | \n", "
6 | \n", "62.942989 | \n", "0.0 | \n", "5.0 | \n", "2.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "116.0 | \n", "130.0 | \n", "35.0 | \n", "38.195309 | \n", "133.0 | \n", "14.099609 | \n", "0.799927 | \n", "381.0 | \n", "0 | \n", "