{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# hide\n", "from neos.models import *\n", "from neos.makers import *\n", "from neos.transforms import *\n", "from neos.fit import *\n", "from neos.infer import *\n", "from neos.smooth import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# neos\n", "\n", "> ~neural~ nice end-to-end optimized statistics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[![DOI](https://zenodo.org/badge/235776682.svg)](https://zenodo.org/badge/latestdoi/235776682) ![CI](https://github.com/pyhf/neos/workflows/CI/badge.svg) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/pyhf/neos/master?filepath=demo_training.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"neos" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](assets/pyhf_3.gif)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## About\n", "\n", "Leverages the shoulders of giants ([`jax`](https://github.com/google/jax/), [`fax`](https://github.com/gehring/fax), and [`pyhf`](https://github.com/scikit-hep/pyhf)) to differentiate through a high-energy physics analysis workflow, including the construction of the frequentist profile likelihood.\n", "\n", "Documentation can be found at [gradhep.github.io/neos](gradhep.github.io/neos)!\n", "\n", "To see examples of `neos` in action, look for the notebooks in the nbs folder with the `demo_` prefix.\n", "\n", "If you're more of a video person, see [this talk](https://www.youtube.com/watch?v=3P4ZDkbleKs) given by [Nathan](https://github.com/phinate) on the broader topic of differentiable programming in high-energy physics, which also covers `neos`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install\n", "\n", "Just run\n", "\n", "```\n", "python -m pip install neos\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Contributing\n", "\n", "**Please read** [`CONTRIBUTING.md`](https://github.com/pyhf/neos/blob/master/CONTRIBUTING.md) **before making a PR**, as this project is maintained using [`nbdev`](https://github.com/fastai/nbdev), which operates completely using Jupyter notebooks. One should make their changes in the corresponding notebooks in the [`nbs`](nbs) folder (including `README` changes -- see `nbs/index.ipynb`), and not in the library code, which is automatically generated." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example usage -- train a neural network to optimize an expected p-value" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# bunch of imports:\n", "import time\n", "\n", "import jax\n", "import jax.experimental.optimizers as optimizers\n", "import jax.experimental.stax as stax\n", "import jax.random\n", "from jax.random import PRNGKey\n", "import numpy as np\n", "from functools import partial\n", "\n", "import pyhf\n", "\n", "pyhf.set_backend(\"jax\")\n", "pyhf.default_backend = pyhf.tensor.jax_backend(precision=\"64b\")\n", "\n", "from neos import data, infer, makers\n", "\n", "rng = PRNGKey(22)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start by making a basic neural network for regression with the `stax` module found in `jax`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "init_random_params, predict = stax.serial(\n", " stax.Dense(1024),\n", " stax.Relu,\n", " stax.Dense(1024),\n", " stax.Relu,\n", " stax.Dense(1),\n", " stax.Sigmoid,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's compose a workflow that can make use of this network in a typical high-energy physics statistical analysis. \n", "\n", "Our workflow is as follows:\n", "- From a set of normal distributions with different means, we'll generate four blobs of `(x,y)` points, corresponding to a signal process, a nominal background process, and two variations of the background from varying the background distribution's mean up and down.\n", "- We'll then feed these points into the previously defined neural network for each blob, and construct a histogram of the output using kernel density estimation. The difference between the two background variations is used as a systematic uncertainty on the nominal background.\n", "- We can then leverage the magic of `pyhf` to construct an [event-counting statistical model](https://scikit-hep.org/pyhf/intro.html#histfactory) from the histogram yields.\n", "- Finally, we calculate the p-value of a test between the nominal signal and background-only hypotheses. This uses a [profile likelihood-based test statistic](https://arxiv.org/abs/1007.1727). \n", "\n", "In code, `neos` can specify this workflow through function composition:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# data generator\n", "data_gen = data.generate_blobs(rng, blobs=4)\n", "# histogram maker\n", "hist_maker = makers.hists_from_nn(data_gen, predict, method=\"kde\")\n", "# statistical model maker\n", "model_maker = makers.histosys_model_from_hists(hist_maker)\n", "# CLs value getter\n", "get_cls = infer.expected_CLs(model_maker, solver_kwargs=dict(pdf_transform=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A peculiarity to note is that each of the functions used in this step actually return functions themselves. The reason we do this is that we need a skeleton of the workflow with all of the fixed parameters to be in place before calculating the loss function, as the only 'moving parts' here are the weights of the neural network." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`neos` also lets you specify hyperparameters for the histograms (e.g. binning, bandwidth) to allow these to be tuned throughout the learning process if neccesary (we don't do that here)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bins = np.linspace(0, 1, 4) # three bins in the range [0,1]\n", "bandwidth = 0.27 # smoothing parameter\n", "get_loss = partial(get_cls, hyperparams=dict(bins=bins, bandwidth=bandwidth))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our loss currently returns a list of metrics -- let's just index into the first one (the CLs value)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def loss(params, test_mu):\n", " return get_loss(params, test_mu)[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we just need to initialize the network's weights, and construct a training loop & optimizer:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# init weights\n", "_, network = init_random_params(jax.random.PRNGKey(2), (-1, 2))\n", "\n", "# init optimizer\n", "opt_init, opt_update, opt_params = optimizers.adam(1e-3)\n", "\n", "# define train loop\n", "def train_network(N):\n", " cls_vals = []\n", " _, network = init_random_params(jax.random.PRNGKey(1), (-1, 2))\n", " state = opt_init(network)\n", " losses = []\n", "\n", " # parameter update function\n", " def update_and_value(i, opt_state, mu):\n", " net = opt_params(opt_state)\n", " value, grad = jax.value_and_grad(loss)(net, mu)\n", " return opt_update(i, grad, state), value, net\n", "\n", " for i in range(N):\n", " start_time = time.time()\n", " state, value, network = update_and_value(i, state, 1.0)\n", " epoch_time = time.time() - start_time\n", " losses.append(value)\n", " metrics = {\"loss\": losses}\n", " yield network, metrics, epoch_time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's time to train!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0: CLs = 0.06885, took 13.4896s\n", "epoch 1: CLs = 0.03580, took 1.9772s\n", "epoch 2: CLs = 0.01728, took 1.9912s\n", "epoch 3: CLs = 0.00934, took 1.9947s\n", "epoch 4: CLs = 0.00561, took 1.9548s\n", "epoch 5: CLs = 0.00378, took 1.9761s\n", "epoch 6: CLs = 0.00280, took 1.9500s\n", "epoch 7: CLs = 0.00224, took 1.9844s\n", "epoch 8: CLs = 0.00190, took 1.9913s\n", "epoch 9: CLs = 0.00168, took 1.9928s\n" ] } ], "source": [ "maxN = 10 # make me bigger for better results (*nearly* true ;])\n", "\n", "for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)):\n", " print(f\"epoch {i}:\", f'CLs = {metrics[\"loss\"][-1]:.5f}, took {epoch_time:.4f}s')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And there we go!\n", "\n", "You'll notice the first epoch seems to have a much larger training time. This is because `jax` is being used to just-in-time compile some of the code, which is an overhead that only needs to happen once.\n", "\n", "If you want to reproduce the full animation from the top of this README, a version of this code with plotting helpers can be found in [`demo_kde_pyhf.ipynb`](https://github.com/pyhf/neos/blob/master/demo_kde_pyhf.ipynb)! :D" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Thanks\n", "\n", "A big thanks to the teams behind [`jax`](https://github.com/google/jax/), [`fax`](https://github.com/gehring/fax), and [`pyhf`](https://github.com/scikit-hep/pyhf) for their software and support." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }