{ "cells": [ { "cell_type": "markdown", "id": "3644c83b", "metadata": {}, "source": [ "(stellar-variability)=\n", "\n", "# Gaussian process models for stellar variability" ] }, { "cell_type": "code", "execution_count": null, "id": "2001847a", "metadata": {}, "outputs": [], "source": [ "import exoplanet\n", "\n", "exoplanet.utils.docs_setup()\n", "print(f\"exoplanet.__version__ = '{exoplanet.__version__}'\")" ] }, { "cell_type": "markdown", "id": "8fdcfa7d", "metadata": {}, "source": [ "When fitting exoplanets, we also need to fit for the stellar variability and Gaussian Processes (GPs) are often a good descriptive model for this variation.\n", "[PyMC3 has support for all sorts of general GP models](https://docs.pymc.io/gp.html), but *exoplanet* interfaces with the [celerite2](https://celerite2.readthedocs.io/) library to provide support for scalable 1D GPs (take a look at the [Getting started](https://celerite2.readthedocs.io/en/latest/tutorials/first/) tutorial on the *celerite2* docs for a crash course) that can work with large datasets.\n", "In this tutorial, we go through the process of modeling the light curve of a rotating star observed by Kepler using *exoplanet* and *celerite2*.\n", "\n", "First, let's download and plot the data:" ] }, { "cell_type": "code", "execution_count": null, "id": "b9d6dcc9", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import lightkurve as lk\n", "import matplotlib.pyplot as plt\n", "\n", "lcf = lk.search_lightcurve(\n", " \"TIC 10863087\", mission=\"TESS\", author=\"SPOC\"\n", ").download_all(quality_bitmask=\"hardest\", flux_column=\"pdcsap_flux\")\n", "lc = lcf.stitch().remove_nans().remove_outliers()\n", "lc = lc[:5000]\n", "_, mask = lc.flatten().remove_outliers(sigma=3.0, return_mask=True)\n", "lc = lc[~mask]\n", "\n", "x = np.ascontiguousarray(lc.time.value, dtype=np.float64)\n", "y = np.ascontiguousarray(lc.flux, dtype=np.float64)\n", "yerr = np.ascontiguousarray(lc.flux_err, dtype=np.float64)\n", "mu = np.mean(y)\n", "y = (y / mu - 1) * 1e3\n", "yerr = yerr * 1e3 / mu\n", "\n", "plt.plot(x, y, \"k\")\n", "plt.xlim(x.min(), x.max())\n", "plt.xlabel(\"time [days]\")\n", "plt.ylabel(\"relative flux [ppt]\")\n", "_ = plt.title(\"TIC 10863087\")" ] }, { "cell_type": "markdown", "id": "6ce25da5", "metadata": {}, "source": [ "## A Gaussian process model for stellar variability\n", "\n", "This looks like the light curve of a rotating star, and [it has been shown](https://arxiv.org/abs/1706.05459) that it is possible to model this variability by using a quasiperiodic Gaussian process.\n", "To start with, let's get an estimate of the rotation period using the Lomb-Scargle periodogram:" ] }, { "cell_type": "code", "execution_count": null, "id": "23b731ad", "metadata": {}, "outputs": [], "source": [ "import exoplanet as xo\n", "\n", "results = xo.estimators.lomb_scargle_estimator(\n", " x, y, max_peaks=1, min_period=0.1, max_period=2.0, samples_per_peak=50\n", ")\n", "\n", "peak = results[\"peaks\"][0]\n", "freq, power = results[\"periodogram\"]\n", "plt.plot(1 / freq, power, \"k\")\n", "plt.axvline(peak[\"period\"], color=\"k\", lw=4, alpha=0.3)\n", "plt.xlim((1 / freq).min(), (1 / freq).max())\n", "plt.yticks([])\n", "plt.xlabel(\"period [days]\")\n", "_ = plt.ylabel(\"power\")" ] }, { "cell_type": "markdown", "id": "8890345c", "metadata": {}, "source": [ "Now, using this initialization, we can set up the GP model in *exoplanet* and *celerite2*.\n", "We'll use the [RotationTerm](https://celerite2.readthedocs.io/en/latest/api/python/#celerite2.terms.RotationTerm) kernel that is a mixture of two simple harmonic oscillators with periods separated by a factor of two.\n", "As you can see from the periodogram above, this might be a good model for this light curve and I've found that it works well in many cases." ] }, { "cell_type": "code", "execution_count": null, "id": "71f77026", "metadata": {}, "outputs": [], "source": [ "import pymc3 as pm\n", "import pymc3_ext as pmx\n", "import aesara_theano_fallback.tensor as tt\n", "from celerite2.theano import terms, GaussianProcess\n", "\n", "with pm.Model() as model:\n", " # The mean flux of the time series\n", " mean = pm.Normal(\"mean\", mu=0.0, sigma=10.0)\n", "\n", " # A jitter term describing excess white noise\n", " log_jitter = pm.Normal(\"log_jitter\", mu=np.log(np.mean(yerr)), sigma=2.0)\n", "\n", " # A term to describe the non-periodic variability\n", " sigma = pm.InverseGamma(\n", " \"sigma\", **pmx.estimate_inverse_gamma_parameters(1.0, 5.0)\n", " )\n", " rho = pm.InverseGamma(\n", " \"rho\", **pmx.estimate_inverse_gamma_parameters(0.5, 2.0)\n", " )\n", "\n", " # The parameters of the RotationTerm kernel\n", " sigma_rot = pm.InverseGamma(\n", " \"sigma_rot\", **pmx.estimate_inverse_gamma_parameters(1.0, 5.0)\n", " )\n", " log_period = pm.Normal(\"log_period\", mu=np.log(peak[\"period\"]), sigma=2.0)\n", " period = pm.Deterministic(\"period\", tt.exp(log_period))\n", " log_Q0 = pm.HalfNormal(\"log_Q0\", sigma=2.0)\n", " log_dQ = pm.Normal(\"log_dQ\", mu=0.0, sigma=2.0)\n", " f = pm.Uniform(\"f\", lower=0.1, upper=1.0)\n", "\n", " # Set up the Gaussian Process model\n", " kernel = terms.SHOTerm(sigma=sigma, rho=rho, Q=1 / 3.0)\n", " kernel += terms.RotationTerm(\n", " sigma=sigma_rot,\n", " period=period,\n", " Q0=tt.exp(log_Q0),\n", " dQ=tt.exp(log_dQ),\n", " f=f,\n", " )\n", " gp = GaussianProcess(\n", " kernel,\n", " t=x,\n", " diag=yerr**2 + tt.exp(2 * log_jitter),\n", " mean=mean,\n", " quiet=True,\n", " )\n", "\n", " # Compute the Gaussian Process likelihood and add it into the\n", " # the PyMC3 model as a \"potential\"\n", " gp.marginal(\"gp\", observed=y)\n", "\n", " # Compute the mean model prediction for plotting purposes\n", " pm.Deterministic(\"pred\", gp.predict(y))\n", "\n", " # Optimize to find the maximum a posteriori parameters\n", " map_soln = pmx.optimize()" ] }, { "cell_type": "markdown", "id": "10672ab7", "metadata": {}, "source": [ "Now that we have the model set up, let's plot the maximum a posteriori model prediction." ] }, { "cell_type": "code", "execution_count": null, "id": "811f3ad6", "metadata": {}, "outputs": [], "source": [ "plt.plot(x, y, \"k\", label=\"data\")\n", "plt.plot(x, map_soln[\"pred\"], color=\"C1\", label=\"model\")\n", "plt.xlim(x.min(), x.max())\n", "plt.legend(fontsize=10)\n", "plt.xlabel(\"time [days]\")\n", "plt.ylabel(\"relative flux [ppt]\")\n", "_ = plt.title(\"TIC 10863087; map model\")" ] }, { "cell_type": "markdown", "id": "50ef0408", "metadata": {}, "source": [ "That looks pretty good!\n", "Now let's sample from the posterior using [the PyMC3 Extras (pymc3-ext) library](https://github.com/exoplanet-dev/pymc3-ext):" ] }, { "cell_type": "code", "execution_count": null, "id": "bb91e006", "metadata": {}, "outputs": [], "source": [ "with model:\n", " trace = pmx.sample(\n", " tune=1000,\n", " draws=1000,\n", " start=map_soln,\n", " cores=2,\n", " chains=2,\n", " target_accept=0.9,\n", " return_inferencedata=True,\n", " random_seed=[10863087, 10863088],\n", " )" ] }, { "cell_type": "markdown", "id": "c94d95a1", "metadata": {}, "source": [ "Now we can do the usual convergence checks:" ] }, { "cell_type": "code", "execution_count": null, "id": "15ca0850", "metadata": {}, "outputs": [], "source": [ "import arviz as az\n", "\n", "az.summary(\n", " trace,\n", " var_names=[\n", " \"f\",\n", " \"log_dQ\",\n", " \"log_Q0\",\n", " \"log_period\",\n", " \"sigma_rot\",\n", " \"rho\",\n", " \"sigma\",\n", " \"log_jitter\",\n", " \"mean\",\n", " ],\n", ")" ] }, { "cell_type": "markdown", "id": "ccf56286", "metadata": {}, "source": [ "And plot the posterior distribution over rotation period:" ] }, { "cell_type": "code", "execution_count": null, "id": "6d33dc06", "metadata": {}, "outputs": [], "source": [ "period_samples = np.asarray(trace.posterior[\"period\"]).flatten()\n", "plt.hist(period_samples, 25, histtype=\"step\", color=\"k\", density=True)\n", "plt.yticks([])\n", "plt.xlabel(\"rotation period [days]\")\n", "_ = plt.ylabel(\"posterior density\")" ] }, { "cell_type": "markdown", "id": "d9026bbe", "metadata": {}, "source": [ "## Citations\n", "\n", "As described in the [citation tutorial](https://docs.exoplanet.codes/en/stable/tutorials/citation/), we can use [citations.get_citations_for_model](https://docs.exoplanet.codes/en/stable/user/api/#exoplanet.citations.get_citations_for_model) to construct an acknowledgement and BibTeX listing that includes the relevant citations for this model." ] }, { "cell_type": "code", "execution_count": null, "id": "dc94e171", "metadata": {}, "outputs": [], "source": [ "with model:\n", " txt, bib = xo.citations.get_citations_for_model()\n", "print(txt)" ] }, { "cell_type": "code", "execution_count": null, "id": "f352af9f", "metadata": {}, "outputs": [], "source": [ "print(bib.split(\"\\n\\n\")[0] + \"\\n\\n...\")" ] }, { "cell_type": "code", "execution_count": null, "id": "c5c2040e-5175-4ead-8d07-88c5a56400db", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.4" } }, "nbformat": 4, "nbformat_minor": 5 }