{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp transforms" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# neos.transforms\n", "\n", "> Contains transforms to map from $[-\\infty,\\infty]$ to a bounded space $[a,b]$ and back." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This module implements two transforms, taken from the minuit optimizer:\n", "\n", "$$P_{\\mathrm{inf}}=\\arcsin \\left(2 \\frac{P_{\\mathrm{bounded}}-a}{b-a}-1\\right):~[a,b] \\rightarrow [-\\infty,\\infty]$$\n", "\n", "$$P_{\\mathrm{bounded}}=a+\\frac{b-a}{2}\\left(\\sin P_{\\mathrm{inf}}+1\\right):~[-\\infty,\\infty]\\rightarrow [a,b] $$\n", "\n", "The purpose of these is to add stability to the maximum likelihood fits of the model parameters, which are currently done by gradient descent. This is done by allowing the minimization to occur on the real line, and then mapping the result to a value in a 'sensible' interval $[a,b]$ before evaluating the likelihood. You can imagine if this wasnt the case, it's possible that the likelihood may be evaluated with negative model parameters or very extreme values, potentially causing numeric instability in the likelihood or gradient evaluations." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "# avoid those precision errors!\n", "jax.config.update(\"jax_enable_x64\", True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "# [-inf, inf] -> [a,b] (vectors)\n", "def to_bounded_vec(param, bounds):\n", " bounds = jnp.asarray(bounds)\n", " a, b = bounds[:, 0], bounds[:, 1]\n", " return a + (b - a) * 0.5 * (jnp.sin(param) + 1.0)\n", "\n", "\n", "# [-inf, inf] -> [a,b]\n", "def to_bounded(param, bounds):\n", " a, b = bounds\n", " return a + (b - a) * 0.5 * (jnp.sin(param) + 1.0)\n", "\n", "\n", "# [-inf, inf] <- [a,b] (vectors)\n", "def to_inf_vec(param, bounds):\n", " bounds = jnp.asarray(bounds)\n", " a, b = bounds[:, 0], bounds[:, 1]\n", " x = (2.0 * param - a) / (b - a) - 1.0\n", " return jnp.arcsin(x)\n", "\n", "\n", "# [-inf, inf] <- [a,b]\n", "def to_inf(param, bounds):\n", " a, b = bounds\n", " # print(f\"a,b: {a,b}\")\n", " x = (2.0 * param - a) / (b - a) - 1.0\n", " return jnp.arcsin(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "p = jnp.asarray([1.0, 1.0]) # points\n", "b = jnp.asarray([[0.0, 10.0], [0.0, 10.0]]) # bounds\n", "\n", "# check if 1 is invariant if we transform to bounded space and back\n", "cond = np.allclose(to_inf(to_bounded(p[0], b[0]), b[0]), p[0])\n", "assert cond, f\"{to_inf(to_bounded(p[0], b[0]), b[0])} != {p[0]}\"\n", "\n", "# check if [1,1] is invariant\n", "cond = np.allclose(to_inf_vec(to_bounded_vec(p, b), b), p)\n", "assert cond, f\"{to_inf_vec(to_bounded_vec(p, b), b)} != {p}\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "min: 0.0, max: 10.0, to inf:\n", "-1.5707963267948966 1.5707963267948966\n", "min: -10000000000.0, max: 10000000000.0, to [ 0 10]:\n", "2.5122629992990753e-06 9.999997487737\n" ] } ], "source": [ "# hide\n", "bounds = jnp.array([[0, 10], [0, 20]])\n", "\n", "# check that we map to inf space (i.e. -pi/2 to pi/2)\n", "w = jnp.linspace(0, 10)\n", "x = to_inf(w, bounds[0])\n", "print(f\"min: {w.min()}, max: {w.max()}, to inf:\")\n", "print(x.min(), x.max())\n", "\n", "\n", "# check that we can map very large values to bounded space\n", "w = jnp.linspace(-1e10, 1e10, 1001)\n", "x = to_bounded(w, bounds[0])\n", "print(f\"min: {w.min()}, max: {w.max()}, to {bounds[0]}:\")\n", "print(x.min(), x.max())\n", "assert np.allclose(\n", " np.asarray([x.min(), x.max()],), bounds[0], atol=1e-5\n", "), \"Large numbers are not mapped to the bounds of the bounded transform\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "check consistency in both spaces:\n", "..good!\n", "gradients in bounded space:\n", "[ -0.96078431 -99.05962385]\n", "gradients in inf space:\n", "[ -2.09398087 -309.31357633]\n", "consistency? check with chain rule:\n", "[ -0.96078431 -99.05962385]\n", "all good here chief\n" ] } ], "source": [ "# hide\n", "# define NLL functions in both parameter spaces\n", "\n", "from neos import models\n", "\n", "\n", "def make_nll_boundspace(hyperpars):\n", " s, b, db = hyperpars\n", "\n", " def nll_boundspace(pars):\n", " truth_pars = [0, 1]\n", " m = models.hepdata_like(jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db]))\n", " val = m.logpdf(pars, m.expected_data(truth_pars))\n", " return -val[0]\n", "\n", " return nll_boundspace\n", "\n", "\n", "def make_nll_infspace(hyperpars):\n", " s, b, db = hyperpars\n", "\n", " def nll_infspace(pars):\n", " truth_pars = [0, 1]\n", "\n", " pars = to_bounded_vec(pars, bounds)\n", "\n", " m = models.hepdata_like(jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db]))\n", " val = m.logpdf(pars, m.expected_data(truth_pars))\n", " return -val[0]\n", "\n", " return nll_infspace\n", "\n", "\n", "nll_boundspace = make_nll_boundspace([1, 50, 7])\n", "nll_infspace = make_nll_infspace([1, 50, 7])\n", "\n", "# define a point and compute it in both spaces\n", "apoint_bnd = jnp.array([0.5, 0.5])\n", "apoint_inf = to_inf_vec(apoint_bnd, bounds)\n", "\n", "# check consistency in both spaces\n", "print(\"check consistency in both spaces:\")\n", "point_bound = nll_boundspace(apoint_bnd)\n", "point_inf = nll_infspace(apoint_inf)\n", "assert np.allclose(\n", " point_bound, point_inf\n", "), f\"{point_bound} (bounded) should be close to {point_inf} (inf)\"\n", "print(\"..good!\")\n", "# check gradients in bounded\n", "print(\"gradients in bounded space:\")\n", "dlb_dpb = jax.grad(nll_boundspace)(apoint_bnd)\n", "print(dlb_dpb)\n", "\n", "# check gradients in inf\n", "print(\"gradients in inf space:\")\n", "dli_dinf = jax.grad(nll_infspace)(apoint_inf)\n", "print(dli_dinf)\n", "\n", "# check consistency of gradients\n", "print(\"consistency? check with chain rule:\")\n", "dli_dpi = dli_dinf * jnp.array(\n", " [\n", " jax.grad(lambda x, b: to_inf_vec(x, b)[i])(apoint_bnd, bounds)[i]\n", " for i in range(2)\n", " ]\n", ")\n", "print(dli_dpi)\n", "\n", "# li maps pi to bounded, then becomes lb, so grad should be the same\n", "cond = np.allclose(dli_dpi, dlb_dpb)\n", "assert cond, \"Chain rule... doesnt work? :o\"\n", "print(\"all good here chief\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "global fit grad\n", "(DeviceArray(4.04676292e-13, dtype=float64), [DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64)])\n", "(DeviceArray(1.71529457e-13, dtype=float64), [DeviceArray(-6.85560674e-14, dtype=float64), DeviceArray(9.1365743e-15, dtype=float64), DeviceArray(-2.28414357e-14, dtype=float64)])\n", "(DeviceArray(9.93649607e-14, dtype=float64), [DeviceArray(-3.98545918e-14, dtype=float64), DeviceArray(5.99354644e-15, dtype=float64), DeviceArray(-2.87322456e-14, dtype=float64)])\n", "(DeviceArray(6.38378239e-14, dtype=float64), [DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64), DeviceArray(nan, dtype=float64)])\n", "constrained!\n", "[1. 0.91976327]\n", "[1. 0.93563135]\n", "[1. 0.95299143]\n", "[1. 0.99821931]\n", "[1. 0.99998647]\n", "constrained fit grad\n", "(DeviceArray(0.91976327, dtype=float64), [DeviceArray(-0.01570885, dtype=float64), DeviceArray(0.00188766, dtype=float64), DeviceArray(-0.00211123, dtype=float64)])\n", "(DeviceArray(0.93563135, dtype=float64), [DeviceArray(-0.01240237, dtype=float64), DeviceArray(0.00169771, dtype=float64), DeviceArray(-0.00457567, dtype=float64)])\n", "(DeviceArray(0.95299143, dtype=float64), [DeviceArray(-0.00890556, dtype=float64), DeviceArray(0.00138767, dtype=float64), DeviceArray(-0.00710027, dtype=float64)])\n", "(DeviceArray(0.99821931, dtype=float64), [DeviceArray(-0.0003251, dtype=float64), DeviceArray(6.73871338e-05, dtype=float64), DeviceArray(-0.00349722, dtype=float64)])\n", "(DeviceArray(0.99998647, dtype=float64), [DeviceArray(-2.78601582e-06, dtype=float64), DeviceArray(4.28295151e-07, dtype=float64), DeviceArray(-0.00022806, dtype=float64)])\n", "reference\n", "[1. 0.91979939]\n", "[1. 0.93570921]\n", "[1. 0.95295097]\n", "[1. 0.9982233]\n", "[1. 0.99998182]\n", "diffable cls\n", "cross check cls\n" ] } ], "source": [ "# hide\n", "import scipy\n", "\n", "import pyhf\n", "from neos import cls, fit\n", "\n", "pyhf.set_backend(pyhf.tensor.jax_backend())\n", "\n", "\n", "def fit_nll_bounded(init, hyperpars):\n", " mu, model_pars = hyperpars[0], hyperpars[1:]\n", " objective = make_nll_boundspace(model_pars)\n", " return scipy.optimize.minimize(objective, x0=init, bounds=bounds).x\n", "\n", "\n", "def fit_nll_infspace(init, hyperpars):\n", " mu, model_pars = hyperpars[0], hyperpars[1:]\n", " objective = make_nll_infspace(model_pars)\n", " # result = scipy.optimize.minimize(objective, x0 = init).x\n", " result = funnyscipy.minimize(objective, x0=init)\n", " return to_bounded_vec(result, bounds)\n", "\n", "\n", "# fit in bounded space\n", "if False:\n", " print(\"scipy minim in bounded space\")\n", " print(fit_nll_bounded(apoint_bnd, [1.0, 5, 50, 7]))\n", " print(fit_nll_bounded(apoint_bnd, [1.0, 5, 50, 2]))\n", " print(fit_nll_bounded(apoint_bnd, [1.0, 5, 50, 1]))\n", " print(fit_nll_bounded(apoint_bnd, [1.0, 5, 50, 0.1]))\n", " print(fit_nll_bounded(apoint_bnd, [1.0, 5, 50, 0.01]))\n", "\n", "# fit in inf space\n", "if False:\n", " print(\"scipy minim in inf space\")\n", " print(fit_nll_infspace(apoint_inf, [1.0, 5, 50, 7]))\n", " print(fit_nll_infspace(apoint_inf, [1.0, 5, 50, 2]))\n", " print(fit_nll_infspace(apoint_inf, [1.0, 5, 50, 1]))\n", " print(fit_nll_infspace(apoint_inf, [1.0, 5, 50, 0.1]))\n", " print(fit_nll_infspace(apoint_inf, [1.0, 5, 50, 0.01]))\n", " print(fit_nll_infspace(apoint_inf, [1.0, 5, 50, 0.001]))\n", "\n", "\n", "def nn_model_maker(nn_params):\n", " s, b, db = nn_params\n", " m = models.hepdata_like(jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db]))\n", " nompars = m.config.suggested_init()\n", " bonlypars = jax.numpy.asarray([x for x in nompars])\n", " bonlypars = jax.ops.index_update(bonlypars, m.config.poi_index, 0.0)\n", " return m, bonlypars\n", "\n", "\n", "g_fitter, c_fitter = fit.get_solvers(\n", " nn_model_maker, pdf_transform=True, learning_rate=1e-4\n", ")\n", "\n", "bounds = jnp.array([[0.0, 10], [0.0, 10.0]])\n", "\n", "if False:\n", " print(\"diffable minim in inf space\")\n", " apoint_bnd = jnp.array([0.5, 0.5])\n", " apoint_inf = to_inf_vec(apoint_bnd, bounds)\n", " print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 50, 7.0]]), bounds))\n", " print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 50, 2.0]]), bounds))\n", " print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 50, 1.0]]), bounds))\n", " print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 50, 0.5]]), bounds))\n", " print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 50, 0.1]]), bounds))\n", " print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 50, 0.01]]), bounds))\n", " print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [5, 55, 1.5]]), bounds))\n", " print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [10, 5, 1.5]]), bounds))\n", " print(to_bounded_vec(g_fitter(apoint_inf, [1.0, [2, 90, 1.5]]), bounds))\n", "\n", "\n", "print(\"global fit grad\")\n", "print(\n", " jax.value_and_grad(\n", " lambda x: to_bounded_vec(g_fitter(apoint_inf, [1.0, x]), bounds)[0]\n", " )([5.0, 50.0, 15.0])\n", ")\n", "print(\n", " jax.value_and_grad(\n", " lambda x: to_bounded_vec(g_fitter(apoint_inf, [1.0, x]), bounds)[0]\n", " )([5.0, 50.0, 10.0])\n", ")\n", "print(\n", " jax.value_and_grad(\n", " lambda x: to_bounded_vec(g_fitter(apoint_inf, [1.0, x]), bounds)[0]\n", " )([5.0, 50.0, 7.0])\n", ")\n", "print(\n", " jax.value_and_grad(\n", " lambda x: to_bounded_vec(g_fitter(apoint_inf, [1.0, x]), bounds)[0]\n", " )([5.0, 50.0, 1.0])\n", ")\n", "\n", "print(\"constrained!\")\n", "\n", "apoint_bnd = jnp.array([1.0, 1.0])\n", "apoint_inf = to_inf_vec(apoint_bnd, bounds)\n", "print(to_bounded_vec(c_fitter(apoint_inf, [1.0, [5, 50, 15.0]]), bounds))\n", "print(to_bounded_vec(c_fitter(apoint_inf, [1.0, [5, 50, 10.0]]), bounds))\n", "print(to_bounded_vec(c_fitter(apoint_inf, [1.0, [5, 50, 7.0]]), bounds))\n", "print(to_bounded_vec(c_fitter(apoint_inf, [1.0, [5, 50, 1.0]]), bounds))\n", "print(to_bounded_vec(c_fitter(apoint_inf, [1.0, [5, 50, 0.1]]), bounds))\n", "\n", "\n", "print(\"constrained fit grad\")\n", "print(\n", " jax.value_and_grad(\n", " lambda x: to_bounded_vec(c_fitter(apoint_inf, [1.0, x]), bounds)[1]\n", " )([5.0, 50.0, 15.0])\n", ")\n", "print(\n", " jax.value_and_grad(\n", " lambda x: to_bounded_vec(c_fitter(apoint_inf, [1.0, x]), bounds)[1]\n", " )([5.0, 50.0, 10.0])\n", ")\n", "print(\n", " jax.value_and_grad(\n", " lambda x: to_bounded_vec(c_fitter(apoint_inf, [1.0, x]), bounds)[1]\n", " )([5.0, 50.0, 7.0])\n", ")\n", "print(\n", " jax.value_and_grad(\n", " lambda x: to_bounded_vec(c_fitter(apoint_inf, [1.0, x]), bounds)[1]\n", " )([5.0, 50.0, 1.0])\n", ")\n", "print(\n", " jax.value_and_grad(\n", " lambda x: to_bounded_vec(c_fitter(apoint_inf, [1.0, x]), bounds)[1]\n", " )([5.0, 50.0, 0.1])\n", ")\n", "\n", "\n", "def fit_nll_bounded_constrained(init, hyperpars, fixed_val):\n", " mu, model_pars = hyperpars[0], hyperpars[1:]\n", " objective = make_nll_boundspace(model_pars)\n", " return scipy.optimize.minimize(\n", " objective,\n", " x0=init,\n", " bounds=bounds,\n", " constraints=[{\"type\": \"eq\", \"fun\": lambda v: v[0] - fixed_val}],\n", " ).x\n", "\n", "\n", "print(\"reference\")\n", "print(fit_nll_bounded_constrained(apoint_bnd, [1.0, 5, 50, 15.0], 1.0))\n", "print(fit_nll_bounded_constrained(apoint_bnd, [1.0, 5, 50, 10.0], 1.0))\n", "print(fit_nll_bounded_constrained(apoint_bnd, [1.0, 5, 50, 7.0], 1.0))\n", "print(fit_nll_bounded_constrained(apoint_bnd, [1.0, 5, 50, 1.0], 1.0))\n", "print(fit_nll_bounded_constrained(apoint_bnd, [1.0, 5, 50, 0.1], 1.0))\n", "\n", "\n", "print(\"diffable cls\")\n", "\n", "j_cls = []\n", "\n", "j_cls.append(\n", " jax.value_and_grad(\n", " cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))\n", " )([5.0, 50.0, 15.0], 1.0)[0]\n", ")\n", "j_cls.append(\n", " jax.value_and_grad(\n", " cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))\n", " )([5.0, 50.0, 10.0], 1.0)[0]\n", ")\n", "j_cls.append(\n", " jax.value_and_grad(\n", " cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))\n", " )([5.0, 50.0, 7.0], 1.0)[0]\n", ")\n", "j_cls.append(\n", " jax.value_and_grad(\n", " cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))\n", " )([5.0, 50.0, 1.0], 1.0)[0]\n", ")\n", "j_cls.append(\n", " jax.value_and_grad(\n", " cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))\n", " )([5.0, 50.0, 0.1], 1.0)[0]\n", ")\n", "\n", "j_cls.append(\n", " jax.value_and_grad(\n", " cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))\n", " )([10.0, 5.0, 0.1], 1.0)[0]\n", ")\n", "j_cls.append(\n", " jax.value_and_grad(\n", " cls.cls_maker(nn_model_maker, solver_kwargs=dict(pdf_transform=True))\n", " )([15.0, 5.0, 0.1], 1.0)[0]\n", ")\n", "\n", "\n", "print(\"cross check cls\")\n", "\n", "\n", "def pyhf_cls(nn_params, mu):\n", " s, b, db = nn_params\n", " m = pyhf.simplemodels.hepdata_like([s], [b], [db])\n", " return pyhf.infer.hypotest(1.0, [b] + m.config.auxdata, m)[0]\n", "\n", "\n", "p_cls = []\n", "\n", "p_cls.append(pyhf_cls([5.0, 50.0, 15.0], 1.0))\n", "p_cls.append(pyhf_cls([5.0, 50.0, 10.0], 1.0))\n", "p_cls.append(pyhf_cls([5.0, 50.0, 7.0], 1.0))\n", "p_cls.append(pyhf_cls([5.0, 50.0, 1.0], 1.0))\n", "p_cls.append(pyhf_cls([5.0, 50.0, 0.1], 1.0))\n", "\n", "p_cls.append(pyhf_cls([10.0, 5.0, 0.1], 1.0))\n", "p_cls.append(pyhf_cls([15.0, 5.0, 0.1], 1.0))\n", "\n", "assert np.allclose(np.asarray(j_cls), np.asarray(p_cls)), \"cls values don't match pyhf\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 4 }