{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# How to reproduce (and play) with `neos`:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We first need to install a special branch of `pyhf`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install git+http://github.com/scikit-hep/pyhf.git@make_difffable_model_ctor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Start with a couple of imports:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from jax.example_libraries import stax # neural network library for JAX\n", "from jax.random import PRNGKey # random number generator\n", "import jax.numpy as jnp # JAX's numpy\n", "import neos # :)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`neos` experiments have been designed to run through a flexible `Pipeline` class, which will compose the necessary ingredients to train differentiable analyses end-to-end.\n", "\n", "We have other examples in the works, but for now, we have wrapped up our current experiments in a module called `nn_observable`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from neos.experiments.nn_observable import (\n", " nn_summary_stat, # create a summary statistic from a neural network\n", " make_model, # use the summary statistic to make a HistFactory style model\n", " generate_data, # generates gaussian blobs to feed into the nn\n", " first_epoch, # special plotting callback for the first epoch\n", " last_epoch, # special plotting callback for the last epoch\n", " per_epoch, # generic plotting callback for each epoch\n", " plot_setup, # inital setup for the plotting\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each of these functions are pretty lightweight (with the exception of the plotting) -- if you want to get experimental and write your own pipeline, you'll find the code for those functions as a good starting point!\n", "\n", "Now we'll jump into training! First, we set up a neural network (for regression) and a random state:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rng_state = 0 # random state\n", "\n", "# feel free to modify :)\n", "init_random_params, nn = stax.serial(\n", " stax.Dense(1024),\n", " stax.Relu,\n", " stax.Dense(1024),\n", " stax.Relu,\n", " stax.Dense(1),\n", " stax.Sigmoid,\n", ")\n", "\n", "_, init_pars = init_random_params(PRNGKey(rng_state), (-1, 2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From there, we compose our pipeline with the relevant ingredients. I'll point out things you can play with immediately:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p = neos.Pipeline(\n", " data=generate_data(rng=rng_state, num_points=10000), # total number of points\n", " yield_kwargs=dict(\n", " nn=nn, # the nn we defined above\n", " bandwidth=1e-1, # bandwidth of the KDE (lower = more like a real histogram)\n", " bins=jnp.linspace(0, 1, 5), # binning of the summary stat (over [0,1])\n", " ),\n", " loss=lambda x: x[\"CLs\"],\n", " num_epochs=10, # number of epochs\n", " batch_size=500, # number of points per batch\n", " plotname=\"demo_nn_observable.png\", # save the final plot!\n", " animate=True, # make cool animations!\n", " animationname=\"demo_nn_observable.gif\", # save them!\n", " random_state=rng_state,\n", " yields_from_pars=nn_summary_stat,\n", " model_from_yields=make_model,\n", " init_pars=init_pars,\n", " first_epoch_callback=first_epoch,\n", " last_epoch_callback=last_epoch,\n", " per_epoch_callback=per_epoch,\n", " plot_setup=plot_setup,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we run! Each epoch takes around 15s on my local CPU, so expect something similar :)\n", "\n", "You'll see some cool plots and animations, so it's worth it ;)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p.run()" ] } ], "metadata": { "interpreter": { "hash": "add7f649dd775e2836e5066b13fa3fabbf6edcd3797c3e56305cb8e8cb136921" }, "kernelspec": { "display_name": "Python 3.9.0 64-bit ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.0" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }