{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PMF Example \n", "\n", "In this notebook, we will present a simple example of the `PMF` method described in [this paper](https://arxiv.org/abs/1910.06724).\n", "\n", "For a more verbose introduction to `pycox` see [this notebook](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/01_introduction.ipynb)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# For preprocessing\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn_pandas import DataFrameMapper \n", "\n", "import torch # For building the networks \n", "import torchtuples as tt # Some useful functions\n", "\n", "from pycox.datasets import metabric\n", "from pycox.models import PMF\n", "from pycox.evaluation import EvalSurv" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "## Uncomment to install `sklearn-pandas`\n", "# ! pip install sklearn-pandas" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "np.random.seed(1234)\n", "_ = torch.manual_seed(123)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We load the METABRIC data set as a pandas DataFrame and split the data in in train, test and validation.\n", "\n", "The `duration` column gives the observed times and the `event` column contains indicators of whether the observation is an event (1) or a censored observation (0)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "df_train = metabric.read_df()\n", "df_test = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_test.index)\n", "df_val = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_val.index)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x0x1x2x3x4x5x6x7x8durationevent
05.6038347.81139210.7979885.9676071.01.00.01.056.84000099.3333360
15.2848829.58104310.2046205.6649701.00.00.01.085.94000295.7333301
36.6540175.3418468.6463795.6558880.00.00.00.066.910004239.3000030
45.4567475.33974110.5557246.0084291.00.00.01.067.84999856.9333341
55.4258266.33118210.4551455.7490531.01.00.01.070.519997123.5333330
\n", "
" ], "text/plain": [ " x0 x1 x2 x3 x4 x5 x6 x7 x8 \\\n", "0 5.603834 7.811392 10.797988 5.967607 1.0 1.0 0.0 1.0 56.840000 \n", "1 5.284882 9.581043 10.204620 5.664970 1.0 0.0 0.0 1.0 85.940002 \n", "3 6.654017 5.341846 8.646379 5.655888 0.0 0.0 0.0 0.0 66.910004 \n", "4 5.456747 5.339741 10.555724 6.008429 1.0 0.0 0.0 1.0 67.849998 \n", "5 5.425826 6.331182 10.455145 5.749053 1.0 1.0 0.0 1.0 70.519997 \n", "\n", " duration event \n", "0 99.333336 0 \n", "1 95.733330 1 \n", "3 239.300003 0 \n", "4 56.933334 1 \n", "5 123.533333 0 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature transforms\n", "\n", "The METABRIC dataset has 9 covariates: `x0, ..., x8`.\n", "We will standardize the 5 numerical covariates, and leave the binary covariates as is.\n", "Note that PyTorch require variables of type `'float32'`.\n", "\n", "We like using the `sklearn_pandas.DataFrameMapper` to make feature mappers." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "cols_standardize = ['x0', 'x1', 'x2', 'x3', 'x8']\n", "cols_leave = ['x4', 'x5', 'x6', 'x7']\n", "\n", "standardize = [([col], StandardScaler()) for col in cols_standardize]\n", "leave = [(col, None) for col in cols_leave]\n", "\n", "x_mapper = DataFrameMapper(standardize + leave)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "x_train = x_mapper.fit_transform(df_train).astype('float32')\n", "x_val = x_mapper.transform(df_val).astype('float32')\n", "x_test = x_mapper.transform(df_test).astype('float32')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Label transforms\n", "\n", "The survival methods require individual label transforms, so we have included a proposed `label_transform` for each method.\n", "In this case `label_transform` is just a shorthand for the class `pycox.preprocessing.label_transforms.LabTransDiscreteTime`.\n", "\n", "The `PMF` is a discrete-time method, meaning it requires discretization of the event times to be applied to continuous-time data.\n", "We let `num_durations` define the size of this (equidistant) discretization grid, meaning our network will have `num_durations` output nodes." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "num_durations = 10\n", "labtrans = PMF.label_transform(num_durations)\n", "get_target = lambda df: (df['duration'].values, df['event'].values)\n", "y_train = labtrans.fit_transform(*get_target(df_train))\n", "y_val = labtrans.transform(*get_target(df_val))\n", "\n", "train = (x_train, y_train)\n", "val = (x_val, y_val)\n", "\n", "# We don't need to transform the test labels\n", "durations_test, events_test = get_target(df_test)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "pycox.preprocessing.label_transforms.LabTransDiscreteTime" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(labtrans)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Neural net\n", "\n", "We make a neural net with `torch`.\n", "For simple network structures, we can use the `MLPVanilla` provided by `torchtuples`.\n", "For building more advanced network architectures, see for example [the tutorials by PyTroch](https://pytorch.org/tutorials/).\n", "\n", "The following net is an MLP with two hidden layers (with 32 nodes each), ReLU activations, and `out_features` output nodes.\n", "We also have batch normalization and dropout between the layers." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "in_features = x_train.shape[1]\n", "num_nodes = [32, 32]\n", "out_features = labtrans.out_features\n", "batch_norm = True\n", "dropout = 0.1\n", "\n", "net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you instead want to build this network with `torch` you can uncomment the following code.\n", "It is essentially equivalent to the `MLPVanilla`, but without the `torch.nn.init.kaiming_normal_` weight initialization." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# net = torch.nn.Sequential(\n", "# torch.nn.Linear(in_features, 32),\n", "# torch.nn.ReLU(),\n", "# torch.nn.BatchNorm1d(32),\n", "# torch.nn.Dropout(0.1),\n", " \n", "# torch.nn.Linear(32, 32),\n", "# torch.nn.ReLU(),\n", "# torch.nn.BatchNorm1d(32),\n", "# torch.nn.Dropout(0.1),\n", " \n", "# torch.nn.Linear(32, out_features)\n", "# )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training the model\n", "\n", "To train the model we need to define an optimizer. You can choose any `torch.optim` optimizer, but here we instead use one from `tt.optim` as it has some added functionality.\n", "We use the `Adam` optimizer, but instead of choosing a learning rate, we will use the scheme proposed by [Smith 2017](https://arxiv.org/pdf/1506.01186.pdf) to find a suitable learning rate with `model.lr_finder`. See [this post](https://towardsdatascience.com/finding-good-learning-rate-and-the-one-cycle-policy-7159fe1db5d6) for an explanation.\n", "\n", "We also set `duration_index` which connects the output nodes of the network the the discretization times. This is only useful for prediction and does not affect the training procedure." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "model = PMF(net, tt.optim.Adam, duration_index=labtrans.cuts)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "batch_size = 256\n", "lr_finder = model.lr_finder(x_train, y_train, batch_size, tolerance=4)\n", "_ = lr_finder.plot()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.04229242874389523" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lr_finder.get_best_lr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Often, this learning rate is a little high, so we instead set it manually to 0.01" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "model.optimizer.set_lr(0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We include the `EarlyStopping` callback to stop training when the validation loss stops improving. After training, this callback will also load the best performing model in terms of validation loss." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0:\t[0s / 0s],\t\ttrain_loss: 1.6639,\tval_loss: 1.4607\n", "1:\t[0s / 0s],\t\ttrain_loss: 1.5303,\tval_loss: 1.4279\n", "2:\t[0s / 0s],\t\ttrain_loss: 1.4937,\tval_loss: 1.4188\n", "3:\t[0s / 0s],\t\ttrain_loss: 1.4411,\tval_loss: 1.4113\n", "4:\t[0s / 0s],\t\ttrain_loss: 1.4168,\tval_loss: 1.3912\n", "5:\t[0s / 0s],\t\ttrain_loss: 1.3792,\tval_loss: 1.3784\n", "6:\t[0s / 0s],\t\ttrain_loss: 1.3632,\tval_loss: 1.3795\n", "7:\t[0s / 0s],\t\ttrain_loss: 1.3575,\tval_loss: 1.3768\n", "8:\t[0s / 0s],\t\ttrain_loss: 1.3495,\tval_loss: 1.3730\n", "9:\t[0s / 0s],\t\ttrain_loss: 1.3305,\tval_loss: 1.3724\n", "10:\t[0s / 0s],\t\ttrain_loss: 1.3109,\tval_loss: 1.3740\n", "11:\t[0s / 0s],\t\ttrain_loss: 1.3149,\tval_loss: 1.3700\n", "12:\t[0s / 0s],\t\ttrain_loss: 1.3116,\tval_loss: 1.3695\n", "13:\t[0s / 0s],\t\ttrain_loss: 1.2873,\tval_loss: 1.3767\n", "14:\t[0s / 0s],\t\ttrain_loss: 1.2961,\tval_loss: 1.3831\n", "15:\t[0s / 0s],\t\ttrain_loss: 1.2934,\tval_loss: 1.3877\n", "16:\t[0s / 0s],\t\ttrain_loss: 1.2987,\tval_loss: 1.3941\n", "17:\t[0s / 0s],\t\ttrain_loss: 1.2882,\tval_loss: 1.3949\n", "18:\t[0s / 0s],\t\ttrain_loss: 1.2891,\tval_loss: 1.3841\n", "19:\t[0s / 0s],\t\ttrain_loss: 1.2712,\tval_loss: 1.3858\n", "20:\t[0s / 0s],\t\ttrain_loss: 1.2582,\tval_loss: 1.3860\n", "21:\t[0s / 0s],\t\ttrain_loss: 1.2768,\tval_loss: 1.3951\n", "22:\t[0s / 0s],\t\ttrain_loss: 1.2572,\tval_loss: 1.3902\n" ] } ], "source": [ "epochs = 100\n", "callbacks = [tt.callbacks.EarlyStopping()]\n", "log = model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_ = log.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prediction\n", "\n", "For evaluation we first need to obtain survival estimates for the test set.\n", "This can be done with `model.predict_surv` which returns an array of survival estimates, or with `model.predict_surv_df` which returns the survival estimates as a dataframe." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "surv = model.predict_surv_df(x_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can plot the survival estimates for the first 5 individuals.\n", "Note that the time scale is correct because we have set `model.duration_index` to be the grid points.\n", "We have, however, only defined the survival estimates at the 10 times in our discretization grid, so, the survival estimates is a step function" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "surv.iloc[:, :5].plot(drawstyle='steps-post')\n", "plt.ylabel('S(t | x)')\n", "_ = plt.xlabel('Time')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is, therefore, often beneficial to interpolate the survival estimates, see [this paper](https://arxiv.org/abs/1910.06724) for a discussion.\n", "Linear interpolation (constant density interpolation) can be performed with the `interpolate` method. We also need to choose how many points we want to replace each grid point with. Her we will use 10." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "surv = model.interpolate(10).predict_surv_df(x_test)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "surv.iloc[:, :5].plot(drawstyle='steps-post')\n", "plt.ylabel('S(t | x)')\n", "_ = plt.xlabel('Time')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation\n", "\n", "The `EvalSurv` class contains some useful evaluation criteria for time-to-event prediction.\n", "We set `censor_surv = 'km'` to state that we want to use Kaplan-Meier for estimating the censoring distribution.\n" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Concordance\n", "\n", "We start with the event-time concordance by [Antolini et al. 2005](https://onlinelibrary.wiley.com/doi/10.1002/sim.2427)." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.6523591504514816" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev.concordance_td('antolini')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Brier Score\n", "\n", "We can plot the the [IPCW Brier score](https://onlinelibrary.wiley.com/doi/abs/10.1002/%28SICI%291097-0258%2819990915/30%2918%3A17/18%3C2529%3A%3AAID-SIM274%3E3.0.CO%3B2-5) for a given set of times.\n", "Here we just use 100 time-points between the min and max duration in the test set.\n", "Note that the score becomes unstable for the highest times. It is therefore common to disregard the rightmost part of the graph." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)\n", "ev.brier_score(time_grid).plot()\n", "plt.ylabel('Brier score')\n", "_ = plt.xlabel('Time')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Negative binomial log-likelihood\n", "\n", "In a similar manner, we can plot the the [IPCW negative binomial log-likelihood](https://onlinelibrary.wiley.com/doi/abs/10.1002/%28SICI%291097-0258%2819990915/30%2918%3A17/18%3C2529%3A%3AAID-SIM274%3E3.0.CO%3B2-5)." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ev.nbll(time_grid).plot()\n", "plt.ylabel('NBLL')\n", "_ = plt.xlabel('Time')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Integrated scores\n", "\n", "The two time-dependent scores above can be integrated over time to produce a single score [Graf et al. 1999](https://onlinelibrary.wiley.com/doi/abs/10.1002/%28SICI%291097-0258%2819990915/30%2918%3A17/18%3C2529%3A%3AAID-SIM274%3E3.0.CO%3B2-5). In practice this is done by numerical integration over a defined `time_grid`." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.16816426632487566" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev.integrated_brier_score(time_grid) " ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.4969224495850561" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev.integrated_nbll(time_grid) " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }