{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Chapter 6. The Haunted DAG & The Causal Terror" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -q numpyro arviz daft networkx" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import collections\n", "import itertools\n", "import os\n", "import warnings\n", "\n", "import arviz as az\n", "import daft\n", "import matplotlib.pyplot as plt\n", "import networkx as nx\n", "import pandas as pd\n", "\n", "import jax.numpy as jnp\n", "from jax import lax, random\n", "\n", "import numpyro\n", "import numpyro.distributions as dist\n", "import numpyro.optim as optim\n", "from numpyro.diagnostics import print_summary\n", "from numpyro.infer import SVI, Trace_ELBO\n", "from numpyro.infer.autoguide import AutoLaplaceApproximation\n", "\n", "if \"SVG\" in os.environ:\n", " %config InlineBackend.figure_formats = [\"svg\"]\n", "warnings.formatwarning = lambda message, category, *args, **kwargs: \"{}: {}\\n\".format(\n", " category.__name__, message\n", ")\n", "az.style.use(\"arviz-darkgrid\")\n", "numpyro.set_platform(\"cpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.1" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray(-0.6453402, dtype=float32)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with numpyro.handlers.seed(rng_seed=1914):\n", " N = 200 # num grant proposals\n", " p = 0.1 # proportion to select\n", " # uncorrelated newsworthiness and trustworthiness\n", " nw = numpyro.sample(\"nw\", dist.Normal().expand([N]))\n", " tw = numpyro.sample(\"tw\", dist.Normal().expand([N]))\n", " # select top 10% of combined scores\n", " s = nw + tw # total score\n", " q = jnp.quantile(s, 1 - p) # top 10% threshold\n", " selected = jnp.where(s >= q, True, False)\n", "jnp.corrcoef(jnp.stack([tw[selected], nw[selected]], 0))[0, 1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.2" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "N = 100 # number of individuals\n", "with numpyro.handlers.seed(rng_seed=909):\n", " # sim total height of each\n", " height = numpyro.sample(\"height\", dist.Normal(10, 2).expand([N]))\n", " # leg as proportion of height\n", " leg_prop = numpyro.sample(\"prop\", dist.Uniform(0.4, 0.5).expand([N]))\n", " # sim left leg as proportion + error\n", " leg_left = leg_prop * height + numpyro.sample(\n", " \"left_error\", dist.Normal(0, 0.02).expand([N])\n", " )\n", " # sim right leg as proportion + error\n", " leg_right = leg_prop * height + numpyro.sample(\n", " \"right_error\", dist.Normal(0, 0.02).expand([N])\n", " )\n", " # combine into data frame\n", " d = pd.DataFrame({\"height\": height, \"leg_left\": leg_left, \"leg_right\": leg_right})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.3" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 2000/2000 [00:01<00:00, 1291.68it/s, init loss: 62894.3672, avg. loss [1901-2000]: 112.9629]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " a 0.81 0.34 0.80 0.27 1.31 1049.96 1.00\n", " bl 2.61 2.28 2.63 -1.06 6.26 813.11 1.00\n", " br -0.59 2.28 -0.60 -4.41 2.96 805.68 1.00\n", " sigma 0.67 0.05 0.67 0.60 0.74 968.52 1.00\n", "\n" ] } ], "source": [ "def model(leg_left, leg_right, height):\n", " a = numpyro.sample(\"a\", dist.Normal(10, 100))\n", " bl = numpyro.sample(\"bl\", dist.Normal(2, 10))\n", " br = numpyro.sample(\"br\", dist.Normal(2, 10))\n", " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", " mu = a + bl * leg_left + br * leg_right\n", " numpyro.sample(\"height\", dist.Normal(mu, sigma), obs=height)\n", "\n", "\n", "m6_1 = AutoLaplaceApproximation(model)\n", "svi = SVI(\n", " model,\n", " m6_1,\n", " optim.Adam(0.1),\n", " Trace_ELBO(),\n", " leg_left=d.leg_left.values,\n", " leg_right=d.leg_right.values,\n", " height=d.height.values,\n", ")\n", "svi_result = svi.run(random.PRNGKey(0), 2000)\n", "p6_1 = svi_result.params\n", "post = m6_1.sample_posterior(random.PRNGKey(1), p6_1, sample_shape=(1000,))\n", "print_summary(post, 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.4" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "az.plot_forest(post, hdi_prob=0.89)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.5" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "post = m6_1.sample_posterior(random.PRNGKey(1), p6_1, sample_shape=(1000,))\n", "az.plot_pair(post, var_names=[\"br\", \"bl\"], scatter_kwargs={\"alpha\": 0.1})\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.6" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sum_blbr = post[\"bl\"] + post[\"br\"]\n", "az.plot_kde(sum_blbr, label=\"sum of bl and br\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.7" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [00:01<00:00, 936.38it/s, init loss: 2486.2402, avg. loss [951-1000]: 108.4584]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " a 0.83 0.35 0.84 0.25 1.35 931.50 1.00\n", " bl 2.02 0.08 2.02 1.91 2.15 940.42 1.00\n", " sigma 0.67 0.05 0.67 0.60 0.75 949.09 1.00\n", "\n" ] } ], "source": [ "def model(leg_left, height):\n", " a = numpyro.sample(\"a\", dist.Normal(10, 100))\n", " bl = numpyro.sample(\"bl\", dist.Normal(2, 10))\n", " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", " mu = a + bl * leg_left\n", " numpyro.sample(\"height\", dist.Normal(mu, sigma), obs=height)\n", "\n", "\n", "m6_2 = AutoLaplaceApproximation(model)\n", "svi = SVI(\n", " model,\n", " m6_2,\n", " optim.Adam(1),\n", " Trace_ELBO(),\n", " leg_left=d.leg_left.values,\n", " height=d.height.values,\n", ")\n", "svi_result = svi.run(random.PRNGKey(0), 1000)\n", "p6_2 = svi_result.params\n", "post = m6_2.sample_posterior(random.PRNGKey(1), p6_2, sample_shape=(1000,))\n", "print_summary(post, 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.8" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "milk = pd.read_csv(\"../data/milk.csv\", sep=\";\")\n", "d = milk\n", "d[\"K\"] = d[\"kcal.per.g\"].pipe(lambda x: (x - x.mean()) / x.std())\n", "d[\"F\"] = d[\"perc.fat\"].pipe(lambda x: (x - x.mean()) / x.std())\n", "d[\"L\"] = d[\"perc.lactose\"].pipe(lambda x: (x - x.mean()) / x.std())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.9" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [00:00<00:00, 1426.15it/s, init loss: 198.3563, avg. loss [951-1000]: 20.1324]\n", "100%|██████████| 1000/1000 [00:00<00:00, 1464.26it/s, init loss: 1449.6163, avg. loss [951-1000]: 15.5010]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " a 0.01 0.08 0.01 -0.13 0.12 931.50 1.00\n", " bF 0.86 0.09 0.86 0.73 1.01 1111.41 1.00\n", " sigma 0.46 0.06 0.46 0.37 0.57 940.36 1.00\n", "\n", "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " a 0.01 0.07 0.01 -0.10 0.11 931.50 1.00\n", " bL -0.90 0.07 -0.90 -1.01 -0.78 1111.89 1.00\n", " sigma 0.39 0.05 0.39 0.31 0.48 957.39 1.00\n", "\n" ] } ], "source": [ "# kcal.per.g regressed on perc.fat\n", "def model(F, K):\n", " a = numpyro.sample(\"a\", dist.Normal(0, 0.2))\n", " bF = numpyro.sample(\"bF\", dist.Normal(0, 0.5))\n", " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", " mu = a + bF * F\n", " numpyro.sample(\"K\", dist.Normal(mu, sigma), obs=K)\n", "\n", "\n", "m6_3 = AutoLaplaceApproximation(model)\n", "svi = SVI(model, m6_3, optim.Adam(1), Trace_ELBO(), F=d.F.values, K=d.K.values)\n", "svi_result = svi.run(random.PRNGKey(0), 1000)\n", "p6_3 = svi_result.params\n", "\n", "\n", "# kcal.per.g regressed on perc.lactose\n", "def model(L, K):\n", " a = numpyro.sample(\"a\", dist.Normal(0, 0.2))\n", " bL = numpyro.sample(\"bL\", dist.Normal(0, 0.5))\n", " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", " mu = a + bL * L\n", " numpyro.sample(\"K\", dist.Normal(mu, sigma), obs=K)\n", "\n", "\n", "m6_4 = AutoLaplaceApproximation(model)\n", "svi = SVI(model, m6_4, optim.Adam(1), Trace_ELBO(), L=d.L.values, K=d.K.values)\n", "svi_result = svi.run(random.PRNGKey(0), 1000)\n", "p6_4 = svi_result.params\n", "\n", "post = m6_3.sample_posterior(random.PRNGKey(1), p6_3, sample_shape=(1000,))\n", "print_summary(post, 0.89, False)\n", "post = m6_4.sample_posterior(random.PRNGKey(1), p6_4, sample_shape=(1000,))\n", "print_summary(post, 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.10" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [00:00<00:00, 1346.20it/s, init loss: 1360.7051, avg. loss [951-1000]: 15.2050]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " a -0.02 0.07 -0.03 -0.13 0.07 1049.96 1.00\n", " bF 0.25 0.19 0.25 -0.05 0.56 823.80 1.00\n", " bL -0.67 0.19 -0.67 -0.99 -0.37 875.48 1.00\n", " sigma 0.38 0.05 0.38 0.30 0.46 982.83 1.00\n", "\n" ] } ], "source": [ "def model(F, L, K):\n", " a = numpyro.sample(\"a\", dist.Normal(0, 0.2))\n", " bF = numpyro.sample(\"bF\", dist.Normal(0, 0.5))\n", " bL = numpyro.sample(\"bL\", dist.Normal(0, 0.5))\n", " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", " mu = a + bF * F + bL * L\n", " numpyro.sample(\"K\", dist.Normal(mu, sigma), obs=K)\n", "\n", "\n", "m6_5 = AutoLaplaceApproximation(model)\n", "svi = SVI(\n", " model, m6_5, optim.Adam(1), Trace_ELBO(), F=d.F.values, L=d.L.values, K=d.K.values\n", ")\n", "svi_result = svi.run(random.PRNGKey(0), 1000)\n", "p6_5 = svi_result.params\n", "post = m6_5.sample_posterior(random.PRNGKey(1), p6_5, sample_shape=(1000,))\n", "print_summary(post, 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.11" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "az.plot_pair(d[[\"kcal.per.g\", \"perc.fat\", \"perc.lactose\"]].to_dict(\"list\"))\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.12" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "milk = pd.read_csv(\"../data/milk.csv\", sep=\";\")\n", "d = milk\n", "\n", "\n", "def sim_coll(i, r=0.9):\n", " sd = jnp.sqrt((1 - r**2) * jnp.var(d[\"perc.fat\"].values))\n", " x = dist.Normal(r * d[\"perc.fat\"].values, sd).sample(random.PRNGKey(3 * i))\n", "\n", " def model(perc_fat, kcal_per_g):\n", " intercept = numpyro.sample(\"intercept\", dist.Normal(0, 10))\n", " b_perc_flat = numpyro.sample(\"b_perc.fat\", dist.Normal(0, 10))\n", " b_x = numpyro.sample(\"b_x\", dist.Normal(0, 10))\n", " sigma = numpyro.sample(\"sigma\", dist.HalfCauchy(2))\n", " mu = intercept + b_perc_flat * perc_fat + b_x * x\n", " numpyro.sample(\"kcal.per.g\", dist.Normal(mu, sigma), obs=kcal_per_g)\n", "\n", " m = AutoLaplaceApproximation(model)\n", " svi = SVI(\n", " model,\n", " m,\n", " optim.Adam(0.01),\n", " Trace_ELBO(),\n", " perc_fat=d[\"perc.fat\"].values,\n", " kcal_per_g=d[\"kcal.per.g\"].values,\n", " )\n", " svi_result = svi.run(random.PRNGKey(3 * i + 1), 20000, progress_bar=False)\n", " params = svi_result.params\n", " samples = m.sample_posterior(random.PRNGKey(3 * i + 2), params, sample_shape=(1000,))\n", " vcov = jnp.cov(jnp.stack(list(samples.values()), axis=0))\n", " stddev = jnp.sqrt(jnp.diag(vcov)) # stddev of parameter\n", " return dict(zip(samples.keys(), stddev))[\"b_perc.fat\"]\n", "\n", "\n", "def rep_sim_coll(r=0.9, n=100):\n", " stddev = lax.map(lambda i: sim_coll(i, r=r), jnp.arange(n))\n", " return jnp.nanmean(stddev)\n", "\n", "\n", "r_seq = jnp.arange(start=0, stop=1, step=0.01)\n", "stddev = lax.map(lambda z: rep_sim_coll(r=z, n=100), r_seq)\n", "plt.plot(r_seq, stddev)\n", "plt.xlabel(\"correlation\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.13" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " fungus 0.31 0.46 0.00 0.00 1.00 18.52 1.17\n", " h0 9.73 1.95 9.63 7.05 13.33 80.22 0.99\n", " h1 13.72 2.47 13.60 10.73 18.38 43.44 1.08\n", " treatment 0.50 0.50 0.50 0.00 1.00 2.64 inf\n", "\n" ] } ], "source": [ "with numpyro.handlers.seed(rng_seed=71):\n", " # number of plants\n", " N = 100\n", "\n", " # simulate initial heights\n", " h0 = numpyro.sample(\"h0\", dist.Normal(10, 2).expand([N]))\n", "\n", " # assign treatments and simulate fungus and growth\n", " treatment = jnp.repeat(jnp.arange(2), repeats=N // 2)\n", " fungus = numpyro.sample(\n", " \"fungus\", dist.Binomial(total_count=1, probs=(0.5 - treatment * 0.4))\n", " )\n", " h1 = h0 + numpyro.sample(\"diff\", dist.Normal(5 - 3 * fungus))\n", "\n", " # compose a clean data frame\n", " d = pd.DataFrame({\"h0\": h0, \"h1\": h1, \"treatment\": treatment, \"fungus\": fungus})\n", "print_summary(dict(zip(d.columns, d.T.values)), 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.14" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " sim_p 1.04 0.27 1.00 0.63 1.44 9936.32 1.00\n", "\n" ] } ], "source": [ "sim_p = dist.LogNormal(0, 0.25).sample(random.PRNGKey(0), (int(1e4),))\n", "print_summary({\"sim_p\": sim_p}, 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.15" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [00:00<00:00, 1057.45it/s, init loss: 279.8950, avg. loss [951-1000]: 200.2917]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " p 1.39 0.02 1.39 1.36 1.42 994.30 1.00\n", " sigma 1.84 0.13 1.84 1.65 2.06 1011.70 1.00\n", "\n" ] } ], "source": [ "def model(h0, h1):\n", " p = numpyro.sample(\"p\", dist.LogNormal(0, 0.25))\n", " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", " mu = h0 * p\n", " numpyro.sample(\"h1\", dist.Normal(mu, sigma), obs=h1)\n", "\n", "\n", "m6_6 = AutoLaplaceApproximation(model)\n", "svi = SVI(model, m6_6, optim.Adam(1), Trace_ELBO(), h0=d.h0.values, h1=d.h1.values)\n", "svi_result = svi.run(random.PRNGKey(0), 1000)\n", "p6_6 = svi_result.params\n", "post = m6_6.sample_posterior(random.PRNGKey(1), p6_6, sample_shape=(1000,))\n", "print_summary(post, 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.16" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [00:01<00:00, 763.78it/s, init loss: 151456.4062, avg. loss [951-1000]: 164.8744]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " a 1.47 0.03 1.47 1.43 1.51 1049.04 1.00\n", " bf -0.28 0.03 -0.28 -0.33 -0.23 910.93 1.00\n", " bt 0.01 0.03 0.01 -0.03 0.06 1123.06 1.00\n", " sigma 1.39 0.10 1.39 1.21 1.54 976.96 1.00\n", "\n" ] } ], "source": [ "def model(treatment, fungus, h0, h1):\n", " a = numpyro.sample(\"a\", dist.LogNormal(0, 0.2))\n", " bt = numpyro.sample(\"bt\", dist.Normal(0, 0.5))\n", " bf = numpyro.sample(\"bf\", dist.Normal(0, 0.5))\n", " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", " p = a + bt * treatment + bf * fungus\n", " mu = h0 * p\n", " numpyro.sample(\"h1\", dist.Normal(mu, sigma), obs=h1)\n", "\n", "\n", "m6_7 = AutoLaplaceApproximation(model)\n", "svi = SVI(\n", " model,\n", " m6_7,\n", " optim.Adam(0.3),\n", " Trace_ELBO(),\n", " treatment=d.treatment.values,\n", " fungus=d.fungus.values,\n", " h0=d.h0.values,\n", " h1=d.h1.values,\n", ")\n", "svi_result = svi.run(random.PRNGKey(0), 1000)\n", "p6_7 = svi_result.params\n", "post = m6_7.sample_posterior(random.PRNGKey(1), p6_7, sample_shape=(1000,))\n", "print_summary(post, 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.17" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [00:01<00:00, 855.77it/s, init loss: 87469.1172, avg. loss [951-1000]: 194.5041]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " a 1.33 0.02 1.33 1.29 1.37 930.82 1.00\n", " bt 0.13 0.04 0.12 0.08 0.19 880.02 1.00\n", " sigma 1.73 0.12 1.73 1.55 1.94 948.82 1.00\n", "\n" ] } ], "source": [ "def model(treatment, h0, h1):\n", " a = numpyro.sample(\"a\", dist.LogNormal(0, 0.2))\n", " bt = numpyro.sample(\"bt\", dist.Normal(0, 0.5))\n", " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", " p = a + bt * treatment\n", " mu = h0 * p\n", " numpyro.sample(\"h1\", dist.Normal(mu, sigma), obs=h1)\n", "\n", "\n", "m6_8 = AutoLaplaceApproximation(model)\n", "svi = SVI(\n", " model,\n", " m6_8,\n", " optim.Adam(1),\n", " Trace_ELBO(),\n", " treatment=d.treatment.values,\n", " h0=d.h0.values,\n", " h1=d.h1.values,\n", ")\n", "svi_result = svi.run(random.PRNGKey(0), 1000)\n", "p6_8 = svi_result.params\n", "post = m6_8.sample_posterior(random.PRNGKey(1), p6_8, sample_shape=(1000,))\n", "print_summary(post, 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.18" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plant_dag = nx.DiGraph()\n", "plant_dag.add_edges_from([(\"H0\", \"H1\"), (\"F\", \"H1\"), (\"T\", \"F\")])\n", "pgm = daft.PGM()\n", "coordinates = {\"H0\": (0, 0), \"T\": (4, 0), \"F\": (3, 0), \"H1\": (2, 0)}\n", "for node in plant_dag.nodes:\n", " pgm.add_node(node, node, *coordinates[node])\n", "for edge in plant_dag.edges:\n", " pgm.add_edge(*edge)\n", "with plt.rc_context({\"figure.constrained_layout.use\": False}):\n", " pgm.render()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.19" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "F _||_ H0\n", "H0 _||_ T\n", "H1 _||_ T | F\n" ] } ], "source": [ "conditional_independencies = collections.defaultdict(list)\n", "for edge in itertools.combinations(sorted(plant_dag.nodes), 2):\n", " remaining = sorted(set(plant_dag.nodes) - set(edge))\n", " for size in range(len(remaining) + 1):\n", " for subset in itertools.combinations(remaining, size):\n", " if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]):\n", " continue\n", " if nx.d_separated(plant_dag, {edge[0]}, {edge[1]}, set(subset)):\n", " conditional_independencies[edge].append(set(subset))\n", " print(f\"{edge[0]} _||_ {edge[1]}\" + (f\" | {' '.join(subset)}\" if subset else \"\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.20" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "with numpyro.handlers.seed(rng_seed=71):\n", " N = 1000\n", " h0 = numpyro.sample(\"h0\", dist.Normal(10, 2).expand([N]))\n", " treatment = jnp.repeat(jnp.arange(2), repeats=N // 2)\n", " M = numpyro.sample(\"M\", dist.Bernoulli(probs=0.5).expand([N]))\n", " fungus = numpyro.sample(\n", " \"fungus\", dist.Binomial(total_count=1, probs=(0.5 - treatment * 0.4))\n", " )\n", " h1 = h0 + numpyro.sample(\"diff\", dist.Normal(5 + 3 * M))\n", " d2 = pd.DataFrame({\"h0\": h0, \"h1\": h1, \"treatment\": treatment, \"fungus\": fungus})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.21" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " age 33.00 18.77 33.00 1.00 58.00 2.51 2.64\n", " happiness 0.00 1.21 0.00 -2.00 1.58 338.78 1.00\n", " married 0.28 0.45 0.00 0.00 1.00 48.04 1.18\n", "\n" ] } ], "source": [ "def sim_happiness(seed=1977, N_years=1000, max_age=65, N_births=20, aom=18):\n", " # age existing individuals & newborns\n", " A = jnp.repeat(jnp.arange(1, N_years + 1), N_births)\n", " # sim happiness trait - never changes\n", " H = jnp.repeat(jnp.linspace(-2, 2, N_births)[None, :], N_years, 0).reshape(-1)\n", " # not yet married\n", " M = jnp.zeros(N_years * N_births, dtype=jnp.int32)\n", "\n", " def update_M(i, M):\n", " # for each person over 17, chance get married\n", " married = dist.Bernoulli(logits=(H - 4)).sample(random.PRNGKey(seed + i))\n", " return jnp.where((A >= i) & (M == 0), married, M)\n", "\n", " M = lax.fori_loop(aom, max_age + 1, update_M, M)\n", " # mortality\n", " deaths = A > max_age\n", " A = A[~deaths]\n", " H = H[~deaths]\n", " M = M[~deaths]\n", "\n", " d = pd.DataFrame({\"age\": A, \"married\": M, \"happiness\": H})\n", " return d\n", "\n", "\n", "d = sim_happiness(seed=1977, N_years=1000)\n", "print_summary(dict(zip(d.columns, d.T.values)), 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.22" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "d2 = d[d.age > 17].copy() # only adults\n", "d2[\"A\"] = (d2.age - 18) / (65 - 18)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.23" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [00:00<00:00, 1245.44it/s, init loss: 12964.1768, avg. loss [951-1000]: 1355.7842]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " a[0] -0.20 0.06 -0.20 -0.30 -0.10 1049.96 1.00\n", " a[1] 1.23 0.09 1.23 1.09 1.37 898.97 1.00\n", " bA -0.69 0.11 -0.69 -0.88 -0.53 1126.51 1.00\n", " sigma 1.02 0.02 1.02 0.98 1.05 966.00 1.00\n", "\n" ] } ], "source": [ "d2[\"mid\"] = d2.married\n", "\n", "\n", "def model(mid, A, happiness):\n", " a = numpyro.sample(\"a\", dist.Normal(0, 1).expand([len(set(mid))]))\n", " bA = numpyro.sample(\"bA\", dist.Normal(0, 2))\n", " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", " mu = a[mid] + bA * A\n", " numpyro.sample(\"happiness\", dist.Normal(mu, sigma), obs=happiness)\n", "\n", "\n", "m6_9 = AutoLaplaceApproximation(model)\n", "svi = SVI(\n", " model,\n", " m6_9,\n", " optim.Adam(1),\n", " Trace_ELBO(),\n", " mid=d2.mid.values,\n", " A=d2.A.values,\n", " happiness=d2.happiness.values,\n", ")\n", "svi_result = svi.run(random.PRNGKey(0), 1000)\n", "p6_9 = svi_result.params\n", "post = m6_9.sample_posterior(random.PRNGKey(1), p6_9, sample_shape=(1000,))\n", "print_summary(post, 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.24" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [00:00<00:00, 1244.73it/s, init loss: 19561.3906, avg. loss [951-1000]: 1520.8224]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " a 0.01 0.08 0.01 -0.12 0.12 931.50 1.00\n", " bA -0.01 0.13 -0.01 -0.22 0.21 940.88 1.00\n", " sigma 1.21 0.03 1.21 1.17 1.26 949.78 1.00\n", "\n" ] } ], "source": [ "def model(A, happiness):\n", " a = numpyro.sample(\"a\", dist.Normal(0, 1))\n", " bA = numpyro.sample(\"bA\", dist.Normal(0, 2))\n", " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", " mu = a + bA * A\n", " numpyro.sample(\"happiness\", dist.Normal(mu, sigma), obs=happiness)\n", "\n", "\n", "m6_10 = AutoLaplaceApproximation(model)\n", "svi = SVI(\n", " model,\n", " m6_10,\n", " optim.Adam(1),\n", " Trace_ELBO(),\n", " A=d2.A.values,\n", " happiness=d2.happiness.values,\n", ")\n", "svi_result = svi.run(random.PRNGKey(0), 1000)\n", "p6_10 = svi_result.params\n", "post = m6_10.sample_posterior(random.PRNGKey(1), p6_10, sample_shape=(1000,))\n", "print_summary(post, 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.25" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "N = 200 # number of grandparent-parent-child triads\n", "b_GP = 1 # direct effect of G on P\n", "b_GC = 0 # direct effect of G on C\n", "b_PC = 1 # direct effect of P on C\n", "b_U = 2 # direct effect of U on P and C" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.26" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "with numpyro.handlers.seed(rng_seed=1):\n", " U = 2 * numpyro.sample(\"U\", dist.Bernoulli(0.5).expand([N])) - 1\n", " G = numpyro.sample(\"G\", dist.Normal().expand([N]))\n", " P = numpyro.sample(\"P\", dist.Normal(b_GP * G + b_U * U))\n", " C = numpyro.sample(\"C\", dist.Normal(b_PC * P + b_GC * G + b_U * U))\n", " d = pd.DataFrame({\"C\": C, \"P\": P, \"G\": G, \"U\": U})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.27" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [00:01<00:00, 733.78it/s, init loss: 4805.8169, avg. loss [951-1000]: 348.3594]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " a -0.08 0.10 -0.09 -0.24 0.06 1049.96 1.00\n", " b_GC -0.71 0.11 -0.71 -0.89 -0.55 813.76 1.00\n", " b_PC 1.72 0.04 1.72 1.65 1.79 982.64 1.00\n", " sigma 1.39 0.07 1.39 1.28 1.49 968.54 1.00\n", "\n" ] } ], "source": [ "def model(P, G, C):\n", " a = numpyro.sample(\"a\", dist.Normal(0, 1))\n", " b_PC = numpyro.sample(\"b_PC\", dist.Normal(0, 1))\n", " b_GC = numpyro.sample(\"b_GC\", dist.Normal(0, 1))\n", " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", " mu = a + b_PC * P + b_GC * G\n", " numpyro.sample(\"C\", dist.Normal(mu, sigma), obs=C)\n", "\n", "\n", "m6_11 = AutoLaplaceApproximation(model)\n", "svi = SVI(\n", " model,\n", " m6_11,\n", " optim.Adam(0.3),\n", " Trace_ELBO(),\n", " P=d.P.values,\n", " G=d.G.values,\n", " C=d.C.values,\n", ")\n", "svi_result = svi.run(random.PRNGKey(0), 1000)\n", "p6_11 = svi_result.params\n", "post = m6_11.sample_posterior(random.PRNGKey(1), p6_11, sample_shape=(1000,))\n", "print_summary(post, 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.28" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1000/1000 [00:01<00:00, 668.45it/s, init loss: 565.4859, avg. loss [951-1000]: 300.9767]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", " mean std median 5.5% 94.5% n_eff r_hat\n", " U 1.87 0.17 1.88 1.59 2.11 1009.20 1.00\n", " a -0.06 0.08 -0.05 -0.18 0.07 766.03 1.00\n", " b_GC 0.01 0.10 0.01 -0.15 0.17 1031.98 1.00\n", " b_PC 0.99 0.07 0.99 0.88 1.11 1106.62 1.00\n", " sigma 1.08 0.05 1.08 0.99 1.16 797.58 1.00\n", "\n" ] } ], "source": [ "def model(P, G, U, C):\n", " a = numpyro.sample(\"a\", dist.Normal(0, 1))\n", " b_PC = numpyro.sample(\"b_PC\", dist.Normal(0, 1))\n", " b_GC = numpyro.sample(\"b_GC\", dist.Normal(0, 1))\n", " b_U = numpyro.sample(\"U\", dist.Normal(0, 1))\n", " sigma = numpyro.sample(\"sigma\", dist.Exponential(1))\n", " mu = a + b_PC * P + b_GC * G + b_U * U\n", " numpyro.sample(\"C\", dist.Normal(mu, sigma), obs=C)\n", "\n", "\n", "m6_12 = AutoLaplaceApproximation(model)\n", "svi = SVI(\n", " model,\n", " m6_12,\n", " optim.Adam(1),\n", " Trace_ELBO(),\n", " P=d.P.values,\n", " G=d.G.values,\n", " U=d.U.values,\n", " C=d.C.values,\n", ")\n", "svi_result = svi.run(random.PRNGKey(0), 1000)\n", "p6_12 = svi_result.params\n", "post = m6_12.sample_posterior(random.PRNGKey(1), p6_12, sample_shape=(1000,))\n", "print_summary(post, 0.89, False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.29" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'A'}\n", "{'C'}\n" ] } ], "source": [ "dag_6_1 = nx.DiGraph()\n", "dag_6_1.add_edges_from(\n", " [(\"X\", \"Y\"), (\"U\", \"X\"), (\"A\", \"U\"), (\"A\", \"C\"), (\"C\", \"Y\"), (\"U\", \"B\"), (\"C\", \"B\")])\n", "backdoor_paths = [path for path in nx.all_simple_paths(dag_6_1.to_undirected(), \"X\", \"Y\")\n", " if dag_6_1.has_edge(path[1], \"X\")]\n", "remaining = sorted(set(dag_6_1.nodes) - {\"X\", \"Y\", \"U\"} - set(nx.descendants(dag_6_1, \"X\")))\n", "adjustment_sets = []\n", "for size in range(len(remaining) + 1):\n", " for subset in itertools.combinations(remaining, size):\n", " subset = set(subset)\n", " if any(s.issubset(subset) for s in adjustment_sets):\n", " continue\n", " need_adjust = True\n", " for path in backdoor_paths:\n", " d_separated = False\n", " for x, z, y in zip(path[:-2], path[1:-1], path[2:]):\n", " if dag_6_1.has_edge(x, z) and dag_6_1.has_edge(y, z):\n", " if set(nx.descendants(dag_6_1, z)) & subset:\n", " continue\n", " d_separated = z not in subset\n", " else:\n", " d_separated = z in subset\n", " if d_separated:\n", " break\n", " if not d_separated:\n", " need_adjust = False\n", " break\n", " if need_adjust:\n", " adjustment_sets.append(subset)\n", " print(subset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.30" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'S'}\n", "{'A', 'M'}\n" ] } ], "source": [ "dag_6_2 = nx.DiGraph()\n", "dag_6_2.add_edges_from(\n", " [(\"S\", \"A\"), (\"A\", \"D\"), (\"S\", \"M\"), (\"M\", \"D\"), (\"S\", \"W\"), (\"W\", \"D\"), (\"A\", \"M\")])\n", "backdoor_paths = [path for path in nx.all_simple_paths(dag_6_2.to_undirected(), \"W\", \"D\")\n", " if dag_6_2.has_edge(path[1], \"W\")]\n", "remaining = sorted(set(dag_6_2.nodes) - {\"W\", \"D\"} - set(nx.descendants(dag_6_2, \"W\")))\n", "adjustment_sets = []\n", "for size in range(len(remaining) + 1):\n", " for subset in itertools.combinations(remaining, size):\n", " subset = set(subset)\n", " if any(s.issubset(subset) for s in adjustment_sets):\n", " continue\n", " need_adjust = True\n", " for path in backdoor_paths:\n", " d_separated = False\n", " for x, z, y in zip(path[:-2], path[1:-1], path[2:]):\n", " if dag_6_2.has_edge(x, z) and dag_6_2.has_edge(y, z):\n", " if set(nx.descendants(dag_6_2, z)) & subset:\n", " continue\n", " d_separated = z not in subset\n", " else:\n", " d_separated = z in subset\n", " if d_separated:\n", " break\n", " if not d_separated:\n", " need_adjust = False\n", " break\n", " if need_adjust:\n", " adjustment_sets.append(subset)\n", " print(subset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code 6.31" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "A _||_ W | S\n", "D _||_ S | A M W\n", "M _||_ W | S\n" ] } ], "source": [ "conditional_independencies = collections.defaultdict(list)\n", "for edge in itertools.combinations(sorted(dag_6_2.nodes), 2):\n", " remaining = sorted(set(dag_6_2.nodes) - set(edge))\n", " for size in range(len(remaining) + 1):\n", " for subset in itertools.combinations(remaining, size):\n", " if any(cond.issubset(set(subset)) for cond in conditional_independencies[edge]):\n", " continue\n", " if nx.d_separated(dag_6_2, {edge[0]}, {edge[1]}, set(subset)):\n", " conditional_independencies[edge].append(set(subset))\n", " print(f\"{edge[0]} _||_ {edge[1]}\" + (f\" | {' '.join(subset)}\" if subset else \"\"))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 4 }