{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "0", "metadata": { "tags": [ "hide-cell" ] }, "outputs": [], "source": [ "try:\n", " import tinygp\n", "except ImportError:\n", " %pip install -q tinygp\n", "\n", "try:\n", " import jaxopt\n", "except ImportError:\n", " %pip install -q jaxopt" ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "(means)=\n", "\n", "# Fitting a Mean Function\n", "\n", "It is quite common in the GP literature to (\"without lack of generality\") set the mean of the process to zero and call it a day.\n", "In practice, however, it is often useful to fit for the parameters of a mean model at the same time as the GP parameters.\n", "In some other tutorials, we fit for a constant mean value using the `mean` argument to {class}`tinygp.GaussianProcess`, but in this tutorial we walk through an example for how you might go about fitting a model with a non-trival parameterized mean function.\n", "\n", "For our example, we'll fit for the location, width, and amplitude of the following model:\n", "\n", "$$\n", "f(x) = b + a\\,\\exp\\left(-\\frac{(x - \\ell)^2}{2\\,w^2}\\right)\n", "$$\n", "\n", "In `jax`, we might implement such a function as follows:" ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": {}, "outputs": [], "source": [ "from functools import partial\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", "\n", "def mean_function(params, X):\n", " mod = jnp.exp(-0.5 * jnp.square((X - params[\"loc\"]) / jnp.exp(params[\"log_width\"])))\n", " beta = jnp.array([1, mod])\n", " return params[\"amps\"] @ beta\n", "\n", "\n", "mean_params = {\n", " \"amps\": np.array([0.1, 0.3]),\n", " \"loc\": 5.0,\n", " \"log_width\": np.log(0.5),\n", "}\n", "\n", "X_grid = np.linspace(0, 10, 200)\n", "model = jax.vmap(partial(mean_function, mean_params))(X_grid)\n", "\n", "plt.plot(X_grid, model)\n", "plt.xlabel(\"x\")\n", "plt.ylabel(\"y\")\n", "_ = plt.title(\"a parametric mean model\")" ] }, { "cell_type": "markdown", "id": "3", "metadata": {}, "source": [ "Our implementation here is somewhat artificially complicated in order to highlight one very important technical point: we must define our mean function to operate on a *single input coordinate*.\n", "What that means is that we don't need to worry about broadcasting and stuff within our mean function: `tinygp` will do all the necessary `vmap`-ing.\n", "More explicitly, if we try to call our `mean_function` on a vector of inputs, it will fail with a strange error (yeah, I know that we could write it in a way that would work, but I'm trying to make a point!):" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": { "tags": [ "raises-exception" ] }, "outputs": [], "source": [ "model = mean_function(mean_params, X_grid)" ] }, { "cell_type": "markdown", "id": "5", "metadata": {}, "source": [ "Instead, we need to manually `vmap` as follows:" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [], "source": [ "model = jax.vmap(partial(mean_function, mean_params))(X_grid)" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "## Simulated data\n", "\n", "Now that we have this mean function defined, let's make some fake data that could benefit from a joint GP + mean function fit.\n", "In this case, we'll add a background trend that's not included in the mean model, as well as some noise." ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "random = np.random.default_rng(135)\n", "X = np.sort(random.uniform(0, 10, 50))\n", "y = jax.vmap(partial(mean_function, mean_params))(X)\n", "y += 0.1 * np.sin(2 * np.pi * (X - 5) / 10.0)\n", "y += 0.03 * random.normal(size=len(X))\n", "plt.plot(X, y, \".k\")\n", "plt.xlabel(\"x\")\n", "plt.ylabel(\"y\")\n", "_ = plt.title(\"simulated data\")" ] }, { "cell_type": "markdown", "id": "9", "metadata": {}, "source": [ "## The fit\n", "\n", "Then, we set up the usual infrastructure to calculate the loss function for this model.\n", "In this case, you'll notice that we've stacked the mean and GP parameters into one dictionary, but that isn't the only way you could do it.\n", "You'll also notice that we're passing a partially evaluated version of the mean function to our GP object, but we're not doing any `vmap`-ing.\n", "That's because `tinygp` is expecting the mean function to operate on a single input coordinate, and it will handle the appropriate mapping." ] }, { "cell_type": "code", "execution_count": null, "id": "10", "metadata": {}, "outputs": [], "source": [ "from tinygp import kernels, GaussianProcess\n", "\n", "\n", "def build_gp(params):\n", " kernel = jnp.exp(params[\"log_gp_amp\"]) * kernels.Matern52(\n", " jnp.exp(params[\"log_gp_scale\"])\n", " )\n", " return GaussianProcess(\n", " kernel,\n", " X,\n", " diag=jnp.exp(params[\"log_gp_diag\"]),\n", " mean=partial(mean_function, params),\n", " )\n", "\n", "\n", "@jax.jit\n", "def loss(params):\n", " gp = build_gp(params)\n", " return -gp.log_probability(y)\n", "\n", "\n", "params = dict(\n", " log_gp_amp=np.log(0.1),\n", " log_gp_scale=np.log(3.0),\n", " log_gp_diag=np.log(0.03),\n", " **mean_params,\n", ")\n", "loss(params)" ] }, { "cell_type": "markdown", "id": "11", "metadata": {}, "source": [ "We can minimize the loss using `jaxopt`:" ] }, { "cell_type": "code", "execution_count": null, "id": "12", "metadata": {}, "outputs": [], "source": [ "import jaxopt\n", "\n", "solver = jaxopt.ScipyMinimize(fun=loss)\n", "soln = solver.run(jax.tree_util.tree_map(jnp.asarray, params))\n", "print(f\"Final negative log likelihood: {soln.state.fun_val}\")" ] }, { "cell_type": "markdown", "id": "13", "metadata": {}, "source": [ "## Visualizing result\n", "\n", "And then plot the conditional distribution:" ] }, { "cell_type": "code", "execution_count": null, "id": "14", "metadata": {}, "outputs": [], "source": [ "gp = build_gp(soln.params)\n", "_, cond = gp.condition(y, X_grid)\n", "\n", "mu = cond.loc\n", "std = np.sqrt(cond.variance)\n", "\n", "plt.plot(X, y, \".k\", label=\"data\")\n", "plt.plot(X_grid, mu, label=\"model\")\n", "plt.fill_between(X_grid, mu + std, mu - std, color=\"C0\", alpha=0.3)\n", "\n", "plt.xlim(X_grid.min(), X_grid.max())\n", "plt.xlabel(\"x\")\n", "plt.ylabel(\"y\")\n", "_ = plt.legend()" ] }, { "cell_type": "markdown", "id": "15", "metadata": {}, "source": [ "That looks pretty good but, when working with mean functions, it is often useful to separate the mean model and GP predictions when plotting the conditional.\n", "The interface for doing this in `tinygp` is not its most ergonomic feature, but it shouldn't be too onerous.\n", "To compute the conditional distribution, without the mean function included, call {func}`tinygp.GaussianProcess.condition` with the `include_mean=False` flag:" ] }, { "cell_type": "code", "execution_count": null, "id": "16", "metadata": {}, "outputs": [], "source": [ "gp = build_gp(soln.params)\n", "_, cond = gp.condition(y, X_grid, include_mean=False)\n", "\n", "mu = cond.loc + soln.params[\"amps\"][0]\n", "std = np.sqrt(cond.variance)\n", "\n", "plt.plot(X, y, \".k\", label=\"data\")\n", "plt.plot(X_grid, mu, label=\"GP model\")\n", "plt.fill_between(X_grid, mu + std, mu - std, color=\"C0\", alpha=0.3)\n", "plt.plot(X_grid, jax.vmap(gp.mean_function)(X_grid), label=\"mean model\")\n", "\n", "plt.xlim(X_grid.min(), X_grid.max())\n", "plt.xlabel(\"x\")\n", "plt.ylabel(\"y\")\n", "_ = plt.legend()" ] }, { "cell_type": "markdown", "id": "17", "metadata": {}, "source": [ "There is one other subtlety that you may notice here: we added the mean model's zero point (`params[\"amps\"][0]`) to the GP prediction.\n", "If we had left this off, the blue line in the above figure would be offset below the data by about `0.1`, and it's pretty common that you'll end up with a workflow like this when visualizing the results of GP fits with non-trivial means.\n", "\n", "## An alternative workflow\n", "\n", "Sometimes it can be easier to manage all the mean function bookkeeping yourself, and instead of using the `mean` argument to {class}`tinygp.GaussianProcess`, you could instead manually subtract the mean function from the data before calling {func}`tinygp.GaussianProcess.log_probability`.\n", "Here's how you might implement such a workflow for our example:" ] }, { "cell_type": "code", "execution_count": null, "id": "18", "metadata": {}, "outputs": [], "source": [ "vmapped_mean_function = jax.vmap(mean_function, in_axes=(None, 0))\n", "\n", "\n", "def build_gp_v2(params):\n", " kernel = jnp.exp(params[\"log_gp_amp\"]) * kernels.Matern52(\n", " jnp.exp(params[\"log_gp_scale\"])\n", " )\n", " return GaussianProcess(kernel, X, diag=jnp.exp(params[\"log_gp_diag\"]))\n", "\n", "\n", "@jax.jit\n", "def loss_v2(params):\n", " gp = build_gp_v2(params)\n", " return -gp.log_probability(y - vmapped_mean_function(params, X))\n", "\n", "\n", "loss_v2(params)" ] }, { "cell_type": "markdown", "id": "19", "metadata": {}, "source": [ "In this case, we are now responsible for making sure that the mean function is properly broadcasted, and we must not forget to also subtract the mean function when calling {func}`tinygp.GaussianProcess.condition`." ] }, { "cell_type": "code", "execution_count": null, "id": "20", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.9" } }, "nbformat": 4, "nbformat_minor": 5 }