{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 02 Continuing the Introduction\n", "\n", "\n", "For a simpler introduction to `pycox` see the [01_introduction.ipynb](https://nbviewer.jupyter.org/github/havakv/pycox/blob/master/examples/01_introduction.ipynb) instead.\n", "\n", "In this notebook we will show some more functionality of the `pycox` package in addition to more functionality of the `torchtuples` package.\n", "We will continue with the `LogisticHazard` method for simplicity.\n", "\n", "We will in the following: \n", "- use SUPPORT as an example dataset,\n", "- use entity embeddings for categorical variables,\n", "- use the [AdamWR optimizer](https://arxiv.org/pdf/1711.05101.pdf) with cyclical learning rates,\n", "- use the scheme proposed by [Smith 2017](https://arxiv.org/pdf/1506.01186.pdf) to find a suitable learning rate." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn_pandas import DataFrameMapper\n", "\n", "import torch\n", "import torchtuples as tt\n", "\n", "from pycox.datasets import support\n", "from pycox.preprocessing.feature_transforms import OrderedCategoricalLong\n", "from pycox.models import LogisticHazard\n", "from pycox.evaluation import EvalSurv" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "np.random.seed(123456)\n", "_ = torch.manual_seed(123456)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset\n", "\n", "We load the SUPPORT data set and split in train, test and validation." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "df_train = support.read_df()\n", "df_test = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_test.index)\n", "df_val = df_train.sample(frac=0.2)\n", "df_train = df_train.drop(df_val.index)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | x0 | \n", "x1 | \n", "x2 | \n", "x3 | \n", "x4 | \n", "x5 | \n", "x6 | \n", "x7 | \n", "x8 | \n", "x9 | \n", "x10 | \n", "x11 | \n", "x12 | \n", "x13 | \n", "duration | \n", "event | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "82.709961 | \n", "1.0 | \n", "2.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "160.0 | \n", "55.0 | \n", "16.0 | \n", "38.195309 | \n", "142.0 | \n", "19.000000 | \n", "1.099854 | \n", "30.0 | \n", "1 | \n", "
1 | \n", "79.660950 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "1.0 | \n", "54.0 | \n", "67.0 | \n", "16.0 | \n", "38.000000 | \n", "142.0 | \n", "10.000000 | \n", "0.899902 | \n", "1527.0 | \n", "0 | \n", "
4 | \n", "71.794983 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "65.0 | \n", "135.0 | \n", "40.0 | \n", "38.593750 | \n", "146.0 | \n", "0.099991 | \n", "0.399963 | \n", "7.0 | \n", "1 | \n", "
5 | \n", "49.932980 | \n", "0.0 | \n", "1.0 | \n", "1.0 | \n", "0.0 | \n", "0.0 | \n", "0.0 | \n", "70.0 | \n", "105.0 | \n", "33.0 | \n", "38.195309 | \n", "127.0 | \n", "5.299805 | \n", "1.199951 | \n", "50.0 | \n", "1 | \n", "
6 | \n", "62.942989 | \n", "0.0 | \n", "5.0 | \n", "2.0 | \n", "1.0 | \n", "0.0 | \n", "1.0 | \n", "116.0 | \n", "130.0 | \n", "35.0 | \n", "38.195309 | \n", "133.0 | \n", "14.099609 | \n", "0.799927 | \n", "381.0 | \n", "0 | \n", "