{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "import pystan\n", "import numpy as np\n", "import ujson as json" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "with open(\"radon.json\", \"rb\") as f:\n", " radon_data = json.load(f)\n", "\n", "key_renaming = {\"x\": \"floor_idx\", \"county\": \"county_idx\", \"u\": \"uranium\"}\n", "radon_data = {\n", " key_renaming.get(key, key): np.array(value) if isinstance(value, list) else value\n", " for key, value in radon_data.items()\n", "}\n", "radon_data[\"county_idx\"] = radon_data[\"county_idx\"] + 1" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "prior_code = \"\"\"\n", "data {\n", " int J;\n", " int N;\n", " int floor_idx[N];\n", " int county_idx[N];\n", " real uranium[J];\n", "}\n", "\n", "generated quantities {\n", " real g[2];\n", " real sigma_a = exponential_rng(1);\n", " real sigma = exponential_rng(1);\n", " real b = normal_rng(0, 1);\n", " real za_county[J]; \n", " real y_hat[N];\n", " real a[J];\n", " real a_county[J];\n", " \n", " g[1] = normal_rng(0, 10);\n", " g[2] = normal_rng(0, 10);\n", " \n", " for (i in 1:J) {\n", " za_county[i] = normal_rng(0, 1);\n", " a[i] = g[1] + g[2] * uranium[i];\n", " a_county[i] = a[i] + za_county[i] * sigma_a;\n", " }\n", " \n", " for (j in 1:N) {\n", " y_hat[j] = normal_rng(a_county[county_idx[j]] + b * floor_idx[j], sigma);\n", " }\n", "}\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_51a6e73bb4685d9d898431904d164252 NOW.\n" ] } ], "source": [ "prior_model = pystan.StanModel(model_code=prior_code, extra_compile_args=['-flto'])" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:pystan:`warmup=0` forced with `algorithm=\"Fixed_param\"`.\n" ] } ], "source": [ "prior_data = {key: value for key, value in radon_data.items() if key not in (\"county_name\", \"y\")}\n", "prior = prior_model.sampling(data=prior_data, iter=500, warmup=0, algorithm=\"Fixed_param\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "radon_code = \"\"\"\n", "data {\n", " int J;\n", " int N;\n", " int floor_idx[N];\n", " int county_idx[N];\n", " real uranium[J];\n", " real y[N];\n", "}\n", "\n", "parameters {\n", " real g[2];\n", " real sigma_a;\n", " real sigma;\n", " real za_county[J];\n", " real b;\n", "}\n", "\n", "transformed parameters {\n", " real theta[N];\n", " real a[J];\n", " real a_county[J];\n", " \n", " for (i in 1:J) {\n", " a[i] = g[1] + g[2] * uranium[i];\n", " a_county[i] = a[i] + za_county[i] * sigma_a;\n", " }\n", " for (j in 1:N)\n", " theta[j] = a_county[county_idx[j]] + b * floor_idx[j];\n", "}\n", "\n", "model {\n", " g ~ normal(0, 10);\n", " sigma_a ~ exponential(1);\n", " \n", " za_county ~ normal(0, 1);\n", " b ~ normal(0, 1);\n", " sigma ~ exponential(1);\n", " \n", " for (j in 1:N)\n", " y[j] ~ normal(theta[j], sigma);\n", "}\n", "\n", "generated quantities {\n", " real log_lik[N];\n", " real y_hat[N];\n", " for (j in 1:N) {\n", " log_lik[j] = normal_lpdf(y[j] | theta[j], sigma);\n", " y_hat[j] = normal_rng(theta[j], sigma);\n", " }\n", "}\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_3ab5b33b0aee1c122fca5450e04a6494 NOW.\n" ] } ], "source": [ "stan_model = pystan.StanModel(model_code=radon_code, extra_compile_args=['-flto'])" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:pystan:Maximum (flat) parameter count (1000) exceeded: skipping diagnostic tests for n_eff and Rhat.\n", "To run all diagnostics call pystan.check_hmc_diagnostics(fit)\n" ] } ], "source": [ "model_data = {key: value for key, value in radon_data.items() if key not in (\"county_name\",)}\n", "fit = stan_model.sampling(data=model_data, control={\"adapt_delta\": 0.99}, iter=1500, warmup=1000)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "coords = {\n", " \"level\": [\"basement\", \"floor\"],\n", " \"obs_id\": np.arange(radon_data[\"y\"].size),\n", " \"county\": radon_data[\"county_name\"],\n", " \"g_coef\": [\"intercept\", \"slope\"],\n", "}\n", "dims = {\n", " \"g\" : [\"g_coef\"],\n", " \"za_county\" : [\"county\"],\n", " \"y\" : [\"obs_id\"],\n", " \"y_hat\" : [\"obs_id\"],\n", " \"floor_idx\" : [\"obs_id\"],\n", " \"county_idx\" : [\"obs_id\"],\n", " \"theta\" : [\"obs_id\"],\n", " \"uranium\" : [\"county\"],\n", " \"a\" : [\"county\"],\n", " \"a_county\" : [\"county\"], \n", "}\n", "idata = az.from_pystan(\n", " posterior=fit,\n", " posterior_predictive=\"y_hat\",\n", " prior=prior,\n", " prior_predictive=\"y_hat\",\n", " observed_data=[\"y\"],\n", " constant_data=[\"floor_idx\", \"county_idx\", \"uranium\"],\n", " log_likelihood={\"y\": \"log_lik\"},\n", " coords=coords,\n", " dims=dims,\n", ").rename({\"y_hat\": \"y\"}) # renames both prior and posterior predictive" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "
\n", "
arviz.InferenceData
\n", "
\n", "
    \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (chain: 4, county: 85, draw: 500, g_coef: 2, obs_id: 919)\n",
             "Coordinates:\n",
             "  * chain      (chain) int64 0 1 2 3\n",
             "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
             "  * g_coef     (g_coef) <U9 'intercept' 'slope'\n",
             "  * county     (county) <U17 'AITKIN' 'ANOKA' ... 'WRIGHT' 'YELLOW MEDICINE'\n",
             "  * obs_id     (obs_id) int64 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918\n",
             "Data variables:\n",
             "    g          (chain, draw, g_coef) float64 1.422 0.8193 1.471 ... 1.468 0.553\n",
             "    sigma_a    (chain, draw) float64 0.07058 0.1412 0.1446 ... 0.1085 0.1642\n",
             "    sigma      (chain, draw) float64 0.7718 0.7581 0.7664 ... 0.753 0.7568\n",
             "    za_county  (chain, draw, county) float64 -0.5769 0.272 ... 1.121 2.431\n",
             "    b          (chain, draw) float64 -0.6334 -0.6666 -0.6876 ... -0.7367 -0.7349\n",
             "    theta      (chain, draw, obs_id) float64 0.1833 0.8167 ... 2.064 2.064\n",
             "    a          (chain, draw, county) float64 0.8575 0.7278 1.329 ... 1.419 1.665\n",
             "    a_county   (chain, draw, county) float64 0.8167 0.747 1.254 ... 1.603 2.064\n",
             "Attributes:\n",
             "    created_at:                 2020-10-14T17:54:38.368054\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          pystan\n",
             "    inference_library_version:  2.19.1.1\n",
             "    args:                       [{"random_seed":"55974540","chain_id":0,"init...\n",
             "    inits:                      [[-0.6890848703967867,1.6651652637888739,0.21...\n",
             "    step_size:                  [0.102456, 0.0748624, 0.0864012, 0.117333]\n",
             "    metric:                     ['diag_e', 'diag_e', 'diag_e', 'diag_e']\n",
             "    inv_metric:                 [[0.00129703,0.00845106,0.138269,0.000598391,...\n",
             "    adaptation_info:            ['# Adaptation terminated\\n# Step size = 0.10...\n",
             "    stan_code:                  \\ndata {\\n  int<lower=0> J;\\n  int<lower=0> N...

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 500, obs_id: 919)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
             "  * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918\n",
             "Data variables:\n",
             "    y        (chain, draw, obs_id) float64 0.3315 1.256 1.357 ... 2.44 0.9659\n",
             "Attributes:\n",
             "    created_at:                 2020-10-14T17:54:38.459055\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          pystan\n",
             "    inference_library_version:  2.19.1.1

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 500, obs_id: 919)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
             "  * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918\n",
             "Data variables:\n",
             "    y        (chain, draw, obs_id) float64 -0.9673 -0.6605 ... -1.139 -1.512\n",
             "Attributes:\n",
             "    created_at:                 2020-10-14T17:54:38.405520\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          pystan\n",
             "    inference_library_version:  2.19.1.1

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:      (chain: 4, draw: 500)\n",
             "Coordinates:\n",
             "  * chain        (chain) int64 0 1 2 3\n",
             "  * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "Data variables:\n",
             "    accept_stat  (chain, draw) float64 0.9732 0.9973 0.9986 ... 0.9892 0.991\n",
             "    stepsize     (chain, draw) float64 0.1025 0.1025 0.1025 ... 0.1173 0.1173\n",
             "    treedepth    (chain, draw) int64 5 5 5 5 5 5 5 5 5 5 ... 5 5 5 5 5 5 5 5 5 5\n",
             "    n_leapfrog   (chain, draw) int64 31 31 31 63 63 31 63 ... 31 31 31 31 31 31\n",
             "    diverging    (chain, draw) bool False False False ... False False False\n",
             "    energy       (chain, draw) float64 297.0 308.0 302.0 ... 303.6 313.2 296.2\n",
             "    lp           (chain, draw) float64 -259.5 -257.5 -251.7 ... -253.1 -242.8\n",
             "Attributes:\n",
             "    created_at:                 2020-10-14T17:54:38.383098\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          pystan\n",
             "    inference_library_version:  2.19.1.1\n",
             "    args:                       [{"random_seed":"55974540","chain_id":0,"init...\n",
             "    inits:                      [[-0.6890848703967867,1.6651652637888739,0.21...\n",
             "    step_size:                  [0.102456, 0.0748624, 0.0864012, 0.117333]\n",
             "    metric:                     ['diag_e', 'diag_e', 'diag_e', 'diag_e']\n",
             "    inv_metric:                 [[0.00129703,0.00845106,0.138269,0.000598391,...\n",
             "    adaptation_info:            ['# Adaptation terminated\\n# Step size = 0.10...\n",
             "    stan_code:                  \\ndata {\\n  int<lower=0> J;\\n  int<lower=0> N...

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:    (chain: 4, county: 85, draw: 500, g_coef: 2)\n",
             "Coordinates:\n",
             "  * chain      (chain) int64 0 1 2 3\n",
             "  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
             "  * g_coef     (g_coef) <U9 'intercept' 'slope'\n",
             "  * county     (county) <U17 'AITKIN' 'ANOKA' ... 'WRIGHT' 'YELLOW MEDICINE'\n",
             "Data variables:\n",
             "    g          (chain, draw, g_coef) float64 -18.13 12.45 ... 0.1568 2.502\n",
             "    sigma_a    (chain, draw) float64 0.3894 0.4864 1.839 ... 1.087 0.7622 1.325\n",
             "    sigma      (chain, draw) float64 0.674 0.2708 1.013 ... 0.5544 0.6577 0.177\n",
             "    b          (chain, draw) float64 -2.32 0.7922 -0.08347 ... 1.517 -0.3149\n",
             "    za_county  (chain, draw, county) float64 0.7894 -0.819 ... 0.8195 -2.563\n",
             "    a          (chain, draw, county) float64 -26.71 -28.68 ... -0.06844 1.046\n",
             "    a_county   (chain, draw, county) float64 -26.4 -29.0 -19.3 ... 1.017 -2.351\n",
             "Attributes:\n",
             "    created_at:                 2020-10-14T17:54:38.470706\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          pystan\n",
             "    inference_library_version:  2.19.1.1\n",
             "    args:                       [{"random_seed":"1111484586","chain_id":0,"in...\n",
             "    inits:                      [[-18.130655239984424,12.45383580853858,0.389...\n",
             "    step_size:                  [nan, nan, nan, nan]\n",
             "    metric:                     ['unit_e', 'unit_e', 'unit_e', 'unit_e']\n",
             "    inv_metric:                 [null,null,null,null]\n",
             "    adaptation_info:            ['', '', '', '']\n",
             "    stan_code:                  \\ndata {\\n  int<lower=0> J;\\n  int<lower=0> N...

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (chain: 4, draw: 500, obs_id: 919)\n",
             "Coordinates:\n",
             "  * chain    (chain) int64 0 1 2 3\n",
             "  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 ... 492 493 494 495 496 497 498 499\n",
             "  * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918\n",
             "Data variables:\n",
             "    y        (chain, draw, obs_id) float64 -28.69 -26.65 -25.7 ... -2.576 -2.268\n",
             "Attributes:\n",
             "    created_at:                 2020-10-14T17:54:38.533183\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          pystan\n",
             "    inference_library_version:  2.19.1.1

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:      (chain: 4, draw: 500)\n",
             "Coordinates:\n",
             "  * chain        (chain) int64 0 1 2 3\n",
             "  * draw         (draw) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499\n",
             "Data variables:\n",
             "    accept_stat  (chain, draw) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0\n",
             "    lp           (chain, draw) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0\n",
             "Attributes:\n",
             "    created_at:                 2020-10-14T17:54:38.479487\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          pystan\n",
             "    inference_library_version:  2.19.1.1\n",
             "    args:                       [{"random_seed":"1111484586","chain_id":0,"in...\n",
             "    inits:                      [[-18.130655239984424,12.45383580853858,0.389...\n",
             "    step_size:                  [nan, nan, nan, nan]\n",
             "    metric:                     ['unit_e', 'unit_e', 'unit_e', 'unit_e']\n",
             "    inv_metric:                 [null,null,null,null]\n",
             "    adaptation_info:            ['', '', '', '']\n",
             "    stan_code:                  \\ndata {\\n  int<lower=0> J;\\n  int<lower=0> N...

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:  (obs_id: 919)\n",
             "Coordinates:\n",
             "  * obs_id   (obs_id) int64 0 1 2 3 4 5 6 7 ... 911 912 913 914 915 916 917 918\n",
             "Data variables:\n",
             "    y        (obs_id) float64 0.7885 0.7885 1.065 0.0 ... 1.609 1.308 1.065\n",
             "Attributes:\n",
             "    created_at:                 2020-10-14T17:54:38.332973\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          pystan\n",
             "    inference_library_version:  2.19.1.1

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
  • \n", " \n", " \n", "
    \n", "
    \n", "
      \n", "
      \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
      <xarray.Dataset>\n",
             "Dimensions:     (county: 85, obs_id: 919)\n",
             "Coordinates:\n",
             "  * obs_id      (obs_id) int64 0 1 2 3 4 5 6 7 ... 912 913 914 915 916 917 918\n",
             "  * county      (county) <U17 'AITKIN' 'ANOKA' ... 'WRIGHT' 'YELLOW MEDICINE'\n",
             "Data variables:\n",
             "    floor_idx   (obs_id) int64 1 0 0 0 0 0 0 0 0 0 0 0 ... 0 0 0 1 0 0 0 0 0 0 0\n",
             "    county_idx  (obs_id) int64 1 1 1 1 2 2 2 2 2 ... 84 84 84 84 84 84 84 85 85\n",
             "    uranium     (county) float64 -0.689 -0.8473 -0.1135 ... -0.09002 0.3553\n",
             "Attributes:\n",
             "    created_at:                 2020-10-14T17:54:38.335649\n",
             "    arviz_version:              0.10.0\n",
             "    inference_library:          pystan\n",
             "    inference_library_version:  2.19.1.1

      \n", "
    \n", "
    \n", "
  • \n", " \n", "
\n", "
\n", " " ], "text/plain": [ "Inference data with groups:\n", "\t> posterior\n", "\t> posterior_predictive\n", "\t> log_likelihood\n", "\t> sample_stats\n", "\t> prior\n", "\t> prior_predictive\n", "\t> sample_stats_prior\n", "\t> observed_data\n", "\t> constant_data" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'pystan.nc'" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idata.to_netcdf(\"pystan.nc\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }