{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:stancache.seed:Setting seed to 1245502385\n" ] } ], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "import random\n", "random.seed(13847942484)\n", "import survivalstan\n", "import numpy as np\n", "import pandas as pd\n", "from stancache import stancache\n", "from matplotlib import pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/* Variable naming:\n", " // dimensions\n", " N = total number of observations (length of data)\n", " S = number of sample ids\n", " T = max timepoint (number of timepoint ids)\n", " M = number of covariates\n", " \n", " // data\n", " s = sample id for each obs\n", " t = timepoint id for each obs\n", " event = integer indicating if there was an event at time t for sample s\n", " x = matrix of real-valued covariates at time t for sample n [N, X]\n", " obs_t = observed end time for interval for timepoint for that obs\n", " \n", "*/\n", "// Jacqueline Buros Novik \n", "\n", "functions {\n", " matrix spline(vector x, int N, int H, vector xi, int P) {\n", " matrix[N, H + P] b_x; // expanded predictors\n", " for (n in 1:N) {\n", " for (p in 1:P) {\n", " b_x[n,p] <- pow(x[n],p-1); // x[n]^(p-1)\n", " }\n", " for (h in 1:H)\n", " b_x[n, h + P] <- fmax(0, pow(x[n] - xi[h],P-1)); \n", " }\n", " return b_x;\n", " }\n", "}\n", "data {\n", " // dimensions\n", " int N;\n", " int S;\n", " int T;\n", " int M;\n", " \n", " // data matrix\n", " int s[N]; // sample id\n", " int t[N]; // timepoint id\n", " int event[N]; // 1: event, 0:censor\n", " matrix[N, M] x; // explanatory vars\n", "\n", " // timepoint data\n", " vector[T] t_obs;\n", " vector[T] t_dur;\n", "}\n", "transformed data {\n", " vector[T] log_t_dur;\n", " int n_trans[S, T]; \n", "\n", " log_t_dur = log(t_dur);\n", " \n", " // n_trans used to map each sample*timepoint to n (used in gen quantities)\n", " // map each patient/timepoint combination to n values\n", " for (n in 1:N) {\n", " n_trans[s[n], t[n]] = n;\n", " }\n", "\n", " // fill in missing values with n for max t for that patient\n", " // ie assume \"last observed\" state applies forward (may be problematic for TVC)\n", " // this allows us to predict failure times >= observed survival times\n", " for (samp in 1:S) {\n", " int last_value;\n", " last_value = 0;\n", " for (tp in 1:T) {\n", " // manual says ints are initialized to neg values\n", " // so <=0 is a shorthand for \"unassigned\"\n", " if (n_trans[samp, tp] <= 0 && last_value != 0) {\n", " n_trans[samp, tp] = last_value;\n", " } else {\n", " last_value = n_trans[samp, tp];\n", " }\n", " }\n", " } \n", "}\n", "parameters {\n", " vector[T] log_baseline_raw; // unstructured baseline hazard for each timepoint t\n", " real baseline_sigma;\n", " real log_baseline_mu;\n", " \n", " vector[M] beta; // beta-intercept\n", " vector[M] beta_time_sigma;\n", " vector[T-1] raw_beta_time_deltas[M]; // for each coefficient\n", " // change in coefficient value from previous time\n", "}\n", "transformed parameters {\n", " vector[N] log_hazard;\n", " vector[T] log_baseline;\n", " vector[T] beta_time[M];\n", " vector[T] beta_time_deltas[M];\n", "\n", " // adjust baseline hazard for duration of each period\n", " log_baseline = log_baseline_raw + log_t_dur;\n", " \n", " // compute timepoint-specific betas \n", " // offsets from previous time\n", " for (coef in 1:M) {\n", " beta_time_deltas[coef][1] = 0;\n", " for (time in 2:T) {\n", " beta_time_deltas[coef][time] = raw_beta_time_deltas[coef][time-1];\n", " }\n", " }\n", " \n", " // coefficients for each timepoint T\n", " for (coef in 1:M) {\n", " beta_time[coef] = beta[coef] + cumulative_sum(beta_time_deltas[coef]);\n", " }\n", "\n", " // compute log-hazard for each obs\n", " for (n in 1:N) {\n", " real log_linpred;\n", " log_linpred <- 0;\n", " for (coef in 1:M) {\n", " // for now, handle each coef separately\n", " // (to be sure we pull out the \"right\" beta..)\n", " log_linpred = log_linpred + x[n, coef] * beta_time[coef][t[n]]; \n", " }\n", " log_hazard[n] = log_baseline_mu + log_baseline[t[n]] + log_linpred;\n", " }\n", "}\n", "model {\n", " // priors on time-varying coefficients\n", " for (m in 1:M) {\n", " raw_beta_time_deltas[m][1] ~ normal(0, 100);\n", " for(i in 2:(T-1)){\n", " raw_beta_time_deltas[m][i] ~ normal(raw_beta_time_deltas[m][i-1], beta_time_sigma[m]);\n", " }\n", " }\n", " beta_time_sigma ~ cauchy(0, 1);\n", " beta ~ cauchy(0, 1);\n", " \n", " // priors on baseline hazard\n", " log_baseline_mu ~ normal(0, 1);\n", " baseline_sigma ~ normal(0, 1);\n", " log_baseline_raw[1] ~ normal(0, 1);\n", " for (i in 2:T) {\n", " log_baseline_raw[i] ~ normal(log_baseline_raw[i-1], baseline_sigma);\n", " }\n", " \n", " // model\n", " event ~ poisson_log(log_hazard);\n", "}\n", "generated quantities {\n", " real log_lik[N];\n", " vector[T] baseline;\n", " int y_hat_mat[S, T]; // ppcheck for each S*T combination\n", " real y_hat_time[S]; // predicted failure time for each sample\n", " int y_hat_event[S]; // predicted event (0:censor, 1:event)\n", " \n", " // compute raw baseline hazard, for summary/plotting\n", " baseline = exp(log_baseline_raw);\n", " \n", " // log_likelihood for loo-psis\n", " for (n in 1:N) {\n", " log_lik[n] <- poisson_log_lpmf(event[n] | log_hazard[n]);\n", " }\n", " \n", " // posterior predicted values\n", " for (samp in 1:S) {\n", " int sample_alive;\n", " sample_alive = 1;\n", " for (tp in 1:T) {\n", " if (sample_alive == 1) {\n", " int n;\n", " int pred_y;\n", " real log_linpred;\n", " real log_haz;\n", " \n", " // determine predicted value of y\n", " n = n_trans[samp, tp];\n", " \n", " // (borrow code from above to calc linpred)\n", " // but use sim tp not t[n] \n", " log_linpred = 0;\n", " for (coef in 1:M) {\n", " // for now, handle each coef separately\n", " // (to be sure we pull out the \"right\" beta..)\n", " log_linpred = log_linpred + x[n, coef] * beta_time[coef][tp]; \n", " }\n", " log_haz = log_baseline_mu + log_baseline[tp] + log_linpred;\n", " \n", " // now, make posterior prediction \n", " if (log_haz < log(pow(2, 30))) \n", " pred_y = poisson_log_rng(log_haz);\n", " else\n", " pred_y = 9; \n", " \n", " // mark this patient as ineligible for future tps\n", " // note: deliberately make 9s ineligible \n", " if (pred_y >= 1) {\n", " sample_alive = 0;\n", " y_hat_time[samp] = t_obs[tp];\n", " y_hat_event[samp] = 1;\n", " }\n", " \n", " // save predicted value of y to matrix\n", " y_hat_mat[samp, tp] = pred_y;\n", " }\n", " else if (sample_alive == 0) {\n", " y_hat_mat[samp, tp] = 9;\n", " } \n", " } // end per-timepoint loop\n", " \n", " // if patient still alive at max\n", " // \n", " if (sample_alive == 1) {\n", " y_hat_time[samp] = t_obs[T];\n", " y_hat_event[samp] = 0;\n", " }\n", " } // end per-sample loop\n", "}\n", "\n" ] } ], "source": [ "print(survivalstan.models.pem_survival_model_timevarying)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:stancache.stancache:sim_data_exp_correlated: cache_filename set to sim_data_exp_correlated.cached.N_100.censor_time_20.rate_coefs_54462717316.rate_form_1 + sex.pkl\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "INFO:stancache.stancache:sim_data_exp_correlated: Loading result from cache\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sexageratetrue_tteventindexage_centered
0male540.0820851.0138551.013855True0-1.12
1male390.0820854.8905974.890597True1-16.12
2female450.0497874.0934044.093404True2-10.12
3female430.0497877.0362267.036226True3-12.12
4female570.0497875.7122995.712299True41.88
\n", "