{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Cox-CC\n", "\n", "In this notebook we will train the [Cox-CC method](http://jmlr.org/papers/volume20/18-424/18-424.pdf).\n", "We will use the METABRIC data sets as an example\n", "\n", "A more detailed introduction to the `pycox` package can be found in [this notebook](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/01_introduction.ipynb) about the `LogisticHazard` method.\n", "\n", "The main benefit Cox-CC (and the other Cox methods) has over Logistic-Hazard is that it is a continuous-time method, meaning we do not need to discretize the time scale." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn_pandas import DataFrameMapper\n", "\n", "import torch\n", "import torchtuples as tt\n", "\n", "from pycox.datasets import metabric\n", "from pycox.models import CoxCC\n", "from pycox.evaluation import EvalSurv" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "## Uncomment to install `sklearn-pandas`\n", "# ! pip install sklearn-pandas" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "np.random.seed(1234)\n", "_ = torch.manual_seed(123)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We load the METABRIC data set and split in train, test and validation." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "df_train = metabric.read_df()\n", "df_test = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_test.index)\n", "df_val = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_val.index)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | x0 | \n", "x1 | \n", "x2 | \n", "x3 | \n", "x4 | \n", "x5 | \n", "x6 | \n", "x7 | \n", "x8 | \n", "duration | \n", "event | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "5.603834 | \n", "7.811392 | \n", "10.797988 | \n", "5.967607 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "56.840000 | \n", "99.333336 | \n", "0 | \n", "
1 | \n", "5.284882 | \n", "9.581043 | \n", "10.204620 | \n", "5.664970 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "85.940002 | \n", "95.733330 | \n", "1 | \n", "
3 | \n", "6.654017 | \n", "5.341846 | \n", "8.646379 | \n", "5.655888 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "66.910004 | \n", "239.300003 | \n", "0 | \n", "
4 | \n", "5.456747 | \n", "5.339741 | \n", "10.555724 | \n", "6.008429 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "67.849998 | \n", "56.933334 | \n", "1 | \n", "
5 | \n", "5.425826 | \n", "6.331182 | \n", "10.455145 | \n", "5.749053 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "70.519997 | \n", "123.533333 | \n", "0 | \n", "