{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# neos.data\n",
    "\n",
    "> Helper module to easily generate example data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "\n",
    "def generate_blobs(\n",
    "    rng,\n",
    "    blobs,\n",
    "    NMC=500,\n",
    "    sig_mean=jnp.asarray([-1, 1]),\n",
    "    bup_mean=jnp.asarray([2.5, 2]),\n",
    "    bdown_mean=jnp.asarray([-2.5, -1.5]),\n",
    "    b_mean=jnp.asarray([1, -1]),\n",
    "):\n",
    "    \"\"\"\n",
    "    Function that returns a callable to generate a set of 2D normally \n",
    "    distributed blobs, corresponding to signal, background, and background \n",
    "    uncertainty modes.\n",
    "\n",
    "    Args:\n",
    "        rng: jax PRNG key (random seed).\n",
    "        blobs: Number of blobs to generate (3 or 4).\n",
    "        NMC: Number of 'monte carlo' samples to generate.\n",
    "        sig_mean: jax array of the mean of the signal distribution.\n",
    "        bup_mean: jax array of the mean of the 'up' background distribution.\n",
    "        bdown_mean: jax array of the mean of the 'up' background distribution.\n",
    "        b_mean: jax array of the mean of the nominal background distribution.\n",
    "    \"\"\"\n",
    "    if blobs == 3:\n",
    "\n",
    "        def gen_blobs():\n",
    "            sig = jax.random.multivariate_normal(\n",
    "                rng, sig_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n",
    "            )\n",
    "            bkg_up = jax.random.multivariate_normal(\n",
    "                rng, bup_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n",
    "            )\n",
    "            bkg_down = jax.random.multivariate_normal(\n",
    "                rng, bdown_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n",
    "            )\n",
    "\n",
    "            return sig, bkg_up, bkg_down\n",
    "\n",
    "    elif blobs == 4:\n",
    "\n",
    "        def gen_blobs():\n",
    "            sig = jax.random.multivariate_normal(\n",
    "                rng, sig_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n",
    "            )\n",
    "            bkg_up = jax.random.multivariate_normal(\n",
    "                rng, bup_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n",
    "            )\n",
    "            bkg_down = jax.random.multivariate_normal(\n",
    "                rng, bdown_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n",
    "            )\n",
    "            bkg_nom = jax.random.multivariate_normal(\n",
    "                rng, b_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n",
    "            )\n",
    "\n",
    "            return sig, bkg_nom, bkg_up, bkg_down\n",
    "\n",
    "    else:\n",
    "        assert False, (\n",
    "            f\"Unsupported number of blobs: {blobs}\"\n",
    "            \" (only using 3 or 4 blobs for these examples).\"\n",
    "        )\n",
    "\n",
    "    return gen_blobs\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Usage:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(100, 2), (100, 2), (100, 2)]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "from jax.random import PRNGKey\n",
    "\n",
    "import neos\n",
    "from neos.data import generate_blobs\n",
    "\n",
    "# initialize generator, then call for data\n",
    "data_gen = generate_blobs(rng=PRNGKey(2), blobs=3, NMC=100)\n",
    "data = data_gen()\n",
    "\n",
    "[x.shape for x in data]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Same for 4 blobs, but half the number of samples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(50, 2), (50, 2), (50, 2), (50, 2)]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# initialize generator, then call for data\n",
    "data_gen = generate_blobs(rng=PRNGKey(2), blobs=4, NMC=50)\n",
    "data = data_gen()\n",
    "\n",
    "[x.shape for x in data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}