{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# neos.data\n", "\n", "> Helper module to easily generate example data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "\n", "def generate_blobs(\n", " rng,\n", " blobs,\n", " NMC=500,\n", " sig_mean=jnp.asarray([-1, 1]),\n", " bup_mean=jnp.asarray([2.5, 2]),\n", " bdown_mean=jnp.asarray([-2.5, -1.5]),\n", " b_mean=jnp.asarray([1, -1]),\n", "):\n", " \"\"\"\n", " Function that returns a callable to generate a set of 2D normally \n", " distributed blobs, corresponding to signal, background, and background \n", " uncertainty modes.\n", "\n", " Args:\n", " rng: jax PRNG key (random seed).\n", " blobs: Number of blobs to generate (3 or 4).\n", " NMC: Number of 'monte carlo' samples to generate.\n", " sig_mean: jax array of the mean of the signal distribution.\n", " bup_mean: jax array of the mean of the 'up' background distribution.\n", " bdown_mean: jax array of the mean of the 'up' background distribution.\n", " b_mean: jax array of the mean of the nominal background distribution.\n", " \"\"\"\n", " if blobs == 3:\n", "\n", " def gen_blobs():\n", " sig = jax.random.multivariate_normal(\n", " rng, sig_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n", " )\n", " bkg_up = jax.random.multivariate_normal(\n", " rng, bup_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n", " )\n", " bkg_down = jax.random.multivariate_normal(\n", " rng, bdown_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n", " )\n", "\n", " return sig, bkg_up, bkg_down\n", "\n", " elif blobs == 4:\n", "\n", " def gen_blobs():\n", " sig = jax.random.multivariate_normal(\n", " rng, sig_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n", " )\n", " bkg_up = jax.random.multivariate_normal(\n", " rng, bup_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n", " )\n", " bkg_down = jax.random.multivariate_normal(\n", " rng, bdown_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n", " )\n", " bkg_nom = jax.random.multivariate_normal(\n", " rng, b_mean, jnp.asarray([[1, 0], [0, 1]]), shape=(NMC,)\n", " )\n", "\n", " return sig, bkg_nom, bkg_up, bkg_down\n", "\n", " else:\n", " assert False, (\n", " f\"Unsupported number of blobs: {blobs}\"\n", " \" (only using 3 or 4 blobs for these examples).\"\n", " )\n", "\n", " return gen_blobs\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Usage:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(100, 2), (100, 2), (100, 2)]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import jax\n", "from jax.random import PRNGKey\n", "\n", "import neos\n", "from neos.data import generate_blobs\n", "\n", "# initialize generator, then call for data\n", "data_gen = generate_blobs(rng=PRNGKey(2), blobs=3, NMC=100)\n", "data = data_gen()\n", "\n", "[x.shape for x in data]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Same for 4 blobs, but half the number of samples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[(50, 2), (50, 2), (50, 2), (50, 2)]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# initialize generator, then call for data\n", "data_gen = generate_blobs(rng=PRNGKey(2), blobs=4, NMC=50)\n", "data = data_gen()\n", "\n", "[x.shape for x in data]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }