{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Cox models: Introduction\n", "\n", "In this notebook we will train three Cox models with neural networks:\n", "\n", "- Cox CC: a proportional hazards model trainied with case-control sampling.\n", "- Cox DeepSurv: a proportional hazards model more or less identical to DeepSurv.\n", "- Cox-Time: a non-poportional case-control model. \n", "\n", "The Cox CC and DeepSurv models are very similar, with the Cox-Time model is a littel more flexible.\n", "\n", "We will use the METABRIC data sets as an example" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import torch\n", "from torch import nn\n", "from pycox import datasets\n", "from pycox.models import CoxCC, CoxPH, CoxTime\n", "from pycox.evaluation import EvalSurv\n", "from torchtuples import tuplefy, optim\n", "from torchtuples import callbacks as cb\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn_pandas import DataFrameMapper" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`sklearn_pandas` can be installed with `! pip install sklearn-pandas`" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "np.random.seed(1234)\n", "_ = torch.manual_seed(123)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We load the METABRIC data set and split in train, test and validation." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "df_train = datasets.metabric.read_df()\n", "df_test = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_test.index)\n", "df_val = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_val.index)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | x0 | \n", "x1 | \n", "x2 | \n", "x3 | \n", "x4 | \n", "x5 | \n", "x6 | \n", "x7 | \n", "x8 | \n", "duration | \n", "event | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "5.603834 | \n", "7.811392 | \n", "10.797988 | \n", "5.967607 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "56.840000 | \n", "99.333336 | \n", "0 | \n", "
1 | \n", "5.284882 | \n", "9.581043 | \n", "10.204620 | \n", "5.664970 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "85.940002 | \n", "95.733330 | \n", "1 | \n", "
3 | \n", "6.654017 | \n", "5.341846 | \n", "8.646379 | \n", "5.655888 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "66.910004 | \n", "239.300003 | \n", "0 | \n", "
4 | \n", "5.456747 | \n", "5.339741 | \n", "10.555724 | \n", "6.008429 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "67.849998 | \n", "56.933334 | \n", "1 | \n", "
5 | \n", "5.425826 | \n", "6.331182 | \n", "10.455145 | \n", "5.749053 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "70.519997 | \n", "123.533333 | \n", "0 | \n", "