{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Automatic differentiation with JAX\n", "\n", "Here we look into automatic differentiation, which can speed up fits with very many parameters.\n", "\n", "iminuit's minimization algorithm MIGRAD uses a mix of gradient descent and Newton's method to find the minimum. Both require a first derivative, which MIGRAD usually computes numerically from finite differences. This requires many function evaluations and the gradient may not be accurate. As an alternative, iminuit also allows the user to compute the gradient and pass it to MIGRAD.\n", "\n", "Although computing derivatives is often straight-forward, it is usually too much hassle to do manually. Automatic differentiation (AD) is an interesting alternative, it allows one to compute exact derivatives efficiently for pure Python/numpy functions. We demonstrate automatic differentiation with the JAX module, which can not only compute derivatives, but also accelerates the computation of Python code (including the gradient code) with a just-in-time compiler.\n", "\n", "[Recommended read: Gentle introduction to AD](https://www.kaggle.com/borisettinger/gentle-introduction-to-automatic-differentiation)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fit of a gaussian model to a histogram\n", "\n", "We fit a gaussian to a histogram using a maximum-likelihood approach based on Poisson statistics. This example is used to investigate how automatic differentiation can accelerate a typical fit in a counting experiment.\n", "\n", "To compare fits with and without passing an analytic gradient fairly, we use `Minuit.strategy = 0`, which prevents Minuit from automatically computing the Hesse matrix after the fit." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:37.436843Z", "start_time": "2020-02-21T10:26:37.432080Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/hdembinski/Extern/iminuit/venv/lib/python3.10/site-packages/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.\n", " warnings.warn(\"JAX on Mac ARM machines is experimental and minimally tested. \"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "JAX version 0.3.2\n", "numba version 0.56.4\n" ] } ], "source": [ "# !pip install jax jaxlib matplotlib numpy iminuit numba-stats\n", "\n", "import jax\n", "from jax import numpy as jnp # replacement for normal numpy\n", "from jax.scipy.special import erf # replacement for scipy.special.erf\n", "from iminuit import Minuit\n", "import numba as nb\n", "import numpy as np # original numpy still needed, since jax does not cover full API\n", "\n", "jax.config.update(\"jax_enable_x64\", True) # enable float64 precision, default is float32\n", "\n", "print(f\"JAX version {jax.__version__}\")\n", "print(f\"numba version {nb.__version__}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We generate some toy data and write the negative log-likelihood (nll) for a fit to binned data, assuming Poisson-distributed counts.\n", "\n", "**Note:** We write all statistical functions in pure Python code, to demonstrate Jax's ability to automatically differentiate and JIT compile this code. In practice, one should import JIT-able statistical distributions from jax.scipy.stats. The library versions can be expected to have fewer bugs and to be faster and more accurate than hand-written code." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:37.594856Z", "start_time": "2020-02-21T10:26:37.585943Z" } }, "outputs": [], "source": [ "# generate some toy data\n", "rng = np.random.default_rng(seed=1)\n", "n, xe = np.histogram(rng.normal(size=10000), bins=1000)\n", "\n", "\n", "def cdf(x, mu, sigma):\n", " # cdf of a normal distribution, needed to compute the expected counts per bin\n", " # better alternative for real code: from jax.scipy.stats.norm import cdf\n", " z = (x - mu) / sigma\n", " return 0.5 * (1 + erf(z / np.sqrt(2)))\n", "\n", "\n", "def nll(par): # negative log-likelihood with constants stripped\n", " amp = par[0]\n", " mu, sigma = par[1:]\n", " p = cdf(xe, mu, sigma)\n", " mu = amp * jnp.diff(p)\n", " result = jnp.sum(mu - n + n * jnp.log(n / (mu + 1e-100) + 1e-100))\n", " return result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check results from all combinations of using JIT and gradient and then compare the execution times." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:37.890967Z", "start_time": "2020-02-21T10:26:37.886224Z" } }, "outputs": [], "source": [ "start_values = (1.5 * np.sum(n), 1.0, 2.0)\n", "limits = ((0, None), (None, None), (0, None))\n", "\n", "\n", "def make_and_run_minuit(fcn, grad=None):\n", " m = Minuit(fcn, start_values, grad=grad, name=(\"amp\", \"mu\", \"sigma\"))\n", " m.errordef = Minuit.LIKELIHOOD\n", " m.limits = limits\n", " m.strategy = 0 # do not explicitly compute hessian after minimisation\n", " m.migrad()\n", " return m" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:38.532308Z", "start_time": "2020-02-21T10:26:38.368563Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Migrad
FCN = 496.2 Nfcn = 66
EDM = 1.84e-08 (Goal: 0.0001) time = 0.2 sec
Valid Minimum No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok APPROXIMATE Pos. def. Not forced
" ], "text/plain": [ "┌─────────────────────────────────────────────────────────────────────────┐\n", "│ Migrad │\n", "├──────────────────────────────────┬──────────────────────────────────────┤\n", "│ FCN = 496.2 │ Nfcn = 66 │\n", "│ EDM = 1.84e-08 (Goal: 0.0001) │ time = 0.2 sec │\n", "├──────────────────────────────────┼──────────────────────────────────────┤\n", "│ Valid Minimum │ No Parameters at limit │\n", "├──────────────────────────────────┼──────────────────────────────────────┤\n", "│ Below EDM threshold (goal x 10) │ Below call limit │\n", "├───────────────┬──────────────────┼───────────┬─────────────┬────────────┤\n", "│ Covariance │ Hesse ok │APPROXIMATE│ Pos. def. │ Not forced │\n", "└───────────────┴──────────────────┴───────────┴─────────────┴────────────┘" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m1 = make_and_run_minuit(nll)\n", "m1.fmin" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:39.371830Z", "start_time": "2020-02-21T10:26:38.797460Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Migrad
FCN = 496.2 Nfcn = 26, Ngrad = 6
EDM = 1.84e-08 (Goal: 0.0001) time = 0.5 sec
Valid Minimum No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok APPROXIMATE Pos. def. Not forced
" ], "text/plain": [ "┌─────────────────────────────────────────────────────────────────────────┐\n", "│ Migrad │\n", "├──────────────────────────────────┬──────────────────────────────────────┤\n", "│ FCN = 496.2 │ Nfcn = 26, Ngrad = 6 │\n", "│ EDM = 1.84e-08 (Goal: 0.0001) │ time = 0.5 sec │\n", "├──────────────────────────────────┼──────────────────────────────────────┤\n", "│ Valid Minimum │ No Parameters at limit │\n", "├──────────────────────────────────┼──────────────────────────────────────┤\n", "│ Below EDM threshold (goal x 10) │ Below call limit │\n", "├───────────────┬──────────────────┼───────────┬─────────────┬────────────┤\n", "│ Covariance │ Hesse ok │APPROXIMATE│ Pos. def. │ Not forced │\n", "└───────────────┴──────────────────┴───────────┴─────────────┴────────────┘" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m2 = make_and_run_minuit(nll, grad=jax.grad(nll))\n", "m2.fmin" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:39.510553Z", "start_time": "2020-02-21T10:26:39.373728Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Migrad
FCN = 496.2 Nfcn = 26, Ngrad = 6
EDM = 1.88e-08 (Goal: 0.0001)
Valid Minimum No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok APPROXIMATE Pos. def. Not forced
" ], "text/plain": [ "┌─────────────────────────────────────────────────────────────────────────┐\n", "│ Migrad │\n", "├──────────────────────────────────┬──────────────────────────────────────┤\n", "│ FCN = 496.2 │ Nfcn = 26, Ngrad = 6 │\n", "│ EDM = 1.88e-08 (Goal: 0.0001) │ │\n", "├──────────────────────────────────┼──────────────────────────────────────┤\n", "│ Valid Minimum │ No Parameters at limit │\n", "├──────────────────────────────────┼──────────────────────────────────────┤\n", "│ Below EDM threshold (goal x 10) │ Below call limit │\n", "├───────────────┬──────────────────┼───────────┬─────────────┬────────────┤\n", "│ Covariance │ Hesse ok │APPROXIMATE│ Pos. def. │ Not forced │\n", "└───────────────┴──────────────────┴───────────┴─────────────┴────────────┘" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m3 = make_and_run_minuit(jax.jit(nll), grad=jax.grad(nll))\n", "m3.fmin" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:40.573574Z", "start_time": "2020-02-21T10:26:40.229476Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Migrad
FCN = 496.2 Nfcn = 26, Ngrad = 6
EDM = 1.88e-08 (Goal: 0.0001) time = 0.1 sec
Valid Minimum No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok APPROXIMATE Pos. def. Not forced
" ], "text/plain": [ "┌─────────────────────────────────────────────────────────────────────────┐\n", "│ Migrad │\n", "├──────────────────────────────────┬──────────────────────────────────────┤\n", "│ FCN = 496.2 │ Nfcn = 26, Ngrad = 6 │\n", "│ EDM = 1.88e-08 (Goal: 0.0001) │ time = 0.1 sec │\n", "├──────────────────────────────────┼──────────────────────────────────────┤\n", "│ Valid Minimum │ No Parameters at limit │\n", "├──────────────────────────────────┼──────────────────────────────────────┤\n", "│ Below EDM threshold (goal x 10) │ Below call limit │\n", "├───────────────┬──────────────────┼───────────┬─────────────┬────────────┤\n", "│ Covariance │ Hesse ok │APPROXIMATE│ Pos. def. │ Not forced │\n", "└───────────────┴──────────────────┴───────────┴─────────────┴────────────┘" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m4 = make_and_run_minuit(jax.jit(nll), grad=jax.jit(jax.grad(nll)))\n", "m4.fmin" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Migrad
FCN = 496.2 Nfcn = 82
EDM = 5.31e-05 (Goal: 0.0001) time = 0.9 sec
Valid Minimum No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok APPROXIMATE Pos. def. Not forced
" ], "text/plain": [ "┌─────────────────────────────────────────────────────────────────────────┐\n", "│ Migrad │\n", "├──────────────────────────────────┬──────────────────────────────────────┤\n", "│ FCN = 496.2 │ Nfcn = 82 │\n", "│ EDM = 5.31e-05 (Goal: 0.0001) │ time = 0.9 sec │\n", "├──────────────────────────────────┼──────────────────────────────────────┤\n", "│ Valid Minimum │ No Parameters at limit │\n", "├──────────────────────────────────┼──────────────────────────────────────┤\n", "│ Below EDM threshold (goal x 10) │ Below call limit │\n", "├───────────────┬──────────────────┼───────────┬─────────────┬────────────┤\n", "│ Covariance │ Hesse ok │APPROXIMATE│ Pos. def. │ Not forced │\n", "└───────────────┴──────────────────┴───────────┴─────────────┴────────────┘" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from numba_stats import norm # numba jit-able version of norm\n", "\n", "@nb.njit\n", "def nb_nll(par):\n", " amp = par[0]\n", " mu, sigma = par[1:]\n", " p = norm.cdf(xe, mu, sigma)\n", " mu = amp * np.diff(p)\n", " result = np.sum(mu - n + n * np.log(n / (mu + 1e-323) + 1e-323))\n", " return result\n", "\n", "m5 = make_and_run_minuit(nb_nll)\n", "m5.fmin" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:45.031931Z", "start_time": "2020-02-21T10:26:40.674388Z" } }, "outputs": [], "source": [ "from timeit import timeit\n", "\n", "times = {\n", " \"no JIT, no grad\": \"m1\",\n", " \"no JIT, grad\": \"m2\",\n", " \"jax JIT, no grad\": \"m3\",\n", " \"jax JIT, grad\": \"m4\",\n", " \"numba JIT, no grad\": \"m5\",\n", "}\n", "for k, v in times.items():\n", " t = timeit(\n", " f\"{v}.values = start_values; {v}.migrad()\",\n", " f\"from __main__ import {v}, start_values\",\n", " number=1,\n", " )\n", " times[k] = t" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:45.142272Z", "start_time": "2020-02-21T10:26:45.033451Z" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from matplotlib import pyplot as plt\n", "\n", "x = np.fromiter(times.values(), dtype=float)\n", "xmin = np.min(x)\n", "\n", "y = -np.arange(len(times))\n", "plt.barh(y, x)\n", "for yi, k, v in zip(y, times, x):\n", " plt.text(v, yi, f\"{v/xmin:.1f}x\")\n", "plt.yticks(y, times.keys())\n", "for loc in (\"top\", \"right\"):\n", " plt.gca().spines[loc].set_visible(False)\n", "plt.xlabel(\"execution time / s\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusions:\n", "\n", "1. As expected, the best results are obtained by JIT compiling the function and the gradient.\n", "\n", "2. JIT compiling the cost function with Jax but not using the gradient gives a negligible performance improvement. Numba is able to do much better.\n", "\n", "3. JIT compiling the gradient is very important. Using the Python-computed gradient even drastically reduces performance in this example.\n", "\n", "In general, the gain from using a gradient is larger for functions with hundreds of parameters, as is common in machine learning. Human-made models often have less than 10 parameters, and then the gain is not so dramatic. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Computing covariance matrices with JAX\n", "\n", "Automatic differentiation gives us another way to compute uncertainties of fitted parameters. MINUIT compute the uncertainties with the HESSE algorithm by default, which computes the matrix of second derivates approximately using finite differences and inverts this.\n", "\n", "Let's compare the output of HESSE with the exact (within floating point precision) computation using automatic differentiation." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:27:38.715871Z", "start_time": "2020-02-21T10:27:37.907690Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sigma[amp] : HESSE = 100.0, JAX = 100.0\n", "sigma[mu] : HESSE = 0.0100, JAX = 0.0100\n", "sigma[sigma]: HESSE = 0.0071, JAX = 0.0071\n" ] } ], "source": [ "m4.hesse()\n", "cov_hesse = m4.covariance\n", "\n", "\n", "def jax_covariance(par):\n", " return jnp.linalg.inv(jax.hessian(nll)(par))\n", "\n", "\n", "par = np.array(m4.values)\n", "cov_jax = jax_covariance(par)\n", "\n", "print(\n", " f\"sigma[amp] : HESSE = {cov_hesse[0, 0] ** 0.5:6.1f}, JAX = {cov_jax[0, 0] ** 0.5:6.1f}\"\n", ")\n", "print(\n", " f\"sigma[mu] : HESSE = {cov_hesse[1, 1] ** 0.5:6.4f}, JAX = {cov_jax[1, 1] ** 0.5:6.4f}\"\n", ")\n", "print(\n", " f\"sigma[sigma]: HESSE = {cov_hesse[2, 2] ** 0.5:6.4f}, JAX = {cov_jax[2, 2] ** 0.5:6.4f}\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Success, HESSE and JAX give the same answer within the relevant precision.\n", "\n", "**Note:** If you compute the covariance matrix in this way from a least-squares cost function instead of a negative log-likelihood, you must multiply it by 2.\n", "\n", "Let us compare the performance of HESSE with Jax." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6.59 ms ± 595 µs per loop (mean ± std. dev. of 3 runs, 1 loop each)\n" ] } ], "source": [ "%%timeit -n 1 -r 3\n", "m = Minuit(nll, par)\n", "m.errordef = Minuit.LIKELIHOOD\n", "m.hesse()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "14.8 ms ± 523 µs per loop (mean ± std. dev. of 3 runs, 1 loop each)\n" ] } ], "source": [ "%%timeit -n 1 -r 3\n", "jax_covariance(par)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The computation with Jax is slower, but it is also more accurate (although the added precision is not relevant).\n", "\n", "Minuit's HESSE algorithm still makes sense today. It has the advantage that it can process any function, while Jax cannot. Jax cannot differentiate a function that calls into C/C++ code or Cython code, for example.\n", "\n", "Final note: If we JIT compile `jax_covariance`, it greatly outperforms Minuit's HESSE algorithm, but that only makes sense if you need to compute the hessian at different parameter values, so that the extra time spend to compile is balanced by the time saved over many invocations. This is not what happens here, the Hessian in only needed at the best fit point." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "104 µs ± 12.8 µs per loop (mean ± std. dev. of 3 runs, 1 loop each)\n" ] } ], "source": [ "%%timeit -n 1 -r 3 jit_jax_covariance = jax.jit(jax_covariance); jit_jax_covariance(par)\n", "jit_jax_covariance(par)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is much faster... but only because the compilation cost is excluded here." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "285 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n" ] } ], "source": [ "%%timeit -n 1 -r 1\n", "# if we include the JIT compilation cost, the performance drops dramatically\n", "@jax.jit\n", "def jax_covariance(par):\n", " return jnp.linalg.inv(jax.hessian(nll)(par))\n", "\n", "\n", "jax_covariance(par)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With compilation cost included, it is much slower.\n", "\n", "Conclusion: Using the JIT compiler makes a lot of sense if the covariance matrix has to be computed repeatedly for the same cost function but different parameters, but this is not the case when we use it to compute parameter errors." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fit data points with uncertainties in x and y\n", "\n", "Let's say we have some data points $(x_i \\pm \\sigma_{x,i}, y_i \\pm \\sigma_{y,i})$ and we have a model $y=f(x)$ that we want to adapt to this data. If $\\sigma_{x,i}$ was zero, we could use the usual least-squares method, minimizing the sum of squared residuals $r^2_i = (y_i - f(x_i))^2 / \\sigma^2_{y,i}$. Here, we don't know where to evaluate $f(x)$, since the exact $x$-location is only known up to $\\sigma_{x,i}$.\n", "\n", "We can approximately extend the standard least-squares method to handle this case. We use that the uncertainty along the $x$-axis can be converted into an additional uncertainty along the $y$-axis with error propagation,\n", "\n", "$$\n", "f(x_i \\pm \\sigma_{x,i}) \\simeq f(x_i) \\pm f'(x_i)\\,\\sigma_{x,i}.\n", "$$\n", "\n", "Using this, we obtain modified squared residuals\n", "\n", "$$\n", "r^2_i = \\frac{(y_i - f(x_i))^2}{\\sigma^2_{y,i} + (f'(x_i) \\,\\sigma_{x,i})^2}.\n", "$$\n", "\n", "We demonstrate this with a fit of a polynomial." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:25:43.510168Z", "start_time": "2020-02-21T10:25:43.371319Z" } }, "outputs": [], "source": [ "# polynomial model\n", "def f(x, par):\n", " return jnp.polyval(par, x)\n", "\n", "\n", "# true polynomial f(x) = x^2 + 2 x + 3\n", "par_true = np.array((1, 2, 3))\n", "\n", "\n", "# grad computes derivative with respect to the first argument\n", "f_prime = jax.jit(jax.grad(f))\n", "\n", "\n", "# checking first derivative f'(x) = 2 x + 2\n", "assert f_prime(0.0, par_true) == 2\n", "assert f_prime(1.0, par_true) == 4\n", "assert f_prime(2.0, par_true) == 6\n", "# ok!\n", "\n", "# generate toy data\n", "n = 30\n", "data_x = np.linspace(-4, 7, n)\n", "data_y = f(data_x, par_true)\n", "\n", "rng = np.random.default_rng(seed=1)\n", "sigma_x = 0.5\n", "sigma_y = 5\n", "data_x += rng.normal(0, sigma_x, n)\n", "data_y += rng.normal(0, sigma_y, n)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:25:43.646212Z", "start_time": "2020-02-21T10:25:43.512384Z" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.errorbar(data_x, data_y, sigma_y, sigma_x, fmt=\"o\");" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:25:44.032210Z", "start_time": "2020-02-21T10:25:43.648365Z" } }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(876.49545695, dtype=float64)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# define the cost function\n", "@jax.jit\n", "def cost(par):\n", " result = 0.0\n", " for xi, yi in zip(data_x, data_y):\n", " y_var = sigma_y ** 2 + (f_prime(xi, par) * sigma_x) ** 2\n", " result += (yi - f(xi, par)) ** 2 / y_var\n", " return result\n", "\n", "cost.errordef = Minuit.LEAST_SQUARES\n", "\n", "# test the jit-ed function\n", "cost(np.zeros(3))" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:25:44.059729Z", "start_time": "2020-02-21T10:25:44.034029Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Migrad
FCN = 23.14 Nfcn = 91
EDM = 3.12e-05 (Goal: 0.0002)
Valid Minimum No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok Accurate Pos. def. Not forced
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Name Value Hesse Error Minos Error- Minos Error+ Limit- Limit+ Fixed
0 x0 1.25 0.15
1 x1 1.5 0.5
2 x2 1.6 1.5
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x0 x1 x2
x0 0.0223 -0.0388 (-0.530) -0.15 (-0.657)
x1 -0.0388 (-0.530) 0.24 0.172 (0.230)
x2 -0.15 (-0.657) 0.172 (0.230) 2.32
" ], "text/plain": [ "┌─────────────────────────────────────────────────────────────────────────┐\n", "│ Migrad │\n", "├──────────────────────────────────┬──────────────────────────────────────┤\n", "│ FCN = 23.14 │ Nfcn = 91 │\n", "│ EDM = 3.12e-05 (Goal: 0.0002) │ │\n", "├──────────────────────────────────┼──────────────────────────────────────┤\n", "│ Valid Minimum │ No Parameters at limit │\n", "├──────────────────────────────────┼──────────────────────────────────────┤\n", "│ Below EDM threshold (goal x 10) │ Below call limit │\n", "├───────────────┬──────────────────┼───────────┬─────────────┬────────────┤\n", "│ Covariance │ Hesse ok │ Accurate │ Pos. def. │ Not forced │\n", "└───────────────┴──────────────────┴───────────┴─────────────┴────────────┘\n", "┌───┬──────┬───────────┬───────────┬────────────┬────────────┬─────────┬─────────┬───────┐\n", "│ │ Name │ Value │ Hesse Err │ Minos Err- │ Minos Err+ │ Limit- │ Limit+ │ Fixed │\n", "├───┼──────┼───────────┼───────────┼────────────┼────────────┼─────────┼─────────┼───────┤\n", "│ 0 │ x0 │ 1.25 │ 0.15 │ │ │ │ │ │\n", "│ 1 │ x1 │ 1.5 │ 0.5 │ │ │ │ │ │\n", "│ 2 │ x2 │ 1.6 │ 1.5 │ │ │ │ │ │\n", "└───┴──────┴───────────┴───────────┴────────────┴────────────┴─────────┴─────────┴───────┘\n", "┌────┬─────────────────────────┐\n", "│ │ x0 x1 x2 │\n", "├────┼─────────────────────────┤\n", "│ x0 │ 0.0223 -0.0388 -0.15 │\n", "│ x1 │ -0.0388 0.24 0.172 │\n", "│ x2 │ -0.15 0.172 2.32 │\n", "└────┴─────────────────────────┘" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = Minuit(cost, np.zeros(3))\n", "m.migrad()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:25:44.566228Z", "start_time": "2020-02-21T10:25:44.065443Z" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.errorbar(data_x, data_y, sigma_y, sigma_x, fmt=\"o\", label=\"data\")\n", "x = np.linspace(data_x[0], data_x[-1], 200)\n", "par = np.array(m.values)\n", "plt.plot(x, f(x, par), label=\"fit\")\n", "plt.legend()\n", "\n", "# check fit quality\n", "chi2 = m.fval\n", "ndof = len(data_y) - 3\n", "plt.title(f\"$\\\\chi^2 / n_\\\\mathrm{{dof}} = {chi2:.2f} / {ndof} = {chi2/ndof:.2f}$\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We obtained a good fit." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.14 ('venv': venv)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "python3", "version": "3.10.8 (main, Oct 13 2022, 09:48:40) [Clang 14.0.0 (clang-1400.0.29.102)]" }, "vscode": { "interpreter": { "hash": "bdbf20ff2e92a3ae3002db8b02bd1dd1b287e934c884beb29a73dced9dbd0fa3" } } }, "nbformat": 4, "nbformat_minor": 2 }