{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Heterogenous Effect Mixture Model (HEMM) Demo" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "# Contents\n", "\n", "### 1. [Introduction](#Introduction) \n", "####               1.1 [Subgroup Discovery and Heterogenous Treatment Effect Problem](#introsgdisc)\n", "####               1.2 [HEMM Description and Plate Notation](#introhemm)\n", "\n", " \n", "### 2. [Synthetic Data Example](#syndata)\n", "\n", "\n", "####               2.1 [Data Description and Generative Process](#syndatagen)\n", "\n", "####               2.2 [Estimation of Counterfactual Outcomes, PEHE Estimation](#syndatapehe)\n", "\n", "####               2.3 [Subgroup Discovery](#syndatasg)\n", "\n", "####               2.4 [Bootstrapping PEHE Estimates](#syndatabs)\n", "\n", "\n", "### 3. [IHDP Data Example](#IHDPdata)\n", "\n", "####               3.1 [Data Description](#IHDPdatadesc)\n", "\n", "####               3.1 Estimation of Counterfactual Outcomes, PEHE Estimation\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 1. Introduction\n", "\n", "In a large number of decision problems, estimating Heterogenous Outcomes to a treatment is not sufficient from a decision making perspective. A Neural Network might be able to approximate the Outcome and Corresponding Counterfactuals well, but deployment for real world decision making would be limited by lack of transparency and exaplinability of the Neural Model. \n", "The Heterogenous Effect Mixture Model approach was originally proposed in the paper\n", "[\"***Interpretable Subgroup Discovery in Treatment Effect Estimation with Application to Opioid Prescribing Guidelines***\"](https://arxiv.org/abs/1905.03297) in order to mitigate this challenge and allow decision makers more insight. \n", "\n", "\n", "### 1.1 Subgroup Discovery and Heterogenous Treatment Effect Problem\n", "\n", "\n", "The idea behind HEMM is involves assuming a low dimensional clustering or a latent $\\mathcal{Z}$ for each individual in the dataset. An intuitive visual example of the following phenomenon is below in **Figure A**. Notice that, almost all instances receiving treatment in $\\mathcal{Z}_1$ have a positive outcome, while very few in $\\mathcal{Z}_3$ do. We are interested in recovering such latent subgroups. \n", " \n", "\n", "\n", "\n", "\n", "\n", "

Fig A. Example of the Heterogenous Effect Subgroup Problem

Fig B. The HEMM Model in Plate Notation

\n", "\n", "\n", "### 1.2 HEMM Description and Plate Notation\n", " \n", "In the HEMM model, this is carried out using a finite (${K}$) mixture of Normals and Bernoullis. The contribution of performing the Treatment is then mediated by belonging to one of the $K$ subgroups. \n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 2. Synthetic Data Example\n", "This section will demonstrate **HEMM**'s ability on a synthetic dataset." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "### 2.1 Data Description and Generative Process\n", "\n", "We will use a Synthetic Dataset included with the **`HEMM`** module in **`causallib`** to demonstrate the ability of **HEMM** to discover subgroups, as well as hetereogenous effects. The synthetic dataset is generated according to the following scheme\n", "\n", "\n", "\n", "\n", "\n", "\n", "

Fig C. The Generative Process for Synthetic Data

\n", "\n", "\n", "1. We take $\n", "\\mathbf{X} = (X_0, X_1) \\in \\mathbb{R}^2$ and sample it from a uniform distribution over $\\mathcal{X} = [0,1]^2$. \n", "\n", "2. In order to simulate the selection bias inherent in observational studies, the treatment variable depends on $\\mathbf{X}$ as $T \\sim \\text{Bernoulli}(0.4)$ for $x_0 < 0.5$ and $\\text{Bernoulli}(0.6)$ for $x_0 > 0.5$. \n", "\n", "3. The potential outcomes $Y(0)$ and $Y(1)$ are also Bernoulli with means given by the functions of $\\mathbf{X}$, $T$ shown in the figure. The figure shows that $p(Y(1)=1 | \\mathbf{X}) > p(Y(0)=1 | \\mathbf{X})$, i.e. treatment increases the the treated population has a higher probability of positive outcome. Note that under the conditional exchangeability assumption we have $p(Y(t)=1 | \\mathbf{X}) = p(Y=1 | T=t, \\mathbf{X})$. \n", "\n", "4. We model the effect of the confounders $\\mathbf{X}$ by assigning higher probability to the upper triangular region of $\\mathcal{X}$. This together with the distribution of $T$ imply that individuals who are more likely to have positive outcome regardless of treatment (upper triangle) are also more likely to receive treatment (right half-square). \n", "\n", "5. Lastly, we model the enhanced treatment effect group as a circular region $\\mathcal{S} = \\{x: \\lVert x-c \\rVert_{2} < r \\}$, where $p(Y(1)=1 | \\mathcal{S}) > p(Y(1)=1 | \\mathcal{X}\\backslash\\mathcal{S})$. We set $c = ( \\frac{1}{2}, \\frac{1}{2})$ and $r = \\frac{1}{4}$. A total of $2,500$ samples $(\\mathbf{x}_i, t_i, y_i)$ are generated as described above.\n", "\n", "\n", "\n", "The `gen_montecarlo` from `causallib.contrib.hemm.gen_synthetic_data` allows us to randomly sample multiple instantiations of the synthetic dataset" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from causallib.contrib.hemm.gen_synthetic_data import gen_montecarlo\n", "\n", "syn_data = gen_montecarlo(5000, 2, 100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "### 2.2 Estimation of Counterfactual Outcomes and PEHE Estimation\n", "\n", "We will now fit an **HEMM** model to each realization of the Synthetic Dataset and evaluate the capability of **HEMM** to estimate the **ITE** ie. **Individual Treatment Effect**, also known as the **Conditional Average Treatment Effect** (**CATE**) using the **Precision of Estimating Heterogenous Effects Metric (PEHE)**. The **PEHE** is the Mean Squared Error around the ***ITE*** more formally defined as:\n", "\n", "$$ \\text{PEHE} = \\frac{1}{n} \\sum^{n}_{i=1} \\left(f_1(\\mathbf{x_i}) - f_0(\\mathbf{x}_i) - \\mathbb{E}[Y(1) - Y(0)] | \\mathbf{X} = \\mathbf{x}_i] \\right)^2.\n", "$$\n", "Here $f_0$ and $f_1$ represent the functions used to estimate the counterfactual outcome for an observation with features $\\mathbf{x_i}$. The Second term inside the parenthesis is the **True CATE**. Note that the **True CATE** can only be observed in synthetic settings and not real world observational data and hence using **PEHE** as a metric ofestimating goodness of fit for **CATE** is limited to synthetic datasets." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from causallib.contrib.hemm import HEMM" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#Leave 30% of the Training Data as Validation\n", "vsize = int(0.3*syn_data['TRAIN']['x'].shape[0])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "#setting the Training Dataset\n", "Xtr = syn_data['TRAIN']['x'] [:-vsize,:, 0]\n", "Ytr = syn_data['TRAIN']['yf'][:-vsize , 0]\n", "Ttr = syn_data['TRAIN']['t'] [:-vsize , 0]\n", "\n", "#setting the Dev/Val Dataset\n", "Xdev = syn_data['TRAIN']['x'] [-vsize:,:, 0]\n", "Ydev = syn_data['TRAIN']['yf'][-vsize: , 0]\n", "Tdev = syn_data['TRAIN']['t'] [-vsize: , 0]\n", "\n", "#setting the Test Dataset\n", "Xte = syn_data['TEST']['x'] [:,:, 0]\n", "Yte = syn_data['TEST']['yf'][: , 0]\n", "Tte = syn_data['TEST']['t'] [: , 0]\n", "\n", "#Feature size\n", "Xdim = Xtr.shape[1]\n", "\n", "#Number of Components to Discover.\n", "K = 3\n", "\n", "#Initialize the model with Population means and Std. Devs.\n", "#Empirically results in faster, and better convergence.\n", "mu = Xtr.mean(axis=0).reshape(1,-1)\n", "std = Xtr.std(axis=0).reshape(1,-1)\n", "\n", "#Set the Learning Rate for Adam Optimizer and Batch Size\n", "learning_rate = 1e-4\n", "batch_size = 100\n", "\n", "#Indicate the Outcome Distribution (Y), could be 'bin' for \n", "#Binary Outcomes or 'cont' for Continuous Outcomes.\n", "response = 'bin'\n", "\n", "#Indicate what kind of a model to be used to adjust for \n", "#confounding. You can also pass your own PyTorch Model!!!!\n", "outcome_model='linear'\n", "\n", "#Instantiate an HEMM model\n", "model = HEMM(Xdim, K, homo=True, mu=mu, std=std, bc=2, lamb=0.0000,\\\n", " spread=.1,outcome_model=outcome_model,sep_heads=True,epochs=10,\\\n", " learning_rate=learning_rate,weight_decay=0.0001,metric='LL', use_p_correction=False,\\\n", " response=response,imb_fun=None,batch_size=batch_size )\n", "\n", "\n", "#Fit the HEMM model on the Training Data, with Early Stopping\n", "#on the Dev Set.\n", "cd = model.fit(Xtr, Ttr,Ytr, validation_data=(Xdev, Tdev, Ydev))\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#Extract the True Potential outcomes for the Training \n", "#and the Test Dataset.\n", "mu1tr = syn_data['TRAIN']['mu1'][:-vsize, 0] \n", "mu0tr = syn_data['TRAIN']['mu0'][:-vsize, 0]\n", "\n", "mu1te = syn_data['TEST']['mu1'][:, 0] \n", "mu0te = syn_data['TEST']['mu0'][:, 0]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "#Estimate the Individual Potential Outcomes using HEMM\n", "inSampleCFoutcomes = model.estimate_individual_outcome(Xtr, Ttr)\n", "outSampleCFoutcomes = model.estimate_individual_outcome(Xte, Tte)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "In Sample PEHE: 0.011010778786300154\n" ] } ], "source": [ "#Compute the True In-sample Individual Treatment Effect (ITE)\n", "trueCATE = mu1tr - mu0tr\n", "\n", "#Compute the In-Sample ITEs\n", "inSampleITE = inSampleCFoutcomes[1]-inSampleCFoutcomes[0]\n", "\n", "#Compute the In-Sample PEHE\n", "inSamplePEHE = ((inSampleITE - trueCATE)**2).mean()\n", "print(\"In Sample PEHE:\", inSamplePEHE)\n", "\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "In Sample PEHE: 0.010915930183215048\n" ] } ], "source": [ "#Compute the True Out of Sample Individual Treatment Effect (ITE)\n", "trueCATE = mu1te - mu0te\n", "\n", "#Compute the Out of Sample ITEs\n", "outSampleITE = outSampleCFoutcomes[1]-outSampleCFoutcomes[0]\n", "\n", "#Compute the Out of Sample PEHE\n", "outSamplePEHE = ((outSampleITE - trueCATE)**2).mean()\n", "\n", "print(\"In Sample PEHE:\", outSamplePEHE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, **HEMM** returns a **low PEHE** score, justifying its ability to estimate Heterogenous Effects to Treatment." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 2.3 Subgroup Discovery" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will now demonstrate **HEMM**'s ability to extract the Subgroup with the **Enhanced Treatment Effect**. We will first use the `get_groups_effect` member function of the **`HEMM`** class to extract the absolute values of the Treatment Effect Terms ( **$\\gamma_k$** in Figure B )" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.69842033 0.71704322 0.50003658]\n", "Subgroup 2 demonstrates a Higher Effect to Treatment\n" ] } ], "source": [ "import numpy as np\n", "\n", "gamma = model.get_groups_effect()\n", "print(gamma)\n", "\n", "K = np.argmax(gamma) \n", "\n", "print (\"Subgroup\", K+1, \"demonstrates a Higher Effect to Treatment\")\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have identified the Subgroup with Higher Effect to Treatment, we will compute the ability of **HEMM** to extract the subgroup by comparing with the True Subgroup Assignment" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "ZinSample = model.get_groups_proba(Xtr)\n", "ZoutSample = model.get_groups_proba(Xte)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have the group assignment scores, we will compute the ROC Curve and Precision-Recall Statistic in Recovering the Subgroups" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "In Sample ROC AUC: 0.9738558368695354\n", "Out Sample ROC AUC: 0.9772955843564394\n" ] } ], "source": [ "from sklearn.metrics import roc_auc_score, roc_curve\n", "\n", "trueInSampleZ = syn_data['TRAIN']['z'][:-vsize, 0] \n", "trueOutSampleZ = syn_data['TEST']['z'][:,0] \n", "\n", "inSample_roc_auc = roc_auc_score( trueInSampleZ , ZinSample[K] )\n", "outSample_roc_auc = roc_auc_score( trueOutSampleZ, ZoutSample[K])\n", "\n", "print (\"In Sample ROC AUC:\", inSample_roc_auc)\n", "print (\"Out Sample ROC AUC:\", outSample_roc_auc)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us now Plot the ROC Curves" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'roc_curve' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0minfpr\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mintpr\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mroc_curve\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0mtrueInSampleZ\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mZinSample\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mK\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0moutfpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouttpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mroc_curve\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0mtrueOutSampleZ\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mZoutSample\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mK\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtitle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"ROC Curve\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mNameError\u001b[0m: name 'roc_curve' is not defined" ] } ], "source": [ "infpr , intpr , _ = roc_curve( trueInSampleZ , ZinSample[K] )\n", "outfpr, outtpr, _ = roc_curve( trueOutSampleZ , ZoutSample[K] )\n", "\n", "from matplotlib import pyplot as plt\n", "plt.title(\"ROC Curve\")\n", "plt.plot(infpr , intpr, label = \"In-Sample\")\n", "plt.plot(outfpr, outtpr, label= \"Out-Sample\")\n", "plt.plot(np.linspace(0,1,100),np.linspace(0,1,100),c='k', label='Random', ls=':')\n", "plt.xscale('linear')\n", "plt.legend()\n", "plt.xlabel('FPR')\n", "plt.ylabel('TPR')\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "In Sample PR AUC: 0.9486824676876224\n", "Out Sample PR AUC: 0.9545712084631546\n" ] } ], "source": [ "from sklearn.metrics import precision_recall_curve, average_precision_score\n", "\n", "\n", "inSample_pr_auc = average_precision_score( trueInSampleZ , ZinSample[K] )\n", "outSample_pr_auc = average_precision_score( trueOutSampleZ, ZoutSample[K])\n", "\n", "print (\"In Sample PR AUC:\", inSample_pr_auc)\n", "print (\"Out Sample PR AUC:\", outSample_pr_auc)\n" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "inpr , inre , _ = precision_recall_curve( trueInSampleZ , ZinSample[K] )\n", "outpr, outre, _ = precision_recall_curve( trueOutSampleZ , ZoutSample[K] )\n", "\n", "from matplotlib import pyplot as plt\n", "plt.title(\"Precision-Recall Curve\")\n", "plt.plot(inre , inpr, label = \"In-Sample\")\n", "plt.plot(outre, outpr, label= \"Out-Sample\")\n", "plt.legend()\n", "plt.xlabel('Recall')\n", "plt.ylabel('Precision')\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2.4 Qualitative Assesment of the Discovered Subgroups\n", "In the previous section, we attempted to perform some Quantitative assesment of the capability of **HEMM** to discover the subgroups with **Enhanced Treatment Effects**. Since we had access to the **True Subgroup Assignment Label**, we proceeded to compute the **ROC** and **Precision-Recall** characteristic. In this section, we will attempt to perform a more qualitative assessement by plotting the Subgroups with a heatmap." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "#Let's first create a uniform grid over the feature space. \n", "\n", "d = 100\n", "\n", "x1 = np.linspace(0, 1, d)\n", "x2 = np.linspace(0, 1, d)\n", "\n", "Xgrid = np.meshgrid(x1, x2)\n", "\n", "Xgrid[0] = Xgrid[0].reshape(-1,1)\n", "Xgrid[1] = Xgrid[1].reshape(-1,1)\n", "\n", "Xgrid = np.hstack([Xgrid[0], Xgrid[1], np.ones_like(Xgrid[1])])\n", "\n", "Zgrid = model.get_groups_proba(Xgrid)[K].values\n", "\n", "Zgrid = Zgrid.reshape(d, d)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Notice how HEMM comes close to recover the True Subgroup\n" ] } ], "source": [ "from matplotlib import pyplot as plt\n", "\n", "plt.figure(figsize=(10,8))\n", "\n", "axis = plt.gca()\n", "\n", "plt.imshow(Zgrid**5, cmap='rainbow', interpolation='nearest', origin='lower', extent=[0, 1, 0, 1])\n", "\n", "truesg = plt.Circle((0.5, 0.5), 0.2, color='k', lw='3', ls='--', fc='none', label='True Subgroup')\n", "plt.gca().add_artist(truesg)\n", "plt.legend(handles=[truesg], fontsize=20)\n", "plt.xlabel(r'$\\mathcal{X}_1$', size=24)\n", "plt.ylabel(r'$\\mathcal{X}_2$', size=24)\n", "plt.xticks(size=18)\n", "plt.yticks(size=18)\n", "plt.title(\"The Discovered Subgroup vs. True Subgroup\", size=20)\n", "plt.colorbar(orientation='vertical')\n", "plt.show()\n", "\n", "print (\"Notice how HEMM comes close to recover the True Subgroup\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 2.5 Bootstrapping PEHE Estimates" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Till now we have only looked at Point Estimates for the Counterfactual Outcomes and the Associated **PEHE** score. For fair comparison we would also like to estimate the uncertainity around our estimates of **PEHE** by computing a **Bootstrapped Confidence Interval**. We will compute this by fitting multiple models on each bootstrapped realization of the Synthetic Dataset. \n", "\n", "We will wrap the previous experiment into a single function, and use `joblib` to parallelly estimate the **PEHE** in each bootstrapped realization." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "###################################################################################\n", "# NOTE: For more complicated decision problems, linear adjustment for confounding #\n", "# may not be sufficient. HEMM allows passing any PyTorch model for determining #\n", "# the outcome as a function of the features independent of the Treatment # \n", "# assignment. The parameters of this 'Outcome Model' and the underlying, #\n", "# 'Heterogenous Effect Mixture Model' are optimized jointly. HEMM implementation #\n", "# in Causllib, has helper functions to instantiate Neural Network outcome models. # \n", "###################################################################################\n", "\n", "from causallib.contrib.hemm.outcome_models import genMLPModule, genLinearModule, BalancedNet\n", "\n", "\n", "def experiment(data, i, K=2, response='bin', outcomeModel='linear', lr=1e-3, batch_size=100, vsize=0.3, bc=2, epochs=20):\n", " \n", " vsize = int(vsize*data['TRAIN']['x'].shape[0])\n", " \n", " Xtr = data['TRAIN']['x'] [:-vsize,:, i]\n", " Ytr = data['TRAIN']['yf'][:-vsize , i]\n", " Ttr = data['TRAIN']['t'] [:-vsize , i]\n", "\n", " Xdev = data['TRAIN']['x'] [-vsize:,:, i]\n", " Ydev = data['TRAIN']['yf'][-vsize: , i]\n", " Tdev = data['TRAIN']['t'] [-vsize: , i]\n", "\n", " Xte = data['TEST']['x'] [:,:, i]\n", " Yte = data['TEST']['yf'][: , i]\n", " Tte = data['TEST']['t'] [: , i]\n", " \n", " mu1tr = data['TRAIN']['mu1'][:-vsize, i] \n", " mu0tr = data['TRAIN']['mu0'][:-vsize, i]\n", "\n", " mu1te = data['TEST']['mu1'][:, i] \n", " mu0te = data['TEST']['mu0'][:, i]\n", "\n", " mu = Xtr.mean(axis=0).reshape(1,-1)\n", " std = Xtr.std(axis=0).reshape(1,-1)\n", " \n", "\n", " #Set the preferred Outcome Adjustment\n", " if outcomeModel == 'MLP':\n", " outcomeModel = genMLPModule(Xte.shape[1], Xte.shape[1], 2 )\n", " elif outcomeModel == 'linear':\n", " outcomeModel = genLinearModule(Xte.shape[1], 2 )\n", " elif outcomeModel == 'CF':\n", " outcomeModel = BalancedNet(Xte.shape[1], Xte.shape[1], 1 )\n", " \n", " Xdim = Xte.shape[1]\n", " \n", " model = HEMM(Xdim, K, homo=True, mu=mu, std=std, bc=bc, lamb=0.0000,\\\n", " spread=.01,outcome_model=outcome_model,sep_heads=True,epochs=epochs,\\\n", " learning_rate=learning_rate,weight_decay=0.0001,metric='LL', use_p_correction=False,\\\n", " response=response,imb_fun=None,batch_size=batch_size )\n", "\n", " \n", "\n", " cd = model.fit(Xtr, Ttr,Ytr, validation_data=(Xdev, Tdev, Ydev))\n", " \n", " inSampleCFoutcomes = model.estimate_individual_outcome(Xtr, Ttr)\n", " outSampleCFoutcomes = model.estimate_individual_outcome(Xte, Tte)\n", " \n", " inSampleITE = inSampleCFoutcomes[1]-inSampleCFoutcomes[0]\n", " outSampleITE = outSampleCFoutcomes[1]-outSampleCFoutcomes[0]\n", "\n", " #Compute the In-Sample PEHE\n", " trueCATE = mu1tr - mu0tr\n", " inSamplePEHE = np.sqrt(((inSampleITE - trueCATE)**2).mean())\n", "\n", " #Compute the Out of Sample PEHE\n", " trueCATE = mu1te - mu0te\n", " outSamplePEHE = np.sqrt(((outSampleITE - trueCATE)**2).mean())\n", "\n", "\n", " return [inSamplePEHE, outSamplePEHE]" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "from joblib import Parallel, delayed\n", "\n", "PEHEs = Parallel(n_jobs=10)(delayed(experiment)(data=syn_data, i=i, K=3,lr=1e-4) for i in range(100))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "In Sample Root PEHE 0.10267147617234845 , Std. Error 0.0001256753369609759\n", "Out Sample Root PEHE 0.1026503798822282 , Std. Error 0.0001542102418122442\n" ] } ], "source": [ "PEHEmean = np.mean(PEHEs,axis=0)\n", "\n", "PEHEstd = np.std(PEHEs,axis=0)\n", "\n", "print (\"In Sample Root PEHE\", PEHEmean[0], ',', \"Std. Error\", PEHEstd[0]/np.sqrt(100) )\n", "print (\"Out Sample Root PEHE\", PEHEmean[1],',', \"Std. Error\", PEHEstd[1]/np.sqrt(100) )" ] }, { "cell_type": "code", "execution_count": 708, "metadata": {}, "outputs": [], "source": [ "#syn_data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## 3. IHDP Example\n", "In this section, we will experiment with the IHDP dataset. As opposed to the previous dataset the IHDP dataset comes from a real world study, with simulated outcomes and hence is considered 'Semi-Synthetc'.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### 3.1 Data Description" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The **UN Infant Health Development Program** (**IHDP**) dataset has gained popularity in the causal inference literature dealing with heterogenous treatment effects. The original data includes 25 real covariates and comes from a randomized experiment to evaluate the benefit of **IHDP** on IQ scores of three-year-old children. A selection bias was introduced by removing some of the treated population, thus resulting in **608 control** patients and **139 treated** (**747 total**). The outcomes were simulated using the standard non-linear **‘Response Surface B’** as described in \n", "***Hill, J. L. (2011). Bayesian nonparametric modeling for causal inference. Journal of Computational and Graphical Statistics, 20(1), 217-240.***\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**`causallib`**'s **`HEMM`** module contains helper functions to download the **IHDP** dataset courtesy of _Frederik Johansson's_ [Personal Website](https://www.fredjo.com) " ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "IHDP Data exists\n" ] } ], "source": [ "from causallib.contrib.hemm.load_ihdp_data import loadIHDPData\n", "\n", "ihdp_data = loadIHDPData()\n" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "from joblib import Parallel, delayed\n", "\n", "PEHEs = Parallel(n_jobs=10)(delayed(experiment)(data=ihdp_data, i=i,outcomeModel='CF', K=3,lr=1e-3,vsize=0.25, batch_size=10, bc=6, response='cont', epochs=500) for i in range(10))" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "In Sample Root PEHE 4.779168419249578 , Std. Error 0.8265074053983431\n", "Out Sample Root PEHE 4.300757764814082 , Std. Error 0.7104830767106435\n" ] } ], "source": [ "PEHEmean = np.mean(PEHEs,axis=0)\n", "\n", "PEHEstd = np.std(PEHEs,axis=0)\n", "\n", "print (\"In Sample Root PEHE\", PEHEmean[0], ',', \"Std. Error\", PEHEstd[0]/np.sqrt(100) )\n", "print (\"Out Sample Root PEHE\", PEHEmean[1],',', \"Std. Error\", PEHEstd[1]/np.sqrt(100) )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.7.5 64-bit", "language": "python", "name": "python37564bitd905ce786ede48728f7c06fb20d97845" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }