{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction to the pycox package\n", "\n", "\n", "In this notebook we introduce the use of `pycox` through an example dataset.\n", "We illustrate the procedure with the `LogisticHazard` method ([paper_link](https://arxiv.org/abs/1910.06724)), but we this can easily be replaced by for example `PMF`, `MTLR` or `DeepHitSingle`.\n", "\n", "In the following we will:\n", "\n", "- Load the METABRIC survival dataset.\n", "- Process the event labels so the they work with our methods.\n", "- Create a [PyTorch](https://pytorch.org) neural network.\n", "- Fit the model.\n", "- Evaluate the predictive performance using the concordance, Brier score, and negative binomial log-likelihood.\n", "\n", "While some knowledge of the [PyTorch](https://pytorch.org) framework is preferable, it is not required for the use of simple neural networks.\n", "For building more advanced network architectures, however, we would recommend looking at [the PyTorch tutorials](https://pytorch.org/tutorials/)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports\n", "\n", "You need `sklearn-pandas` which can be installed by uncommenting the following block" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# ! pip install sklearn-pandas" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# For preprocessing\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn_pandas import DataFrameMapper " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`pycox` is built on top of [PyTorch](https://pytorch.org) and [torchtuples](https://github.com/havakv/torchtuples), where the latter is just a simple way of training neural nets with less boilerplate code." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import torch # For building the networks \n", "import torchtuples as tt # Some useful functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We import the `metabric` dataset, the `LogisticHazard` method ([paper_link](https://arxiv.org/abs/1910.06724)) also known as [Nnet-survival](https://peerj.com/articles/6257/), and `EvalSurv` which simplifies the evaluation procedure at the end.\n", "\n", "You can alternatively replace `LogisticHazard` with, for example, `PMF` or `DeepHitSingle`, which should both work in this notebook." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from pycox.datasets import metabric\n", "from pycox.models import LogisticHazard\n", "# from pycox.models import PMF\n", "# from pycox.models import DeepHitSingle\n", "from pycox.evaluation import EvalSurv" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# We also set some seeds to make this reproducable.\n", "# Note that on gpu, there is still some randomness.\n", "np.random.seed(1234)\n", "_ = torch.manual_seed(123)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We load the METABRIC data set as a pandas DataFrame and split the data in in train, test and validation.\n", "\n", "The `duration` column gives the observed times and the `event` column contains indicators of whether the observation is an event (1) or a censored observation (0)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "df_train = metabric.read_df()\n", "df_test = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_test.index)\n", "df_val = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_val.index)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | x0 | \n", "x1 | \n", "x2 | \n", "x3 | \n", "x4 | \n", "x5 | \n", "x6 | \n", "x7 | \n", "x8 | \n", "duration | \n", "event | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "5.603834 | \n", "7.811392 | \n", "10.797988 | \n", "5.967607 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "56.840000 | \n", "99.333336 | \n", "0 | \n", "
1 | \n", "5.284882 | \n", "9.581043 | \n", "10.204620 | \n", "5.664970 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "85.940002 | \n", "95.733330 | \n", "1 | \n", "
3 | \n", "6.654017 | \n", "5.341846 | \n", "8.646379 | \n", "5.655888 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "66.910004 | \n", "239.300003 | \n", "0 | \n", "
4 | \n", "5.456747 | \n", "5.339741 | \n", "10.555724 | \n", "6.008429 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "67.849998 | \n", "56.933334 | \n", "1 | \n", "
5 | \n", "5.425826 | \n", "6.331182 | \n", "10.455145 | \n", "5.749053 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "70.519997 | \n", "123.533333 | \n", "0 | \n", "