{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MTLR (N-MTLR) Example \n", "\n", "In this notebook, we will present a simple example of the neural network version of the `MTLR` method described in \n", "[this](https://papers.nips.cc/paper/4210-learning-patient-specific-cancer-survival-distributions-as-a-sequence-of-dependent-regressors) and [this paper](https://arxiv.org/pdf/1801.05512.pdf).\n", "\n", "For a more verbose introduction to `pycox` see [this notebook](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/01_introduction.ipynb)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# For preprocessing\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn_pandas import DataFrameMapper \n", "\n", "import torch # For building the networks \n", "import torchtuples as tt # Some useful functions\n", "\n", "from pycox.datasets import metabric\n", "from pycox.models import MTLR\n", "from pycox.evaluation import EvalSurv" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "## Uncomment to install `sklearn-pandas`\n", "# ! pip install sklearn-pandas" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "np.random.seed(1234)\n", "_ = torch.manual_seed(123)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We load the METABRIC data set as a pandas DataFrame and split the data in in train, test and validation.\n", "\n", "The `duration` column gives the observed times and the `event` column contains indicators of whether the observation is an event (1) or a censored observation (0)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "df_train = metabric.read_df()\n", "df_test = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_test.index)\n", "df_val = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_val.index)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | x0 | \n", "x1 | \n", "x2 | \n", "x3 | \n", "x4 | \n", "x5 | \n", "x6 | \n", "x7 | \n", "x8 | \n", "duration | \n", "event | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "5.603834 | \n", "7.811392 | \n", "10.797988 | \n", "5.967607 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "56.840000 | \n", "99.333336 | \n", "0 | \n", "
1 | \n", "5.284882 | \n", "9.581043 | \n", "10.204620 | \n", "5.664970 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "85.940002 | \n", "95.733330 | \n", "1 | \n", "
3 | \n", "6.654017 | \n", "5.341846 | \n", "8.646379 | \n", "5.655888 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "66.910004 | \n", "239.300003 | \n", "0 | \n", "
4 | \n", "5.456747 | \n", "5.339741 | \n", "10.555724 | \n", "6.008429 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "67.849998 | \n", "56.933334 | \n", "1 | \n", "
5 | \n", "5.425826 | \n", "6.331182 | \n", "10.455145 | \n", "5.749053 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "70.519997 | \n", "123.533333 | \n", "0 | \n", "