{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from neos.models import *\n", "from neos.makers import *\n", "from neos.transforms import *\n", "from neos.fit import *\n", "from neos.cls import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# neos\n", "\n", "> nice end-to-end optimized statistics ;)\n", "\n", "Leverages the shoulders of giants (pyhf, jax).\n", "\n", "Documentation can be found within the notebooks in the nbs folder, or (temporarily) at [phinate.github.io/neos](phinate.github.io/neos).\n", "\n", "To see examples of neos in action, look for the notebooks in the nbs folder with the `demo_` prefix." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[![DOI](https://zenodo.org/badge/235776682.svg)](https://zenodo.org/badge/latestdoi/235776682) ![CI](https://github.com/pyhf/neos/workflows/CI/badge.svg) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/pyhf/neos/master?filepath=demo_training.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"neos" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](assets/pyhf_3.gif)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install\n", "\n", "Just run\n", "\n", "```\n", "python -m pip install neos\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Contributing\n", "\n", "**Please read** [`CONTRIBUTING.md`](https://github.com/pyhf/neos/blob/master/CONTRIBUTING.md) **before making a PR**, as this project is maintained using [`nbdev`](https://github.com/fastai/nbdev), which operates completely using Jupyter notebooks. One should make their changes in the corresponding notebooks in the [`nbs`](nbs) folder (including `README` changes -- see `nbs/index.ipynb`), and not in the library code, which is automatically generated." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example using a sigmoid-based neural network:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import neos.makers as makers\n", "import neos.cls as cls\n", "import numpy as np\n", "import jax.experimental.stax as stax\n", "import jax.experimental.optimizers as optimizers\n", "import jax.random\n", "import time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Initialise network using `jax.experimental.stax`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "init_random_params, predict = stax.serial(\n", " stax.Dense(1024),\n", " stax.Relu,\n", " stax.Dense(1024),\n", " stax.Relu,\n", " stax.Dense(2),\n", " stax.Softmax,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Initialse tools from `neos`:\n", "\n", "The way we initialise in `neos` is to define functions that make a statistical model from histograms, which in turn are themselves made from a predictive model, such as a neural network. Here's some detail on the unctions used below:\n", "\n", "- `hists_from_nn_three_blobs(predict)` uses the nn decision function `predict` defined in the cell above to form histograms from signal and background data, all drawn from multivariate normal distributions with different means. Two background distributions are sampled from, which is meant to mimic the situation in particle physics where one has a 'nominal' prediction for a nuisance parameter and then an alternate value (e.g. from varying up/down by one standard deviation), which then modifies the background pdf. Here, we take that effect to be a shift of the mean of the distribution. The value for the background histogram is then the mean of the resulting counts of the two modes, and the uncertainty can be quantified through the count standard deviation.\n", "- `nn_hepdata_like(hmaker)` uses `hmaker` to construct histograms, then feeds them into the `neos.models.hepdata_like` function that constructs a pyhf-like model. This can then be used to call things like `logpdf` and `expected_data` downstream.\n", "- `cls_maker` takes a model-making function as it's primary argument, which is fed into functions from `neos.fit` that minimise the `logpdf` of the model in both a constrained (fixed parameter of interest) and a global way. Moreover, these fits are wrapped in a function that allows us to calculate gradients through the fits using *fixed-point differentiation*. This allows for the calculation of both the profile likelihood and its gradient, and then the same for cls :)\n", "\n", "All three of these methods return functions. in particular, `cls_maker` returns a function that differentiably calculates CLs values, which is our desired objective to minimise." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "hmaker = makers.hists_from_nn_three_blobs(predict)\n", "nnm = makers.nn_hepdata_like(hmaker)\n", "loss = cls.cls_maker(nnm, solver_kwargs=dict(pdf_transform=True))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/phinate/envs/neos/lib/python3.7/site-packages/jax-0.1.59-py3.7.egg/jax/lib/xla_bridge.py:122: UserWarning: No GPU/TPU found, falling back to CPU.\n" ] } ], "source": [ "_, network = init_random_params(jax.random.PRNGKey(2), (-1, 2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define training loop!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "opt_init, opt_update, opt_params = optimizers.adam(1e-3)\n", "\n", "def update_and_value(i, opt_state, mu):\n", " net = opt_params(opt_state)\n", " value, grad = jax.value_and_grad(loss)(net, mu)\n", " return opt_update(i, grad, opt_state), value, net\n", "\n", "def train_network(N):\n", " cls_vals = []\n", " _, network = init_random_params(jax.random.PRNGKey(1), (-1, 2))\n", " state = opt_init(network)\n", " losses = []\n", " \n", " for i in range(N):\n", " start_time = time.time()\n", " state, value, network = update_and_value(i,state,1.0)\n", " epoch_time = time.time() - start_time\n", " losses.append(value)\n", " metrics = {\"loss\": losses}\n", " yield network, metrics, epoch_time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Let's run it!!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch 0: CLs = 0.06680655092981347, took 5.355436325073242s\n", "epoch 1: CLs = 0.4853891149072429, took 1.5733795166015625s\n", "epoch 2: CLs = 0.3379355596004474, took 1.5171947479248047s\n", "epoch 3: CLs = 0.1821927415636535, took 1.5081253051757812s\n", "epoch 4: CLs = 0.09119136931683047, took 1.5193650722503662s\n", "epoch 5: CLs = 0.04530559823843272, took 1.5008423328399658s\n", "epoch 6: CLs = 0.022572851867672883, took 1.499192476272583s\n", "epoch 7: CLs = 0.013835564056077887, took 1.5843737125396729s\n", "epoch 8: CLs = 0.01322058601444187, took 1.520324468612671s\n", "epoch 9: CLs = 0.013407422454837725, took 1.5050244331359863s\n", "epoch 10: CLs = 0.011836452218993765, took 1.509469985961914s\n", "epoch 11: CLs = 0.00948507486266359, took 1.5089364051818848s\n", "epoch 12: CLs = 0.007350505632595539, took 1.5106918811798096s\n", "epoch 13: CLs = 0.005755974539907838, took 1.5267891883850098s\n", "epoch 14: CLs = 0.0046464301411786035, took 1.5851080417633057s\n", "epoch 15: CLs = 0.0038756402968267434, took 1.8452086448669434s\n", "epoch 16: CLs = 0.003323640670405803, took 1.9116990566253662s\n", "epoch 17: CLs = 0.0029133909840759475, took 1.7648999691009521s\n", "epoch 18: CLs = 0.002596946123608612, took 1.6314191818237305s\n", "epoch 19: CLs = 0.0023454051342963744, took 1.5911424160003662s\n" ] } ], "source": [ "maxN = 20 # make me bigger for better results!\n", "\n", "# Training\n", "for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)):\n", " print(f\"epoch {i}:\", f'CLs = {metrics[\"loss\"][-1]}, took {epoch_time}s') " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And there we go!!\n", "\n", "If you want to reproduce the full animation, a version of this code with plotting helpers can be found in [`demo_training.ipynb`](https://github.com/pyhf/neos/blob/master/demo_training.ipynb)! :D" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Thanks\n", "\n", "A big thanks to the teams behind [`jax`](https://github.com/google/jax/), [`fax`](https://github.com/gehring/fax), and [`pyhf`](https://github.com/scikit-hep/pyhf) for their software and support." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }