{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 03 Network Architectures \n", "\n", "In this notebook, we investigate how we can work with more complicated network architectures than the simple MLP's.\n", "\n", "The example will be a Logistic-Hazard which use the encoded (latent) variables of an autoencoder as covariates.\n", "This means the network will have two heads, one for the autoencoder and the other for the Logistic-Hazard.\n", "\n", "To approach this task we can use the `LogisticHazard`, but we need to define the network structure that combines the survival net with an autoencoder.\n", "Also, we need to define a loss function that combines the loss of the `LogisticHazard` with that of the autoencoder." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# For preprocessing\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn_pandas import DataFrameMapper \n", "\n", "import torch # For building the networks \n", "from torch import nn\n", "import torch.nn.functional as F\n", "import torchtuples as tt # Some useful functions\n", "\n", "from pycox.datasets import metabric\n", "from pycox.models import LogisticHazard\n", "from pycox.models.loss import NLLLogistiHazardLoss\n", "from pycox.evaluation import EvalSurv" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "## Uncomment to install `sklearn-pandas`\n", "# ! pip install sklearn-pandas" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "np.random.seed(1234)\n", "_ = torch.manual_seed(123)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We load the METABRIC data set as a pandas DataFrame and split the data in in train, test and validation.\n", "\n", "The `duration` column gives the observed times and the `event` column contains indicators of whether the observation is an event (1) or a censored observation (0)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "df_train = metabric.read_df()\n", "df_test = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_test.index)\n", "df_val = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_val.index)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x0x1x2x3x4x5x6x7x8durationevent
05.6038347.81139210.7979885.9676071.01.00.01.056.84000099.3333360
15.2848829.58104310.2046205.6649701.00.00.01.085.94000295.7333301
36.6540175.3418468.6463795.6558880.00.00.00.066.910004239.3000030
45.4567475.33974110.5557246.0084291.00.00.01.067.84999856.9333341
55.4258266.33118210.4551455.7490531.01.00.01.070.519997123.5333330
\n", "
" ], "text/plain": [ " x0 x1 x2 x3 x4 x5 x6 x7 x8 \\\n", "0 5.603834 7.811392 10.797988 5.967607 1.0 1.0 0.0 1.0 56.840000 \n", "1 5.284882 9.581043 10.204620 5.664970 1.0 0.0 0.0 1.0 85.940002 \n", "3 6.654017 5.341846 8.646379 5.655888 0.0 0.0 0.0 0.0 66.910004 \n", "4 5.456747 5.339741 10.555724 6.008429 1.0 0.0 0.0 1.0 67.849998 \n", "5 5.425826 6.331182 10.455145 5.749053 1.0 1.0 0.0 1.0 70.519997 \n", "\n", " duration event \n", "0 99.333336 0 \n", "1 95.733330 1 \n", "3 239.300003 0 \n", "4 56.933334 1 \n", "5 123.533333 0 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature transforms\n", "\n", "The METABRIC dataset has 9 covariates: `x0, ..., x8`.\n", "We will standardize the 5 numerical covariates, and leave the binary covariates as is.\n", "Note that PyTorch require variables of type `'float32'`.\n", "\n", "We like using the `sklearn_pandas.DataFrameMapper` to make feature mappers." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "cols_standardize = ['x0', 'x1', 'x2', 'x3', 'x8']\n", "cols_leave = ['x4', 'x5', 'x6', 'x7']\n", "\n", "standardize = [([col], StandardScaler()) for col in cols_standardize]\n", "leave = [(col, None) for col in cols_leave]\n", "\n", "x_mapper = DataFrameMapper(standardize + leave)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "x_train = x_mapper.fit_transform(df_train).astype('float32')\n", "x_val = x_mapper.transform(df_val).astype('float32')\n", "x_test = x_mapper.transform(df_test).astype('float32')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Label transforms\n", "\n", "The `LogisticHazard` is a discrete-time method, meaning it requires discretization of the event times to be applied to continuous-time data.\n", "We let `num_durations` define the size of this (equidistant) discretization grid, meaning our network will have `num_durations` output nodes.\n", "\n", "Note that we have two sets of targets.\n", "The first, `y_train_surv`, is for the survival and contains the labels `(idx_durations, events)`.\n", "The second is for the autoencoder (which is just the input covariates `x_train`).\n", "This is important to note, as it will define the call arguments of our loss function." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "num_durations = 10\n", "labtrans = LogisticHazard.label_transform(num_durations)\n", "get_target = lambda df: (df['duration'].values, df['event'].values)\n", "y_train_surv = labtrans.fit_transform(*get_target(df_train))\n", "y_val_surv = labtrans.transform(*get_target(df_val))\n", "\n", "train = tt.tuplefy(x_train, (y_train_surv, x_train))\n", "val = tt.tuplefy(x_val, (y_val_surv, x_val))\n", "\n", "# We don't need to transform the test labels\n", "durations_test, events_test = get_target(df_test)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([2, 3, 6, ..., 1, 5, 3]),\n", " array([0., 1., 0., ..., 1., 0., 0.], dtype=float32))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_train_surv" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that we have combined `idx_durations` and `events` intro the tuple `y_train_surv`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Neural Net\n", "\n", "We first define our network `NetAESurv` with contains the encoder, decoder, and the survival part. \n", "We also include a `predict` method that saves computations by not computing the decoder." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class NetAESurv(nn.Module):\n", " def __init__(self, in_features, encoded_features, out_features):\n", " super().__init__()\n", " self.encoder = nn.Sequential(\n", " nn.Linear(in_features, 32), nn.ReLU(),\n", " nn.Linear(32, 16), nn.ReLU(),\n", " nn.Linear(16, encoded_features),\n", " )\n", " self.decoder = nn.Sequential(\n", " nn.Linear(encoded_features, 16), nn.ReLU(),\n", " nn.Linear(16, 32), nn.ReLU(),\n", " nn.Linear(32, in_features),\n", " )\n", " self.surv_net = nn.Sequential(\n", " nn.Linear(encoded_features, 16), nn.ReLU(),\n", " nn.Linear(16, 16), nn.ReLU(),\n", " nn.Linear(16, out_features),\n", " )\n", "\n", " def forward(self, input):\n", " encoded = self.encoder(input)\n", " decoded = self.decoder(encoded)\n", " phi = self.surv_net(encoded)\n", " return phi, decoded\n", "\n", " def predict(self, input):\n", " # Will be used by model.predict later.\n", " # As this only has the survival output, \n", " # we don't have to change LogisticHazard.\n", " encoded = self.encoder(input)\n", " return self.surv_net(encoded)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All methods in `pycox` are built on `torchtuples.Model`.\n", "The way `torchtuples` is made, a call to `model.predict` will use `model.net.predict` if it is defined, and use `model.net.forward` if `predict` is not defined.\n", "As all the survival predictions in `pycox` are based on the `model.predict` method, a call to `model.predict_surv` will use the `NetAESurv.predict` method instead of the `NetAESurv.forward` method.\n", "This way, the `model.predict_surv` methods will work as before!" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "in_features = x_train.shape[1]\n", "encoded_features = 4\n", "out_features = labtrans.out_features\n", "net = NetAESurv(in_features, encoded_features, out_features)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The Loss\n", "\n", "We need to define a loss that combines the `LogisticHazard` loss with the loss of an autoencoder `MLELoss`.\n", "The `forward` method defines how the loss is called, and need to be defined in accordance with how the data is structured.\n", "\n", "The first arguments need to be the output of the net (`phi` and `decoded` above), and the remainder of the arguments need to have the same structure as the tuple structure of the targets in your data set, that is `train[1]`." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((1, 1), 0)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train[1].levels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have structured our data such that `train[1] = (target_loghaz, target_ae)`, where `target_loghaz = (idx_durations, events)` and `target_ae` is just the input covariates.\n", "So we need a loss with the call signature `loss(phi, decoded, target_loghaz, target_ae)`.\n", "We, therefore, create the following loss function (note that it is created in the same manner as a torch network)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "class LossAELogHaz(nn.Module):\n", " def __init__(self, alpha):\n", " super().__init__()\n", " assert (alpha >= 0) and (alpha <= 1), 'Need `alpha` in [0, 1].'\n", " self.alpha = alpha\n", " self.loss_surv = NLLLogistiHazardLoss()\n", " self.loss_ae = nn.MSELoss()\n", " \n", " def forward(self, phi, decoded, target_loghaz, target_ae):\n", " idx_durations, events = target_loghaz\n", " loss_surv = self.loss_surv(phi, idx_durations, events)\n", " loss_ae = self.loss_ae(decoded, target_ae)\n", " return self.alpha * loss_surv + (1 - self.alpha) * loss_ae" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "loss = LossAELogHaz(0.6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The model\n", "\n", "We can now use the `LogisticHazard` model with the `net` and the `loss`." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "model = LogisticHazard(net, tt.optim.Adam(0.01), duration_index=labtrans.cuts, loss=loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Testing that it works\n", "\n", "We can not test that everything works as expected by considering a single batch of data. This can be done by using the `make_dataloader` method." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "dl = model.make_dataloader(train, batch_size=5, shuffle=False)\n", "batch = next(iter(dl))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([5, 9]), ((torch.Size([5]), torch.Size([5])), torch.Size([5, 9])))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch.shapes()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(0, ((2, 2), 1))" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch.levels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `compute_metrics` method is the \"brain\" of the `model` as it is responsible for computing the output of the network and calculating the loss.\n", "Here we also see the logic of how the loss function is called as we unpack the tuples `metric(*out, *target)`.\n", "To test that it works, we can call it with the batch" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "\u001b[0;31mSignature:\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mSource:\u001b[0m \n", " \u001b[0;32mdef\u001b[0m \u001b[0mcompute_metrics\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;34m\"\"\"Function for computing the loss and other metrics.\u001b[0m\n", "\u001b[0;34m \u001b[0m\n", "\u001b[0;34m Arguments:\u001b[0m\n", "\u001b[0;34m data {tensor or tuple} -- A batch of data. Typically the tuple `(input, target)`.\u001b[0m\n", "\u001b[0;34m\u001b[0m\n", "\u001b[0;34m Keyword Arguments:\u001b[0m\n", "\u001b[0;34m metrics {dict} -- A dictionary with metrics. If `None` use `self.metrics`. (default: {None})\u001b[0m\n", "\u001b[0;34m \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmetrics\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mmetrics\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Need to set `self.loss`.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_to_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mtarget\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_to_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtuplefy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", "\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetric\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mFile:\u001b[0m ~/packages/torchtuples/torchtuples/base.py\n", "\u001b[0;31mType:\u001b[0m method\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "??model.compute_metrics" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'loss': tensor(2.1274, grad_fn=)}" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.compute_metrics(batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alternatively, we can call the `score_in_batches` which computes the loss over the full data set. For larger data sets, this can be slow." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'loss': 2.1353516578674316}" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.score_in_batches(*train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For better monitoring, we add some metrics that corresponds to the `LogisticHazard` loss and the `MSELoss`, but have the same call structure as `LossAELogHaz`." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "metrics = dict(\n", " loss_surv = LossAELogHaz(1),\n", " loss_ae = LossAELogHaz(0)\n", ")\n", "callbacks = [tt.cb.EarlyStopping()]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now fit the model and plot the losses" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "batch_size = 256\n", "epochs = 100\n", "log = model.fit(*train, batch_size, epochs, callbacks, False, val_data=val, metrics=metrics)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "res = model.log.to_pandas()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
train_losstrain_loss_survtrain_loss_aeval_lossval_loss_survval_loss_ae
02.0127002.8604050.7411401.8183852.6159160.622087
11.6760862.3545240.6584291.3586141.8788660.578235
21.3070181.7641660.6212961.1510251.5659040.528707
31.1954681.6193110.5597021.0388171.4162640.472648
41.1184291.5288730.5027641.0499981.4597450.435376
\n", "
" ], "text/plain": [ " train_loss train_loss_surv train_loss_ae val_loss val_loss_surv \\\n", "0 2.012700 2.860405 0.741140 1.818385 2.615916 \n", "1 1.676086 2.354524 0.658429 1.358614 1.878866 \n", "2 1.307018 1.764166 0.621296 1.151025 1.565904 \n", "3 1.195468 1.619311 0.559702 1.038817 1.416264 \n", "4 1.118429 1.528873 0.502764 1.049998 1.459745 \n", "\n", " val_loss_ae \n", "0 0.622087 \n", "1 0.578235 \n", "2 0.528707 \n", "3 0.472648 \n", "4 0.435376 " ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res.head()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_ = res[['train_loss', 'val_loss']].plot()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_ = res[['train_loss_surv', 'val_loss_surv']].plot()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_ = res[['train_loss_ae', 'val_loss_ae']].plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Prediction\n", "\n", "For prediction, remember that `model.predict` use the `net.predict` method, and because we defined it as only the survival part, the `predict_surv_df` behave as before." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "surv = model.interpolate(10).predict_surv_df(x_test)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "surv.iloc[:, :5].plot(drawstyle='steps-post')\n", "plt.ylabel('S(t | x)')\n", "_ = plt.xlabel('Time')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Under you can see that the `model.predict` method gives an array out, as the `model.net.predict` only gives an array (or tensor). \n", "\n", "If we want predictions from the `net.forward` method, we can use `model.predict_net` instead." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-13.314146 , -3.232014 , -1.6640139 , -0.32413414,\n", " -0.17921695, -0.24341404, 0.09743847, 0.6742561 ,\n", " -0.08086838, 9.740675 ],\n", " [-12.070942 , -1.3327147 , -1.0963962 , -2.2161934 ,\n", " -2.0299942 , -2.0760274 , -1.5875834 , -0.576796 ,\n", " 0.6470416 , 8.179644 ]], dtype=float32)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict(x_test[:2])" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([[-13.314146 , -3.232014 , -1.6640139 , -0.32413414,\n", " -0.17921695, -0.24341404, 0.09743847, 0.6742561 ,\n", " -0.08086838, 9.740675 ],\n", " [-12.070942 , -1.3327147 , -1.0963962 , -2.2161934 ,\n", " -2.0299942 , -2.0760274 , -1.5875834 , -0.576796 ,\n", " 0.6470416 , 8.179644 ]], dtype=float32),\n", " array([[-0.87191707, 1.8376998 , 0.52234185, 1.4613888 , 0.8503939 ,\n", " 0.93256044, 0.82951564, 0.15673046, 1.0731695 ],\n", " [ 0.05569835, -0.8430163 , 2.3545551 , 0.93773973, -1.2269706 ,\n", " 0.38499445, 0.79142815, 0.6561861 , 0.33640665]],\n", " dtype=float32))" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_net(x_test[:2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also pass a function to the `predict` methods, so we can only get the survival part from `predict_net`." ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[-13.314146 , -3.232014 , -1.6640139 , -0.32413414,\n", " -0.17921695, -0.24341404, 0.09743847, 0.6742561 ,\n", " -0.08086838, 9.740675 ],\n", " [-12.070942 , -1.3327147 , -1.0963962 , -2.2161934 ,\n", " -2.0299942 , -2.0760274 , -1.5875834 , -0.576796 ,\n", " 0.6470416 , 8.179644 ]], dtype=float32)" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.predict_net(x_test[:2], func=lambda x: x[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation (as before)\n", "\n", "The `EvalSurv` class contains some useful evaluation criteria for time-to-event prediction.\n", "We set `censor_surv = 'km'` to state that we want to use Kaplan-Meier for estimating the censoring distribution.\n" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Concordance\n", "\n", "We start with the event-time concordance by [Antolini et al. 2005](https://onlinelibrary.wiley.com/doi/10.1002/sim.2427)." ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.6662427402602908" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev.concordance_td('antolini')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Brier Score\n", "\n", "We can plot the the [IPCW Brier score](https://onlinelibrary.wiley.com/doi/abs/10.1002/%28SICI%291097-0258%2819990915/30%2918%3A17/18%3C2529%3A%3AAID-SIM274%3E3.0.CO%3B2-5) for a given set of times.\n", "Here we just use 100 time-points between the min and max duration in the test set.\n", "Note that the score becomes unstable for the highest times. It is therefore common to disregard the rightmost part of the graph." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)\n", "ev.brier_score(time_grid).plot()\n", "plt.ylabel('Brier score')\n", "_ = plt.xlabel('Time')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Negative binomial log-likelihood\n", "\n", "In a similar manner, we can plot the the [IPCW negative binomial log-likelihood](https://onlinelibrary.wiley.com/doi/abs/10.1002/%28SICI%291097-0258%2819990915/30%2918%3A17/18%3C2529%3A%3AAID-SIM274%3E3.0.CO%3B2-5)." ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ev.nbll(time_grid).plot()\n", "plt.ylabel('NBLL')\n", "_ = plt.xlabel('Time')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Integrated scores\n", "\n", "The two time-dependent scores above can be integrated over time to produce a single score [Graf et al. 1999](https://onlinelibrary.wiley.com/doi/abs/10.1002/%28SICI%291097-0258%2819990915/30%2918%3A17/18%3C2529%3A%3AAID-SIM274%3E3.0.CO%3B2-5). In practice this is done by numerical integration over a defined `time_grid`." ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.16289642991716796" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev.integrated_brier_score(time_grid) " ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.4834288744418562" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ev.integrated_nbll(time_grid) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Next\n", "\n", "You can now look at other examples of survival methods in the [examples folder](https://nbviewer.jupyter.org/github/havakv/pycox/tree/master/examples).\n", "Or, alternatively take a look at\n", "\n", "- the more advanced training procedures in the notebook [02_introduction.ipynb](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/02_introduction.ipynb).\n", "- working with DataLoaders and convolutional networks in the notebook [04_mnist_dataloaders_cnn.ipynb](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/04_mnist_dataloaders_cnn.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }