{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:stancache.seed:Setting seed to 1245502385\n" ] } ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import random\n", "random.seed(1100038344)\n", "import survivalstan\n", "import numpy as np\n", "import pandas as pd\n", "from stancache import stancache\n", "from matplotlib import pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The model\n", "\n", "This style of modeling is often called the \"piecewise exponential model\", or PEM. It is the simplest case where we estimate the *hazard* of an event occurring in a time period as the outcome, rather than estimating the *survival* (ie, time to event) as the outcome.\n", "\n", "Recall that, in the context of survival modeling, we have two models:\n", "\n", "1. A model for **Survival ($S$)**, ie the probability of surviving to time $t$:\n", "\n", " $$S(t)=Pr(Y > t)$$\n", "\n", "2. A model for the **instantaneous *hazard* $\\lambda$**, ie the probability of a failure event occuring in the interval [$t$, $t+\\delta t$], given survival to time $t$:\n", "\n", " $$\\lambda(t) = \\lim_{\\delta t \\rightarrow 0 } \\; \\frac{Pr( t \\le Y \\le t + \\delta t | Y > t)}{\\delta t}$$\n", "\n", "\n", "By definition, these two are related to one another by the following equation:\n", "\n", " $$\\lambda(t) = \\frac{-S'(t)}{S(t)}$$\n", " \n", "Solving this, yields the following:\n", "\n", " $$S(t) = \\exp\\left( -\\int_0^t \\lambda(z) dz \\right)$$\n", "\n", "This model is called the **piecewise exponential model** because of this relationship between the Survival and hazard functions. It's piecewise because we are not estimating the *instantaneous* hazard; we are instead breaking time periods up into pieces and estimating the hazard for each piece.\n", "\n", "There are several variations on the PEM model implemented in survivalstan. In this notebook, we are exploring just one of them." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A note about data formatting \n", "\n", "When we model *Survival*, we typically operate on data in time-to-event form. In this form, we have one record per Subject (ie, per patient). Each record contains [event_status, time_to_event] as the outcome. This data format is sometimes called *per-subject*.\n", "\n", "When we model the *hazard* by comparison, we typically operate on data that are transformed to include one record per Subject per time_period. This is called *per-timepoint* or *long* form.\n", "\n", "All other things being equal, a model for *Survival* will typically estimate more efficiently (faster & smaller memory footprint) than one for *hazard* simply because the data are larger in the per-timepoint form than the per-subject form. The benefit of the *hazard* models is increased flexibility in terms of specifying the baseline hazard, time-varying effects, and introducing time-varying covariates.\n", "\n", "In this example, we are demonstrating use of the standard **PEM survival model**, which uses data in long form. The stan code expects to recieve data in this structure." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Stan code for the model\n", "\n", "This model is provided in survivalstan.models.pem_survival_model. Let's take a look at the stan code. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/* Variable naming:\n", " // dimensions\n", " N = total number of observations (length of data)\n", " S = number of sample ids\n", " T = max timepoint (number of timepoint ids)\n", " M = number of covariates\n", " \n", " // main data matrix (per observed timepoint*record)\n", " s = sample id for each obs\n", " t = timepoint id for each obs\n", " event = integer indicating if there was an event at time t for sample s\n", " x = matrix of real-valued covariates at time t for sample n [N, X]\n", " \n", " // timepoint-specific data (per timepoint, ordered by timepoint id)\n", " t_obs = observed time since origin for each timepoint id (end of period)\n", " t_dur = duration of each timepoint period (first diff of t_obs)\n", " \n", "*/\n", "// Jacqueline Buros Novik \n", "\n", "data {\n", " // dimensions\n", " int N;\n", " int S;\n", " int T;\n", " int M;\n", " \n", " // data matrix\n", " int s[N]; // sample id\n", " int t[N]; // timepoint id\n", " int event[N]; // 1: event, 0:censor\n", " matrix[N, M] x; // explanatory vars\n", " \n", " // timepoint data\n", " vector[T] t_obs;\n", " vector[T] t_dur;\n", "}\n", "transformed data {\n", " vector[T] log_t_dur; // log-duration for each timepoint\n", " int n_trans[S, T]; \n", " \n", " log_t_dur = log(t_dur);\n", "\n", " // n_trans used to map each sample*timepoint to n (used in gen quantities)\n", " // map each patient/timepoint combination to n values\n", " for (n in 1:N) {\n", " n_trans[s[n], t[n]] = n;\n", " }\n", "\n", " // fill in missing values with n for max t for that patient\n", " // ie assume \"last observed\" state applies forward (may be problematic for TVC)\n", " // this allows us to predict failure times >= observed survival times\n", " for (samp in 1:S) {\n", " int last_value;\n", " last_value = 0;\n", " for (tp in 1:T) {\n", " // manual says ints are initialized to neg values\n", " // so <=0 is a shorthand for \"unassigned\"\n", " if (n_trans[samp, tp] <= 0 && last_value != 0) {\n", " n_trans[samp, tp] = last_value;\n", " } else {\n", " last_value = n_trans[samp, tp];\n", " }\n", " }\n", " } \n", "}\n", "parameters {\n", " vector[T] log_baseline_raw; // unstructured baseline hazard for each timepoint t\n", " vector[M] beta; // beta for each covariate\n", " real baseline_sigma;\n", " real log_baseline_mu;\n", "}\n", "transformed parameters {\n", " vector[N] log_hazard;\n", " vector[T] log_baseline; // unstructured baseline hazard for each timepoint t\n", " \n", " log_baseline = log_baseline_mu + log_baseline_raw + log_t_dur;\n", " \n", " for (n in 1:N) {\n", " log_hazard[n] = log_baseline[t[n]] + x[n,]*beta;\n", " }\n", "}\n", "model {\n", " beta ~ cauchy(0, 2);\n", " event ~ poisson_log(log_hazard);\n", " log_baseline_mu ~ normal(0, 1);\n", " baseline_sigma ~ normal(0, 1);\n", " log_baseline_raw ~ normal(0, baseline_sigma);\n", "}\n", "generated quantities {\n", " real log_lik[N];\n", " vector[T] baseline;\n", " real y_hat_time[S]; // predicted failure time for each sample\n", " int y_hat_event[S]; // predicted event (0:censor, 1:event)\n", " \n", " // compute raw baseline hazard, for summary/plotting\n", " baseline = exp(log_baseline_mu + log_baseline_raw);\n", " \n", " // prepare log_lik for loo-psis\n", " for (n in 1:N) {\n", " log_lik[n] = poisson_log_log(event[n], log_hazard[n]);\n", " }\n", "\n", " // posterior predicted values\n", " for (samp in 1:S) {\n", " int sample_alive;\n", " sample_alive = 1;\n", " for (tp in 1:T) {\n", " if (sample_alive == 1) {\n", " int n;\n", " int pred_y;\n", " real log_haz;\n", " \n", " // determine predicted value of this sample's hazard\n", " n = n_trans[samp, tp];\n", " log_haz = log_baseline[tp] + x[n,] * beta;\n", " \n", " // now, make posterior prediction of an event at this tp\n", " if (log_haz < log(pow(2, 30))) \n", " pred_y = poisson_log_rng(log_haz);\n", " else\n", " pred_y = 9; \n", " \n", " // summarize survival time (observed) for this pt\n", " if (pred_y >= 1) {\n", " // mark this patient as ineligible for future tps\n", " // note: deliberately treat 9s as events \n", " sample_alive = 0;\n", " y_hat_time[samp] = t_obs[tp];\n", " y_hat_event[samp] = 1;\n", " }\n", " \n", " }\n", " } // end per-timepoint loop\n", " \n", " // if patient still alive at max\n", " if (sample_alive == 1) {\n", " y_hat_time[samp] = t_obs[T];\n", " y_hat_event[samp] = 0;\n", " }\n", " } // end per-sample loop \n", "}\n", "\n" ] } ], "source": [ "print(survivalstan.models.pem_survival_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simulate survival data " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to demonstrate the use of this model, we will first simulate some survival data using survivalstan.sim.sim_data_exp_correlated. As the name implies, this function simulates data assuming a constant hazard throughout the follow-up time period, which is consistent with the Exponential survival function.\n", "\n", "This function includes two simulated covariates by default (age and sex). We also simulate a situation where hazard is a function of the simulated value for sex. \n", "\n", "We also center the age variable since this will make it easier to interpret estimates of the baseline hazard.\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:stancache.stancache:sim_data_exp_correlated: cache_filename set to sim_data_exp_correlated.cached.N_100.censor_time_20.rate_coefs_54462717316.rate_form_1 + sex.pkl\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:stancache.stancache:sim_data_exp_correlated: Loading result from cache\n" ] } ], "source": [ "d = stancache.cached(\n", " survivalstan.sim.sim_data_exp_correlated,\n", " N=100,\n", " censor_time=20,\n", " rate_form='1 + sex',\n", " rate_coefs=[-3, 0.5],\n", ")\n", "d['age_centered'] = d['age'] - d['age'].mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Aside: In order to make this a more reproducible example, this code is using a file-caching function stancache.cached to wrap a function call to survivalstan.sim.sim_data_exp_correlated. *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Explore simulated data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is what these data look like - this is per-subject or time-to-event form:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sexageratetrue_tteventindexage_centered
0male540.0820851.0138551.013855True0-1.12
1male390.0820854.8905974.890597True1-16.12
2female450.0497874.0934044.093404True2-10.12
3female430.0497877.0362267.036226True3-12.12
4female570.0497875.7122995.712299True41.88
\n", "
" ], "text/plain": [ " sex age rate true_t t event index age_centered\n", "0 male 54 0.082085 1.013855 1.013855 True 0 -1.12\n", "1 male 39 0.082085 4.890597 4.890597 True 1 -16.12\n", "2 female 45 0.049787 4.093404 4.093404 True 2 -10.12\n", "3 female 43 0.049787 7.036226 7.036226 True 3 -12.12\n", "4 female 57 0.049787 5.712299 5.712299 True 4 1.88" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*It's not that obvious from the field names, but in this example \"subjects\" are indexed by the field index.*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can plot these data using lifelines, or the rudimentary plotting functions provided by survivalstan." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "