{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Eclipsing binary: Linear solution for the maps"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we're going to use a linear solve to infer the surface maps of two stars in an eclipsing binary given the light curve of the system. We generated the data in [this notebook](EclipsingBinary_Generate.ipynb). This is a follow up to the [notebook](EclipsingBinary_PyMC3.ipynb) in which we solved the system using `pymc3`. Because `starry` is a linear model, we can actually solve the same problem *analytically* and in no time at all.\n",
    "\n",
    "Let's begin with some imports. Note that we're again disabling the `lazy` evaluation to make things a bit easier, although this notebook would also work with that enabled."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide_input"
    ]
   },
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide_input"
    ]
   },
   "outputs": [],
   "source": [
    "%run notebook_setup.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import starry\n",
    "from scipy.linalg import cho_solve\n",
    "from corner import corner\n",
    "\n",
    "np.random.seed(12)\n",
    "starry.config.lazy = False\n",
    "starry.config.quiet = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the data\n",
    "\n",
    "Let's load the EB dataset as before:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide_input",
     "hide_output"
    ]
   },
   "outputs": [],
   "source": [
    "# Run the Generate notebook if needed\n",
    "if not os.path.exists(\"eb.npz\"):\n",
    "    import nbformat\n",
    "    from nbconvert.preprocessors import ExecutePreprocessor\n",
    "\n",
    "    with open(\"EclipsingBinary_Generate.ipynb\") as f:\n",
    "        nb = nbformat.read(f, as_version=4)\n",
    "    ep = ExecutePreprocessor(timeout=600, kernel_name=\"python3\")\n",
    "    ep.preprocess(nb);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = np.load(\"eb.npz\", allow_pickle=True)\n",
    "A = data[\"A\"].item()\n",
    "B = data[\"B\"].item()\n",
    "t = data[\"t\"]\n",
    "flux = data[\"flux\"]\n",
    "sigma = data[\"sigma\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instantiate the primary, secondary, and system objects. As before, we assume we know the true values of all the orbital parameters and star properties, *except* for the two surface maps. (If you just read the [PyMC3 notebook](EclipsingBinary_PyMC3.ipynb), note that we're no longer instantiating these within a `pymc3.Model` context.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Primary\n",
    "pri = starry.Primary(\n",
    "    starry.Map(ydeg=A[\"ydeg\"], udeg=A[\"udeg\"], inc=A[\"inc\"]),\n",
    "    r=A[\"r\"],\n",
    "    m=A[\"m\"],\n",
    "    prot=A[\"prot\"],\n",
    ")\n",
    "pri.map[1:] = A[\"u\"]\n",
    "\n",
    "# Secondary\n",
    "sec = starry.Secondary(\n",
    "    starry.Map(ydeg=B[\"ydeg\"], udeg=B[\"udeg\"], inc=B[\"inc\"]),\n",
    "    r=B[\"r\"],\n",
    "    m=B[\"m\"],\n",
    "    porb=B[\"porb\"],\n",
    "    prot=B[\"prot\"],\n",
    "    t0=B[\"t0\"],\n",
    "    inc=B[\"inc\"],\n",
    ")\n",
    "sec.map[1:] = B[\"u\"]\n",
    "\n",
    "# System\n",
    "sys = starry.System(pri, sec)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's the light curve we're going to do inference on:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(1, figsize=(12, 5))\n",
    "ax.plot(t, flux, \"k.\", alpha=0.5, ms=4)\n",
    "ax.set_xlabel(\"time [days]\", fontsize=24)\n",
    "ax.set_ylabel(\"normalized flux\", fontsize=24);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Linear solve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "hide_input",
     "hide_output"
    ]
   },
   "outputs": [],
   "source": [
    "# HACK: Pre-compile the solve function\n",
    "# to get an accurate timing test below!\n",
    "sys.set_data(np.array([0.0]), C=1.0)\n",
    "pri.map.set_prior(L=1)\n",
    "sec.map.set_prior(L=1)\n",
    "sys.solve(t=np.array([0.0]));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to compute the posterior over maps, we need a prior for the spherical harmonic coefficients of each star. The linear solve **requires** Gausssian priors on the spherical harmonic coefficients; these are specified in terms of a mean vector $\\mu$ (``mu``) and a covariance matrix $\\Lambda$ (``L``). Recall that this is similar to what we did in the [PyMC3 notebook](EclipsingBinary_PyMC3.ipynb).\n",
    "\n",
    "It is important to note that when using the linear solve feature in ``starry``, the prior is actually placed on the \n",
    "**amplitude-weighted** spherical harmonic coefficients. \n",
    "In other words, if $\\alpha$ is the map amplitude (``map.amp``) and $y$ is the vector of spherical harmonic coefficients (``map.y``), we place a prior on the quantity $x \\equiv \\alpha y$. While this may be confusing at first, recall that the coefficient of the $Y_{0,0}$ harmonic is always **fixed at unity** in ``starry``, so we can't really solve for it. But we *can* solve for all elements of the vector $x$. Once we have the posterior for $x$, we can easily obtain both the amplitude (equal to $x_0$) and the spherical harmonic coefficient vector (equal to $x / x_0$). This allows us to simultaneously obtain both the amplitude and the coefficients using a single efficient linear solve.\n",
    "\n",
    "Because of this convention, the first element of the mean and the first row/column of the covariance are special: they control the amplitude of the map. For maps whose baseline has been properly normalized, the mean $\\mu_\\alpha$ of this term should be equal to (or close to) one. Its variance $\\lambda_\\alpha$ (the first diagonal entry of the covariance) is the square of the uncertainty on the amplitude of the map.\n",
    "\n",
    "The remaining elements are the prior on the $l>0$ spherical harmonic coefficients, weighted by the amplitude. For these, the easiest kind of prior we can place is an isotropic prior (no preferred direction), in which $\\mu = 0$ and the corresponding block of $\\Lambda$ is a diagonal matrix. In this case, the diagonal entries of $\\Lambda$ are related to the power spectrum of the map. We'll discuss this in more detail later, but for now let's assume a flat power spectrum, in which there is no preferred scale, so $\\Lambda = \\lambda I$. The quantity $\\lambda$ is essentially a regularization parameter, whose amplitude controls the relative weighting of the data and the prior in determining the posterior. \n",
    "\n",
    "For definiteness, we'll choose $\\mu_\\alpha = 1$ and $\\lambda_\\alpha = \\lambda = 10^{-2}$ for the primary and $\\mu_\\alpha = 0.1$ and $\\lambda_\\alpha = \\lambda = 10^{-4}$ for the secondary (i.e., we assume we know the secondary has one-tenth the luminosity of the primary, but we allow for some uncertainty in that value). Readers are encouraged to experiment with different values."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prior on primary\n",
    "pri_mu = np.zeros(pri.map.Ny)\n",
    "pri_mu[0] = 1.0\n",
    "pri_L = np.zeros(pri.map.Ny)\n",
    "pri_L[0] = 1e-2\n",
    "pri_L[1:] = 1e-2\n",
    "pri.map.set_prior(mu=pri_mu, L=pri_L)\n",
    "\n",
    "# Prior on secondary\n",
    "sec_mu = np.zeros(sec.map.Ny)\n",
    "sec_mu[0] = 0.1\n",
    "sec_L = np.zeros(sec.map.Ny)\n",
    "sec_L[0] = 1e-4\n",
    "sec_L[1:] = 1e-4\n",
    "sec.map.set_prior(mu=sec_mu, L=sec_L)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(Note that $L$ may be a scalar, vector, or matrix, and `starry` will construct the covariance matrix for you. Alternatively, users may instead specify `cho_L`, the Cholesky factorization of the covariance matrix)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we specify the data and data covariance $C$ (the measurement uncertainty):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.set_data(flux, C=sigma ** 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(As before, users can pass a scalar, vector or matrix as the data covariance, or the Cholesky factorization `cho_C`)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, let's solve the linear problem! We do this by calling `sys.solve()` and passing the array of times at which to evaluate the light curve. The method returns the mean $\\mu$ and Cholesky factorization $L$ of the posterior covariance for each body in the system. Let's time how long this takes:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mu, cho_cov = sys.solve(t=t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The linear solve is **extremely fast**! Note that once we run the `solve` method, we can call the `draw` method to draw samples from the posterior. Let's do that and visualize a random sample from each map:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.draw()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pri.map.show(theta=np.linspace(0, 360, 50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sec.map.show(theta=np.linspace(0, 360, 50))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can compare these maps to the true maps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# true values\n",
    "pri.map[1:, :] = A[\"y\"]\n",
    "pri.map.amp = A[\"amp\"]\n",
    "pri_true = pri.map.render(projection=\"rect\")\n",
    "sec.map[1:, :] = B[\"y\"]\n",
    "sec.map.amp = B[\"amp\"]\n",
    "sec_true = sec.map.render(projection=\"rect\")\n",
    "\n",
    "# mean values\n",
    "pri.map.amp = mu[0]\n",
    "pri.map[1:, :] = mu[1 : pri.map.Ny] / pri.map.amp\n",
    "pri_mu = pri.map.render(projection=\"rect\")\n",
    "sec.map.amp = mu[pri.map.Ny]\n",
    "sec.map[1:, :] = mu[pri.map.Ny + 1 :] / sec.map.amp\n",
    "sec_mu = sec.map.render(projection=\"rect\")\n",
    "\n",
    "# a random draw\n",
    "sys.draw()\n",
    "pri_draw = pri.map.render(projection=\"rect\")\n",
    "sec_draw = sec.map.render(projection=\"rect\")\n",
    "\n",
    "fig, ax = plt.subplots(3, 2, figsize=(8, 7))\n",
    "ax[0, 0].imshow(\n",
    "    pri_true,\n",
    "    origin=\"lower\",\n",
    "    extent=(-180, 180, -90, 90),\n",
    "    cmap=\"plasma\",\n",
    "    vmin=0,\n",
    "    vmax=0.4,\n",
    ")\n",
    "ax[1, 0].imshow(\n",
    "    pri_mu,\n",
    "    origin=\"lower\",\n",
    "    extent=(-180, 180, -90, 90),\n",
    "    cmap=\"plasma\",\n",
    "    vmin=0,\n",
    "    vmax=0.4,\n",
    ")\n",
    "ax[2, 0].imshow(\n",
    "    pri_draw,\n",
    "    origin=\"lower\",\n",
    "    extent=(-180, 180, -90, 90),\n",
    "    cmap=\"plasma\",\n",
    "    vmin=0,\n",
    "    vmax=0.4,\n",
    ")\n",
    "ax[0, 1].imshow(\n",
    "    sec_true,\n",
    "    origin=\"lower\",\n",
    "    extent=(-180, 180, -90, 90),\n",
    "    cmap=\"plasma\",\n",
    "    vmin=0,\n",
    "    vmax=0.04,\n",
    ")\n",
    "ax[1, 1].imshow(\n",
    "    sec_mu,\n",
    "    origin=\"lower\",\n",
    "    extent=(-180, 180, -90, 90),\n",
    "    cmap=\"plasma\",\n",
    "    vmin=0,\n",
    "    vmax=0.04,\n",
    ")\n",
    "ax[2, 1].imshow(\n",
    "    sec_draw,\n",
    "    origin=\"lower\",\n",
    "    extent=(-180, 180, -90, 90),\n",
    "    cmap=\"plasma\",\n",
    "    vmin=0,\n",
    "    vmax=0.04,\n",
    ")\n",
    "ax[0, 0].set_title(\"primary\")\n",
    "ax[0, 1].set_title(\"secondary\")\n",
    "ax[0, 0].set_ylabel(\"true\", rotation=0, labelpad=20)\n",
    "ax[1, 0].set_ylabel(\"mean\", rotation=0, labelpad=20)\n",
    "ax[2, 0].set_ylabel(\"draw\", rotation=0, labelpad=20);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not bad! Also note how similar these are to the results we got in the [PyMC3 notebook](EclipsingBinary_PyMC3.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The other thing we can do is draw samples from this solution and plot the traditional corner plot for the posterior. Armed with the posterior mean `mu` and the Cholesky factorization of the covariance `cho_cov`, this is [super easy to do](https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Drawing_values_from_the_distribution). Let's generate 10000 samples from the posterior of the primary's surface map:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nsamples = 10000\n",
    "u = np.random.randn(len(mu), nsamples)\n",
    "samples = mu.reshape(1, -1) + np.dot(cho_cov, u).T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here's the posterior for the amplitude and the first eight $l > 0$ coefficients of the primary:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(9, 9, figsize=(7, 7))\n",
    "labels = [r\"$\\alpha$\"] + [\n",
    "    r\"$Y_{%d,%d}$\" % (l, m)\n",
    "    for l in range(1, pri.map.ydeg + 1)\n",
    "    for m in range(-l, l + 1)\n",
    "]\n",
    "\n",
    "# De-weight the samples to get\n",
    "# samples of the actual Ylm coeffs\n",
    "samps = np.array(samples[:, :9])\n",
    "samps[:, 1:] /= samps[:, 0].reshape(-1, 1)\n",
    "\n",
    "corner(samps, fig=fig, labels=labels)\n",
    "for axis in ax.flatten():\n",
    "    axis.xaxis.set_tick_params(labelsize=6)\n",
    "    axis.yaxis.set_tick_params(labelsize=6)\n",
    "    axis.xaxis.label.set_size(12)\n",
    "    axis.yaxis.label.set_size(12)\n",
    "    axis.xaxis.set_label_coords(0.5, -0.6)\n",
    "    axis.yaxis.set_label_coords(-0.6, 0.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that this is **exactly** the same covariance matrix we got in the [PyMC3 notebook](EclipsingBinary_PyMC3.ipynb) (within sampling error)!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So, just to recap: the spherical harmonics coefficients can be *linearly* computed given a light curve, provided we know everything else about the system. In most realistic cases we don't know the orbital parameters, limb darkening coefficients, etc. exactly, so the thing to do is to *combine* the linear solve with `pymc3` sampling. We'll do that in the [next notebook](EclipsingBinary_FullSolution.ipynb)."
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}