{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 02 Continuing the Introduction\n", "\n", "\n", "For a simpler introduction to `pycox` see the [01_introduction.ipynb](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/01_introduction.ipynb) instead.\n", "\n", "In this notebook we will show some more functionality of the `pycox` package in addition to more functionality of the `torchtuples` package.\n", "We will continue with the `LogisticHazard` method for simplicity.\n", "\n", "We will in the following: \n", "- use SUPPORT as an example dataset,\n", "- use entity embeddings for categorical variables,\n", "- use the [AdamWR optimizer](https://arxiv.org/pdf/1711.05101.pdf) with cyclical learning rates,\n", "- use the scheme proposed by [Smith 2017](https://arxiv.org/pdf/1506.01186.pdf) to find a suitable learning rate." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn_pandas import DataFrameMapper\n", "\n", "import torch\n", "import torchtuples as tt\n", "\n", "from pycox.datasets import support\n", "from pycox.preprocessing.feature_transforms import OrderedCategoricalLong\n", "from pycox.models import LogisticHazard\n", "from pycox.evaluation import EvalSurv" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "np.random.seed(123456)\n", "_ = torch.manual_seed(123456)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We load the SUPPORT data set and split in train, test and validation." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "df_train = support.read_df()\n", "df_test = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_test.index)\n", "df_val = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_val.index)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x0x1x2x3x4x5x6x7x8x9x10x11x12x13durationevent
082.7099611.02.01.00.00.00.0160.055.016.038.195309142.019.0000001.09985430.01
179.6609501.00.01.00.00.01.054.067.016.038.000000142.010.0000000.8999021527.00
471.7949830.01.01.00.00.00.065.0135.040.038.593750146.00.0999910.3999637.01
549.9329800.01.01.00.00.00.070.0105.033.038.195309127.05.2998051.19995150.01
662.9429890.05.02.01.00.01.0116.0130.035.038.195309133.014.0996090.799927381.00
\n", "
" ], "text/plain": [ " x0 x1 x2 x3 x4 x5 x6 x7 x8 x9 x10 \\\n", "0 82.709961 1.0 2.0 1.0 0.0 0.0 0.0 160.0 55.0 16.0 38.195309 \n", "1 79.660950 1.0 0.0 1.0 0.0 0.0 1.0 54.0 67.0 16.0 38.000000 \n", "4 71.794983 0.0 1.0 1.0 0.0 0.0 0.0 65.0 135.0 40.0 38.593750 \n", "5 49.932980 0.0 1.0 1.0 0.0 0.0 0.0 70.0 105.0 33.0 38.195309 \n", "6 62.942989 0.0 5.0 2.0 1.0 0.0 1.0 116.0 130.0 35.0 38.195309 \n", "\n", " x11 x12 x13 duration event \n", "0 142.0 19.000000 1.099854 30.0 1 \n", "1 142.0 10.000000 0.899902 1527.0 0 \n", "4 146.0 0.099991 0.399963 7.0 1 \n", "5 127.0 5.299805 1.199951 50.0 1 \n", "6 133.0 14.099609 0.799927 381.0 0 " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature transforms\n", "We have 14 covariates, in addition to the durations and event indicators.\n", "\n", "We will standardize the 8 numerical covariates, and leave the 3 binary variables unaltered. \n", "\n", "We will use entity embedding for the 3 categorical variables `x2`, `x3`, and `x6`.\n", "Hence, they are transformed to `int64` integers representing the categories. The category 0 is reserved for `None` and very small categories that are set to `None`. \n", "We use the `OrderedCategoricalLong` transform to achieve this." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "cols_standardize = ['x0', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13']\n", "cols_leave = ['x1', 'x4', 'x5']\n", "cols_categorical = ['x2', 'x3', 'x6']\n", "\n", "standardize = [([col], StandardScaler()) for col in cols_standardize]\n", "leave = [(col, None) for col in cols_leave]\n", "categorical = [(col, OrderedCategoricalLong()) for col in cols_categorical]\n", "\n", "x_mapper_float = DataFrameMapper(standardize + leave)\n", "x_mapper_long = DataFrameMapper(categorical) # we need a separate mapper to ensure the data type 'int64'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "x_fit_transform = lambda df: tt.tuplefy(x_mapper_float.fit_transform(df), x_mapper_long.fit_transform(df))\n", "x_transform = lambda df: tt.tuplefy(x_mapper_float.transform(df), x_mapper_long.transform(df))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "x_train = x_fit_transform(df_train)\n", "x_val = x_transform(df_val)\n", "x_test = x_transform(df_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the `x_fit_transform` and `x_transform` we have wrapped the results with `tt.tuplefy`. The result is a `TupleTree` which equivalent to a regular `tuple`, but with some added functionality that makes it easier to investigate the data.\n", "\n", "From the code below we see that `x_train` is a `tuple` with two arrays representing the transformed numerical and categorical covariates." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(numpy.ndarray, numpy.ndarray)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_train.types()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((5678, 11), (5678, 3))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_train.shapes()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(dtype('float32'), dtype('int64'))" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_train.dtypes()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Label transforms\n", "\n", "`LogisticHazard` is a discrete-time method, meaning it requires discretization of the event times to be applied to continuous-time data.\n", "We let `num_durations` define the size of the discretization grid, but we will now let the **quantiles** of the estimated event-time distribution define the grid, as explained in [this paper](https://arxiv.org/abs/1910.06724)." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "num_durations = 20\n", "scheme = 'quantiles'\n", "labtrans = LogisticHazard.label_transform(num_durations, scheme)\n", "\n", "get_target = lambda df: (df['duration'].values, df['event'].values)\n", "y_train = labtrans.fit_transform(*get_target(df_train))\n", "y_val = labtrans.transform(*get_target(df_val))\n", "durations_test, events_test = get_target(df_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the discretization grid is far from equidistant. The idea behind the quantile discretization is that the grid is finer where there are many events and coarser where there are few." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0., 3., 5., 8., 11., 16., 22., 30., 43.,\n", " 64., 92., 129., 192., 256., 368., 522., 739., 1005.,\n", " 1348., 2029.], dtype=float32)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labtrans.cuts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can visualize the grid together with the Kaplan-Meier estaimtes to see this clearly. " ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from pycox.utils import kaplan_meier\n", "plt.vlines(labtrans.cuts, 0, 1, colors='gray', linestyles=\"--\", label='Discretization Grid')\n", "kaplan_meier(*get_target(df_train)).plot(label='Kaplan-Meier')\n", "plt.ylabel('S(t)')\n", "plt.legend()\n", "_ = plt.xlabel('Time')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Investigating the data\n", "\n", "Next we collect the training and validation data with `tt.tuplefy` in a nested tuple to make it simpler to inspect them" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "train = tt.tuplefy(x_train, y_train)\n", "val = tt.tuplefy(x_val, y_val)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((numpy.ndarray, numpy.ndarray), (numpy.ndarray, numpy.ndarray))" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.types()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(((5678, 11), (5678, 3)), ((5678,), (5678,)))" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.shapes()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((dtype('float32'), dtype('int64')), (dtype('int64'), dtype('float32')))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train.dtypes()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now alternatively transform the data to torch tensors with `to_tensor`. This is not useful for this notebook, but can be very handy for development." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "train_tensor = train.to_tensor()" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((torch.Tensor, torch.Tensor), (torch.Tensor, torch.Tensor))" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_tensor.types()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((torch.Size([5678, 11]), torch.Size([5678, 3])),\n", " (torch.Size([5678]), torch.Size([5678])))" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_tensor.shapes()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((torch.float32, torch.int64), (torch.int64, torch.float32))" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_tensor.dtypes()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "del train_tensor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Neural net\n", "\n", "We want our network to take two input arguments, one for the numerical covariates and one for the categorical covariates such that we can apply entity embedding.\n", "\n", "The `tt.practical.MixedInputMLP` does exactly this for us. If first applies entity embeddings to the categorical covariates and then concatenate the embeddings with the numerical covariates." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we need to define the embedding sizes. Here we will let the embedding dimensions be half the size of the number of categories.\n", "This means that each category is represented by a vector that is half the size of the number of categories." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "num_embeddings = x_train[1].max(0) + 1\n", "embedding_dims = num_embeddings // 2" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([8, 7, 4]), array([4, 3, 2]))" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "num_embeddings, embedding_dims" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We then define a net with four hidden layers, each of size 32, and include batch normalization and dropout between each layer." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "in_features = x_train[0].shape[1]\n", "out_features = labtrans.out_features\n", "num_nodes = [32, 32, 32, 32]\n", "batch_norm = True\n", "dropout = 0.2\n", "\n", "net = tt.practical.MixedInputMLP(in_features, num_embeddings, embedding_dims,\n", " num_nodes, out_features, batch_norm, dropout)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MixedInputMLP(\n", " (embeddings): EntityEmbeddings(\n", " (embeddings): ModuleList(\n", " (0): Embedding(8, 4)\n", " (1): Embedding(7, 3)\n", " (2): Embedding(4, 2)\n", " )\n", " )\n", " (mlp): MLPVanilla(\n", " (net): Sequential(\n", " (0): DenseVanillaBlock(\n", " (linear): Linear(in_features=20, out_features=32, bias=True)\n", " (activation): ReLU()\n", " (batch_norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " )\n", " (1): DenseVanillaBlock(\n", " (linear): Linear(in_features=32, out_features=32, bias=True)\n", " (activation): ReLU()\n", " (batch_norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " )\n", " (2): DenseVanillaBlock(\n", " (linear): Linear(in_features=32, out_features=32, bias=True)\n", " (activation): ReLU()\n", " (batch_norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " )\n", " (3): DenseVanillaBlock(\n", " (linear): Linear(in_features=32, out_features=32, bias=True)\n", " (activation): ReLU()\n", " (batch_norm): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (dropout): Dropout(p=0.2, inplace=False)\n", " )\n", " (4): Linear(in_features=32, out_features=20, bias=True)\n", " )\n", " )\n", ")" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fitting model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We want to use the cyclic [AdamWR optimizer](https://arxiv.org/abs/1711.05101) where we multiply the learning rate with 0.8 and double then cycle length after every cycle.\n", "Also, we add [decoupled weight decay](https://arxiv.org/abs/1711.05101) for regularization." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "optimizer = tt.optim.AdamWR(decoupled_weight_decay=0.01, cycle_eta_multiplier=0.8,\n", " cycle_multiplier=2)\n", "model = LogisticHazard(net, optimizer, duration_index=labtrans.cuts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use `lr_finder` to find a suitable initial learning rate\n", "with the scheme proposed by [Smith 2017](https://arxiv.org/pdf/1506.01186.pdf).\n", "See [this post](https://towardsdatascience.com/finding-good-learning-rate-and-the-one-cycle-policy-7159fe1db5d6) for an explanation.\n", "\n", "The `tolerance` argument just defines the largest loss allowed before terminating the procedure. Is serves mostly a visual purpose." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "batch_size = 256\n", "lrfind = model.lr_finder(x_train, y_train, batch_size, tolerance=50)\n", "_ = lrfind.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that this sets the optimizer learning rate in our model to the same value as that of `get_best_lr`." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.08902150854450441" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.optimizer.param_groups[0]['lr']" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.08902150854450441" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lrfind.get_best_lr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, we have found that `get_best_lr` sometimes gives a little high learning rate, so we instead set it to" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "model.optimizer.set_lr(0.02)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For early stopping, we will use `EarlyStoppingCycle` which work in the same manner as `EarlyStopping` but will stop at the end of **the cycle** if the current best model was not obtained in the current cycle." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "epochs = 512\n", "callbacks = [tt.cb.EarlyStoppingCycle()]\n", "verbose = False # set to True if you want printout" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 47.1 s, sys: 1.74 s, total: 48.9 s\n", "Wall time: 23.3 s\n" ] } ], "source": [ "%%time\n", "log = model.fit(x_train, y_train, batch_size, epochs, callbacks, verbose,\n", " val_data=val)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_ = log.to_pandas().iloc[1:].plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now plot the learning rates used through the training with following piece of code" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "lrs = model.optimizer.lr_scheduler.to_pandas() * model.optimizer.param_groups[0]['initial_lr']\n", "lrs.plot()\n", "plt.grid(linestyle='--')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation\n", "\n", "The `LogisticHazard` method has two implemented interpolation schemes: the constant density interpolation (default) and constant hazard interpolation. See [this paper](https://arxiv.org/abs/1910.06724) for details." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "surv_cdi = model.interpolate(100).predict_surv_df(x_test)\n", "surv_chi = model.interpolate(100, 'const_hazard').predict_surv_df(x_test)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "ev_cdi = EvalSurv(surv_cdi, durations_test, events_test, censor_surv='km')\n", "ev_chi = EvalSurv(surv_chi, durations_test, events_test, censor_surv='km')" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.6305366659876536, 0.630559781328872)" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev_cdi.concordance_td(), ev_chi.concordance_td()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0.19502317626753513, 0.20115977074187133)" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)\n", "ev_cdi.integrated_brier_score(time_grid), ev_chi.integrated_brier_score(time_grid)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ev_cdi.brier_score(time_grid).rename('CDI').plot()\n", "ev_chi.brier_score(time_grid).rename('CHI').plot()\n", "plt.legend()\n", "plt.ylabel('Brier score')\n", "_ = plt.xlabel('Time')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see from the figures that, in this case, the constant hazard interpolated estimates are not as good as the constant density interpolated estimates." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Investigate what's going on\n", "\n", "The instabilities at the end of the plot above is a consequence for our discretization scheme.\n", "\n", "From `labtrans` we can get the last two discretization points" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1348., 2029.], dtype=float32)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labtrans.cuts[-2:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, because the censoring times in this interval are rounded down while events times are rounded up, we get an unnatural proportion of events at the final time point." ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", " 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", " 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n", " 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1.],\n", " dtype=float32)" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = train.iloc[train[1][0] == train[1][0].max()]\n", "data[1][1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see the almost all training individuals here have an event!" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.93939394" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data[1][1].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "While the true event proportion of individuals that survival past 1500 is almost zero" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.061224489795918366" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.loc[lambda x: x['duration'] > 1500]['event'].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is one of the dangers with discretization, and one should take caution. The simple solution would be to ensure that there are more discretization point in this interval, or simply not evaluate past the time 1348.\n", "\n", "If we take a look at individuals in the test set that are censored some time after 1500, we see that the survival estimates are not very appropriate." ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "test = tt.tuplefy(x_test, (durations_test, events_test))\n", "data = test.iloc[(durations_test > 1500) & (events_test == 0)]\n", "n = data[0][0].shape[0]" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "idx = np.random.choice(n, 6)\n", "fig, axs = plt.subplots(2, 3, figsize=(12, 6), sharex=True, sharey=True)\n", "for i, ax in zip(idx, axs.flat):\n", " x, (t, _) = data.iloc[[i]]\n", " surv = model.predict_surv_df(x)\n", " surv[0].rename('Survival estimate').plot(ax=ax)\n", " ax.vlines(t, 0, 1, colors='red', linestyles=\"--\",\n", " label='censoring time')\n", " ax.grid(linestyle='--')\n", " ax.legend()\n", " ax.set_ylabel('S(t | x)')\n", " _ = ax.set_xlabel('Time')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Next\n", "\n", "You can now look at other examples of survival methods in the [examples folder](https://nbviewer.jupyter.org/github/havakv/pycox/tree/master/examples).\n", "Or, alternatively take a look at\n", "\n", "- other network architectures that combine autoencoders and survival networks in the notebook [03_network_architectures.ipynb](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/03_network_architectures.ipynb).\n", "- working with DataLoaders and convolutional networks in the notebook [04_mnist_dataloaders_cnn.ipynb](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/04_mnist_dataloaders_cnn.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }