{ "cells": [ { "cell_type": "markdown", "id": "d5e1404f", "metadata": {}, "source": [ "# Deep Cox Mixtures with Heterogenous Effects (CMHE) Demo\n", "
\n", "\n", "Author: ***Mononito Goswami*** <mgoswami@cs.cmu.edu>\n", "\n", "
\n", "\n", "\n", "
\n", "\n", "# Contents\n", "\n", "\n", "\n", "### 1. [Introduction](#introduction) \n", "\n", "\n", "### 2. [Synthetic Data](#syndata) \n", "####               2.1 [Generative Process for the Synthetic Dataset.](#gensyndata)\n", "####               2.2 [Loading and Visualizing the Dataset.](#vissyndata)\n", "####               2.3 [Split Dataset into Train and Test.](#splitdata)\n", "\n", " \n", "### 3. [Counterfactual Phenotyping](#phenotyping)\n", "\n", "####               3.1 [Phenotyping with CMHE](#phenocmhe)\n", "####               3.2 [Comparison with Clustering](#clustering)\n", "\n", "\n", "### 4. [Factual Regression](#regression)\n", "####               4.1 [Factual Regression with CMHE](#regcmhe)\n", "####               4.2 [Comparison with a Deep Cox Proportional Hazards Model](#deepcph)\n", "\n", "
\n" ] }, { "cell_type": "markdown", "id": "bebb3be4", "metadata": {}, "source": [ "\n", "\n", "## 1. Introduction\n", "\n", "Figure A: Schematic Description of CMHE:The set of features (confounders) $\\mathbf{x}$ are passed through an encoder to obtain deep non-linear representations. These representations then describe the latent phenogroups $\\mathbf{P}(Z|X=\\mathbf{x})$ and $\\mathbf{P}(\\mathbf{\\phi}|X=\\mathbf{x})$ that determine the base survival rate and the treatment effect respectively. Finally, the individual level hazard (survival) curve under an intervention $A=\\mathbf{a}$ is described by marginalizing over $Z$ and $\\mathbf{\\phi}$ as $\\mathbf{S}(t|X=x, A=a) = \\mathbf{E}_{(Z,\\mathbf{\\phi)}\\sim \\mathbf{P}(\\cdot|X)}\\big[ \\mathbf{S}(t|A=\\mathbf{a}, X, Z, \\mathbf{\\phi})\\big]$. \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "

Cox Mixture with Heterogenous Effects (CMHE) is a flexible approach to recover counterfactual phenotypes of individuals that demonstrate heterogneous effects to an intervention in terms of censored Time-to-Event outcomes. CMHE is not restricted by the strong Cox Proportional Hazards assumption or any parametric assumption on the time to event distributions. CMHE achieves this by describing each individual as belonging to two different latent groups, \n", "$Z$ that mediate the base survival rate and $\\phi$ the effect of the treatment. CMHE can also be employed to model individual level counterfactuals or for standard factual survival regression.\n", "\n", "**Figure B (Right)**: CMHE in Plate Notation. $\\mathbf{x}$ confounds treatment assignment $A$ and outcome $T$ (Model parameters and censoring distribution have been abstracted out).\n", "\n", " \n", " \n", "*For full details on Cox Mixtures with Heterogenous Effects, please refer to our preprint*:\n", "\n", "[Counterfactual Phenotyping with Censored Time-to-Events, arXiv preprint, C. Nagpal, M. Goswami, K. Dufendach, A. Dubrawski](https://arxiv.org/abs/2202.11089)\n", "\n" ] }, { "cell_type": "markdown", "id": "82376e75", "metadata": {}, "source": [ "\n", "\n", "## 2. Synthetic Data Example\n" ] }, { "cell_type": "code", "execution_count": null, "id": "376a63a9", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import torch\n", "from tqdm import tqdm \n", "import sys\n", "sys.path.append('../')\n", "\n", "from auton_survival.datasets import load_dataset\n", "from cmhe_demo_utils import * " ] }, { "cell_type": "markdown", "id": "2cacdcff", "metadata": {}, "source": [ "\n", "### 2.1. Generative Process for the Synthetic Data" ] }, { "cell_type": "markdown", "id": "98834662", "metadata": {}, "source": [ "1. Features $x_1$, $x_2$ and the base survival phenotypes $Z$ are sampled from $\\texttt{scikit-learn's make_blobs(...)}$ function which generates isotropic Gaussian blobs:\n", "$$[x_1, x_2], Z \\sim \\texttt{sklearn.datasets.make_blobs(K = 3)}$$\n", "2. Features $x_3$ and $x_4$ are sampled uniformly, whereas the underlying treatment effect phenotypes $\\phi$ are defined according to an $L_1$-ball:\n", "$$ [x_1, x_2] \\sim \\texttt{Uniform}(-2, 2) $$\n", "$$ \\phi \\triangleq \\mathbb{1}\\{|x_3| + |x_3| > 2\\} $$\n", "3. We then sample treat assignments from a Bernoulli distribution:\n", "$$ A \\sim \\texttt{Bernoulli}(\\frac{1}{2}) $$\n", "4. Next, the time-to-event $T$ conditioned on the confounders $x$, latent $Z$ and latent effect group $\\phi$ are generated from a Gompertz distribution:\n", "$$ T^{*}| (Z=k, {\\phi}=m, A={a}) \\sim \\nonumber \\texttt{Gompertz}\\big({\\beta}_{k}^{\\top}{x} +({-a}^m)\\big) $$\n", "5. Finally, the observed time $T$ is obtained after censoring some of the events and censoring time is chosen uniformly at random upto $T^*$:\n", "$$\\delta \\sim \\texttt{Bernoulli}(\\frac{3}{4}), \\quad C \\sim \\texttt{Uniform}(0, {T}^{*})$$\n", "$$ T = \\begin{cases} T^*, & \\text{if } \\delta = 1 \\\\ C, & \\text{if } \\delta = 0 \\end{cases} $$" ] }, { "cell_type": "code", "execution_count": null, "id": "89915ea1", "metadata": {}, "outputs": [], "source": [ "# Load the synthetic dataset\n", "outcomes, features, interventions = load_dataset(dataset='SYNTHETIC')\n", "\n", "# Let's take a look at take the dataset\n", "features.head(5)" ] }, { "cell_type": "markdown", "id": "fbcb9532", "metadata": {}, "source": [ "\n", "### 2.2. Visualizing the Synthetic Data" ] }, { "cell_type": "code", "execution_count": null, "id": "6919fbe6", "metadata": {}, "outputs": [], "source": [ "plot_synthetic_data(outcomes, features, interventions)" ] }, { "cell_type": "markdown", "id": "c63e29de", "metadata": {}, "source": [ "\n", "### 2.3 Split Dataset into Train and Test folds" ] }, { "cell_type": "code", "execution_count": null, "id": "8f1cd518", "metadata": {}, "outputs": [], "source": [ "# Hyper-parameters\n", "random_seed = 0\n", "test_size = 0.25\n", "\n", "# Split the synthetic data into training and testing data\n", "import numpy as np\n", "\n", "np.random.seed(random_seed)\n", "n = features.shape[0] \n", "\n", "test_idx = np.zeros(n).astype('bool')\n", "test_idx[np.random.randint(n, size=int(n*test_size))] = True \n", "\n", "features_tr = features.iloc[~test_idx] \n", "outcomes_tr = outcomes.iloc[~test_idx]\n", "interventions_tr = interventions[~test_idx]\n", "print(f'Number of training data points: {len(features_tr)}')\n", "\n", "features_te = features.iloc[test_idx] \n", "outcomes_te = outcomes.iloc[test_idx]\n", "interventions_te = interventions[test_idx]\n", "print(f'Number of test data points: {len(features_te)}')\n", "\n", "x_tr = features_tr.values.astype('float32')\n", "t_tr = outcomes_tr['time'].values.astype('float32')\n", "e_tr = outcomes_tr['event'].values.astype('float32')\n", "a_tr = interventions_tr.values.astype('float32')\n", "\n", "x_te = features_te.values.astype('float32')\n", "t_te = outcomes_te['time'].values.astype('float32')\n", "e_te = outcomes_te['event'].values.astype('float32')\n", "a_te = interventions_te.values.astype('float32')\n", "\n", "print('Training Data Statistics:')\n", "print(f'Shape of covariates: {x_tr.shape} | times: {t_tr.shape} | events: {e_tr.shape} | interventions: {a_tr.shape}')" ] }, { "cell_type": "code", "execution_count": null, "id": "48457e79", "metadata": {}, "outputs": [], "source": [ "def find_max_treatment_effect_phenotype(g, zeta_probs, factual_outcomes):\n", " \"\"\"\n", " Find the group with the maximum treatement effect phenotype\n", " \"\"\"\n", " mean_differential_survival = np.zeros(zeta_probs.shape[1]) # Area under treatment phenotype group\n", " outcomes_train, interventions_train = factual_outcomes \n", "\n", " # Assign each individual to their treatment phenotype group\n", " for gr in range(g): # For each treatment phenotype group\n", " # Probability of belonging the the g^th treatment phenotype\n", " zeta_probs_g = zeta_probs[:, gr] \n", " # Consider only those individuals who are in the top 75 percentiles in this phenotype\n", " z_mask = zeta_probs_g>np.quantile(zeta_probs_g, 0.75) \n", "\n", " mean_differential_survival[gr] = find_mean_differential_survival(\n", " outcomes_train.loc[z_mask], interventions_train.loc[z_mask]) \n", "\n", " return np.nanargmax(mean_differential_survival)" ] }, { "cell_type": "markdown", "id": "e490d4f1", "metadata": {}, "source": [ "\n", "## 3. Counterfactual Phenotyping" ] }, { "cell_type": "markdown", "id": "18d411d0", "metadata": {}, "source": [ "\n", "### 3.1 Counterfactual Phenotyping with CMHE" ] }, { "cell_type": "code", "execution_count": null, "id": "d79ebeff", "metadata": {}, "outputs": [], "source": [ "# Hyper-parameters to train model\n", "k = 1 # number of underlying base survival phenotypes\n", "g = 2 # number of underlying treatment effect phenotypes.\n", "layers = [50, 50] # number of neurons in each hidden layer.\n", "\n", "random_seed = 10\n", "iters = 100 # number of training epochs\n", "learning_rate = 0.01\n", "batch_size = 256 \n", "vsize = 0.15 # size of the validation split\n", "patience = 3\n", "optimizer = \"Adam\"" ] }, { "cell_type": "code", "execution_count": null, "id": "cdde29c8", "metadata": { "scrolled": true }, "outputs": [], "source": [ "from auton_survival.models.cmhe import DeepCoxMixturesHeterogenousEffects\n", "\n", "torch.manual_seed(random_seed)\n", "np.random.seed(random_seed)\n", "\n", "# Instantiate the CMHE model\n", "model = DeepCoxMixturesHeterogenousEffects(random_seed=random_seed, k=k, g=g, layers=layers)\n", "\n", "model = model.fit(x_tr, t_tr, e_tr, a_tr, vsize=vsize, val_data=None, iters=iters, \n", " learning_rate=learning_rate, batch_size=batch_size, \n", " optimizer=optimizer, patience=patience)" ] }, { "cell_type": "code", "execution_count": null, "id": "03128cb8", "metadata": {}, "outputs": [], "source": [ "print(f'Treatment Effect for the {g} groups: {model.torch_model[0].omega.detach()}')\n", "\n", "zeta_probs_train = model.predict_latent_phi(x_tr)\n", "zeta_train = np.argmax(zeta_probs_train, axis=1)\n", "print(f'Distribution of individuals in each treatement phenotype in the training data: \\\n", "{np.unique(zeta_train, return_counts=True)[1]}')\n", "\n", "max_treat_idx_CMHE = find_max_treatment_effect_phenotype(\n", " g=2, zeta_probs=zeta_probs_train, factual_outcomes=(outcomes_tr, interventions_tr))\n", "print(f'\\nGroup {max_treat_idx_CMHE} has the maximum restricted mean survival time on the training data!')" ] }, { "cell_type": "markdown", "id": "f80efe86", "metadata": {}, "source": [ "### Evaluate CMHE on Test Data" ] }, { "cell_type": "code", "execution_count": null, "id": "d03fc282", "metadata": {}, "outputs": [], "source": [ "# Now for each individual in the test data, let's find the probability that \n", "# they belong to the max treatment effect group\n", "\n", "zeta_probs_test_CMHE = model.predict_latent_phi(x_te)\n", "zeta_test = np.argmax(zeta_probs_test_CMHE, axis=1)\n", "print(f'Distribution of individuals in each treatement phenotype in the test data: \\\n", "{np.unique(zeta_test, return_counts=True)[1]}')\n", "\n", "# Now let us evaluate our performance\n", "plot_phenotypes_roc(outcomes_te, zeta_probs_test_CMHE[:, max_treat_idx_CMHE])" ] }, { "cell_type": "markdown", "id": "bce5722a", "metadata": {}, "source": [ "\n", "### 3.3 Comparison with the Clustering phenotyper" ] }, { "cell_type": "markdown", "id": "94f63bc7", "metadata": {}, "source": [ "We compare the ability of CMHE against dimensionality reduction followed by clustering for counterfactual phenotyping. Specifically, we first perform dimensionality reduction of the input confounders, $\\mathbf{x}$, followed by clustering. Due to a small number of confounders in the synthetic data, in the following experiment, we directly perform clustering using a Gaussian Mixture Model (GMM) with 2 components and diagonal covariance matrices." ] }, { "cell_type": "code", "execution_count": null, "id": "35a063e4", "metadata": {}, "outputs": [], "source": [ "from phenotyping import ClusteringPhenotyper\n", "from sklearn.metrics import auc\n", "\n", "clustering_method = 'gmm'\n", "dim_red_method = None # We would not perform dimensionality reduction for the synthetic dataset\n", "n_components = None \n", "n_clusters = 2 # Number of underlying treatment effect phenotypes\n", "\n", "# Running the phenotyper\n", "phenotyper = ClusteringPhenotyper(clustering_method=clustering_method, \n", " dim_red_method=dim_red_method, \n", " n_components=n_components, \n", " n_clusters=n_clusters,\n", " random_seed=36) " ] }, { "cell_type": "code", "execution_count": null, "id": "6842678b", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "5bf55005", "metadata": {}, "outputs": [], "source": [ "zeta_probs_train = phenotyper.fit(features_tr.values).predict_proba(features_tr.values)\n", "zeta_train = phenotyper.fit_predict(features_tr.values)\n", "print(f'Distribution of individuals in each treatement phenotype in the training data: \\\n", "{np.unique(zeta_train, return_counts=True)[1]}')\n", "\n", "max_treat_idx_CP = find_max_treatment_effect_phenotype(\n", " g=2, zeta_probs=zeta_probs_train, factual_outcomes=(outcomes_tr, interventions_tr))\n", "print(f'\\nGroup {max_treat_idx_CP} has the maximum restricted mean survival time on the training data!')" ] }, { "cell_type": "markdown", "id": "1d3d15a6", "metadata": {}, "source": [ "### Evaluate Clustering Phenotyper on Test Data" ] }, { "cell_type": "code", "execution_count": null, "id": "659ef29f", "metadata": {}, "outputs": [], "source": [ "# Now for each individual in the test data, let's find the probability that \n", "# they belong to the max treatment effect group\n", "\n", "# Use the phenotyper trained on training data to phenotype on testing data\n", "zeta_probs_test_CP = phenotyper.predict_proba(x_te)\n", "zeta_test_CP = np.argmax(zeta_probs_test_CP, axis=1)\n", "print(f'Distribution of individuals in each treatement phenotype in the test data: \\\n", "{np.unique(zeta_test_CP, return_counts=True)[1]}')\n", "\n", "# Now let us evaluate our performance\n", "plot_phenotypes_roc(outcomes_te, zeta_probs_test_CP[:, max_treat_idx_CP])" ] }, { "cell_type": "markdown", "id": "e6bfd2ce", "metadata": {}, "source": [ "\n", "## 4. CMHE for Factual Regression" ] }, { "cell_type": "markdown", "id": "c4798065", "metadata": {}, "source": [ "For completeness, we further evaluate the performance of CMHE in estimating factual risk over multiple time horizons using the standard survival analysis metrics, including: \n", "\n", "1. $\\textbf{Brier Score} \\ (\\textrm{BS})$: Defined as the Mean Squared Error (MSE) around the probabilistic prediction at a certain time horizon.\n", "\\begin{align}\n", "\\text{BS}(t) = \\mathop{\\mathbf{E}}_{x\\sim\\mathcal{D}}\\big[ ||\\mathbf{1}\\{ T > t \\} - \\widehat{\\mathbf{P}}(T>t|X)\\big)||_{_\\textbf{2}}^\\textbf{2} \\big]\n", "\\end{align}\n", "2. $ \\textbf{Time Dependent Concordance Index} \\ (C^{\\text{td}}$): A rank order statistic that computes model performance in ranking patients based on their estimated risk at a specfic time horizon.\n", "\\begin{align}\n", "C^{td }(t) = \\mathbf{P}\\big( \\hat{F}(t| \\mathbf{x}_i) > \\hat{F}(t| \\mathbf{x}_j) | \\delta_i=1, T_i\n", "\n", "### 4.1 Factual Regression Performance of CMHE" ] }, { "cell_type": "code", "execution_count": null, "id": "48303372", "metadata": {}, "outputs": [], "source": [ "horizons = [1, 3, 5]\n", "\n", "# Now let us predict survival using CMHE\n", "predictions_test_CMHE = model.predict_survival(x_te, a_te, t=horizons)\n", "\n", "CI1, CI3, CI5, IBS = factual_evaluate((x_tr, t_tr, e_tr, a_tr), (x_te, t_te, e_te, a_te), \n", " horizons, predictions_test_CMHE)\n", "print(f'Concordance Index (1 Year): {np.around(CI1, 4)} (3 Year) {np.around(CI3, 4)}: (5 Year): {np.around(CI5, 4)}')\n", "print(f'Integrated Brier Score: {np.around(IBS, 4)}')" ] }, { "cell_type": "markdown", "id": "d38ccb2e", "metadata": {}, "source": [ "\n", "### 4.2 Comparison with Deep Cox-Proportional Hazards Model" ] }, { "cell_type": "code", "execution_count": null, "id": "fee7fcea", "metadata": { "scrolled": true }, "outputs": [], "source": [ "from auton_survival.estimators import SurvivalModel\n", "\n", "# Now let us train a Deep Cox-proportional Hazard model with two linear layers and tanh activations\n", "dcph_model = SurvivalModel('dcph', random_seed=0, bs=100, learning_rate=1e-3, layers=[50, 50])\n", "\n", "interventions_tr.name, interventions_te.name = 'treat', 'treat'\n", "features_tr_dcph = pd.concat([features_tr, interventions_tr.astype('float64')], axis=1)\n", "features_te_dcph = pd.concat([features_te, interventions_te.astype('float64')], axis=1)\n", "outcomes_tr_dcph = pd.DataFrame(outcomes_tr, columns=['event', 'time']).astype('float64')\n", "\n", "# Train the DCPH model\n", "dcph_model = dcph_model.fit(features_tr_dcph, outcomes_tr_dcph)" ] }, { "cell_type": "markdown", "id": "dd154102", "metadata": {}, "source": [ "### Evaluate DCPH on Test Data" ] }, { "cell_type": "code", "execution_count": null, "id": "d2887aa7", "metadata": {}, "outputs": [], "source": [ "# Find suvival scores in the test data\n", "predictions_test_DCPH = dcph_model.predict_survival(features_te_dcph, horizons)\n", "\n", "CI1, CI3, CI5, IBS = factual_evaluate((x_tr, t_tr, e_tr, a_tr), (x_te, t_te, e_te, a_te), \n", " horizons, predictions_test_DCPH)\n", "print(f'Concordance Index (1 Year): {np.around(CI1, 4)} (3 Year) {np.around(CI3, 4)}: (5 Year): {np.around(CI5, 4)}')\n", "print(f'Integrated Brier Score: {np.around(IBS, 4)}')" ] }, { "cell_type": "code", "execution_count": null, "id": "a1348798", "metadata": {}, "outputs": [], "source": [ "features" ] }, { "cell_type": "code", "execution_count": null, "id": "b939708d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "46889b82", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }