{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "0", "metadata": { "tags": [ "hide-cell" ] }, "outputs": [], "source": [ "try:\n", " import tinygp\n", "except ImportError:\n", " %pip install -q tinygp\n", "\n", "try:\n", " import jaxopt\n", "except ImportError:\n", " %pip install -q jaxopt" ] }, { "cell_type": "markdown", "id": "1", "metadata": {}, "source": [ "(quickstart)=\n", "\n", "# Getting Started\n", "\n", "```{note}\n", "This tutorial will introduce some of the basic usage of `tinygp`, but since we're going to be leaning pretty heavily on `jax`, it might be useful to also take a look at [the `jax` docs](https://jax.readthedocs.io) for some more basic introduction to `jax` programming patterns.\n", "```\n", "\n", "In the following, we'll reproduce the analysis for Figure 5.6 in [Chapter 5 of Rasmussen & Williams (R&W)](http://www.gaussianprocess.org/gpml/chapters/RW5.pdf).\n", "The data are measurements of the atmospheric CO2 concentration made at Mauna Loa, Hawaii (Keeling & Whorf 2004).\n", "The dataset is said to be available online but I couldn't seem to download it from the original source.\n", "Luckily the [statsmodels](http://statsmodels.sourceforge.net/) package [includes a copy](http://statsmodels.sourceforge.net/devel/datasets/generated/co2.html) that we can load as follows: " ] }, { "cell_type": "code", "execution_count": null, "id": "2", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from statsmodels.datasets import co2\n", "\n", "data = co2.load_pandas().data\n", "t = 2000 + (np.array(data.index.to_julian_date()) - 2451545.0) / 365.25\n", "y = np.array(data.co2)\n", "m = np.isfinite(t) & np.isfinite(y) & (t < 1996)\n", "t, y = t[m][::4], y[m][::4]\n", "\n", "plt.plot(t, y, \".k\")\n", "plt.xlim(t.min(), t.max())\n", "plt.xlabel(\"year\")\n", "_ = plt.ylabel(\"CO$_2$ in ppm\")" ] }, { "cell_type": "markdown", "id": "3", "metadata": {}, "source": [ "In this figure, you can see that there is periodic (or quasi-periodic) signal with a year-long period superimposed on a long term trend. We will follow R&W and model these effects non-parametrically using a complicated covariance function. The covariance function that we’ll use is:\n", "\n", "$$k(r) = k_1(r) + k_2(r) + k_3(r) + k_4(r)$$\n", "\n", "where\n", "\n", "$$\n", "\\begin{eqnarray}\n", " k_1(r) &=& \\theta_1^2 \\, \\exp \\left(-\\frac{r^2}{2\\,\\theta_2} \\right) \\\\\n", " k_2(r) &=& \\theta_3^2 \\, \\exp \\left(-\\frac{r^2}{2\\,\\theta_4}\n", " -\\theta_5\\,\\sin^2\\left(\n", " \\frac{\\pi\\,r}{\\theta_6}\\right)\n", " \\right) \\\\\n", " k_3(r) &=& \\theta_7^2 \\, \\left [ 1 + \\frac{r^2}{2\\,\\theta_8\\,\\theta_9}\n", " \\right ]^{-\\theta_8} \\\\\n", " k_4(r) &=& \\theta_{10}^2 \\, \\exp \\left(-\\frac{r^2}{2\\,\\theta_{11}} \\right)\n", " + \\theta_{12}^2\\,\\delta_{ij}\n", "\\end{eqnarray}\n", "$$\n", "\n", "We can implement this kernel in `tinygp` as follows (we'll use the R&W results as the hyperparameters for now):" ] }, { "cell_type": "code", "execution_count": null, "id": "4", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "\n", "from tinygp import kernels, GaussianProcess\n", "\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", "\n", "def build_gp(theta, X):\n", " # We want most of our parameters to be positive so we take the `exp` here\n", " # Note that we're using `jnp` instead of `np`\n", " amps = jnp.exp(theta[\"log_amps\"])\n", " scales = jnp.exp(theta[\"log_scales\"])\n", "\n", " # Construct the kernel by multiplying and adding `Kernel` objects\n", " k1 = amps[0] * kernels.ExpSquared(scales[0])\n", " k2 = (\n", " amps[1]\n", " * kernels.ExpSquared(scales[1])\n", " * kernels.ExpSineSquared(\n", " scale=jnp.exp(theta[\"log_period\"]),\n", " gamma=jnp.exp(theta[\"log_gamma\"]),\n", " )\n", " )\n", " k3 = amps[2] * kernels.RationalQuadratic(\n", " alpha=jnp.exp(theta[\"log_alpha\"]), scale=scales[2]\n", " )\n", " k4 = amps[3] * kernels.ExpSquared(scales[3])\n", " kernel = k1 + k2 + k3 + k4\n", "\n", " return GaussianProcess(\n", " kernel, X, diag=jnp.exp(theta[\"log_diag\"]), mean=theta[\"mean\"]\n", " )\n", "\n", "\n", "def neg_log_likelihood(theta, X, y):\n", " gp = build_gp(theta, X)\n", " return -gp.log_probability(y)\n", "\n", "\n", "theta_init = {\n", " \"mean\": np.float64(340.0),\n", " \"log_diag\": np.log(0.19),\n", " \"log_amps\": np.log([66.0, 2.4, 0.66, 0.18]),\n", " \"log_scales\": np.log([67.0, 90.0, 0.78, 1.6]),\n", " \"log_period\": np.float64(0.0),\n", " \"log_gamma\": np.log(4.3),\n", " \"log_alpha\": np.log(1.2),\n", "}\n", "\n", "# `jax` can be used to differentiate functions, and also note that we're calling\n", "# `jax.jit` for the best performance.\n", "obj = jax.jit(jax.value_and_grad(neg_log_likelihood))\n", "\n", "print(f\"Initial negative log likelihood: {obj(theta_init, t, y)[0]}\")\n", "print(\n", " f\"Gradient of the negative log likelihood, wrt the parameters:\\n{obj(theta_init, t, y)[1]}\"\n", ")" ] }, { "cell_type": "markdown", "id": "5", "metadata": {}, "source": [ "Some things to note here:\n", "\n", "1. If you're new to `jax` the way that I'm mixing `np` and `jnp` (the `jax` version of `numpy`) might seem a little confusing. In this example, I'm using regular `numpy` to simulate and prepare our test dataset, and then using `jax.numpy` everywhere else. The important point is that within the `neg_log_likelihood` function (and all the functions it calls), `np` is never used.\n", "\n", "2. This pattern of writing a `build_gp` function is a pretty common workflow in these docs. It's useful to have a way of instantiating our GP model at a new set of parameters, as we'll see below when we plot the conditional model. This might seem a little strange if you're coming from other libraries (like `george`, for example), but if you `jit` the function (see below) each model evaluation won't actually instantiate all these classes so you don't need to worry about performance implications. Check out {ref}`modeling` for some alternative workflows.\n", "\n", "3. Make sure that you remember to wrap your function in `jax.jit`. [The `jax` docs](https://jax.readthedocs.io) have more details about how this works, but for our purposes, the key thing is that this allows us to use the expressive `tinygp` kernel building syntax without worrying about the performance costs of all of these allocations.\n", "\n", "Using our loss function defined above, we'll run a gradient based optimization routine from `jaxopt` to fit this model as follows:" ] }, { "cell_type": "code", "execution_count": null, "id": "6", "metadata": {}, "outputs": [], "source": [ "import jaxopt\n", "\n", "solver = jaxopt.ScipyMinimize(fun=neg_log_likelihood)\n", "soln = solver.run(theta_init, X=t, y=y)\n", "print(f\"Final negative log likelihood: {soln.state.fun_val}\")" ] }, { "cell_type": "markdown", "id": "7", "metadata": {}, "source": [ "**Warning:** *An optimization code something like this should work on most problems but the results can be very sensitive to your choice of initialization and algorithm. If the results are nonsense, try choosing a better initial guess or try a different value of the ``method`` parameter in ``jaxopt.ScipyMinimize``.*\n", "\n", "We can plot our prediction of the CO2 concentration into the future using our optimized Gaussian process model by running:" ] }, { "cell_type": "code", "execution_count": null, "id": "8", "metadata": {}, "outputs": [], "source": [ "x = np.linspace(max(t), 2025, 2000)\n", "gp = build_gp(soln.params, t)\n", "cond_gp = gp.condition(y, x).gp\n", "mu, var = cond_gp.loc, cond_gp.variance\n", "\n", "plt.plot(t, y, \".k\")\n", "plt.fill_between(x, mu + np.sqrt(var), mu - np.sqrt(var), color=\"C0\", alpha=0.5)\n", "plt.plot(x, mu, color=\"C0\", lw=2)\n", "\n", "plt.xlim(t.min(), 2025)\n", "plt.xlabel(\"year\")\n", "_ = plt.ylabel(\"CO$_2$ in ppm\")" ] }, { "cell_type": "code", "execution_count": null, "id": "9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.9" } }, "nbformat": 4, "nbformat_minor": 5 }