{ "cells": [ { "cell_type": "markdown", "id": "f1d51e31", "metadata": { "papermill": { "duration": 0.052042, "end_time": "2021-09-18T10:51:56.464794", "exception": false, "start_time": "2021-09-18T10:51:56.412752", "status": "completed" }, "tags": [] }, "source": [ "# Bayesian Hierarchical Stacking: Well Switching Case Study\n", "\n", "
\n", " \n", "
Photo by Belinda Fewings, https://unsplash.com/photos/6p-KtXCBGNw.
\n", "
\n", "\n", "## Table of Contents\n", "\n", "* [Intro](#intro)\n", "* [1. Exploratory Data Analysis](#1)\n", "* [2. Prepare 6 Different Models](#2)\n", " * [2.1 Feature Engineering](#2.1)\n", " * [2.2 Training](#2.2)\n", "* [3. Bayesian Hierarchical Stacking](#3)\n", " * [3.1 Prepare stacking datasets](#3.1)\n", " * [3.2 Define stacking model](#3.2)\n", "* [4. Evaluate on test set](#4)\n", " * [4.1 Stack predictions](#4.1)\n", " * [4.2 Compare methods](#4.2)\n", "* [Conclusion](#conclusion)\n", "* [References](#references)\n", "\n", "## Intro \n", "\n", "Suppose you have just fit 6 models to a dataset, and need to choose which one to use to make predictions on your test set. How do you choose which one to use? A couple of common tactics are:\n", "- choose the best model based on cross-validation;\n", "- average the models, using weights based on cross-validation scores.\n", "\n", "In the paper [Bayesian hierarchical stacking: Some models are (somewhere) useful](https://arxiv.org/abs/2101.08954), a new technique is introduced: average models based on weights which are allowed to vary across according to the input data, based on a hierarchical structure.\n", "\n", "\n", "Here, we'll implement the first case study from that paper - readers are nonetheless encouraged to look at the original paper to find other cases studies, as well as theoretical results. Code from the article (in R / Stan) can be found [here](https://github.com/yao-yl/hierarchical-stacking-code)." ] }, { "cell_type": "code", "execution_count": 1, "id": "2d43427d-0ac3-4383-8441-375164cbecb0", "metadata": {}, "outputs": [], "source": [ "!pip install -q numpyro@git+https://github.com/pyro-ppl/numpyro" ] }, { "cell_type": "code", "execution_count": 2, "id": "7a71e927", "metadata": { "papermill": { "duration": 4.069199, "end_time": "2021-09-18T10:52:00.594720", "exception": false, "start_time": "2021-09-18T10:51:56.525521", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import os\n", "\n", "from IPython.display import set_matplotlib_formats\n", "import arviz as az\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from scipy.interpolate import BSpline\n", "import seaborn as sns\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "import numpyro\n", "import numpyro.distributions as dist\n", "\n", "plt.style.use(\"seaborn\")\n", "if \"NUMPYRO_SPHINXBUILD\" in os.environ:\n", " set_matplotlib_formats(\"svg\")\n", "\n", "numpyro.set_host_device_count(4)\n", "assert numpyro.__version__.startswith(\"0.10.1\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "227f2ff1-63f3-4529-89ba-4c92fc7bb518", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "id": "255e8d79", "metadata": { "papermill": { "duration": 0.043256, "end_time": "2021-09-18T10:52:00.780796", "exception": false, "start_time": "2021-09-18T10:52:00.737540", "status": "completed" }, "tags": [] }, "source": [ "## 1. Exploratory Data Analysis \n", "\n", "The data we have to work with looks at households in Bangladesh, some of which were affected by high levels of arsenic in their water. Would affected households want to switch to a neighbour's well?\n", "\n", "We'll split the data into a train and test set, and then we'll train six different models to try to predict whether households would switch wells. Then, we'll see how we can stack them when predicting on the test set!\n", "\n", "But first, let's load it in and visualise it! Each row represents a household, and the features we have available to us are:\n", "\n", "- switch: whether a household switched to another well;\n", "- arsenic: level of arsenic in drinking water;\n", "- educ: level of education of \"head of household\";\n", "- dist100: distance to nearest safe-drinking well;\n", "- assoc: whether the household participates in any community activities." ] }, { "cell_type": "code", "execution_count": 4, "id": "01d1703b", "metadata": { "papermill": { "duration": 0.078754, "end_time": "2021-09-18T10:52:00.905455", "exception": false, "start_time": "2021-09-18T10:52:00.826701", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "wells = pd.read_csv(\n", " \"http://stat.columbia.edu/~gelman/arm/examples/arsenic/wells.dat\", sep=\" \"\n", ")" ] }, { "cell_type": "code", "execution_count": 5, "id": "2bf6c000-cb9a-4c81-898f-5ac4cdd1020a", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
switcharsenicdistassoceduc
112.3616.82600000
210.7147.32199900
302.0720.966999010
411.1521.486000012
511.1040.874001114
\n", "
" ], "text/plain": [ " switch arsenic dist assoc educ\n", "1 1 2.36 16.826000 0 0\n", "2 1 0.71 47.321999 0 0\n", "3 0 2.07 20.966999 0 10\n", "4 1 1.15 21.486000 0 12\n", "5 1 1.10 40.874001 1 14" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wells.head()" ] }, { "cell_type": "code", "execution_count": 6, "id": "5dec77a2", "metadata": { "papermill": { "duration": 1.122344, "end_time": "2021-09-18T10:52:02.072825", "exception": false, "start_time": "2021-09-18T10:52:00.950481", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(2, 2, figsize=(12, 6))\n", "fig.suptitle(\"Target variable plotted against various predictors\")\n", "sns.scatterplot(data=wells, x=\"arsenic\", y=\"switch\", ax=ax[0][0])\n", "sns.scatterplot(data=wells, x=\"dist\", y=\"switch\", ax=ax[0][1])\n", "sns.barplot(\n", " data=wells.groupby(\"assoc\")[\"switch\"].mean().reset_index(),\n", " x=\"assoc\",\n", " y=\"switch\",\n", " ax=ax[1][0],\n", ")\n", "ax[1][0].set_ylabel(\"Proportion switch\")\n", "sns.barplot(\n", " data=wells.groupby(\"educ\")[\"switch\"].mean().reset_index(),\n", " x=\"educ\",\n", " y=\"switch\",\n", " ax=ax[1][1],\n", ")\n", "ax[1][1].set_ylabel(\"Proportion switch\");" ] }, { "cell_type": "markdown", "id": "05c9daff", "metadata": { "papermill": { "duration": 0.046834, "end_time": "2021-09-18T10:52:02.167845", "exception": false, "start_time": "2021-09-18T10:52:02.121011", "status": "completed" }, "tags": [] }, "source": [ "Next, we'll choose 200 observations to be part of our train set, and 1500 to be part of our test set." ] }, { "cell_type": "code", "execution_count": 7, "id": "e6b41da0", "metadata": { "papermill": { "duration": 0.058671, "end_time": "2021-09-18T10:52:02.274078", "exception": false, "start_time": "2021-09-18T10:52:02.215407", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "np.random.seed(1)\n", "train_id = wells.sample(n=200).index\n", "test_id = wells.loc[~wells.index.isin(train_id)].sample(n=1500).index\n", "y_train = wells.loc[train_id, \"switch\"].to_numpy()\n", "y_test = wells.loc[test_id, \"switch\"].to_numpy()" ] }, { "cell_type": "markdown", "id": "01c56e27", "metadata": { "papermill": { "duration": 0.047031, "end_time": "2021-09-18T10:52:02.368998", "exception": false, "start_time": "2021-09-18T10:52:02.321967", "status": "completed" }, "tags": [] }, "source": [ "## 2. Prepare 6 different candidate models \n", "\n", "### 2.1 Feature Engineering \n", "\n", "First, let's add a few new columns:\n", "- `edu0`: whether `educ` is `0`,\n", "- `edu1`: whether `educ` is between `1` and `5`,\n", "- `edu2`: whether `educ` is between `6` and `11`,\n", "- `edu3`: whether `educ` is between `12` and `17`,\n", "- `logarsenic`: natural logarithm of `arsenic`,\n", "- `assoc_half`: half of `assoc`,\n", "- `as_square`: natural logarithm of `arsenic`, squared,\n", "- `as_third`: natural logarithm of `arsenic`, cubed,\n", "- `dist100`: `dist` divided by `100`,\n", " - `intercept`: just a columns of `1`s.\n", "\n", "We're going to start by fitting 6 different models to our train set:\n", "\n", "- logistic regression using `intercept`, `arsenic`, `assoc`, `edu1`, `edu2`, and `edu3`;\n", "- same as above, but with `logarsenic` instead of `arsenic`;\n", "- same as the first one, but with square and cubic features as well;\n", "- same as the first one, but with spline features derived from `logarsenic` as well;\n", "- same as the first one, but with spline features derived from `dist100` as well;\n", "- same as the first one, but with `educ` instead of the binary `edu` variables." ] }, { "cell_type": "code", "execution_count": 8, "id": "fa79c0ee-54b9-458d-9f97-c9e91ae83e7a", "metadata": {}, "outputs": [], "source": [ "wells[\"edu0\"] = wells[\"educ\"].isin(np.arange(0, 1)).astype(int)\n", "wells[\"edu1\"] = wells[\"educ\"].isin(np.arange(1, 6)).astype(int)\n", "wells[\"edu2\"] = wells[\"educ\"].isin(np.arange(6, 12)).astype(int)\n", "wells[\"edu3\"] = wells[\"educ\"].isin(np.arange(12, 18)).astype(int)\n", "wells[\"logarsenic\"] = np.log(wells[\"arsenic\"])\n", "wells[\"assoc_half\"] = wells[\"assoc\"] / 2.0\n", "wells[\"as_square\"] = wells[\"logarsenic\"] ** 2\n", "wells[\"as_third\"] = wells[\"logarsenic\"] ** 3\n", "wells[\"dist100\"] = wells[\"dist\"] / 100.0\n", "wells[\"intercept\"] = 1" ] }, { "cell_type": "code", "execution_count": 9, "id": "6726d0fa", "metadata": { "papermill": { "duration": 0.062523, "end_time": "2021-09-18T10:52:02.478421", "exception": false, "start_time": "2021-09-18T10:52:02.415898", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def bs(x, knots, degree):\n", " \"\"\"\n", " Generate the B-spline basis matrix for a polynomial spline.\n", "\n", " Parameters\n", " ----------\n", " x\n", " predictor variable.\n", " knots\n", " locations of internal breakpoints (not padded).\n", " degree\n", " degree of the piecewise polynomial.\n", "\n", " Returns\n", " -------\n", " pd.DataFrame\n", " Spline basis matrix.\n", "\n", " Notes\n", " -----\n", " This mirrors ``bs`` from splines package in R.\n", " \"\"\"\n", " padded_knots = np.hstack(\n", " [[x.min()] * (degree + 1), knots, [x.max()] * (degree + 1)]\n", " )\n", " return pd.DataFrame(\n", " BSpline(padded_knots, np.eye(len(padded_knots) - degree - 1), degree)(x)[:, 1:],\n", " index=x.index,\n", " )\n", "\n", "\n", "knots = np.quantile(wells.loc[train_id, \"logarsenic\"], np.linspace(0.1, 0.9, num=10))\n", "spline_arsenic = bs(wells[\"logarsenic\"], knots=knots, degree=3)\n", "knots = np.quantile(wells.loc[train_id, \"dist100\"], np.linspace(0.1, 0.9, num=10))\n", "spline_dist = bs(wells[\"dist100\"], knots=knots, degree=3)" ] }, { "cell_type": "code", "execution_count": 10, "id": "064a3de6", "metadata": { "papermill": { "duration": 0.081958, "end_time": "2021-09-18T10:52:02.608879", "exception": false, "start_time": "2021-09-18T10:52:02.526921", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "features_0 = [\"intercept\", \"dist100\", \"arsenic\", \"assoc\", \"edu1\", \"edu2\", \"edu3\"]\n", "features_1 = [\"intercept\", \"dist100\", \"logarsenic\", \"assoc\", \"edu1\", \"edu2\", \"edu3\"]\n", "features_2 = [\n", " \"intercept\",\n", " \"dist100\",\n", " \"arsenic\",\n", " \"as_third\",\n", " \"as_square\",\n", " \"assoc\",\n", " \"edu1\",\n", " \"edu2\",\n", " \"edu3\",\n", "]\n", "features_3 = [\"intercept\", \"dist100\", \"assoc\", \"edu1\", \"edu2\", \"edu3\"]\n", "features_4 = [\"intercept\", \"logarsenic\", \"assoc\", \"edu1\", \"edu2\", \"edu3\"]\n", "features_5 = [\"intercept\", \"dist100\", \"logarsenic\", \"assoc\", \"educ\"]\n", "\n", "X0 = wells.loc[train_id, features_0].to_numpy()\n", "X1 = wells.loc[train_id, features_1].to_numpy()\n", "X2 = wells.loc[train_id, features_2].to_numpy()\n", "X3 = (\n", " pd.concat([wells.loc[:, features_3], spline_arsenic], axis=1)\n", " .loc[train_id]\n", " .to_numpy()\n", ")\n", "X4 = pd.concat([wells.loc[:, features_4], spline_dist], axis=1).loc[train_id].to_numpy()\n", "X5 = wells.loc[train_id, features_5].to_numpy()\n", "\n", "X0_test = wells.loc[test_id, features_0].to_numpy()\n", "X1_test = wells.loc[test_id, features_1].to_numpy()\n", "X2_test = wells.loc[test_id, features_2].to_numpy()\n", "X3_test = (\n", " pd.concat([wells.loc[:, features_3], spline_arsenic], axis=1)\n", " .loc[test_id]\n", " .to_numpy()\n", ")\n", "X4_test = (\n", " pd.concat([wells.loc[:, features_4], spline_dist], axis=1).loc[test_id].to_numpy()\n", ")\n", "X5_test = wells.loc[test_id, features_5].to_numpy()" ] }, { "cell_type": "code", "execution_count": 11, "id": "64fa1b43", "metadata": { "papermill": { "duration": 0.055757, "end_time": "2021-09-18T10:52:02.713347", "exception": false, "start_time": "2021-09-18T10:52:02.657590", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "train_x_list = [X0, X1, X2, X3, X4, X5]\n", "test_x_list = [X0_test, X1_test, X2_test, X3_test, X4_test, X5_test]\n", "K = len(train_x_list)" ] }, { "cell_type": "markdown", "id": "e7d1a65d", "metadata": { "papermill": { "duration": 0.049466, "end_time": "2021-09-18T10:52:02.811950", "exception": false, "start_time": "2021-09-18T10:52:02.762484", "status": "completed" }, "tags": [] }, "source": [ "### 2.2 Training \n", "\n", "Each model will be trained in the same way - with a Bernoulli likelihood and a logit link function." ] }, { "cell_type": "code", "execution_count": 12, "id": "c070567f", "metadata": { "papermill": { "duration": 0.056796, "end_time": "2021-09-18T10:52:02.917713", "exception": false, "start_time": "2021-09-18T10:52:02.860917", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def logistic(x, y=None):\n", " beta = numpyro.sample(\"beta\", dist.Normal(0, 3).expand([x.shape[1]]))\n", " logits = numpyro.deterministic(\"logits\", jnp.matmul(x, beta))\n", "\n", " numpyro.sample(\n", " \"obs\",\n", " dist.Bernoulli(logits=logits),\n", " obs=y,\n", " )" ] }, { "cell_type": "code", "execution_count": 13, "id": "b29ed6c2", "metadata": { "papermill": { "duration": 820.388941, "end_time": "2021-09-18T11:05:43.355092", "exception": false, "start_time": "2021-09-18T10:52:02.966151", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "fit_list = []\n", "for k in range(K):\n", " sampler = numpyro.infer.NUTS(logistic)\n", " mcmc = numpyro.infer.MCMC(\n", " sampler, num_chains=4, num_samples=1000, num_warmup=1000, progress_bar=False\n", " )\n", " rng_key = jax.random.fold_in(jax.random.PRNGKey(13), k)\n", " mcmc.run(rng_key, x=train_x_list[k], y=y_train)\n", " fit_list.append(mcmc)" ] }, { "cell_type": "markdown", "id": "c2ac5012", "metadata": { "papermill": { "duration": 0.051074, "end_time": "2021-09-18T11:05:43.479751", "exception": false, "start_time": "2021-09-18T11:05:43.428677", "status": "completed" }, "tags": [] }, "source": [ "### 2.3 Estimate leave-one-out cross-validated score for each training point \n", "\n", "Rather than refitting each model 100 times, we will estimate the leave-one-out cross-validated score using [LOO](https://arxiv.org/abs/2001.00980)." ] }, { "cell_type": "code", "execution_count": 14, "id": "0dfe6166", "metadata": { "papermill": { "duration": 14.787853, "end_time": "2021-09-18T11:05:58.318434", "exception": false, "start_time": "2021-09-18T11:05:43.530581", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def find_point_wise_loo_score(fit):\n", " return az.loo(az.from_numpyro(fit), pointwise=True, scale=\"log\").loo_i.values\n", "\n", "\n", "lpd_point = np.vstack([find_point_wise_loo_score(fit) for fit in fit_list]).T\n", "exp_lpd_point = np.exp(lpd_point)" ] }, { "cell_type": "markdown", "id": "e3f7a74a", "metadata": { "papermill": { "duration": 0.051972, "end_time": "2021-09-18T11:05:58.422802", "exception": false, "start_time": "2021-09-18T11:05:58.370830", "status": "completed" }, "tags": [] }, "source": [ "## 3. Bayesian Hierarchical Stacking \n", "\n", "### 3.1 Prepare stacking datasets \n", "\n", "To determine how the stacking weights should vary across training and test sets, we will need to create \"stacking datasets\" which include all the features which we want the stacking weights to depend on. How should such features be included? For discrete features, this is easy, we just one-hot-encode them. But for continuous features, we need a trick. In Equation (16), the authors recommend the following: if you have a continuous feature `f`, then replace it with the following two features:\n", "\n", "- `f_l`: `f` minus the median of `f`, clipped above at 0;\n", "- `f_r`: `f` minus the median of `f`, clipped below at 0;" ] }, { "cell_type": "code", "execution_count": 15, "id": "8450ac11", "metadata": { "papermill": { "duration": 0.078407, "end_time": "2021-09-18T11:05:58.566113", "exception": false, "start_time": "2021-09-18T11:05:58.487706", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "dist100_median = wells.loc[wells.index[train_id], \"dist100\"].median()\n", "logarsenic_median = wells.loc[wells.index[train_id], \"logarsenic\"].median()\n", "wells[\"dist100_l\"] = (wells[\"dist100\"] - dist100_median).clip(upper=0)\n", "wells[\"dist100_r\"] = (wells[\"dist100\"] - dist100_median).clip(lower=0)\n", "wells[\"logarsenic_l\"] = (wells[\"logarsenic\"] - logarsenic_median).clip(upper=0)\n", "wells[\"logarsenic_r\"] = (wells[\"logarsenic\"] - logarsenic_median).clip(lower=0)\n", "\n", "stacking_features = [\n", " \"edu0\",\n", " \"edu1\",\n", " \"edu2\",\n", " \"edu3\",\n", " \"assoc_half\",\n", " \"dist100_l\",\n", " \"dist100_r\",\n", " \"logarsenic_l\",\n", " \"logarsenic_r\",\n", "]\n", "X_stacking_train = wells.loc[train_id, stacking_features].to_numpy()\n", "X_stacking_test = wells.loc[test_id, stacking_features].to_numpy()" ] }, { "cell_type": "markdown", "id": "cb323c68", "metadata": { "papermill": { "duration": 0.052318, "end_time": "2021-09-18T11:05:58.671602", "exception": false, "start_time": "2021-09-18T11:05:58.619284", "status": "completed" }, "tags": [] }, "source": [ "### 3.2 Define stacking model \n", "\n", "What we seek to find is a matrix of weights $W$ with which to multiply the models' predictions. Let's define a matrix $Pred$ such that $Pred_{i,k}$ represents the prediction made for point $i$ by model $k$. Then the final prediction for point $i$ will then be:\n", "\n", "$$ \\sum_k W_{i, k}Pred_{i,k} $$\n", "\n", "Such a matrix $W$ would be required to have each column sum to $1$. Hence, we calculate each row $W_i$ of $W$ as:\n", "\n", "$$ W_i = \\text{softmax}(X\\_\\text{stacking}_i \\cdot \\beta), $$\n", "\n", "where $\\beta$ is a matrix whose values we seek to determine. For the discrete features, $\\beta$ is given a hierarchical structure over the possible inputs. Continuous features, on the other hand, get no hierarchical structure in this case study and just vary according to the input values.\n", "\n", "Notice how, for the discrete features, a [non-centered parametrisation is used](https://twiecki.io/blog/2017/02/08/bayesian-hierchical-non-centered/). Also note that we only need to estimate `K-1` columns of $\\beta$, because the weights `W_{i, k}` will have to sum to `1` for each `i`." ] }, { "cell_type": "code", "execution_count": 16, "id": "f2203a8c", "metadata": { "papermill": { "duration": 0.075301, "end_time": "2021-09-18T11:05:58.799743", "exception": false, "start_time": "2021-09-18T11:05:58.724442", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def stacking(\n", " X,\n", " d_discrete,\n", " X_test,\n", " exp_lpd_point,\n", " tau_mu,\n", " tau_sigma,\n", " *,\n", " test,\n", "):\n", " \"\"\"\n", " Get weights with which to stack candidate models' predictions.\n", "\n", " Parameters\n", " ----------\n", " X\n", " Training stacking matrix: features on which stacking weights should depend, for the\n", " training set.\n", " d_discrete\n", " Number of discrete features in `X` and `X_test`. The first `d_discrete` features\n", " from these matrices should be the discrete ones, with the continuous ones coming\n", " after them.\n", " X_test\n", " Test stacking matrix: features on which stacking weights should depend, for the\n", " testing set.\n", " exp_lpd_point\n", " LOO score evaluated at each point in the training set, for each candidate model.\n", " tau_mu\n", " Hyperprior for mean of `beta`, for discrete features.\n", " tau_sigma\n", " Hyperprior for standard deviation of `beta`, for continuous features.\n", " test\n", " Whether to calculate stacking weights for test set.\n", "\n", " Notes\n", " -----\n", " Naming of variables mirrors what's used in the original paper.\n", " \"\"\"\n", " N = X.shape[0]\n", " d = X.shape[1]\n", " N_test = X_test.shape[0]\n", " K = lpd_point.shape[1] # number of candidate models\n", "\n", " with numpyro.plate(\"Candidate models\", K - 1, dim=-2):\n", " # mean effect of discrete features on stacking weights\n", " mu = numpyro.sample(\"mu\", dist.Normal(0, tau_mu))\n", " # standard deviation effect of discrete features on stacking weights\n", " sigma = numpyro.sample(\"sigma\", dist.HalfNormal(scale=tau_sigma))\n", " with numpyro.plate(\"Discrete features\", d_discrete, dim=-1):\n", " # effect of discrete features on stacking weights\n", " tau = numpyro.sample(\"tau\", dist.Normal(0, 1))\n", " with numpyro.plate(\"Continuous features\", d - d_discrete, dim=-1):\n", " # effect of continuous features on stacking weights\n", " beta_con = numpyro.sample(\"beta_con\", dist.Normal(0, 1))\n", "\n", " # effects of features on stacking weights\n", " beta = numpyro.deterministic(\n", " \"beta\", jnp.hstack([(sigma.squeeze() * tau.T + mu.squeeze()).T, beta_con])\n", " )\n", " assert beta.shape == (K - 1, d)\n", "\n", " # stacking weights (in unconstrained space)\n", " f = jnp.hstack([X @ beta.T, jnp.zeros((N, 1))])\n", " assert f.shape == (N, K)\n", "\n", " # log probability of LOO training scores weighted by stacking weights.\n", " log_w = jax.nn.log_softmax(f, axis=1)\n", " # stacking weights (constrained to sum to 1)\n", " numpyro.deterministic(\"w\", jnp.exp(log_w))\n", " logp = jax.nn.logsumexp(lpd_point + log_w, axis=1)\n", " numpyro.factor(\"logp\", jnp.sum(logp))\n", "\n", " if test:\n", " # test set stacking weights (in unconstrained space)\n", " f_test = jnp.hstack([X_test @ beta.T, jnp.zeros((N_test, 1))])\n", " # test set stacking weights (constrained to sum to 1)\n", " w_test = numpyro.deterministic(\"w_test\", jax.nn.softmax(f_test, axis=1))" ] }, { "cell_type": "code", "execution_count": 17, "id": "9827977d", "metadata": { "papermill": { "duration": 296.084187, "end_time": "2021-09-18T11:10:54.936288", "exception": false, "start_time": "2021-09-18T11:05:58.852101", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "sampler = numpyro.infer.NUTS(stacking)\n", "mcmc = numpyro.infer.MCMC(\n", " sampler, num_chains=4, num_samples=1000, num_warmup=1000, progress_bar=False\n", ")\n", "mcmc.run(\n", " jax.random.PRNGKey(17),\n", " X=X_stacking_train,\n", " d_discrete=4,\n", " X_test=X_stacking_test,\n", " exp_lpd_point=exp_lpd_point,\n", " tau_mu=1.0,\n", " tau_sigma=0.5,\n", " test=True,\n", ")\n", "trace = mcmc.get_samples()" ] }, { "cell_type": "markdown", "id": "c7ede764", "metadata": { "papermill": { "duration": 0.052553, "end_time": "2021-09-18T11:10:55.042375", "exception": false, "start_time": "2021-09-18T11:10:54.989822", "status": "completed" }, "tags": [] }, "source": [ "We can now extract the weights with which to weight the different models from the posterior, and then visualise how they vary across the training set.\n", "\n", "Let's compare them with what the weights would've been if we'd just used fixed stacking weights (computed using ArviZ - see [their docs](https://arviz-devs.github.io/arviz/api/generated/arviz.compare.html) for details)." ] }, { "cell_type": "code", "execution_count": 18, "id": "812117cb", "metadata": { "papermill": { "duration": 2.523295, "end_time": "2021-09-18T11:10:57.979955", "exception": false, "start_time": "2021-09-18T11:10:55.456660", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 6), sharey=True)\n", "training_stacking_weights = trace[\"w\"].mean(axis=0)\n", "sns.scatterplot(data=pd.DataFrame(training_stacking_weights), ax=ax[0])\n", "fixed_weights = (\n", " az.compare({idx: fit for idx, fit in enumerate(fit_list)}, method=\"stacking\")\n", " .sort_index()[\"weight\"]\n", " .to_numpy()\n", ")\n", "fixed_weights_df = pd.DataFrame(\n", " np.repeat(\n", " fixed_weights[jnp.newaxis, :],\n", " len(X_stacking_train),\n", " axis=0,\n", " )\n", ")\n", "sns.scatterplot(data=fixed_weights_df, ax=ax[1])\n", "ax[0].set_title(\"Training weights from Bayesian Hierarchical stacking\")\n", "ax[1].set_title(\"Fixed weights stacking\")\n", "ax[0].set_xlabel(\"Index\")\n", "ax[1].set_xlabel(\"Index\")\n", "fig.suptitle(\n", " \"Bayesian Hierarchical Stacking weights can vary according to the input\",\n", " fontsize=18,\n", ")\n", "fig.tight_layout();" ] }, { "cell_type": "markdown", "id": "c60e0c01", "metadata": { "papermill": { "duration": 0.065143, "end_time": "2021-09-18T11:10:58.110931", "exception": false, "start_time": "2021-09-18T11:10:58.045788", "status": "completed" }, "tags": [] }, "source": [ "## 4. Evaluate on test set \n", "\n", "### 4.1 Stack predictions \n", "\n", "Now, for each model, let's evaluate the log predictive density for each point in the test set. Once we have predictions for each model, we need to think about how to combine them, such that for each test point, we get a single prediction.\n", "\n", "We decided we'd do this in three ways:\n", "- Bayesian Hierarchical Stacking (`bhs_pred`);\n", "- choosing the model with the best training set LOO score (`model_selection_preds`);\n", "- fixed-weights stacking (`fixed_weights_preds`)." ] }, { "cell_type": "code", "execution_count": 19, "id": "ce86bd9e-b2c6-4947-9675-92c925b6088d", "metadata": {}, "outputs": [], "source": [ "# for each candidate model, extract the posterior predictive logits\n", "train_preds = []\n", "for k in range(K):\n", " predictive = numpyro.infer.Predictive(logistic, fit_list[k].get_samples())\n", " rng_key = jax.random.fold_in(jax.random.PRNGKey(19), k)\n", " train_pred = predictive(rng_key, x=train_x_list[k])[\"logits\"]\n", " train_preds.append(train_pred.mean(axis=0))\n", "# reshape, so we have (N, K)\n", "train_preds = np.vstack(train_preds).T" ] }, { "cell_type": "code", "execution_count": 20, "id": "5b686b7c", "metadata": { "papermill": { "duration": 0.54285, "end_time": "2021-09-18T11:10:59.694998", "exception": false, "start_time": "2021-09-18T11:10:59.152148", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# same as previous cell, but for test set\n", "test_preds = []\n", "for k in range(K):\n", " predictive = numpyro.infer.Predictive(logistic, fit_list[k].get_samples())\n", " rng_key = jax.random.fold_in(jax.random.PRNGKey(20), k)\n", " test_pred = predictive(rng_key, x=test_x_list[k])[\"logits\"]\n", " test_preds.append(test_pred.mean(axis=0))\n", "test_preds = np.vstack(test_preds).T" ] }, { "cell_type": "code", "execution_count": 21, "id": "436f8789", "metadata": { "papermill": { "duration": 0.145066, "end_time": "2021-09-18T11:11:00.042707", "exception": false, "start_time": "2021-09-18T11:10:59.897641", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# get the stacking weights for the test set\n", "test_stacking_weights = trace[\"w_test\"].mean(axis=0)\n", "# get predictions using the stacking weights\n", "bhs_predictions = (test_stacking_weights * test_preds).sum(axis=1)\n", "# get predictions using only the model with the best LOO score\n", "model_selection_preds = test_preds[:, lpd_point.sum(axis=0).argmax()]\n", "# get predictions using fixed stacking weights, dependent on the LOO score\n", "fixed_weights_preds = (fixed_weights * test_preds).sum(axis=1)" ] }, { "cell_type": "markdown", "id": "76233762", "metadata": { "papermill": { "duration": 0.064289, "end_time": "2021-09-18T11:11:00.170538", "exception": false, "start_time": "2021-09-18T11:11:00.106249", "status": "completed" }, "tags": [] }, "source": [ "### 4.2 Compare methods " ] }, { "cell_type": "markdown", "id": "c2d889c2", "metadata": { "papermill": { "duration": 0.06178, "end_time": "2021-09-18T11:11:00.293209", "exception": false, "start_time": "2021-09-18T11:11:00.231429", "status": "completed" }, "tags": [] }, "source": [ "Let's compare the negative log predictive density scores on the test set (note - lower is better):" ] }, { "cell_type": "code", "execution_count": 22, "id": "33e15689", "metadata": { "papermill": { "duration": 0.463508, "end_time": "2021-09-18T11:11:00.819086", "exception": false, "start_time": "2021-09-18T11:11:00.355578", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(figsize=(12, 6))\n", "neg_log_pred_densities = np.vstack(\n", " [\n", " -dist.Bernoulli(logits=bhs_predictions).log_prob(y_test),\n", " -dist.Bernoulli(logits=model_selection_preds).log_prob(y_test),\n", " -dist.Bernoulli(logits=fixed_weights_preds).log_prob(y_test),\n", " ]\n", ").T\n", "neg_log_pred_density = pd.DataFrame(\n", " neg_log_pred_densities,\n", " columns=[\n", " \"Bayesian Hierarchical Stacking\",\n", " \"Model selection\",\n", " \"Fixed stacking weights\",\n", " ],\n", ")\n", "sns.barplot(\n", " data=neg_log_pred_density.reindex(\n", " columns=neg_log_pred_density.mean(axis=0).sort_values(ascending=False).index\n", " ),\n", " orient=\"h\",\n", " ax=ax,\n", ")\n", "ax.set_title(\n", " \"Bayesian Hierarchical Stacking performs best here\", fontdict={\"fontsize\": 18}\n", ")\n", "ax.set_xlabel(\"Negative mean log predictive density (lower is better)\");" ] }, { "cell_type": "markdown", "id": "dc7517ce", "metadata": { "papermill": { "duration": 0.066707, "end_time": "2021-09-18T11:11:01.178051", "exception": false, "start_time": "2021-09-18T11:11:01.111344", "status": "completed" }, "tags": [] }, "source": [ "So, in this dataset, with this particular train-test split, Bayesian Hierarchical Stacking does indeed bring a small gain compared with model selection and compared with fixed-weight stacking.\n", "\n", "### 4.3 Does this prove that Bayesian Hierarchical Stacking works? \n", "\n", "No, a single train-test split doesn't prove anything. Check the original paper for results with varying training set sizes, repeated with different train-test splits, in which they show that Bayesian Hierarchical Stacking consistently outperforms model selection and fixed-weight stacking.\n", "\n", "The goal of this notebook was just to show how to implement this technique in NumPyro." ] }, { "cell_type": "markdown", "id": "29cf8140", "metadata": { "papermill": { "duration": 0.066367, "end_time": "2021-09-18T11:11:01.310721", "exception": false, "start_time": "2021-09-18T11:11:01.244354", "status": "completed" }, "tags": [] }, "source": [ "## Conclusion \n", "\n", "We've seen how Bayesian Hierarchical Stacking can help us average models with input-dependent weights, in a manner which doesn't overfit. We only implemented the first case study from the paper, but readers are encouraged to check out the other two as well. Also check the paper for theoretical results and results from more experiments.\n", "\n", "## References\n", "\n", "1. Yuling Yao, Gregor Pirš, Aki Vehtari, Andrew Gelman (2021). [Bayesian hierarchical stacking: Some models are (somewhere) useful](https://arxiv.org/abs/2101.08954)\n", "2. Måns Magnusson, Michael Riis Andersen, Johan Jonasson, Aki Vehtari (2020). [Leave-One-Out Cross-Validation for Bayesian Model Comparison in Large Data](https://arxiv.org/abs/2001.00980)\n", "3. https://github.com/yao-yl/hierarchical-stacking-code.\n", "4. Thomas Wiecki (2017). [Why hierarchical models are awesome, tricky, and Bayesian](https://twiecki.io/blog/2017/02/08/bayesian-hierchical-non-centered/)" ] }, { "cell_type": "code", "execution_count": null, "id": "32694722", "metadata": { "papermill": { "duration": 0.068268, "end_time": "2021-09-18T11:11:03.218672", "exception": false, "start_time": "2021-09-18T11:11:03.150404", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.8" }, "papermill": { "default_parameters": {}, "duration": 1224.415684, "end_time": "2021-09-18T11:11:05.150406", "environment_variables": {}, "exception": null, "input_path": "__notebook__.ipynb", "output_path": "__notebook__.ipynb", "parameters": {}, "start_time": "2021-09-18T10:50:40.734722", "version": "2.3.3" } }, "nbformat": 4, "nbformat_minor": 5 }