{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/lukasheinrich/Code/iml_tutorial/_venv/lib/python3.9/site-packages/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.\n", " warnings.warn(\"JAX on Mac ARM machines is experimental and minimally tested. \"\n" ] } ], "source": [ "import pyhf\n", "pyhf.set_backend('jax')\n", "import jax\n", "import jaxlib\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import jax.numpy as jnp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## IML 2022-- Tutorial on Automatic Differentiation\n", "\n", "\n", "### Introduction\n", "\n", "Welcome to this tutorial on automatic differentiation. Automatic Differentiation is a method to compute exact derivatives of functions implements as **programs**. It's a widely applicable method and famously is used in\n", "many Machine learning optimization problems. E.g. neural networks, which are parametrized by weights $\\text{NN}(\\text{weights})$ are trained by (stocastic) **gradient** descent to find the minimum of the loss function $L$ where \n", "\n", "\n", "$$\\text{weights}_\\text{opt} = \\text{argmin}_\\text{weights} L(\\text{weights}) \\hspace{1cm} \\nabla L(\\text{weights}) = 0$$\n", "\n", "\n", "This means that efficient algorithms to compute derivatives are crucial.\n", "\n", "Aside from ML, many other use-cases require gradients: standard statistical analysis in HEP (fitting, hypothesis testing, ...) requires gradients. Uncertainty propagation (e.g. track parameters) uses gradients, etc..\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Other approaches to differentiation\n", "\n", "Before diving into automatic differentiation, let's review how my might otherwise compute derivatives\n", "\n", "\n", "\n", "#### Finite Differences\n", "\n", "\n", "A common appraoch to approximate gradients of a black-box function is to evaluate it\n", "at close-by points $x$ and $x+Δx$ and \n", "\n", "$\\frac{\\partial f}{\\partial x} \\approx \\frac{f(x) - f(x+\\Delta x}{\\Delta x}$ if $\\Delta x$ is sufficiently small\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def black_box_func(x):\n", " return x**3\n", "\n", "def true_gradient_func(x):\n", " return 3*x**2\n", "\n", "\n", "def plot_gradients(nsteps,title):\n", " xi = np.linspace(-5,5,nsteps)\n", " yi = black_box_func(xi)\n", "\n", " approx_gradient = np.gradient(yi,xi)\n", " true_gradient = true_gradient_func(xi)\n", "\n", " plt.plot(xi,yi, label = 'black-box func')\n", " plt.scatter(xi,yi)\n", "\n", " plt.plot(xi,approx_gradient, label = 'finite diff grad')\n", " plt.scatter(xi,approx_gradient)\n", "\n", " plt.plot(xi,true_gradient, label = 'true grad')\n", " plt.scatter(xi,true_gradient)\n", "\n", " plt.legend()\n", " plt.title(title)\n", " plt.show()\n", " \n", " \n", "plot_gradients(11, title = 'it is pretty bad if Δx is too large')\n", "plot_gradients(41, title = 'it gets better at the cost of many evaluations') " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

\n", "while only approximate, finite differences is *simple*. I don't need to know \n", "anything about the function beyond having the ability \n", "to *evaluate* it\n", "

\n", "

\n", "This way I can compute gradients of functions encoded as a computer\n", "program, and it works in any programming language\n", "

\n", "\n", "

\n", "For multivariate (possibly vector-valued) functions $\\vec{f}(\\vec{x}) = f_i(x_1,x_2,\\dots,x_n)$ one needs to compute a finite difference\n", "gradient for each partial derivative $\\frac{\\partial f}{\\partial x}$ in order to get the\n", "full jacobian / total derivative $df_i = J_{ik} dx_k\\; J_{ik} = \\frac{\\partial f_i}{\\partial x_k}$\n", " \n", "In high dimensions, the number of required evaluations explodes!\n", "

\n", "\n", "\n", "**Finite Differences**:\n", "\n", "* Pro: easy to to, works in any language, no \"framework needed\"\n", "* Con: inaccurate unless one does a lot of evaluations\n", "* Con does not scale to large dimensions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Symbolic Differentiation in a CAS\n", "\n", "Computer Algebra Systems (CAS), such as Mathematica (or sympy)\n", "can manipulate functional *expressions* and know about differentiation rules (and many other things)\n", "\n", "If the function / the prograrm which we want to derive is available as such an expression the \n", "symbolic differentiation can produce **exact gradients**" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$\\displaystyle x^{3}$" ], "text/plain": [ "x**3" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sympy\n", "\n", "def function(x):\n", " return x**3\n", "\n", "def true_deriv(x):\n", " return 3*x**2\n", "\n", "symbolic_x = sympy.symbols('x')\n", "symbolic_func = function(symbolic_x)\n", "symbolic_func" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using `lambdify` we can turn it into a normal python function we can evaluate" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "

" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xi = np.linspace(-5,5,11)\n", "yi = sympy.lambdify(symbolic_x,symbolic_func)(xi)\n", "plt.plot(xi,yi)\n", "plt.scatter(xi,yi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`symbolic_func` is now an experssion which we can differentiate *symbolically*" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$\\displaystyle 3 x^{2}$" ], "text/plain": [ "3*x**2" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "symbolic_deriv = symbolic_func.diff(symbolic_x)\n", "symbolic_deriv" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def plot_symbolic(nsteps,title):\n", " xi = np.linspace(-5,5,nsteps)\n", " yi = sympy.lambdify(symbolic_x,symbolic_func)(xi)\n", " plt.scatter(xi,yi)\n", " plt.plot(xi,yi, label = 'function')\n", "\n", "\n", "\n", " yi = true_deriv(xi)\n", " plt.plot(xi,yi)\n", " plt.scatter(xi,yi, label = 'true deriv')\n", "\n", " yi = sympy.lambdify(symbolic_x,symbolic_deriv)(xi)\n", " plt.plot(xi,yi)\n", " plt.scatter(xi,yi, label = 'symbolic deriv')\n", "\n", " plt.legend()\n", " plt.title(title)\n", " plt.show()\n", " \n", " \n", "plot_symbolic(11,title = 'the symbolid derivative is always exact')\n", "plot_symbolic(4, title = 'it does not matter where/how often you evaluate it') " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Chain Rule in CAS\n", "\n", "We can even handle function compositions" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$\\displaystyle \\cos{\\left(x^{2} \\right)}$" ], "text/plain": [ "cos(x**2)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def f1(x):\n", " #standard operations are overloaded\n", " return x**2\n", "\n", "def f2(x):\n", " #note here we use a special cos function from sympy\n", " #instead of e.g. np.cos or math.cos\n", " return sympy.cos(x) \n", " \n", "\n", "composition = f2(f1(symbolic_x))\n", "\n", "composition" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$\\displaystyle - 2 x \\sin{\\left(x^{2} \\right)}$" ], "text/plain": [ "-2*x*sin(x**2)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "composition.diff(symbolic_x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since `sympy` knows about the chain rule it can differentiate accordingly" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Problems with Symbolic Differentiation\n", "\n", "This looks great! We get exact derivatives. However,\n", "there are drawbacks\n", "\n", "1. Need to implement it in CAS\n", "\n", "Most functions we are interested in are not implemented \n", "e.g. Mathematica. Rather we have loads of C, C++, Python\n", "code that we are interested in. \n", "\n", "But ok, `sympy` alleviates this to some degree. The functions\n", "`f1` and `f2` are fairly generic since they use operator\n", "overloading. So a symbolic program and a \"normal\" program\n", "could only differ by a few import statements\n", "\n", "\n", "\n", "```python\n", "from sympy import cos\n", "\n", "def f1(x):\n", " return x**2\n", "\n", "def f2(x):\n", " return cos(x) \n", "```\n", "\n", "versus:\n", "\n", "```python\n", "from math import cos\n", "\n", "def f1(x):\n", " return x**2\n", "\n", "def f2(x):\n", " return cos(x) \n", "```\n", "\n", "\n", "Note the code is almost exactly the same\n", "\n", "But not all our functions are so simple!\n", "\n", "\n", "**Expression swell**\n", "\n", "Let's look at a quadratic map which is applied a few times" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$\\displaystyle 243 x^{2} + 729 x + 81 \\left(x^{2} + 3 x + 4\\right)^{2} + 27 \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 9 \\left(9 x^{2} + 27 x + 3 \\left(x^{2} + 3 x + 4\\right)^{2} + \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 52\\right)^{2} + 3 \\left(27 x^{2} + 81 x + 9 \\left(x^{2} + 3 x + 4\\right)^{2} + 3 \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + \\left(9 x^{2} + 27 x + 3 \\left(x^{2} + 3 x + 4\\right)^{2} + \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 52\\right)^{2} + 160\\right)^{2} + \\left(81 x^{2} + 243 x + 27 \\left(x^{2} + 3 x + 4\\right)^{2} + 9 \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 3 \\left(9 x^{2} + 27 x + 3 \\left(x^{2} + 3 x + 4\\right)^{2} + \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 52\\right)^{2} + \\left(27 x^{2} + 81 x + 9 \\left(x^{2} + 3 x + 4\\right)^{2} + 3 \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + \\left(9 x^{2} + 27 x + 3 \\left(x^{2} + 3 x + 4\\right)^{2} + \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 52\\right)^{2} + 160\\right)^{2} + 484\\right)^{2} + 1456$" ], "text/plain": [ "243*x**2 + 729*x + 81*(x**2 + 3*x + 4)**2 + 27*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 9*(9*x**2 + 27*x + 3*(x**2 + 3*x + 4)**2 + (3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 52)**2 + 3*(27*x**2 + 81*x + 9*(x**2 + 3*x + 4)**2 + 3*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + (9*x**2 + 27*x + 3*(x**2 + 3*x + 4)**2 + (3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 52)**2 + 160)**2 + (81*x**2 + 243*x + 27*(x**2 + 3*x + 4)**2 + 9*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 3*(9*x**2 + 27*x + 3*(x**2 + 3*x + 4)**2 + (3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 52)**2 + (27*x**2 + 81*x + 9*(x**2 + 3*x + 4)**2 + 3*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + (9*x**2 + 27*x + 3*(x**2 + 3*x + 4)**2 + (3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 52)**2 + 160)**2 + 484)**2 + 1456" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def quadmap(x):\n", " return x**2 + 3*x + 4\n", "\n", "def func(x):\n", " for i in range(6):\n", " x = quadmap(x)\n", " return x\n", "\n", "quad_6_times = func(symbolic_x)\n", "quad_6_times" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This looks pretty intimidating. What happened? \n", "Symbolic programs run through the prgram and \n", "accumulate the full program into a single expression\n", "\n", "If we would just blindly differentiate this it would look like this" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$\\displaystyle 486 x + 81 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 27 \\cdot \\left(12 x + 2 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 18\\right) \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right) + 9 \\cdot \\left(36 x + 6 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 2 \\cdot \\left(12 x + 2 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 18\\right) \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right) + 54\\right) \\left(9 x^{2} + 27 x + 3 \\left(x^{2} + 3 x + 4\\right)^{2} + \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 52\\right) + 3 \\cdot \\left(108 x + 18 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 6 \\cdot \\left(12 x + 2 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 18\\right) \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right) + 2 \\cdot \\left(36 x + 6 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 2 \\cdot \\left(12 x + 2 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 18\\right) \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right) + 54\\right) \\left(9 x^{2} + 27 x + 3 \\left(x^{2} + 3 x + 4\\right)^{2} + \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 52\\right) + 162\\right) \\left(27 x^{2} + 81 x + 9 \\left(x^{2} + 3 x + 4\\right)^{2} + 3 \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + \\left(9 x^{2} + 27 x + 3 \\left(x^{2} + 3 x + 4\\right)^{2} + \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 52\\right)^{2} + 160\\right) + \\left(324 x + 54 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 18 \\cdot \\left(12 x + 2 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 18\\right) \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right) + 6 \\cdot \\left(36 x + 6 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 2 \\cdot \\left(12 x + 2 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 18\\right) \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right) + 54\\right) \\left(9 x^{2} + 27 x + 3 \\left(x^{2} + 3 x + 4\\right)^{2} + \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 52\\right) + 2 \\cdot \\left(108 x + 18 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 6 \\cdot \\left(12 x + 2 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 18\\right) \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right) + 2 \\cdot \\left(36 x + 6 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 2 \\cdot \\left(12 x + 2 \\cdot \\left(4 x + 6\\right) \\left(x^{2} + 3 x + 4\\right) + 18\\right) \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right) + 54\\right) \\left(9 x^{2} + 27 x + 3 \\left(x^{2} + 3 x + 4\\right)^{2} + \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 52\\right) + 162\\right) \\left(27 x^{2} + 81 x + 9 \\left(x^{2} + 3 x + 4\\right)^{2} + 3 \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + \\left(9 x^{2} + 27 x + 3 \\left(x^{2} + 3 x + 4\\right)^{2} + \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 52\\right)^{2} + 160\\right) + 486\\right) \\left(81 x^{2} + 243 x + 27 \\left(x^{2} + 3 x + 4\\right)^{2} + 9 \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 3 \\left(9 x^{2} + 27 x + 3 \\left(x^{2} + 3 x + 4\\right)^{2} + \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 52\\right)^{2} + \\left(27 x^{2} + 81 x + 9 \\left(x^{2} + 3 x + 4\\right)^{2} + 3 \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + \\left(9 x^{2} + 27 x + 3 \\left(x^{2} + 3 x + 4\\right)^{2} + \\left(3 x^{2} + 9 x + \\left(x^{2} + 3 x + 4\\right)^{2} + 16\\right)^{2} + 52\\right)^{2} + 160\\right)^{2} + 484\\right) + 729$" ], "text/plain": [ "486*x + 81*(4*x + 6)*(x**2 + 3*x + 4) + 27*(12*x + 2*(4*x + 6)*(x**2 + 3*x + 4) + 18)*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16) + 9*(36*x + 6*(4*x + 6)*(x**2 + 3*x + 4) + 2*(12*x + 2*(4*x + 6)*(x**2 + 3*x + 4) + 18)*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16) + 54)*(9*x**2 + 27*x + 3*(x**2 + 3*x + 4)**2 + (3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 52) + 3*(108*x + 18*(4*x + 6)*(x**2 + 3*x + 4) + 6*(12*x + 2*(4*x + 6)*(x**2 + 3*x + 4) + 18)*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16) + 2*(36*x + 6*(4*x + 6)*(x**2 + 3*x + 4) + 2*(12*x + 2*(4*x + 6)*(x**2 + 3*x + 4) + 18)*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16) + 54)*(9*x**2 + 27*x + 3*(x**2 + 3*x + 4)**2 + (3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 52) + 162)*(27*x**2 + 81*x + 9*(x**2 + 3*x + 4)**2 + 3*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + (9*x**2 + 27*x + 3*(x**2 + 3*x + 4)**2 + (3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 52)**2 + 160) + (324*x + 54*(4*x + 6)*(x**2 + 3*x + 4) + 18*(12*x + 2*(4*x + 6)*(x**2 + 3*x + 4) + 18)*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16) + 6*(36*x + 6*(4*x + 6)*(x**2 + 3*x + 4) + 2*(12*x + 2*(4*x + 6)*(x**2 + 3*x + 4) + 18)*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16) + 54)*(9*x**2 + 27*x + 3*(x**2 + 3*x + 4)**2 + (3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 52) + 2*(108*x + 18*(4*x + 6)*(x**2 + 3*x + 4) + 6*(12*x + 2*(4*x + 6)*(x**2 + 3*x + 4) + 18)*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16) + 2*(36*x + 6*(4*x + 6)*(x**2 + 3*x + 4) + 2*(12*x + 2*(4*x + 6)*(x**2 + 3*x + 4) + 18)*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16) + 54)*(9*x**2 + 27*x + 3*(x**2 + 3*x + 4)**2 + (3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 52) + 162)*(27*x**2 + 81*x + 9*(x**2 + 3*x + 4)**2 + 3*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + (9*x**2 + 27*x + 3*(x**2 + 3*x + 4)**2 + (3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 52)**2 + 160) + 486)*(81*x**2 + 243*x + 27*(x**2 + 3*x + 4)**2 + 9*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 3*(9*x**2 + 27*x + 3*(x**2 + 3*x + 4)**2 + (3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 52)**2 + (27*x**2 + 81*x + 9*(x**2 + 3*x + 4)**2 + 3*(3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + (9*x**2 + 27*x + 3*(x**2 + 3*x + 4)**2 + (3*x**2 + 9*x + (x**2 + 3*x + 4)**2 + 16)**2 + 52)**2 + 160)**2 + 484) + 729" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "quad_6_times.diff(symbolic_x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This looks even worse!\n", "\n", "Also note that that if we just blindly substitute x for some value\n", "e.g. x=2, we would be computing a lot of the same terms\n", "manyt times. E.g. in the above expression $x^2+3x+4$ appears in a \n", "lot of places due to the \"structure' of the original progrm\n", "\n", "If you knew the structure of the program you likely could precompute\n", "some of these repeating terms. However once it got all expanded all\n", "this knowledge about the structure is gone!\n", "\n", "Modern CAS can recover some of this by finding \"common subexpressions\" (CSE)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "([(x0, x**2),\n", " (x1, (3*x + x0 + 4)**2),\n", " (x2, (9*x + 3*x0 + x1 + 16)**2),\n", " (x3, (27*x + 9*x0 + 3*x1 + x2 + 52)**2),\n", " (x4, (81*x + 27*x0 + 9*x1 + 3*x2 + x3 + 160)**2)],\n", " [729*x + 243*x0 + 81*x1 + 27*x2 + 9*x3 + 3*x4 + (243*x + 81*x0 + 27*x1 + 9*x2 + 3*x3 + x4 + 484)**2 + 1456])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sympy.cse(quad_6_times)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But it's not as automatic and may note find all relevant subexpressions. In any case it's trying hard to recover some\n", "of the structure that is already implicitly present in the prograam we want to differentiate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Control Flow**\n", "\n", "In addition to looping constucts like above, a lot of the functions we are interested in have \n", "control flow structures like if/else statements, while loops, etc..\n", "\n", "\n", "If we try to create a symbolic expression with conditionals we fail badly\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "cannot determine truth value of Relational", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Input \u001b[0;32mIn [12]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m3\u001b[39m\n\u001b[0;32m----> 7\u001b[0m symbolic_result \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43msymbolic_x\u001b[49m\u001b[43m)\u001b[49m\n", "Input \u001b[0;32mIn [12]\u001b[0m, in \u001b[0;36mfunc\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfunc\u001b[39m(x):\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m2\u001b[39m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", "File \u001b[0;32m~/Code/iml_tutorial/_venv/lib/python3.9/site-packages/sympy/core/relational.py:511\u001b[0m, in \u001b[0;36mRelational.__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 510\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__bool__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 511\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcannot determine truth value of Relational\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "\u001b[0;31mTypeError\u001b[0m: cannot determine truth value of Relational" ] } ], "source": [ "def func(x):\n", " if x > 2:\n", " return x**2\n", " else:\n", " return x**3\n", " \n", "symbolic_result = func(symbolic_x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That's too bad because this is a perfectly respectable functino *almost everywhere*" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xi = np.linspace(-2,5,1001)\n", "yi = np.asarray([func(xx) for xx in xi])\n", "\n", "plt.plot(xi,yi)\n", "plt.scatter(xi,yi)\n", "plt.title(\"pretty smooth except at x=2\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we could afford finite diffences it would compute gradients *just fine*." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "g = np.gradient(yi,xi)\n", "plt.plot(xi,g)\n", "plt.scatter(xi,g)\n", "plt.ylim(-2,10)\n", "plt.title('''\\\n", "parabolesque gradient in x^3 region,\n", "linear in x^2 region as expected''');" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In short: symbolic differentiation is not our saving grace.\n", "\n", "* Pro: Gradients are exact, if you can compute them\n", "* Con: Need to implement in CAS. Full-featured Cas not easily available in all languages\n", "* Con: lead to expression swell by losing any structure of the program (needs to be recovered separately0\n", "* Con: Cannot handle common control-flow structures like loops and conditionals easily" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## What we need\n", "\n", "To recap: \n", "\n", "Finite differences is\n", "* easy to implement in any language\n", "* handles arbitrary (halting) programs but\n", "* is inaccurate unless we're ready to pay a large computational overhead\n", "\n", "Symbolic differentiation is:\n", "* exact to machine precision\n", "* can lead to exccessive / inefficient computation if not careful\n", "* cannot handle complex programs with control flow structures\n", "\n", "\n", "

So what we need is a third approach!

\n", "\n", "One, that is \n", "* exact\n", "* efficient\n", "* can handle arbitrayr programs\n", "* that is easy to implement in many languages\n", "\n", "\n", "This third approach is 'Automatic' differentiation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Short Interlude on Linear Transformations\n", "\n", "Before we start, let's first look at *linear transformations** from ℝᵐ → ℝⁿ:\n", "$$y(x) = Ax$$\n", "\n", "With a given basis, this is representable as a (rectangular0 matrix: \n", "$$y_i(x) = A_{ij}x_j$$\n", "\n", "\n", "For a given linear problem, there are few ways we can run this computation\n", "\n", "\n", "1. **full matrix computation**\n", "\n", " i.e. we store the full (dense) $nm$ elements of the rectangular matrix and \n", " compute an explicit matrix multiplication.\n", " \n", " The computation can be fully generic for any matrix\n", " \n", "```python\n", " def result(matrix, vector):\n", " return np.matmul(matrix,vector)\n", "```\n", "
\n", "\n", "2. **sparse matrix computation**\n", "\n", " If many $A_ij=0$, it might be wasteful to expend memory on them. We can just \n", " create a sparse matrix, by\n", " \n", " * storing only the non-zerro elements \n", " * storing a look-up table, where those elements are in the matrix\n", " \n", " The computation can be kept general\n", "\n", "```python\n", " def result(sparse_matrix, vector):\n", " return sparse_matmul(sparse_matrix,vector)\n", "```\n", "\n", "
\n", " \n", "3. **matrix-free computation**\n", "\n", " In many cases a linear program is not explicitly given by a Matrix, but it's\n", " given as *code* / a \"black-box\" function. As long as the computation in the body of \n", " keeps to (hard-coded) linear transformation the program is linear. The matrix elements\n", " are no longer explicitly enumerated and stored in a data structure\n", " but implicitly defined in the source code.\n", " \n", " This is not anymore a generic computation, but each linear transformation is its own\n", " program. At the same time this is also the most memory efficient representation. No\n", " lookup table is needed since all constants are hard-coded.\n", " \n", " \n", "```python\n", " def linear_program(vector):\n", " z1,z2 = 0,0\n", " z1 += A_11*x1\n", " z2 += A_12*x2\n", " z2 += A_22*x2\n", " return [z1,z2]\n", "```\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Recovering Matrix Elements from matrix-free computations\n", "\n", "\n", "#### Matrix-vector products\n", "\n", "In the matrix-free setting, the program does not give access to the matrix elements,\n", "but only computes \"matrix-vector\" products (MVP)\n", "\n", "We can use basis vectors to recover the matrix **one column at a time**\n", "\n", "\"A\n", "\n" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "M derived from matrix-vector products:\n", "[[2 0 0]\n", " [0 1 3]]\n" ] } ], "source": [ "def matrix_vector_product(x):\n", " x1,x2,x3 = x\n", " z1,z2 = 0,0\n", " z1 += 2*x1 #MVP statement 1\n", " z2 += 1*x2 #MVP statement 2\n", " z2 += 3*x3 #MVP statement 3\n", " return np.asarray([z1,z2])\n", "\n", "M = np.concatenate([\n", " matrix_vector_product(np.asarray([1,0,0])).reshape(-1,1),\n", " matrix_vector_product(np.asarray([0,1,0])).reshape(-1,1),\n", " matrix_vector_product(np.asarray([0,0,1])).reshape(-1,1),\n", "],axis=1)\n", "print(f'M derived from matrix-vector products:\\n{M}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Vector Matrix product (VMP)\n", "\n", "The same matrix induces a \"dual\" linear map: ℝⁿ → ℝᵐ \n", "$$ x_k = y_i A_{ik}$$\n", "\n", "i.e. instead of a Matrix-vector product it's now a *vector-Matrix* product (VMP)\n", "\n", "If one has access to a \"vector-Matrix\" program corresponding to a matrix $A$ one\n", "can again -- as in the MVP-case -- recover the matrix elements, by feeding in basis vectors.\n", "\n", "This time the matrix is built **one row at a time**\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"A" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "M derived from vector-matix products:\n", "[[2 0 0]\n", " [0 1 3]]\n" ] } ], "source": [ "def vector_matrix_product(z):\n", " x1,x2,x3 = 0,0,0\n", " z1,z2 = z\n", "\n", " x3 += z2*3 #VMP version of statement 3\n", " x2 += z2*1 #VMP version of statement 2\n", " x1 += z1*2 #VMP version of statement 1\n", "\n", " return np.asarray([x1,x2,x3])\n", "\n", "\n", "M = np.concatenate([\n", " vector_matrix_product(np.asarray([1,0])).reshape(1,-1),\n", " vector_matrix_product(np.asarray([0,1])).reshape(1,-1),\n", "],axis=0)\n", "print(f'M derived from vector-matix products:\\n{M}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Short Recap:\n", "\n", "For a given linear transformation, characterized by a matrix $A_{ij}$ we have a forward (matrix-vector) and backward (vector-matrix) map $$y_i = A_{ij}x_k$$ $$x_j = y_i A_{ij}$$\n", "\n", "and we can use either to recover the full matrix $A_{ij}$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Wide versus Tall Transformation\n", "\n", "If you look at the code above, you'll notice that the number of calls necessary to the MVP or VMP program\n", "is related to the dimensions of matrix itself.\n", "\n", "For a $n\\times m$ matrix (for a map: ℝᵐ → ℝⁿ), you need as $m$ calls to the \"Matrix-vector\" program to \n", "built the full matrix one-column-at-a-time. Likewise you need $n$ calls to the \"vector-Matrix\" program\n", "to build the matrix one-row-at-a-time.\n", "\n", "This becomes relevant for very asymmetric maps: e.g. scalar maps from very high-dimensional spaces\n", "$\\mathbb{R}^{10000} \\to \\mathbb{R}$ the \"vector-Matrix\" appraoch is *vastly* more efficient than the\n", "\"Matrix-vector one. There's only one row, so only one call too the VMP program is needed to construct the full matrix!\n", "\n", "Similarly, functions mapping few variables into very high dimensional spaces $\\mathbb{R} \\to \\mathbb{R}^{10000}$\n", "it's the opposite: the \"Matrix-vector\" approach is much better suited than the \"vector-Matrix\" one (this time it's a single column!).\n", "\n", "\n", "## Function Compositions\n", "\n", "Of course copositions $(f\\circ g)(x) = f(g(x))$ of linear maps are also linear, so the above applies.\n", "\n", "\"A\n", "\n", "Depending on whether the \"Matrix-vector\" or \"vector-Matrix\" appraoch is used, the data is propagated **forwards** or **backwards**." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"A\n", "\"A\n", "\n", "### From Matrices to Graphs\n", "\n", "The \"vector-Matrix\" or \"Matrix-vector\" picture can be generalized to arrbitrary directed acyclic graphs.\n", "\n", "* In the \"Matrix-vector\" picture the node value is the edge-weighted sum of the \"upstream nodes\".\n", "* In the \"vector-Matrix\" picture the node value is the edge-weighted sum of its \"downstream nodes\".\n", "\n", "(one could in principle always recove a rectangular/matrix-like version of a DAG by inserting trivial nodes)\n", "\n", "\n", "\"A\n", "\"A\n" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def graph_like(x):\n", " x1,x2,x3 = x\n", " y1 = 2*x1+x2\n", " z1,z2 = y1+2*x3,x3-y1 #note that we reach \"over\" the \"ys\" and diectly touch x_n\n", " return np.asarray([z1,z2])\n", "\n", "def matrix_like(x):\n", " x1,x2,x3 = x\n", " y1 = 2*x1+x2\n", " y2 = x3 #can just introduce a dummy variable to make it matrix-like\n", " z1,z2 = y1+2*x3,y2-y1\n", " return np.asarray([z1,z2])\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "M derived from matrix like computation:\n", "[[ 2 1 2]\n", " [-2 -1 1]]\n", "M derived from graph-like products:\n", "[[ 2 1 2]\n", " [-2 -1 1]]\n" ] } ], "source": [ "M = np.concatenate([\n", " matrix_like(np.asarray([1,0,0])).reshape(-1,1),\n", " matrix_like(np.asarray([0,1,0])).reshape(-1,1),\n", " matrix_like(np.asarray([0,0,1])).reshape(-1,1),\n", "],axis=1)\n", "print(f'M derived from matrix like computation:\\n{M}')\n", "\n", "\n", "M = np.concatenate([\n", " graph_like(np.asarray([1,0,0])).reshape(-1,1),\n", " graph_like(np.asarray([0,1,0])).reshape(-1,1),\n", " graph_like(np.asarray([0,0,1])).reshape(-1,1),\n", "],axis=1)\n", "print(f'M derived from graph-like products:\\n{M}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Derivatives\n", "\n", "\n", "Why are we talking about linear transformations? After all lot of the code we write is non-linear! However, derivatives are always linear.\n", "\n", "And derivatives (the jacobian) of a composition $f\\circ g$ is the composition of linear derivatives (the jacobians\n", "of each map) i.e. the full jacobian Matrix is the result of multipying all Jacobians of the composition.\n", "$$J = J_0 J_1 J_2 J_3 \\dots J_n $$\n", "\n", "(This is just the chain rule)\n", "$$z = f(y) = f(g(x))\\hspace{1cm} \\frac{\\partial f_i}{\\partial x_j} = \\frac{\\partial f_i}{\\partial z_j}\\frac{\\partial z_j}{\\partial x_k}$$ \n", "\n", "\n", "I.e. finding derivatives, means characterizing the jacobian matrix. From the above discussion, we can use the \"Jacobian-vector product\" (JVP, builds Jacobians column-wise) or \"vector-Jacobian product\" (builds Jacobians row-wise) approach.\n", "\n", "In the language of automatic differentiation \n", "\n", "* Jacobian-vector products (JVP) = forward mode (forward propagation)\n", "\n", "$$ Jv_n = J_0 J_1 J_3 \\dots J_n v_n = J_0 J_1 J_2 J_3 v_3 = J_0 J_1 J_2 v_2 = J_0 J_1 v_1 = J_0 v_0 = \\text{col}$$\n", "\n", "* vector-Jacobian products (VJP) = reverse mode (reverse propagation)\n", "\n", "$$ v_0 J = v_0 J_0 J_1 J_3 \\dots J_n = v_1 J_1 J_2 J_3 \\dots J_n = v_2 J_2 J_3 \\dots J_n = v_3 J_3 \\dots J_n = \\dots = v_n J_n = \\text{row}$$\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example\n", "\n", "Let's work this out on a very simple problem\n", "\n", "\n", "\"A" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the forward pass we use \"Matrix-vector\" products and need to do two evaluation\n", "\n", "\"A" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the backward pass we use \"vector-Matrix\" products and need to do only a single evaluation\n", "\n", "\"A" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Both approaches give the same result. Since this is a map from $\\mathbb{R}^2 \\to \\mathbb{R}^1$ the backward pass is more efficient than the forward pass\n", "\n", "\n", "Let's look at a real-life example\n", "\n", "$$z(x_1,x_2) = y + x_2 = x_1x_2 + x_2$$\n", "\n", "This is easy python code" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "12\n" ] } ], "source": [ "def mul_func(x1,x2):\n", " return x1*x2\n", "\n", "def sum_func(x1,x2):\n", " return x1+x2\n", "\n", "def function(x):\n", " x1,x2 = x\n", " y = mul_func(x1,x2)\n", " z = sum_func(y,x2)\n", " return z\n", "\n", "print(function([2,4]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the forward pass, an autodiff system needs to create a JVP implementation for each elementary operation " ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def mul_jvp(x1,dx1,x2,dx2):\n", " y = mul_func(x1,x2)\n", " dy = x1*dx2 + x2*dx1\n", " return y, dy\n", "\n", "def sum_jvp(x1,dx1,x2,dx2):\n", " return sum_func(x1,x2), dx1 + dx2\n", "\n", "def function_jvp(x,dx):\n", " x1,x2 = x\n", " dx1,dx2 = dx\n", " y, dy = mul_jvp(x1,dx1,x2,dx2)\n", " z, dz = sum_jvp(y,dy, x2, dx2)\n", " return z,dz" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since in the forward pass we build \"column-at a time\" and our final jacobian is has shape (1x2), i.e. two columns we need two forward passes to get the full Jacobian. Not that for eacch forward pass we also get the fully computed functino value delivered on top!\n", "\n", "\n", "Also note that the \"JVP\" version of the functino has the same *structure* as the original function. For each call in the original program there is an equivalent call in the JVP program. However the JVP call does always two things at once\n", "\n", "1. compute the nominal result\n", "2. compute the differentials\n", "\n", "So it has roughly 2x the run-time as the original program (depending on the complexity of the derivatives). Said another way: computing the one-pass in the derivative has the same computational complexity as the function itself." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(12, 4)\n", "(12, 3)\n" ] } ], "source": [ "print(function_jvp([2,4],[1,0]))\n", "print(function_jvp([2,4],[0,1]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the backward pass we build \"row-at-a-time'. For each elementary operation we need to build a VJP implementation" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "def mul_vjp(x1,x2,dx1,dx2,dout):\n", " dx2 += dout * x1\n", " dx1 += dout * x2\n", " return dx1,dx2\n", "\n", "def sum_vjp(x1,x2,dx1,dx2,dout):\n", " dx1 += dout * 1\n", " dx2 += dout * 1\n", " return dx1,dx2\n", "\n", "def function_vjp(x,dz):\n", " \n", " #run forward\n", " x1,x2 = x\n", " y = mul_func(x1,x2)\n", " z = sum_func(y,x2)\n", "\n", " #zero gradients\n", " dy = 0 \n", " dx1 = 0\n", " dx2 = 0\n", " \n", " #run backward\n", " dy,dx1 = sum_vjp(y,x1, dy, dx1, dz)\n", " dx1,dx2 = mul_vjp(x1,x2, dx1, dx2, dy)\n", " return z,[dx1,dx2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here, we see the power of backward propagation (or the reverse mode) we get all gradients of the single row ine oone go. Since this Jacobian only has one row, we're done! And we get the function value delivered on top of the gradients as well!" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(12, [5.0, 2.0])\n" ] } ], "source": [ "print(function_vjp([2,4],1.0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again, let's look at the \"VJP\" code. The forward pass is *exactly* the same as the original function. This just records the final result and all intermediate values, which we will need for the backward pass.\n", "\n", "Moving on to the backward pass, we see again, as in JVP, it has the same *structure* as the forward pass. For each call to a subroutine there is an equivalent call in the backward pass to compute the VJP. \n", "\n", "\n", "As in the JVP case, the computational complexity of one backward pass is roughly the same as the forward pass. Now unlike the JVP-case we only needed a single pass for **all the gradients** of this scalar function. So obtaining the **full gradient** of a function is only as expensive as the function itself." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Recap:\n", "\n", "Above we have built a *manual* autodiff system. Let's recap what we needed to do\n", "\n", "* define a set of operations we want to be differentiable\n", "* define sub-routines for nominal operations, JVP and VJP\n", "\n", "\n", "\n", "Once given a program, we had to do the following\n", "\n", "**In the forward mode**:\n", "\n", "* just replace the nominal function with the JVP one\n", "* for each variable in the program allocate a \"differential\" variable and pass it\n", " into the JVP whereever we also pass the nominal variable\n", " \n", " \n", "**In the backward mode**:\n", "\n", "* Run the program forward, keep track of all values\n", "* keep track of the order of operations on a \"record\" of sorts\n", "* allocate \"differential\" variables for all values and initialize to zero\n", "* use the record to replay the order of operations backwards, passing along the \n", " appropriate differential values, and updating the relevant ones with the result\n", " of the VJP\n", "\n", "\n", "All of this is pretty mechanistic and hence \"automatable\". And given that it's a very narrow\n", "domain of only implementing JVP/JVP operations this is easy to do in any language.\n", "\n", "That's why it's **automatic differentiation**\n", " \n", " \n", "What we gain from this is that we get\n", "\n", "* exact derivatives (to machine precision) for arbitrary composed of the operations we define\n", "* complexity of a derivative-pass through the program is of same order of complexity as the original program\n", "* often only a single pass is necessary (e.g. scalar multi-variate functions)\n", "* unlike symbolic differrentiation, the structure of the program is preserved and allows naturally to avoid\n", " repetitive calculations of the same values\n", "* (we will see that) arbitrary control flows are handles naturally\n", "* it's something that is easy for a comoputer do and for a progarmmer to imlpement\n", "\n", "\n", "\n", "Some notes on pros and cons:\n", "\n", "**In the forward mode**:\n", "\n", "the signature of each opeartion basically extends \n", " ```c++\n", " float f(float x,float y,float z)\n", " ```\n", " to\n", " ```c++\n", " pair f(float x,float dx,float y,float float dy, float z,float dz)\n", " ```\n", " * if you use composite types (\"dual numbers\") that hold both x,dx you can basically \n", " keep the signature unchanged\n", " ```c++\n", " f(dual x, dual x, dual z)\n", " ```\n", " * together with operator overloading on these dual types e.g. `dual * dual` you can \n", " essentially keep the source code unchanged\n", " ```c++\n", " float f(float x, float y): return x*y\n", " ``` \n", " ->\n", " ```c++\n", " dual f(dual x,dual y): return x*y\n", " ```\n", " \n", "* That means it's very easy implement. And memory efficient, no superfluous values are kept when they run out of scope.\n", "* But forward more better for vector-value functions of few parameters\n", "\n", "\n", "**In the reverse mode**:\n", "\n", "* very efficient, but we need to keep track of order (need a \"tape\" of sorts)\n", "* since we need to access all intermediate varriables, we can run into memory bounds\n", "* the procedurer is a bit more complex than fwd: 1) run fwd, 2) zero grads 3) run bwd\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## I don't want to implement an autodiff system.. Aren't there libraries for this??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Yes there are! And a lot of them in many languages. On the othe rhand, try finding CAS systems in each of those \n", "\n", "\"A\n", "\n", "This is PyHEP, so let's focus on Python. Here, basically what you think of as \"Machine Learning frameworks\" are at the core autodiff libraries\n", "\n", "* Tensorflow\n", "* PyTorch\n", "* JAX" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's focus on jax" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "def f(x):\n", " return x**2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`jax.numpy` is almost a drop-in rerplacement for `numpy`. I do `import jax.numpy as jnp` but if you're daring you could do `import jax.numpy as np`" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] } ], "source": [ "x = jnp.array([1,2,3])\n", "y = jnp.array([2,3,4])" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[3 5 7]\n", "[ 2 6 12]\n", "[0. 0.69314718 1.09861229]\n", "[ 7.3890561 20.08553692 54.59815003]\n" ] } ], "source": [ "print(x+y)\n", "print(x*y)\n", "print(jnp.log(x))\n", "print(jnp.exp(y))" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "def f(x):\n", " return x**3" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "64.0\n", "48.0\n", "24.0\n", "6.0\n", "0.0\n" ] } ], "source": [ "print(f(4.0))\n", "print(jax.grad(f)(4.0)) #boom!\n", "print(jax.grad(jax.grad(f))(4.0)) #boom!\n", "print(jax.grad(jax.grad(jax.grad(f)))(4.0)) #boom!\n", "print(jax.grad(jax.grad(jax.grad(jax.grad(f))))(4.0)) #boom!" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xi = jnp.linspace(-5,5)\n", "yi = f(xi)\n", "\n", "plt.plot(xi,yi)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "Gradient only defined for scalar-output functions. Output had shape: (50,).", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Input \u001b[0;32mIn [31]\u001b[0m, in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgrad\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mxi\u001b[49m\u001b[43m)\u001b[49m\n", " \u001b[0;31m[... skipping hidden 4 frame]\u001b[0m\n", "File \u001b[0;32m~/Code/iml_tutorial/_venv/lib/python3.9/site-packages/jax/_src/api.py:1019\u001b[0m, in \u001b[0;36m_check_scalar\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 1017\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(aval, ShapedArray):\n\u001b[1;32m 1018\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m aval\u001b[38;5;241m.\u001b[39mshape \u001b[38;5;241m!=\u001b[39m ():\n\u001b[0;32m-> 1019\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhad shape: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00maval\u001b[38;5;241m.\u001b[39mshape\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m 1020\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1021\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(msg(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhad abstract value \u001b[39m\u001b[38;5;132;01m{\u001b[39;00maval\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m))\n", "\u001b[0;31mTypeError\u001b[0m: Gradient only defined for scalar-output functions. Output had shape: (50,)." ] } ], "source": [ "jax.grad(f)(xi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Whoops, jax.grad defaults to reverse mode with a single backward pass, but through broadcasting we get a `vector -> vector` map. We can use some jax magic to \"unbroadcast\" the function, take the gradient and re-broadcast it" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([7.50000000e+01, 6.90024990e+01, 6.32548938e+01,\n", " 5.77571845e+01, 5.25093711e+01, 4.75114536e+01,\n", " 4.27634319e+01, 3.82653061e+01, 3.40170762e+01,\n", " 3.00187422e+01, 2.62703040e+01, 2.27717618e+01,\n", " 1.95231154e+01, 1.65243648e+01, 1.37755102e+01,\n", " 1.12765514e+01, 9.02748855e+00, 7.02832153e+00,\n", " 5.27905040e+00, 3.77967514e+00, 2.53019575e+00,\n", " 1.53061224e+00, 7.80924615e-01, 2.81132861e-01,\n", " 3.12369846e-02, 3.12369846e-02, 2.81132861e-01,\n", " 7.80924615e-01, 1.53061224e+00, 2.53019575e+00,\n", " 3.77967514e+00, 5.27905040e+00, 7.02832153e+00,\n", " 9.02748855e+00, 1.12765514e+01, 1.37755102e+01,\n", " 1.65243648e+01, 1.95231154e+01, 2.27717618e+01,\n", " 2.62703040e+01, 3.00187422e+01, 3.40170762e+01,\n", " 3.82653061e+01, 4.27634319e+01, 4.75114536e+01,\n", " 5.25093711e+01, 5.77571845e+01, 6.32548938e+01,\n", " 6.90024990e+01, 7.50000000e+01], dtype=float64)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.vmap(jax.grad(f))(xi)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "that looks better!\n", "\n", "`jax.grad(f)` just returns another function. Of course we can just \n", "take the gradient of that as well. And so on..." ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "g1i = jax.vmap(jax.grad(f))(xi)\n", "g2i = jax.vmap(jax.grad(jax.grad(f)))(xi)\n", "g3i = jax.vmap(jax.grad(jax.grad(jax.grad(f))))(xi)\n", "plt.plot(xi,yi, label = \"f\")\n", "plt.plot(xi,g1i, label = \"f'\")\n", "plt.plot(xi,g2i, label = \"f''\")\n", "plt.plot(xi,g3i, label = \"f'''\")\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Control Flow\n", "\n", "Back when discussing symbolic differentiation we hit a snag when adding \n", "control flow through to our prorgam. In Jax this just passes through\n", "transparently. \n", "\n", "\n", "Let's compare this to finite differences. So far the only system\n", "we had to compute derivatives of control-flow-ful programs" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def control_flow_func(x):\n", " if x > 2:\n", " return x**2\n", " else:\n", " return x**3\n", " \n", "\n", "first_gradient_of_cflow = jax.grad(control_flow_func)\n", " \n", "xi = jnp.linspace(-2,5,101)\n", "yi = np.asarray([first_gradient_of_cflow(xx) for xx in xi])\n", "plt.plot(xi,yi,c = 'k')\n", "\n", "xi = jnp.linspace(-2,5,11)\n", "yi = np.asarray([first_gradient_of_cflow(xx) for xx in xi])\n", "plt.scatter(xi,yi, label = 'jax autodiff')\n", "\n", "\n", "\n", "xi = jnp.linspace(-2,5,11)\n", "yi = np.asarray([control_flow_func(xx) for xx in xi])\n", "plt.scatter(xi,np.gradient(yi,xi), label = 'finite differences')\n", "\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can start to see the benefits autodiff. Among other things, finite differnces becomes\n", "quite sensitive to exactly where the evaluation points are (e.g. wrt to the discontinuity)\n", "\n", "\n", "As we compute higher derivatives, this error compounds badly for finite differences. But for\n", "autodiff, it's smooth sailing!" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "second_gradient_of_cflow = jax.grad(first_gradient_of_cflow)\n", "xi = jnp.linspace(-2,5,101)\n", "yi = np.asarray([second_gradient_of_cflow(xx) for xx in xi])\n", "plt.plot(xi,yi,c = 'k')\n", "\n", "xi = jnp.linspace(-2,5,11)\n", "yi = np.asarray([second_gradient_of_cflow(xx) for xx in xi])\n", "plt.scatter(xi,yi, label = '2nd deriv jax autodiff')\n", "\n", "xi = jnp.linspace(-2,5,11)\n", "yi = np.asarray([control_flow_func(xx) for xx in xi])\n", "plt.scatter(xi,np.gradient(np.gradient(yi),xi), label = '2nd deriv finite differences',)\n", "\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## In HEP\n", "\n", "\n", "Of course we can use automatic differentiation\n", "for neural networks. But other things in HEP also \n", "can make use of gradients. A prime example where this is the \n", "case is statistical analysis\n", "\n", "For a maximum likelihood fit we want to minimize the log likelihood\n", "\n", "$\\theta^* = \\mathrm{argmin}_\\theta(\\log L)$ " ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "import pyhf\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "pyhf.set_backend('jax')" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([-4.25748227], dtype=float64)" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = pyhf.simplemodels.uncorrelated_background([5.],[10.],[3.5])\n", "pars = jnp.array(m.config.suggested_init())\n", "data = jnp.array([15.] + m.config.auxdata)\n", "m.logpdf(pars,data)" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([1., 1.], dtype=float64)" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bestfit = pyhf.infer.mle.fit(data,m)\n", "bestfit" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "grid = x,y = np.mgrid[.5:1.5:101j,.5:1.5:101j]\n", "\n", "points = np.swapaxes(grid,0,-1).reshape(-1,2)\n", "v = jax.vmap(m.logpdf, in_axes = (0,None))(points,data)\n", "v = np.swapaxes(v.reshape(101,101),0,-1)\n", "plt.contourf(x,y,v, levels = 100)\n", "plt.contour(x,y,v, levels = 20, colors = 'w')\n", "\n", "\n", "\n", "grid = x,y = np.mgrid[.5:1.5:11j,.5:1.5:11j]\n", "points = np.swapaxes(grid,0,-1).reshape(-1,2)\n", "values, gradients = jax.vmap(\n", " jax.value_and_grad(\n", " lambda p,d: m.logpdf(p,d)[0]\n", " ), in_axes = (0,None)\n", ")(points,data)\n", "\n", "plt.quiver(\n", " points[:,0],\n", " points[:,1],\n", " gradients[:,0],\n", " gradients[:,1],\n", " angles = 'xy',\n", " scale = 75\n", ")\n", "plt.scatter(bestfit[0],bestfit[1], c = 'r')\n", "\n", "plt.xlim(0.5,1.5)\n", "plt.ylim(0.5,1.5)\n", "plt.gcf().set_size_inches(5,5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Thanks for joining the Tutorial!\n", "\n", "\n", "\n", "\"A\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.11" } }, "nbformat": 4, "nbformat_minor": 4 }