{ "cells": [ { "cell_type": "markdown", "id": "0c22256a", "metadata": {}, "source": [ "# TMLE - Targeted Maximum Likelihood Estimation\n", "\n", "Targeted learning is a method developed by Mark van der Laan [[1](https://www.degruyter.com/document/doi/10.2202/1557-4679.1043/html)] establishing a theoretically-guaranteed way of applying complex machine learning estimators to causal inference estimation. \n", "\n", "### Motivation\n", "Let's denote $X$ as our confounding features, $Y$ our target variable and $A$ a binary intervention variable whose effect we seek. \n", "Recall that a regular outcome model (i.e., a Standrdization model like S-Learner), simply works by estimating $E[Y|X,A]$. \n", "However, the way it adjusts for $X$ and $A$ dependends on the core-statitsical estimator used. \n", "For example, a strictly linear (i.e. no interactions or polynomials) regression model will adjust for $X$ linearly, but that might not describe the response suface properly, especially in high-dimensional complex data.\n", "To account for more complex data, we can use a more complex core estimator. \n", "However, applying expressive estimators might lead to some bias in estimating the causal effect of the treatment.\n", "In most real-world scenarios, the treatment effect is usually pretty small, and not unrelatedly, it might have very little predictive power over the outcome. \n", "For example, imagine feeding a tree-based estimator a matrix with features stacked with a treatment assignment column (like an S-learner). \n", "It is not unreasonable that the tree might simply ignore the treatment variable altogether. \n", "And so, it will conclude a strictly zero causal effect estimation, since there's no difference in outcome when we plug in $A=1$ and $A=0$. \n", "This happens because we optimize the prediction $E[Y|X,A]$, which is not the same as optimizing for the causal parameter of interest $E[Y|X,A=1]-E[Y|X,A=0]$. \n", "\n", "### Intutition\n", "A naive statistical estimator will maximize the global likelihood - it will try to estimate correctly _all_ the coefficients of all covariates, only one of which is the treatment assignment.\n", "However, we care for one single parameter - the treatment effect - more than we care for other parameters.\n", "Therefore, we would like to focus our estimator's attention on that parameter of interest, even at the price of neglecting some other parameters. \n", "Using Dr. Susan Gruber's example, think of it like a picture - where the person smiling is of greater importance, so in order to focus on them we allow the background to become a bit more blurry in exchange. \n", "\n", "While the math behind it is complex, the basic principle is pretty simple: \n", "In order to focus our estimator on the treatment effect, we will use information from the treatment mechanism to update and re-target the initial outcome model prediction.\n", "This will allow us to use highly data-adaptive preidction models, but still estimate the treatment effect properly.\n", "\n", "\n", "### Steps\n", "Fitting a TMLE can be summarized into four simple steps:\n", "1. Fit an outcome model $Q_0(A,X)$, estimating $E[Y|X,A]$ by predicting the outcome $Y$ using the covariates $X$ and the treatment $A$. \n", " $Q_0$ can be a highly expressive method, and a common use is a \"Super Learner\", which is basically a stacking meta-learner using a broad library (pool) of base-estimators. \n", " In causallib, this will be done by specifying a `Standardization` model with any kind of core estimator.\n", "1. Fit a propensity model $g(A,X)$, estimating $\\Pr[A|X]$ by predicting the treatment assignment $A$ using the covariate $X$. \n", " In `causallib`, this will be done by specifying an `IPW` model with any kind of core estimator. \n", " Note that `causallib` also allows this set of `X` to be different than the `X` used for the outcome model in step 1.\n", "1. Generate a \"clever covariate\"* $H(A,X)$ using the propensity scores: $\\frac{2A-1}{g(A,X)}$.\n", " Namely: take the inverse propensity scores of the treated units, and the _minus_ inverse propensities for the controls.\n", "1. Update the initial outcome prediction using treatment information from the \"clever covariate\": estimate an $\\epsilon$ parameter such that: \n", " $ Q_*(A,X) = expit(logit(Q_0(A,X))+ \\epsilon H(A,X))$ \n", " Namely, we update the initial $Q_0$ with some contribution of $H(A,X)$ estimated in logit space. \n", " in causallib, this is estimated by applying a uni-variable logistic regression, regressing $Y$ on $H(A,X)$ with $Q_0(A,X)$ as offset (i.e. forcing its coefficient to 1). \n", " \n", "The intuition behind step (4) is that we basically regressing the \"clever covariate\" (with its treatment mechanism information) on the residuals of the outcome prediction. \n", "If the initial prediction is perfect - then $H(A,X)$ is regressed on random noise and $\\epsilon$ is therefore $\\approx 0$, contributing nothing to the update step.\n", "However, in case there _is_ residual bias in the initial estimator, $\\epsilon$ will control the magnitude of correction needed - small residual bias will lead to small update and vice versa. \n", "This is because \"surprising\" units, those with small $g(A,X)$, have large $\\frac{1}{g(A,X)}$, so small changes in $\\epsilon$ will lead to bigger impact on the fitting. \n", "This is also why we need to be extra careful avoiding overfitting the initial estimator $Q_0$, because it will falsely minimize the signal in the residuals required for the updating step. \n", "\n", "Note that there are several flavors of the \"clever covariates\", which causallib implements 4 of, and they will be described further down. \n", "\n", "#### Counterfactual prediction\n", "For counterfactual prediction we assign a specific treatment value $a\\in A$ and propagate it through the model's component: \n", "$ Q_*(a,X) = expit(logit(Q_0(a,X))+ \\epsilon H(a,X))$ \n", "And then we can calculate any contrast of two intevention values to obtain an effect, like risk difference ($Q_*(1,X)-Q_*(0,X)$) or risk ratio ($\\frac{Q_*(1,X)}{Q_*(0,X)}$).\n", "\n", "\n", "### Doubly robust\n", "TMLE combines an outcome model with a treatment model in a way that makes it doubly robust: we get two chances to get things right. \n", "As we seen above, We either correctly specify the outcome model and then there's no signal left for correction by the treatment model.\n", "And, conversly, in cases were the initial model is strongly misspecified (think a simple `Y~A` regression), then $\\epsilon$ will be large and cover up for it like an IPW model.\n", "Therefore, the targeting step is a second chance to get things right.\n", "\n", "### Conclusion\n", "TMLE for causal framework allow us to apply flexible machine learning estimators on high-dimensional complex data and still obtain valid causal effect estimations.\n" ] }, { "cell_type": "markdown", "id": "27fa6e93-0df3-4a30-a457-641d8c4007bb", "metadata": {}, "source": [ "## TMLE in `causallib`\n", "\n", "Let's see an example of how TMLE works using `causallib`" ] }, { "cell_type": "code", "execution_count": 1, "id": "a11a9eed", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "\n", "from causallib.estimation import TMLE\n", "from causallib.estimation import Standardization, IPW\n", "from mlxtend.classifier import StackingCVClassifier\n", "from mlxtend.regressor import StackingCVRegressor\n", "\n", "from sklearn.pipeline import make_pipeline\n", "from sklearn.preprocessing import PolynomialFeatures\n", "from sklearn.model_selection import RandomizedSearchCV, GridSearchCV\n", "from sklearn.ensemble import GradientBoostingClassifier, HistGradientBoostingClassifier, RandomForestClassifier\n", "# from sklearn.svm import SVC\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.ensemble import GradientBoostingRegressor, HistGradientBoostingRegressor, RandomForestRegressor\n", "from sklearn.svm import SVR\n", "from sklearn.linear_model import Lasso, Ridge, LinearRegression, LassoCV, RidgeCV\n", "\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns" ] }, { "cell_type": "markdown", "id": "13cf5805-0488-4ecd-b89f-be99373d66ce", "metadata": {}, "source": [ "### Data\n", "We synthesize data, so we know the true causal effect, based on [_Schuler and Rose 2017_](https://academic.oup.com/aje/article/185/1/65/2662306). \n", "This is a relatively simple mechanism where exposure is dependent on a linear combination of 3 Bernoulli variables and the outcome also depends on an effect modification of the third variable." ] }, { "cell_type": "code", "execution_count": 2, "id": "4ff56b4f-7d32-4a0b-b3c5-0dc9eb9bf019", "metadata": {}, "outputs": [], "source": [ "def generate_single_dataset(n, seed=None):\n", " if seed is not None:\n", " np.random.seed(seed)\n", " \n", " X = np.random.binomial(n=1, p=(0.55, 0.30, 0.25), size=(n, 3))\n", " \n", " a_logit = -0.5 + X @ np.array([0.75, 1, 1.5])\n", " propensity = 1 / (1 + np.exp(-a_logit))\n", " a = np.random.binomial(1, propensity)\n", " \n", " y_func = lambda z: 24 - 3*z + X@np.array([3, -4, -6]) - 1.5*z*X[:, -1]\n", " noise = np.random.normal(0, 4.5, size=n)\n", " y = y_func(a) + noise\n", " y0 = y_func(np.zeros_like(a)) + noise\n", " y1 = y_func(np.ones_like(a)) + noise\n", " sate = np.mean(y1 - y0)\n", " # sate = -3.38\n", " # cate = y1 - y0\n", " \n", " X = pd.DataFrame(X)\n", " a = pd.Series(a)\n", " y = pd.Series(y)\n", " return X, a, y, sate" ] }, { "cell_type": "code", "execution_count": 3, "id": "55b06cc7-0eee-4f22-86ed-2ccd268e0142", "metadata": {}, "outputs": [], "source": [ "def run_single_simulation(tmle, n_samples, seed=None):\n", " X, a, y, sate = generate_single_dataset(n_samples, seed=None)\n", " tmle.fit(X, a, y)\n", " \n", " tmle_po = tmle.estimate_population_outcome(X, a)\n", " tmle_ate = tmle.estimate_effect(tmle_po[1], tmle_po[0])['diff']\n", " \n", " outcome_model_po = tmle.outcome_model.estimate_population_outcome(X, a)\n", " outcome_model_ate = tmle.outcome_model.estimate_effect(outcome_model_po[1], outcome_model_po[0])['diff']\n", " \n", " return tmle_ate, outcome_model_ate, sate\n", "\n", "\n", "def run_multiple_simulations(tmle, n_simulations, n_samples):\n", " true_ates = []\n", " tmle_ates = []\n", " om_ates = []\n", " for i in range(n_simulations):\n", " tmle_ate, outcome_model_ate, sate = run_single_simulation(tmle, n_samples, i)\n", "\n", " tmle_ates.append(tmle_ate)\n", " om_ates.append(outcome_model_ate)\n", " true_ates.append(sate)\n", " \n", " predictions = pd.DataFrame(\n", " {\"tmle\": tmle_ates, \"initial_model\": om_ates, \"true\": true_ates}\n", " ).rename_axis(\"simulation\").reset_index()\n", " # true_ate = np.mean(true_ates)\n", " true_ate = -3.38\n", " return predictions, true_ate\n", "\n", "def plot_multiple_simulations(results, true_ate):\n", " results = results.drop(columns=[\"true\"])\n", " results = results.melt(id_vars=\"simulation\", var_name=\"method\", value_name=\"ate\")\n", " \n", " # Plot inidividual experiments:\n", " fig, axes = plt.subplots(1, 2, sharey=True, \n", " gridspec_kw={'width_ratios': [3, 1]},\n", " figsize=(8, 3))\n", " sns.scatterplot(\n", " x=\"simulation\", y=\"ate\", hue=\"method\", style=\"method\", \n", " data=results, \n", " ax=axes[0]\n", " )\n", " axes[0].axhline(y=true_ate, linestyle='--', color=\"grey\")\n", " # axes[0].text(results[\"simulation\"].max(), true_ate - 0.1, \"True ATE\", color=\"slategrey\")\n", " # (results.set_index([\"simulation\", \"method\"]) - true_ate).groupby(\"method\").mean()\n", " \n", " # Plot distribution summary:\n", " axes[1].axhline(y=true_ate, linestyle='--', color=\"grey\")\n", " sns.kdeplot(\n", " y=\"ate\", hue=\"method\", \n", " data=results,\n", " legend=False,\n", " ax=axes[1]\n", " )\n", " axes[1].text(axes[1].get_xlim()[1], true_ate - 0.2, \"True ATE\", color=\"slategrey\")\n", " \n", " # mean_bias = (results.set_index([\"simulation\", \"method\"]) - true_ate).groupby(\"method\").mean()\n", " # pd.plotting.table(axes[2], mean_bias)\n", " \n", " axes[0].legend(loc='center left', bbox_to_anchor=(1.35, 0.1))\n", "\n", " plt.subplots_adjust(wspace=0.01)\n", " return axes\n", " " ] }, { "cell_type": "markdown", "id": "549d339e-ae60-4e62-b679-f363c08bb9dc", "metadata": {}, "source": [ "### Doubly Robustness\n", "\n", "We will see how misspecifying either outcome model is contradicted by TMLE and results in an overall good estimation" ] }, { "cell_type": "markdown", "id": "356ad954-3806-4cac-b26e-277a5eeb2669", "metadata": {}, "source": [ "We start by specifying the outcome model to be an L1-regularized logistic regression (LASSO). The model is also misspecified as it doesn't take into account the interactions between the treatment and the third indicator." ] }, { "cell_type": "code", "execution_count": 4, "id": "c7127543-8d48-4a48-961a-99842e866f56", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "outcome_model = Standardization(Lasso(random_state=0))\n", "weight_model = IPW(LogisticRegression(penalty=\"none\", random_state=0))\n", "tmle = TMLE(\n", " outcome_model=outcome_model,\n", " weight_model=weight_model,\n", " reduced=True\n", ")\n", "\n", "n_samples = 500\n", "n_simulations = 20\n", "\n", "results, true_ate = run_multiple_simulations(tmle, n_simulations, n_samples)\n", "plot_multiple_simulations(results, true_ate);" ] }, { "cell_type": "markdown", "id": "b8aff33c-ce14-4893-94bb-f3ad74ed3c4f", "metadata": {}, "source": [ "We see that the outcome model predictions are way off toards the null, probably even ignoring the treatment column occasionally (where effect is near 0).\n", "\n", "We can improve that by first allowing a more flexible model (by using interactions), and also using cross-validation to find the right amount of regularization over all the new polynomial parameters." ] }, { "cell_type": "code", "execution_count": 5, "id": "5886ab69-320e-4c2c-b47a-684a2982a031", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "outcome_model = Standardization(make_pipeline(PolynomialFeatures(2), LassoCV(random_state=0)))\n", "tmle = TMLE(\n", " outcome_model=outcome_model,\n", " weight_model=weight_model,\n", " reduced=True\n", ")\n", "\n", "n_samples = 500\n", "n_simulations = 20\n", "\n", "results, true_ate = run_multiple_simulations(tmle, n_simulations, n_samples)\n", "plot_multiple_simulations(results, true_ate);" ] }, { "cell_type": "markdown", "id": "234fb34c-6685-4879-b53f-e8bdfa746bd7", "metadata": {}, "source": [ "This is much better performance by the initial outcome model. Though we can see TMLE is able to slightly correct it.\n", "\n", "Finally, that we have a well-specified outcome model, we can extermly limit our weight model so it is misspecfied and see what happens:" ] }, { "cell_type": "code", "execution_count": 6, "id": "676295c6-9526-4cd4-b95f-0cac4f529788", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "weight_model = IPW(make_pipeline(PolynomialFeatures(2), LogisticRegression(penalty=\"l1\", C=0.00001, solver=\"saga\")))\n", "tmle = TMLE(\n", " outcome_model=outcome_model,\n", " weight_model=weight_model,\n", " reduced=True\n", ")\n", "\n", "n_samples = 500\n", "n_simulations = 20\n", "\n", "results, true_ate = run_multiple_simulations(tmle, n_simulations, n_samples)\n", "plot_multiple_simulations(results, true_ate);" ] }, { "cell_type": "markdown", "id": "73f79aea-015b-4cd2-b146-5a441b0110e6", "metadata": {}, "source": [ "We see that the TMLE knows to ignore the misspecified propensity model. It does little to no correction, essentially resulting the effect of the outcome model" ] }, { "cell_type": "markdown", "id": "cf02470a-0885-4661-9fed-d32e5d581c35", "metadata": {}, "source": [ "Finally, TMLE is usually coupled with \"super learning\" (as both came from Mark van der Laan and his students), which basically a stacking meta-learner of a rich library of base estimators. \n", "We can see how TMLE behaves when both outcome model and treatment model are highly data-adaptive. \n", "Note that we use stacking _CV_ estimators, so we ensure that both the base estimators and the meta estimator use cross validation. \n", "This nested cross validation reudces the risk of information leakage between the base and meta estimators, and so it further ensures no overfitting. \n", "If you recall, TMLE regresses the residuals of the initial estimator, and so it is very important the initial outcome model does not overfit and falsely hide residual bias information." ] }, { "cell_type": "code", "execution_count": 7, "id": "7fb6fca9-6747-4fb1-a435-f857ef1e5b44", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "outcome_model = StackingCVRegressor(\n", " [\n", " GradientBoostingRegressor(n_estimators=10),\n", " GradientBoostingRegressor(n_estimators=30),\n", " GradientBoostingRegressor(n_estimators=50),\n", " HistGradientBoostingRegressor(max_iter=10),\n", " HistGradientBoostingRegressor(max_iter=30),\n", " HistGradientBoostingRegressor(max_iter=50),\n", " RandomForestRegressor(n_estimators=25),\n", " RandomForestRegressor(n_estimators=50),\n", " RandomForestRegressor(n_estimators=100),\n", " SVR(kernel=\"rbf\", C=0.01),\n", " SVR(kernel=\"rbf\", C=0.1),\n", " SVR(kernel=\"rbf\", C=1),\n", " SVR(kernel=\"rbf\", C=10),\n", " SVR(kernel=\"poly\", C=0.01, degree=2),\n", " SVR(kernel=\"poly\", C=0.1, degree=2),\n", " SVR(kernel=\"poly\", C=1, degree=2),\n", " SVR(kernel=\"poly\", C=10, degree=2),\n", " SVR(kernel=\"poly\", C=0.01, degree=3),\n", " SVR(kernel=\"poly\", C=0.1, degree=3),\n", " SVR(kernel=\"poly\", C=1, degree=3),\n", " SVR(kernel=\"poly\", C=10, degree=3),\n", " make_pipeline(PolynomialFeatures(degree=2), Lasso(alpha=0.01)),\n", " make_pipeline(PolynomialFeatures(degree=2), Lasso(alpha=0.1)),\n", " make_pipeline(PolynomialFeatures(degree=2), Lasso(alpha=1)),\n", " make_pipeline(PolynomialFeatures(degree=2), Lasso(alpha=10)),\n", " make_pipeline(PolynomialFeatures(degree=2), Ridge(alpha=0.01)),\n", " make_pipeline(PolynomialFeatures(degree=2), Ridge(alpha=0.1)),\n", " make_pipeline(PolynomialFeatures(degree=2), Ridge(alpha=1)),\n", " make_pipeline(PolynomialFeatures(degree=2), Ridge(alpha=10)),\n", " ],\n", " LinearRegression(),\n", " cv=10,\n", " random_state=0,\n", " # verbose=1,\n", ")\n", "\n", "weight_model = StackingCVClassifier(\n", " [\n", " GradientBoostingClassifier(n_estimators=10),\n", " GradientBoostingClassifier(n_estimators=30),\n", " GradientBoostingClassifier(n_estimators=50),\n", " make_pipeline(PolynomialFeatures(degree=2), LogisticRegression(penalty=\"l1\", C=0.01, solver=\"saga\", max_iter=5000)),\n", " make_pipeline(PolynomialFeatures(degree=2), LogisticRegression(penalty=\"l1\", C=0.1, solver=\"saga\", max_iter=5000)),\n", " make_pipeline(PolynomialFeatures(degree=2), LogisticRegression(penalty=\"l1\", C=1, solver=\"saga\", max_iter=5000)),\n", " make_pipeline(PolynomialFeatures(degree=2), LogisticRegression(penalty=\"l1\", C=10, solver=\"saga\", max_iter=5000)),\n", " make_pipeline(PolynomialFeatures(degree=2), LogisticRegression(penalty=\"l2\", C=0.01)),\n", " make_pipeline(PolynomialFeatures(degree=2), LogisticRegression(penalty=\"l2\", C=0.1)),\n", " make_pipeline(PolynomialFeatures(degree=2), LogisticRegression(penalty=\"l2\", C=1)),\n", " make_pipeline(PolynomialFeatures(degree=2), LogisticRegression(penalty=\"l2\", C=10)),\n", " ],\n", " LogisticRegression(penalty=\"none\"),\n", " cv=10,\n", " random_state=0,\n", " # verbose=1,\n", ")\n", "\n", "outcome_model = Standardization(outcome_model)\n", "weight_model = IPW(weight_model)\n", "\n", "tmle = TMLE(\n", " outcome_model=outcome_model,\n", " weight_model=weight_model,\n", " reduced=True\n", ")\n", "\n", "n_samples = 500\n", "n_simulations = 20\n", "\n", "results, true_ate = run_multiple_simulations(tmle, n_simulations, n_samples)\n", "plot_multiple_simulations(results, true_ate);" ] }, { "cell_type": "markdown", "id": "84dbc187-c613-4a77-8bb1-5578cdcc0b94", "metadata": {}, "source": [ " " ] }, { "cell_type": "markdown", "id": "a5e011e9-9278-455d-a4e4-d501985db715", "metadata": {}, "source": [ "### TMLE Flavors\n", "If you go all the way up to step number 3 where we construct the \"clever covariate\", you'll see an asterisk.\n", "This is because there are multiple flavors for that clever covariate. \n", "All flavors use information from the treatment mechanism as inverse propensity weights. \n", "However, exactly _how_ they use it can slightly differ.\n", "\n", "#### Reduced vs. full form\n", "First, the \"clever covariate\" can either be a single covariate or a matrix. \n", "TMLE requires $\\Pr[A=a|X=x_i] \\forall a \\in A$ - we need to estimate the probabiilty for every individual and for every possible treatment level.\n", "For binary treatment, it is more straight forward, since we have one value (and we can easily calculate its complement). \n", "However, for 3 or more treatment levels, we need to retain more information.\n", "\n", "Therefore, for binary cases, we can use a `reduced` form, which is a single-vector form.\n", "Since we only have two groups, we can code the treated units as $\\frac{1}{\\Pr[A=1|X]}$ and the controls as $-\\frac{1}{\\Pr[A=0|X]}$. That is, simply negating their inverse propensity scores. \n", "The full form, that can captures both binary and poly treatment, codes the clever covariate as a matrix. \n", "For each unit we will have a vector of $\\Pr[A=a_i|X]$ for and 0 for for all $a \\neq a_i$. \n", "We then regress the residual outcomes using multiple regression on a matrix, and estimate multiple $\\epsilon_a$ values rather than a single $\\epsilon$ value.\n", "\n", "\n", "#### Feature vs. weighting\n", "The second alternative is where we use the inverse propensity information.\n", "Up to this point, we described it as a feature (predictor / covariate) we regress on.\n", "However, we can also use the inverse propensity information as _weights_ to the targeting-step logistic model. \n", "\n", "This means we can break down the \"clever covariate\" to two components: \n", "1. The actual inverse-propensity information from the treatment mechanism, and\n", "1. The coding of treatment groups\n", "\n", "Returning to the reduces/full form - using a vector or a matrix is essentially an indication coding:\n", " * In the vector (reduced) version, the coding is 1 for treated and -1 for controls\n", " * In the matrix (full) version, the coding is one-hot (full dummy) encoding.\n", "\n", "`causallib` allow users to control for this 4 flavors using two binary parameters:\n", " * `reduced=True` will create a single vector covariate, while `reduced=False` will create a matrix, and\n", " * `ImportanceSampling=True` will use the inverse propensity as regression weights, while `ImportanceSampling=False` will use it as a covariate." ] }, { "cell_type": "code", "execution_count": 8, "id": "a933cc1a-2f95-41a5-9399-b87ba3252779", "metadata": {}, "outputs": [], "source": [ "from itertools import product\n", "from collections import defaultdict\n", "\n", "def compare_TMLE_flavors(outcome_model, weight_model, n_simulations, n_samples):\n", " true_ates = []\n", " pred_ates = defaultdict(list)\n", " for i in range(n_simulations):\n", " X, a, y, sate = generate_single_dataset(n_samples, seed=i)\n", " \n", " for reduced, imp_samp in product([True, False], [True, False]):\n", " tmle = TMLE(\n", " outcome_model=outcome_model,\n", " weight_model=weight_model,\n", " reduced=reduced,\n", " importance_sampling=imp_samp,\n", " )\n", " tmle.fit(X, a, y)\n", " potential_outcomes = tmle.estimate_population_outcome(X, a)\n", " effect = potential_outcomes[1] - potential_outcomes[0]\n", " \n", " name = str(type(tmle.clever_covariate_)).split(\".\")[-1][:-2]\n", " name = name.lstrip(\"CleverCovariate\") # shorten name\n", " pred_ates[name].append(effect)\n", "\n", " true_ates.append(sate)\n", " \n", " # # true_ate = np.mean(true_ates)\n", " # true_ate = -3.38\n", " pred_ates = pd.DataFrame(pred_ates).rename_axis(\"simulation\").reset_index()\n", " return pred_ates, true_ates\n", "\n", "def plot_tmle_flavor_comparison(pred_ates, true_ates):\n", " pred_ates = pred_ates.melt(id_vars=\"simulation\", value_name=\"ate\", var_name=\"method\")\n", " fig, ax = plt.subplots(figsize=(8,5))\n", " sns.violinplot(y=\"method\", x=\"ate\", hue=\"method\", \n", " data=pred_ates, orient=\"h\", dodge=False, ax=ax)\n", " sns.stripplot(y=\"method\", x=\"ate\", \n", " color=\"lightgrey\", alpha=0.5, \n", " data=pred_ates, orient=\"h\", ax=ax)\n", " ax.legend([]) # ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) \n", " ax.axvline(x=np.mean(true_ates), linestyle=\"--\", color=\"grey\")\n", " ax.text(np.mean(true_ates), ax.get_ylim()[0] + 0.2, \"True ATE\", color=\"grey\");\n", " # ax.text(true_ate + 0.02, (ax.get_ylim()[1] + ax.get_ylim()[0]) / 2 + 0.1, \"True ATE\", color=\"grey\");" ] }, { "cell_type": "code", "execution_count": 9, "id": "cf56b02b-7b29-44a3-a936-a27831074390", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "outcome_model = Standardization(make_pipeline(PolynomialFeatures(2), LassoCV(random_state=0)))\n", "weight_model = IPW(LogisticRegression(penalty=\"none\", random_state=0))\n", "\n", "pred_ates, true_ates = compare_TMLE_flavors(outcome_model, weight_model, 50, 5000)\n", "plot_tmle_flavor_comparison(pred_ates, true_ates)" ] }, { "cell_type": "markdown", "id": "4949a3b8-222a-4781-8ea4-1bdb07b9aab4", "metadata": {}, "source": [ "We see that all methods are comparable for the relatively simple case at hand." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" } }, "nbformat": 4, "nbformat_minor": 5 }