{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# neos.models\n", "\n", ">This module implements a very lightweght version of pyhf-like model building. For now, there are some hard-coded numbers (bounds, init) that help with the three gaussian blobs demonstration." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "import jax\n", "import jax.numpy as jnp\n", "import pyhf\n", "\n", "pyhf.set_backend(\"jax\")\n", "\n", "# class-based\n", "class _Config(object):\n", " def __init__(self):\n", " self.poi_index = 0\n", " self.npars = 2\n", "\n", " def suggested_init(self):\n", " return jnp.asarray([1.0, 1.0])\n", "\n", " def suggested_bounds(self):\n", " return jnp.asarray([jnp.asarray([0.0, 10.0]), jnp.asarray([0.0, 10.0])])\n", "\n", "\n", "class Model(object):\n", " \"\"\"Dummy class to mimic the functionality of `pyhf.Model`.\"\"\"\n", " def __init__(self, spec):\n", " self.sig, self.nominal, self.uncert = spec\n", " self.factor = (self.nominal / self.uncert) ** 2\n", " self.aux = 1.0 * self.factor\n", " self.config = _Config()\n", "\n", " def expected_data(self, pars, include_auxdata=True):\n", " mu, gamma = pars\n", " expected_main = jnp.asarray([gamma * self.nominal + mu * self.sig])\n", " aux_data = jnp.asarray([self.aux])\n", " return jnp.concatenate([expected_main, aux_data])\n", "\n", " def logpdf(self, pars, data):\n", " maindata, auxdata = data\n", " main, _ = self.expected_data(pars)\n", " _, gamma = pars\n", " main = pyhf.probability.Poisson(main).log_prob(maindata)\n", " constraint = pyhf.probability.Poisson(gamma * self.factor).log_prob(auxdata)\n", " # sum log probs over bins\n", " return jnp.asarray([jnp.sum(main + constraint, axis=0)])\n", "\n", "\n", "def hepdata_like(signal_data, bkg_data, bkg_uncerts, batch_size=None):\n", " \"\"\"Dummy class to mimic the functionality of `pyhf.simplemodels.hepdata_like`.\"\"\"\n", " return Model([signal_data, bkg_data, bkg_uncerts])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Usage:\n", "\n", "Let's build an example model, and get gradients of the likelihood function with respect to the model parameters: " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceArray(-27.74804929, dtype=float64),\n", " [DeviceArray(-22., dtype=float64), DeviceArray(-19., dtype=float64)])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import jax\n", "import jax.numpy as jnp\n", "\n", "import neos\n", "from neos.models import hepdata_like\n", "\n", "# three-bin model example data\n", "sig = jnp.asarray([20, 40, 3])\n", "bkg = jnp.asarray([40, 20, 3])\n", "uncert = jnp.asarray([3, 3, 3])\n", "\n", "# construct model\n", "m = hepdata_like(sig, bkg, uncert)\n", "d = m.expected_data([1.0, 1.0])\n", "\n", "# need scalar output (logpdf returns array w/ one element)\n", "def logpdf_scalar(pars):\n", " return m.logpdf(pars, d)[0]\n", "\n", "# check we can get gradients!\n", "jax.value_and_grad(logpdf_scalar)([2.0, 1.0])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }