{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# KDE demo, with histosys!\n", "\n", "> It works :)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](assets/kde_pyhf_animation.gif)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "import time\n", "\n", "import jax\n", "import jax.experimental.optimizers as optimizers\n", "import jax.experimental.stax as stax\n", "import jax.random\n", "from jax.random import PRNGKey\n", "import numpy as np\n", "from functools import partial\n", "\n", "import pyhf\n", "\n", "pyhf.set_backend(\"jax\")\n", "pyhf.default_backend = pyhf.tensor.jax_backend(precision=\"64b\")\n", "\n", "from neos import data, makers\n", "from relaxed import infer\n", "\n", "rng = PRNGKey(22)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# regression net\n", "init_random_params, predict = stax.serial(\n", " stax.Dense(1024),\n", " stax.Relu,\n", " stax.Dense(1024),\n", " stax.Relu,\n", " stax.Dense(1),\n", " stax.Sigmoid,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compose differentiable workflow" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "dgen = data.generate_blobs(rng, blobs=4)\n", "\n", "# Specify our hyperparameters ahead of time for the kde histograms\n", "bins = np.linspace(0, 1, 4)\n", "bandwidth = 0.27\n", "reflect_infinite_bins = True\n", "\n", "hmaker = makers.hists_from_nn(\n", " dgen,\n", " predict,\n", " hpar_dict=dict(bins=bins, bandwidth=bandwidth),\n", " method=\"kde\",\n", " reflect_infinities=reflect_infinite_bins,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nnm = makers.histosys_model_from_hists(hmaker)\n", "get_cls = infer.make_hypotest(nnm, solver_kwargs=dict(pdf_transform=True))\n", "\n", "# loss returns a list of metrics -- let's just index into one (CLs)\n", "def loss(params, test_mu):\n", " return get_cls(params, test_mu)[\"CLs\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Randomly initialise nn weights and check that we can get the gradient of the loss wrt nn params" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": "(DeviceArray(0.05989214, dtype=float64),\n [(DeviceArray([[-2.6561793e-05, -1.3431838e-04, 3.7267394e-04, ...,\n -9.5937648e-05, -3.1907926e-04, -4.7556667e-05],\n [ 3.2410408e-05, 1.6153457e-04, -2.3639399e-04, ...,\n -7.6833248e-06, 3.2346524e-04, 2.8453616e-05]], dtype=float32),\n DeviceArray([-1.0775060e-06, 1.6448357e-05, 2.5383127e-04, ...,\n 9.4325784e-05, -2.0307447e-04, -1.4521212e-05], dtype=float32)),\n (),\n (DeviceArray([[ 2.0301204e-07, -2.2220850e-06, 6.1780779e-06, ...,\n -1.0296096e-07, 6.6260043e-07, 4.6643160e-07],\n [ 4.1513945e-07, 9.9888112e-08, 1.8419865e-05, ...,\n 4.3947058e-08, 6.2674635e-06, -1.1988274e-06],\n [-2.0467695e-08, -1.2231366e-05, 5.6882186e-08, ...,\n -6.0378164e-07, -1.1139725e-05, 3.0422871e-06],\n ...,\n [ 4.0854286e-07, -6.6176659e-07, 1.1947506e-05, ...,\n -2.9116174e-08, 3.6848719e-06, 7.2041793e-08],\n [-5.8339246e-08, -1.7854960e-05, 5.4534485e-07, ...,\n -9.0966802e-07, -1.6797023e-05, 4.3758305e-06],\n [ 4.5675144e-08, -3.7403470e-06, 5.2949540e-06, ...,\n -2.3437008e-07, -4.0581904e-06, 8.9615037e-07]], dtype=float32),\n DeviceArray([ 1.49559655e-05, -8.68027928e-05, 5.29851706e-04, ...,\n -3.46148317e-06, 5.94751364e-05, -9.33103820e-06], dtype=float32)),\n (),\n (DeviceArray([[ 3.7021862e-05],\n [-2.0765483e-04],\n [ 3.0189482e-04],\n ...,\n [-4.2103562e-05],\n [ 2.0076217e-04],\n [ 6.4033600e-05]], dtype=float32),\n DeviceArray([-5.677808e-06], dtype=float32)),\n ()])" }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "_, network = init_random_params(jax.random.PRNGKey(2), (-1, 2))\n", "\n", "# gradient wrt nn weights\n", "jax.value_and_grad(loss)(network, test_mu=1.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define training loop!" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "opt_init, opt_update, opt_params = optimizers.adam(1e-3)\n", "\n", "\n", "def train_network(N):\n", " cls_vals = []\n", " _, network = init_random_params(jax.random.PRNGKey(1), (-1, 2))\n", " state = opt_init(network)\n", " losses = []\n", "\n", " # parameter update function\n", " # @jax.jit\n", " def update_and_value(i, opt_state, mu):\n", " net = opt_params(opt_state)\n", " value, grad = jax.value_and_grad(loss)(net, mu)\n", " return opt_update(i, grad, state), value, net\n", "\n", " for i in range(N):\n", " start_time = time.time()\n", " state, value, network = update_and_value(i, state, 1.0)\n", " epoch_time = time.time() - start_time\n", " losses.append(value)\n", " metrics = {\"loss\": losses}\n", "\n", " yield network, metrics, epoch_time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plotting helper function for awesome animations :)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def plot(axarr, network, metrics, maxN):\n", " ax = axarr[0]\n", " g = np.mgrid[-5:5:101j, -5:5:101j]\n", " levels = bins\n", " ax.contourf(\n", " g[0],\n", " g[1],\n", " predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 1)[:, :, 0],\n", " levels=levels,\n", " cmap=\"binary\",\n", " )\n", " ax.contour(\n", " g[0],\n", " g[1],\n", " predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 1)[:, :, 0],\n", " colors=\"w\",\n", " levels=levels,\n", " )\n", " sig, bkg_nom, bkg_up, bkg_down = dgen()\n", "\n", " ax.scatter(sig[:, 0], sig[:, 1], alpha=0.3, c=\"C9\")\n", " ax.scatter(bkg_up[:, 0], bkg_up[:, 1], alpha=0.1, c=\"C1\", marker=6)\n", " ax.scatter(bkg_down[:, 0], bkg_down[:, 1], alpha=0.1, c=\"C1\", marker=7)\n", " ax.scatter(bkg_nom[:, 0], bkg_nom[:, 1], alpha=0.3, c=\"C1\")\n", "\n", " ax.set_xlim(-5, 5)\n", " ax.set_ylim(-5, 5)\n", " ax.set_xlabel(\"x\")\n", " ax.set_ylabel(\"y\")\n", "\n", " ax = axarr[1]\n", " # ax.axhline(0.05, c=\"slategray\", linestyle=\"--\")\n", " ax.plot(metrics[\"loss\"], c=\"steelblue\", linewidth=2.0)\n", " ax.set_yscale(\"log\")\n", " ax.set_ylim(1e-4, 0.06)\n", " ax.set_xlim(0, maxN)\n", " ax.set_xlabel(\"epoch\")\n", " ax.set_ylabel(r\"$CL_s$\")\n", "\n", " ax = axarr[2]\n", " s, b, bup, bdown = hmaker(network)\n", "\n", " bin_width = 1 / (len(bins) - 1)\n", " centers = bins[:-1] + np.diff(bins) / 2.0\n", "\n", " ax.bar(centers, b, color=\"C1\", width=bin_width)\n", " ax.bar(centers, s, bottom=b, color=\"C9\", width=bin_width)\n", "\n", " bunc = np.asarray([[x, y] if x > y else [y, x] for x, y in zip(bup, bdown)])\n", " plot_unc = []\n", " for unc, be in zip(bunc, b):\n", " if all(unc > be):\n", " plot_unc.append([max(unc), be])\n", " elif all(unc < be):\n", " plot_unc.append([be, min(unc)])\n", " else:\n", " plot_unc.append(unc)\n", "\n", " plot_unc = np.asarray(plot_unc)\n", " b_up, b_down = plot_unc[:, 0], plot_unc[:, 1]\n", "\n", " ax.bar(\n", " centers, bup - b, bottom=b, alpha=0.4, color=\"red\", width=bin_width, hatch=\"+\"\n", " )\n", " ax.bar(\n", " centers,\n", " b - bdown,\n", " bottom=bdown,\n", " alpha=0.4,\n", " color=\"green\",\n", " width=bin_width,\n", " hatch=\"-\",\n", " )\n", "\n", " ax.set_ylim(0, 120)\n", " ax.set_ylabel(\"frequency\")\n", " ax.set_xlabel(\"nn output\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Let's run it!!" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# slow\n", "import numpy as np\n", "from IPython.display import HTML\n", "\n", "from matplotlib import pyplot as plt\n", "\n", "plt.rcParams.update(\n", " {\n", " \"axes.labelsize\": 13,\n", " \"axes.linewidth\": 1.2,\n", " \"xtick.labelsize\": 13,\n", " \"ytick.labelsize\": 13,\n", " \"figure.figsize\": [13.0, 4.0],\n", " \"font.size\": 13,\n", " \"xtick.major.size\": 3,\n", " \"ytick.major.size\": 3,\n", " \"legend.fontsize\": 11,\n", " }\n", ")\n", "\n", "\n", "fig, axarr = plt.subplots(1, 3, dpi=120)\n", "\n", "maxN = 500 # make me bigger for better results!\n", "\n", "animate = True # animations fail tests...\n", "\n", "if animate:\n", " from celluloid import Camera\n", "\n", " camera = Camera(fig)\n", "\n", "# Training\n", "for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)):\n", " # print(f\"epoch {i}:\", f'CLs = {metrics[\"loss\"][-1]}, took {epoch_time}s')\n", " if animate:\n", " plot(axarr, network, metrics, maxN=maxN)\n", " plt.tight_layout()\n", " camera.snap()\n", " # if i % 10 == 0:\n", " # camera.animate().save(\"animation.gif\", writer=\"imagemagick\", fps=8)\n", " # HTML(camera.animate().to_html5_video())\n", " # break\n", "if animate:\n", " camera.animate().save(\"animationinfesoft.gif\", writer=\"imagemagick\", fps=15)" ] }, { "cell_type": "raw", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "interpreter": { "hash": "653c4ff722aa20b955619ff1fd61a2acdddc645229f21af3c6e92ae1ff77e485" }, "kernelspec": { "display_name": "Python 3.8.3 ('venv': venv)", "name": "python3" }, "language_info": { "name": "python", "version": "" } }, "nbformat": 4, "nbformat_minor": 4 }