{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Cox-Time\n", "\n", "In this notebook we will train the [Cox-Time method](http://jmlr.org/papers/volume20/18-424/18-424.pdf).\n", "We will use the METABRIC data sets as an example\n", "\n", "A more detailed introduction to the `pycox` package can be found in [this notebook](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/01_introduction.ipynb) about the `LogisticHazard` method.\n", "\n", "The main benefit Cox-Time (and the other Cox methods) has over Logistic-Hazard is that it is a continuous-time method, meaning we do not need to discretize the time scale." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn_pandas import DataFrameMapper\n", "\n", "import torch\n", "import torchtuples as tt\n", "\n", "from pycox.datasets import metabric\n", "from pycox.models import CoxTime\n", "from pycox.models.cox_time import MLPVanillaCoxTime\n", "from pycox.evaluation import EvalSurv" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "## Uncomment to install `sklearn-pandas`\n", "# ! pip install sklearn-pandas" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "np.random.seed(1234)\n", "_ = torch.manual_seed(123)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We load the METABRIC data set and split in train, test and validation." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "df_train = metabric.read_df()\n", "df_test = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_test.index)\n", "df_val = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_val.index)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x0x1x2x3x4x5x6x7x8durationevent
05.6038347.81139210.7979885.9676071.01.00.01.056.84000099.3333360
15.2848829.58104310.2046205.6649701.00.00.01.085.94000295.7333301
36.6540175.3418468.6463795.6558880.00.00.00.066.910004239.3000030
45.4567475.33974110.5557246.0084291.00.00.01.067.84999856.9333341
55.4258266.33118210.4551455.7490531.01.00.01.070.519997123.5333330
\n", "
" ], "text/plain": [ " x0 x1 x2 x3 x4 x5 x6 x7 x8 \\\n", "0 5.603834 7.811392 10.797988 5.967607 1.0 1.0 0.0 1.0 56.840000 \n", "1 5.284882 9.581043 10.204620 5.664970 1.0 0.0 0.0 1.0 85.940002 \n", "3 6.654017 5.341846 8.646379 5.655888 0.0 0.0 0.0 0.0 66.910004 \n", "4 5.456747 5.339741 10.555724 6.008429 1.0 0.0 0.0 1.0 67.849998 \n", "5 5.425826 6.331182 10.455145 5.749053 1.0 1.0 0.0 1.0 70.519997 \n", "\n", " duration event \n", "0 99.333336 0 \n", "1 95.733330 1 \n", "3 239.300003 0 \n", "4 56.933334 1 \n", "5 123.533333 0 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature transforms\n", "We have 9 covariates, in addition to the durations and event indicators.\n", "\n", "We will standardize the 5 numerical covariates, and leave the binary variables as is. As variables needs to be of type `'float32'`, as this is required by pytorch." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "cols_standardize = ['x0', 'x1', 'x2', 'x3', 'x8']\n", "cols_leave = ['x4', 'x5', 'x6', 'x7']\n", "\n", "standardize = [([col], StandardScaler()) for col in cols_standardize]\n", "leave = [(col, None) for col in cols_leave]\n", "\n", "x_mapper = DataFrameMapper(standardize + leave)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "x_train = x_mapper.fit_transform(df_train).astype('float32')\n", "x_val = x_mapper.transform(df_val).astype('float32')\n", "x_test = x_mapper.transform(df_test).astype('float32')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The targets (durations and events) also needs to be arrays of type `'float32'`, and with the `CoxTime.label_transform` we standardize the durations." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "labtrans = CoxTime.label_transform()\n", "get_target = lambda df: (df['duration'].values, df['event'].values)\n", "y_train = labtrans.fit_transform(*get_target(df_train))\n", "y_val = labtrans.transform(*get_target(df_val))\n", "durations_test, events_test = get_target(df_test)\n", "val = tt.tuplefy(x_val, y_val)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((305, 9), ((305,), (305,)))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val.shapes()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With `TupleTree` (the results of `tt.tuplefy`) we can easily repeat the validation dataset multiple times. This will be useful for reduce the variance of the validation loss, as the validation loss of `CoxTime` is not deterministic." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((610, 9), ((610,), (610,)))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "val.repeat(2).cat().shapes()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Neural net\n", "\n", "We create a simple MLP with two hidden layers, ReLU activations, batch norm and dropout. \n", "The net required by `CoxTime` is slightly different than most of the other methods as it also take `time` and an additional input argument. \n", "We have therefore crated the `MLPVanillaCoxTime` class that is a suitable version of `tt.practical.MLPVanilla`.\n", "This class also removes the options for setting `out_features` and `output_bias` as they should be `1` and `False`, respectively.\n", "\n", "To see the code for the networks call `??MLPVanillaCoxTime`." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "in_features = x_train.shape[1]\n", "num_nodes = [32, 32]\n", "batch_norm = True\n", "dropout = 0.1\n", "net = MLPVanillaCoxTime(in_features, num_nodes, batch_norm, dropout)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training the model\n", "\n", "To train the model we need to define an optimizer. You can choose any `torch.optim` optimizer, but here we instead use one from `tt.optim` as it has some added functionality.\n", "We use the `Adam` optimizer, but instead of choosing a learning rate, we will use the scheme proposed by [Smith 2017](https://arxiv.org/pdf/1506.01186.pdf) to find a suitable learning rate with `model.lr_finder`. See [this post](https://towardsdatascience.com/finding-good-learning-rate-and-the-one-cycle-policy-7159fe1db5d6) for an explanation.\n", "\n", "We also set `labtrans` which connects the output nodes of the network the the label transform of the durations. This is only useful for prediction and does not affect the training procedure." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "model = CoxTime(net, tt.optim.Adam, labtrans=labtrans)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "batch_size = 256\n", "lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance=2)\n", "_ = lrfinder.plot()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.050941380148164093" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lrfinder.get_best_lr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Often, this learning rate is a little high, so we instead set it manually to 0.01" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "model.optimizer.set_lr(0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We include the `EarlyStopping` callback to stop training when the validation loss stops improving. After training, this callback will also load the best performing model in terms of validation loss." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "epochs = 512\n", "callbacks = [tt.callbacks.EarlyStopping()]\n", "verbose = True" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\t[0s / 0s],\t\ttrain_loss: 0.6933,\tval_loss: 0.6385\n", "1:\t[0s / 0s],\t\ttrain_loss: 0.6647,\tval_loss: 0.6199\n", "2:\t[0s / 0s],\t\ttrain_loss: 0.6206,\tval_loss: 0.6070\n", "3:\t[0s / 0s],\t\ttrain_loss: 0.6209,\tval_loss: 0.6077\n", "4:\t[0s / 0s],\t\ttrain_loss: 0.6191,\tval_loss: 0.5916\n", "5:\t[0s / 0s],\t\ttrain_loss: 0.5963,\tval_loss: 0.5833\n", "6:\t[0s / 0s],\t\ttrain_loss: 0.5866,\tval_loss: 0.6018\n", "7:\t[0s / 0s],\t\ttrain_loss: 0.5936,\tval_loss: 0.6034\n", "8:\t[0s / 0s],\t\ttrain_loss: 0.5827,\tval_loss: 0.5987\n", "9:\t[0s / 0s],\t\ttrain_loss: 0.5864,\tval_loss: 0.6055\n", "10:\t[0s / 0s],\t\ttrain_loss: 0.5861,\tval_loss: 0.6134\n", "11:\t[0s / 0s],\t\ttrain_loss: 0.5712,\tval_loss: 0.5782\n", "12:\t[0s / 0s],\t\ttrain_loss: 0.5819,\tval_loss: 0.5991\n", "13:\t[0s / 0s],\t\ttrain_loss: 0.5775,\tval_loss: 0.5765\n", "14:\t[0s / 0s],\t\ttrain_loss: 0.5685,\tval_loss: 0.5881\n", "15:\t[0s / 0s],\t\ttrain_loss: 0.5803,\tval_loss: 0.5782\n", "16:\t[0s / 0s],\t\ttrain_loss: 0.5956,\tval_loss: 0.5880\n", "17:\t[0s / 0s],\t\ttrain_loss: 0.5657,\tval_loss: 0.5825\n", "18:\t[0s / 0s],\t\ttrain_loss: 0.5677,\tval_loss: 0.6120\n", "19:\t[0s / 1s],\t\ttrain_loss: 0.5648,\tval_loss: 0.6027\n", "20:\t[0s / 1s],\t\ttrain_loss: 0.5777,\tval_loss: 0.6050\n", "21:\t[0s / 1s],\t\ttrain_loss: 0.5633,\tval_loss: 0.5860\n", "22:\t[0s / 1s],\t\ttrain_loss: 0.5808,\tval_loss: 0.5941\n", "23:\t[0s / 1s],\t\ttrain_loss: 0.5830,\tval_loss: 0.5917\n", "CPU times: user 2.68 s, sys: 94.4 ms, total: 2.77 s\n", "Wall time: 1.32 s\n" ] } ], "source": [ "%%time\n", "log = model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose,\n", " val_data=val.repeat(10).cat())" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_ = log.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can get the partial log-likelihood" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-4.855360578086461" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.partial_log_likelihood(*val).mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prediction\n", "\n", "For evaluation we first need to obtain survival estimates for the test set.\n", "This can be done with `model.predict_surv` which returns an array of survival estimates, or with `model.predict_surv_df` which returns the survival estimates as a dataframe.\n", "\n", "However, as Cox-Time is semi-parametric, we first need to get the non-parametric baseline hazard estimates with `compute_baseline_hazards`. \n", "\n", "Note that for large datasets the `sample` argument can be used to estimate the baseline hazard on a subset." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "_ = model.compute_baseline_hazards()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "surv = model.predict_surv_df(x_test)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "surv.iloc[:, :5].plot()\n", "plt.ylabel('S(t | x)')\n", "_ = plt.xlabel('Time')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that because we set `labtrans` in `CoxTime` we get the correct time scale for our predictions." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation\n", "\n", "We can use the `EvalSurv` class for evaluation the concordance, brier score and binomial log-likelihood. Setting `censor_surv='km'` means that we estimate the censoring distribution by Kaplan-Meier on the test set." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.6746906255297508" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev.concordance_td()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)\n", "_ = ev.brier_score(time_grid).plot()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.15931537174591134" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev.integrated_brier_score(time_grid)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.4700909149743365" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev.integrated_nbll(time_grid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }