{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Automatic differentiation with JAX\n", "\n", "Look into the slides to get a general intro into automatic differentiation.\n", "\n", "## First steps" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "JAX version 0.1.72\n", "numba version 0.51.1\n" ] } ], "source": [ "# !pip install jax jaxlib matplotlib numpy iminuit numba-stats\n", "\n", "import jax\n", "from jax import numpy as jnp # replacement for normal numpy\n", "import numpy as np # original numpy still needed, since jax does not cover full API\n", "import numba as nb # will use that later\n", "from matplotlib import pyplot as plt\n", "\n", "jax.config.update(\"jax_enable_x64\", True) # enable float64 precision, default is float32\n", "\n", "print(f\"JAX version {jax.__version__}\")\n", "print(f\"numba version {nb.__version__}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's try something simple." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def f(x):\n", " return 2 + 3 * x ** 2\n", "\n", "fprime = jax.grad(f)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Like in a symbolic computation, the object fprime is a real function. You can call it with different values." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "5.0" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f(1.0)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.8/site-packages/jax/lib/xla_bridge.py:125: UserWarning: No GPU/TPU found, falling back to CPU.\n", " warnings.warn('No GPU/TPU found, falling back to CPU.')\n" ] }, { "data": { "text/plain": [ "DeviceArray(6., dtype=float64)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fprime(1.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Jax is designed to compute on the GPU, so it calculates with `DeviceArray`s." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's visualize this." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x=-2.0 y=14.0 m=-12.0\n", "x=-1.0 y= 5.0 m= -6.0\n", "x= 0.0 y= 2.0 m= 0.0\n", "x= 1.0 y= 5.0 m= 6.0\n", "x= 2.0 y=14.0 m= 12.0\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAD6CAYAAAC4RRw1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAs4UlEQVR4nO3deVxVdf7H8deHHfcNcRdTR9NcKjTXdrWcyiybNhutnNLJzElbpprUyWwzlzZL07Q0NUszy0zNfQFBUXPfkE1ERFAQQeB+f39A83McU8QL37t8no/HfcA9XLjvk/b2yznfc75ijEEppZT78bEdQCmlVMlogSullJvSAldKKTelBa6UUm5KC1wppdyUFrhSSrmpSxa4iASJyCYR2SYiO0VkVNH2RiISKSIHRGSuiASUflyllFK/k0vNAxcRAcobY7JExB9YBzwHPA/MN8bMEZFPgW3GmEkX+1k1atQwYWFhzkmulFJeYvPmzceNMSHnb/e71DeawobPKnrqX/QwwK3AI0XbZwAjgYsWeFhYGNHR0cVPrZRSChGJu9D2Yh0DFxFfEdkKHAOWAQeBDGNMftFLEoG6TsiplFKqmIpV4MaYAmNMW6Ae0B5oXtw3EJGnRCRaRKJTU1NLllIppdT/uKxZKMaYDGAl0BGoIiK/H4KpByT9wfdMNsaEG2PCQ0L+5xCOUkqpEirOLJQQEalS9Hkw0A3YTWGR9yl6WT9gYSllVEopdQGXPIkJ1AZmiIgvhYX/jTHmRxHZBcwRkdFADDC1FHMqpZQ6T3FmoWwHrr3A9kMUHg9XSillgV6JqZRSbsotCnzt/lQ+XnnAdgyllLps2WfzGbVoJ3Fpp53+s92iwNftP864Zfs4lpljO4pSSl2WH7cn88X6wxzLzHX6z3aLAv9Lu/oUOAzfbk60HUUppS7L3KgEGoeUJ7xhVaf/bLco8MYhFWjfqBpzoxJwOHQNT6WUe9iXksnmuHQeateAwttKOZdbFDjAQ+3qE5eWTURsmu0oSilVLHOjEvD3Fe67rnTuNOI2Bd6zVW0qBfkxZ1OC7ShKKXVJufkFzN+SSPcWtaheIbBU3sNtCjzI35fe19ZlyY6jpJ8+azuOUkpd1C87U0jPzuOh9vVL7T3cpsABHmzXgLMFDhbEXPC2K0op5TLmRsVTr2ownRvXKLX3cKsCb1GnEm3qVWZuVAKXWohCKaVsiUs7zfoDaTwYXh8fH+efvPydWxU4FI7C96ZkEpOQYTuKUkpd0DfRCfgI9AmvV6rv43YFfk/bOpQL8GWunsxUSrmg/AIH86ITuaVZTWpXDi7V93K7Aq8Q6MfdreuwaPsRsnLzL/0NSilVhlbsOcaxzFweat+g1N/L7Qoc4MH29ck+W8CibUdsR1FKqf8yNyqBmhUDuaVZ6S9g45YFfm39KjQLrcicTfG2oyil1H8knzzDyr3HeCC8Hn6+pV+vblngIsKD7eqzLfEku46csh1HKaUA+DY6EYeBv4SX3tzvc7llgQPcd11dAvx8mBulo3CllH0Oh2FudAKdm1SnYfXyZfKeblvgVcoFcOc1tVgQk0ROXoHtOEopL7fuwHES08/wULvSP3n5O7ctcICH2jXgVE6+nsxUSlk3KzKOauUD6N4ytMze060LvMNV1WhSswIzI/UwilLKnqMnc1i+u/DkZaCfb5m9r1sXuIjw6A0N2JaQwY6kk7bjKKW81OxN8TiM4dH2Dcv0fd26wAHuu64ewf6+zIqMsx1FKeWF8gsczImK58amITSoXq5M39vtC7xysD/3tKnD9zFHOJWTZzuOUsrLLN99jJRTufTtULajb/CAAgfo26EhZ/IKWLBFbzOrlCpbsyLjqFM5iFub1yzz9/aIAm9VrzKt61VmZkSc3mZWKVVmYo+fZu3+4zzUvgG+pXjb2D/iEQUO0PeGhuw/lsWm2BO2oyilvMTXkXH4+QgPtSubKy/P5zEFfnebOlQK8tMphUqpMpGTV8C8zYl0bxlKzUpBVjJ4TIEHB/hy//X1WLIjmdTMXNtxlFIe7qftyWRk59H3hrI/efm7Sxa4iNQXkZUisktEdorIc0XbR4pIkohsLXr0LP24F/foDQ3JKzB8E62LPSilStfMyDiuCilPx8bVrWUozgg8HxhmjGkBdACeEZEWRV8bb4xpW/RYXGopi6lJzQp0vKo6X0fGU+DQk5lKqdKx88hJYuIzePSGhoiU/cnL312ywI0xycaYLUWfZwK7gbqlHayk+nZoSFLGGVbvO2Y7ilLKQ82MiCfI34c+15XumpeXclnHwEUkDLgWiCzaNFhEtovINBGp6uxwJdG9ZSghFQOZGaEnM5VSzpeZk8fCrUnc3boOlcv5W81S7AIXkQrAd8BQY8wpYBLQGGgLJAPv/8H3PSUi0SISnZqaeuWJL8Hf14eH2tVn5d5jJJzILvX3U0p5lwUxSWSfLbBy5eX5ilXgIuJPYXnPMsbMBzDGpBhjCowxDmAK0P5C32uMmWyMCTfGhIeElP4acQAPt2+AjwizdEqhUsqJjDF8uTGO1vUq06Z+FdtxijULRYCpwG5jzLhzttc+52W9gR3Oj1cydaoE071FKHOi4jlzVhd7UEo5x/oDaRw4lkW/jmG2owDFG4F3Bh4Dbj1vyuC7IvKbiGwHbgH+UZpBL1f/TmFkZBceq1JKKWeYviGWGhUCuKtN7Uu/uAz4XeoFxph1wIXmyVifNngx7RtV4+ralZi+4TAPtqtvdaqPUsr9xadl8+ueYzx7S5MyXbThYjzmSszziQiPdwpjz9FMIvX+KEqpK/TlxsP4ivCoC5y8/J3HFjjAPW3rULWcP9PXH7YdRSnlxk7n5jM3OoE7W9Um1NJ9Ty7Eows8yN+Xh9o3YOmuoySm65RCpVTJzI9JIjMnn/6dwmxH+S8eXeBQeGWmiPBVhC65ppS6fMYYZmw4TKu6lbmuQRXbcf6Lxxd43SrB9GgZytyoBJ1SqJS6bL9PHezfKczlJkN4fIED9O/USKcUKqVKxNWmDp7LKwq8XVjV/0wp1CXXlFLFFZd2ml/3HOOR9g1cZurgubyiwM+dUhhxSKcUKqWK58uNcS43dfBcXlHgcM6Uwg2xtqMopdzA6dx8volyvamD5/KaAg/y9+Xh9g1YtitFpxQqpS5pfkwSmbmuN3XwXF5T4PD/Uwq/3KhTCpVSf8zhMExfH+uSUwfP5VUFXqdKMHdcU4vZm+LJys23HUcp5aJW7TvGwdTTPNHF9aYOnsurChzgb12vIjMnn7lRuvCxUurCpqyJpValIO5qXcd2lIvyugJvW78K7cKqMm1dLPkFDttxlFIuZkfSSTYeSqN/5zD8fV27Il07XSkZ0PUqkjLOsGTnUdtRlFIu5vO1hygfUDjpwdV5ZYHffnUoYdXLMWVtrF7Yo5T6j+STZ/hxezJ/aVefysF2FywuDq8scF8f4ckujdiWkEF0XLrtOEopFzF9/WEcxvBE50a2oxSLVxY4QJ/r61OlnD9T1hyyHUUp5QKycvP5elM8d15Tm/rVytmOUyxeW+DBAb70vaEhy3anEHv8tO04SinL5kYlkJmTz4Cu7jH6Bi8ucIC/dmqIv48P09bp5fVKebP8AgfT1sUS3rAq1zaoajtOsXl1gdesGESvtnWYtzmB9NNnbcdRSlmyZOdRkjLOMKDrVbajXBavLnAonFKYk+dgVqReXq+UNzLGMGVtLA2rl6Nbi1DbcS6L1xd4s1oVufFPIczYGEduvq7Yo5S3iY5LZ1tCBk92aYSvj+teNn8hXl/gAH/r2ojUzFwWbj1iO4pSqoxNWXOIysH+9Lm+nu0ol00LHOjSpAbNa1VkyppDOBx6YY9S3uJQahbLdqfQt0MDygX42Y5z2bTAKVyx5+mbrmL/sSx+3XPMdhylVBn5bPUhAnx96N/JfaYOnksLvMjdretQr2own6w6oJfXK+UFkk+eYX5MIn8Jr09IxUDbcUrkkgUuIvVFZKWI7BKRnSLyXNH2aiKyTET2F310n8mTF+Dn68PTN15FTHwGkbG6bqZSnm7q2lgcBp660b2mDp6rOCPwfGCYMaYF0AF4RkRaAC8DvxpjmgK/Fj13aw+E16dGhQA+WXXQdhSlVClKP32WrzfFc0+bOm5z2fyFXLLAjTHJxpgtRZ9nAruBukAvYEbRy2YA95ZSxjIT5O/L450bsWZfKjuSTtqOo5QqJTM2Hib7bAEDb2psO8oVuaxj4CISBlwLRAKhxpjkoi8dBdxrBvwfeKxjQyoG+jFJR+FKeaTss/lM33CY26+uSbNaFW3HuSLFLnARqQB8Bww1xpw692um8KzfBc/8ichTIhItItGpqalXFLYsVAryp2/Hhizekaw3uVLKA83elEBGdh6Dbm5iO8oVK1aBi4g/heU9yxgzv2hziojULvp6beCC8++MMZONMeHGmPCQkBBnZC51T3RuhL+vD5+t1lG4Up7kbL6Dz9ce4oZG1bi+oVvPuwCKNwtFgKnAbmPMuHO+9APQr+jzfsBC58ezI6RiIH8Jr8d3WxI5ejLHdhyllJN8vzWJ5JM5DLrZvY99/644I/DOwGPArSKytejRE3gb6CYi+4Hbi557jKdvbIzDFK6Pp5RyfwUOw6erD9KyTiVu+pN7HA24lEteO2qMWQf80R1ebnNuHNdRv1o57m5dm683xTP41iZUKRdgO5JS6gos3XmUQ6mn+eiRayk8sOD+9ErMixh4c2OyzxYwY4PealYpd2aMYdLqg4RVL8ed19S2HcdptMAvonmtStzWvCZfbIglKzffdhylVAmtO3Cc7Yknefqmxm53y9iL0QK/hIE3hpGRncdXG3UUrpQ7MsYwcfl+alXw577r6tqO41Ra4Bexc+dOHu7ekeaSxJS1hzito3Cl3M6Gg2lEx6VTY+8C7ut1D3l5ebYjOY0W+EWEhYWRk5ND3I+fkJaVw8wIHYUr5U5+H31XyUtj2TfTCAkJwd/f33Ysp9ECv4jy5cvz7rvvsmfHNuoci2TymkNkn9VRuFLuYuOhNDYdPoFP1EwCAgJ46623bEdyKi3wS3j44Yfp1KkT+xZNJvVEBrMi4m1HUkoV08Tl+wk6tpOYtct49dVXqV3bc2aggBb4JYkIEyZM4ERaKhV3/8Bnaw5y5qwufqyUq9t4MI2Ig6lkrZpKo0aNGDp0qO1ITqcFXgzt2rWjf//+7Fsxl+T4WGZF6rFwpVzdxF/3we5lJMXu4/333ycoKMh2JKfTAi+mMWPGEBQYiGyayWdrDpGTp6NwpVxV5KE01u88zPHVX3Hrrbdy77332o5UKrTAi6l27dq89tprJG1bS/z2jXwdqcfClXJVE3/dz9lNc8k5ncmECRM85tL582mBX4ahQ4fSuHFjctZ+waQV+3QUrpQLijp8glURWzge9SMDBw6kVatWtiOVGi3wyxAYGMjYsWM5lRzLwbULmLNJR+FKuZoJy/aRtXoqlSpWZNSoUbbjlCot8MvUq1cvbrvtNrI2fM0HP8foKFwpF7I57gTLliwm8+AWRo0aRY0aNWxHKlVa4Jfp92mFBTmnOfDzND0WrpQLefenHZxaNZXmza9m0KBBtuOUOi3wErjmmmsYNGgQWTE/M3bOMr1ToVIuYN3+4yyd9wW5J44wYcJ4j7pk/o9ogZfQqFGjqFipErE/fsw0XbVHKauMMbwxbz2ZG7/hz3fdRY8ePWxHKhNa4CVUvXp13hz9Bjlx23l/yiwyss/ajqSU11q6K4X1cz6GgjzGjxt36W/wEFrgV2DgwIE0+VNzjiydzEfL99iOo5RXKnAYXp+6kNM7lvPckCE0bdrUdqQyowV+Bfz8/Pjkow/IzzjKxIkTOHZKV7BXqqwt3JrItm8mUqVqdV5//V+245QpLfAr1K1bN26/48+krZvDW99ttB1HKa9yNt/BK2M/IzdpN+++/RaVK1e2HalMaYE7waQPJyCOPKZOGEPCiWzbcZTyGl+t28eBHz+jaYvWPPHE47bjlDktcCdo0qQJg54ZQub25bz86QLbcZTyCjl5Bbz+7zEUZB5n6qcf4evraztSmdMCd5Ix/x5BhSrV+e6T0ew7esp2HKU83vgFG0heO5fb7+pN165dbcexQgvcSSpVqsSYMWPITdrDoFEf2o6jlEc7lZPHu2/8Cx8Rpn48wXYca7TAneiZpwdQr2lLVs8cT8TeJNtxlPJYr3wyj4wdq/nb4KE0aNDAdhxrtMCdyMfHh2mffkxBVhpPDvsXxhjbkZTyOMnpp/n8vdcpXy2U90e/bjuOVVrgTtbt1pvo0O0edv8yi1nLo23HUcrjPPHqe+QePchbb79NuXLlbMex6pIFLiLTROSYiOw4Z9tIEUkSka1Fj56lG9O9zJr8ASLC88NfIK/AYTuOUh4jal8CS2dMpMHV1zJ4QD/bcawrzgh8OnDHBbaPN8a0LXosdm4s93ZVWEMefepZUrevZuTkb23HUcpj9B/yTxzZJ5kx5ROPXSbtclyywI0xa4ATZZDFo0x6ZyTlqoYy/t+vkHFaL7FX6kp9vTSSXcvm0rlnH27u3MF2HJdwJcfAB4vI9qJDLFWdlshDlC9fnhGjx3Dm6EEGvPqu7ThKubUCh2Ho88/j6x/ArM/G247jMkpa4JOAxkBbIBl4/49eKCJPiUi0iESnpqaW8O3c0wuDHqdOs7Z8P2Ucu+OSbcdRym396+NZpO7cwF///jwN69W1HcdlSHGmuolIGPCjMeaay/na+cLDw010tHfNzPh51Xp63tKV6/78KJt//Mp2HKXczsnTZ6jVqDm+ODgev5+goCDbkcqciGw2xoSfv71EI3ARqX3O097Ajj96rbe78+bOtOtxP1t+ns2iNd71j5dSzvD4i2PISY3n32Pe8cryvphLjsBFZDZwM1ADSAFGFD1vCxjgMPC0MeaSxwi8cQQOcCAukWbNmhH6p7YkbVunZ8+VKqbdsYlc0+Jq6jRpSfz2jV77/06JR+DGmIeNMbWNMf7GmHrGmKnGmMeMMa2MMa2NMfcUp7y9WZOG9Xjgb8+R/NsG3po823YcpdxG378Px5GbzdRJH3pteV9MsY6BO4u3jsABTp/JIaRBUxAhNW4/5YMDbUdSyqXNX76B+7t3pf2fHyJy0Szbcaxy6jFwdfnKBwfxrzfe4kxqAv1f+LftOEq5NIfDwcDBQ/ANKs/cSd6zSPHl0gIvQy8//Sh1r+nA/M8nsvNQgu04SrmsV8ZPI3XvZv76zAuE1Qu1Hcdl6SGUMrZ8w2a6dWnPNbf25rflepm9Uuc7fjKLeo2aEhBUjuOH9xAQ4G87knV6CMVF3N7perrc8yg7Vixg1uI1tuMo5XL6Dn2d3PSjvDv2fS3vS9ARuAVJKamEXdWESnUak7InGj9f/XdUKYANv+2ny/Wtady2I/s3rbAdx2XoCNyF1A0N4W//eJkTB2J4/t3JtuMo5RKMMfx14D/AUcCsKR/ZjuMWtMAtmThiOJXrNuazd0dy5PhJ23GUsu7juUs4uOEn7njoSdq3aWE7jlvQArfE39+fiRMmcDYjhUeee812HKWsys7N45UXhxFQsRozP3rbdhy3oQVuUb8+d9G8w22smfc5Kzbvth1HKWsGjZxIZsJuhr06kmpVKtuO4zb0JKZl23bu5do2rah//W3EblyMj49eLqy8y96EY1zTsgVVa4Ry9MBv+PjouPJ8ehLTRbVp2Yxejz1F/KYlvPHFQttxlCpTxhgefOaf5GemMWXSR1rel0lH4C7g1KlThDZojFSozqGdW6hV2btX2lbeY9rPm3jy7i7ccGtPIpZ+bzuOy9IRuAurVKkSb4x+kzNJe+n78ljbcZQqEyfP5DH8heH4+voyZ8oHtuO4JS1wF/H83wfQoHlrVs2cyOItsbbjKFXqnhn7Fek71zLw2ecJa9jAdhy3pAXuInx8fPhy8scUZJ3gqWGvkX0233YkpUrNpkOpfPPRaCqH1Oa9N3QabUlpgbuQm7p24Y57HyBp7TxGfKWXESvPlFfg4PGX3yHvWCwfThhHcHCw7UhuSwvcxXz+0Xj8/Xz55L1R7DyiV2gqzzNh8Vb2/DiFVuEd6Pvwg7bjuDUtcBdTt25dXnr5n2Tv3cBTb8+gwFF2s4SUKm1xaacZM/oNHDmZfDllki6TdoW0wF3QKy+9QM3a9YiaM57paw/YjqOUUxhjGPzJj2REL6LvXx+nbdu2tiO5PS1wFxQcHMxHE8eRl3qY1979gMT0bNuRlLpi321JYvn09wgOLsf7775lO45H0AJ3UX369KFD566krvySoTPW4dBDKcqNJZ88w/BxX5BzaDNvjBpBzZo1bUfyCFrgLkpEmPTRBzhys1g682NmRcbZjqRUiRhjGD53C8m/TOaqxk159tlnbUfyGFrgLqxt27b8bcAAsmJ+YsSMX4hP00Mpyv3MjUrg57nTOZuWyIcfTCAgIMB2JI+hBe7iRo8eTcUKFUhdPoVh87bqoRTlVhLTsxn5zUayNs6hR4876Nmzp+1IHkUL3MWFhIQwcuQIsg5uZvWvS5m+4bDtSEoVi8NhePHb7aSsnIHJy2H8+HG2I3kcLXA38Mwzz9CsWTNy107jncW/cSg1y3YkpS5pVmQcqzZEcTJmCYMHD+bqq6+2HcnjaIG7gYCAAMaPH8+plAQyN//I8Hnb9AIf5dLi07J586fdFGz4gurVq/P666/bjuSRLlngIjJNRI6JyI5ztlUTkWUisr/oY9XSjanuvPNOevbsyckNc4jaHcvnaw/ZjqTUBTkchuHfbiN773qO7Yth9OjRVK2qFVEaijMCnw7ccd62l4FfjTFNgV+LnqtSNm7cOM7mnKHCb9/y/rJ97D2aaTuSUv9j2vpYIvYlk7N+Bq1bt2bAgAG2I3msSxa4MWYNcOK8zb2AGUWfzwDudW4sdSHNmjXj2WefZe+aH/A7cZghs2PIySuwHUup/9iRdJJ3l+wl5PAyUpMTmThxIr6+vrZjeaySHgMPNcYkF31+FAj9oxeKyFMiEi0i0ampqSV8O/W7119/nerVq+Mf9SV7jp7izZ90NXvlGk7n5jNkdgzl8zPYs+RL7r//fm6++WbbsTzaFZ/ENIWLav7hGTVjzGRjTLgxJjwkJORK387rValShTfffJPt0RF0lH18FRHHLzuP2o6lFKMW7SQ27TQ198ynoKCA9957z3Ykj1fSAk8RkdoARR+POS+SupQnn3ySNm3asHH2RFqEBPLSd9tJPnnGdizlxRZtO8I30Yn0rHmKpT98y/Dhw2nUqJHtWB6vpAX+A9Cv6PN+wELnxFHF4evry8SJE4mPj6dZ6irO5jsYOmerTi1UViScyOaV+b/Rtl4lIr8eR506dXj5ZZ3XUBaKM41wNrARaCYiiSLyJPA20E1E9gO3Fz1XZeimm26iT58+fPbBOIbcUI3I2BN8slLvHa7KVl6BgyFzYgC4UXazOTqat99+mwoVKlhO5h2k8BB22QgPDzfR0dFl9n6e7vDhwzRv3pz777+fGncP58ftyXzzdAeub1jNdjTlJcb+spePVh7g7bub8I8+N9OwYUM2bNiAj49eI+hMIrLZGBN+/nb9r+zGwsLCeOGFF4iMjOTFWxtQp0oQQ2Zv5eSZPNvRlBfYeDCNj1cd4IHr6xF69ghZWVlMnDhRy7sM6QjczZ05cwYfHx8CAwOJiU/ngU83ckvzmnzW93p8fHS9QVU6jp3K4a4P11E+0I8fn+1C+UA/MjIyqFKliu1oHklH4B4qODiYwMBAAK5tUJVXel7Nsl0pTFp90HIy5anyChw88/UWMnPymdT3OsoH+gFoeVugBe5hHu8cRq+2dRi7dC9r9umFU8r53vxpN1GH03n7/lY0r1XJdhyvpgXuYUSEt+5rRbPQigyZE0PCCV3FRznP9zFJTN9wmCc6N6JX27q243g9LXAPVC7Aj0/7Xk+BwzBw5ma9X4pyil1HTvHy/O20b1SNf/ZsbjuOQgvcY4XVKM+EB9uy88gpXvt+B2V5slp5npPZeQycuZnKwf589Mi1+PtqdbgC/VPwYLddHcqQ25ry7eZEZkXG246j3JTDYRg6N4bkk2f45NHrqFkxyHYkVUQL3MMNva0pNzcLYdSinWyJT7cdR7mhD1bsZ+XeVP51Vwu9SMzFaIF7OB8fYeKD11K7cjBPf7WZxHQ9qamK76ftyUxYvp/7rqvLYx0a2o6jzqMF7gUql/NnWv9wcvIKeHJ6NKdy9EpNdWlb4tP5xzdbCW9YlTG9WyGiF4a5Gi1wL9GkZkU+7Xs9B1OzeGbWFvIKHLYjKReWcCKbv82IpnblID577HqC/HVVHVekBe4l0tPT2bliPqPvbcna/ccZ8cNOnZmiLujkmTz6f7GJfIfhgz7NmfvlVP274qK0wL3EZ599xtNPP83KqWN4qmtDvo6M5/O1sbZjKRdzNt/BoJmbiT+RzRu316Ff7zsZMmQI27dvtx1NXYCf7QCqbLz44otkZmYyZswYuh+Oo1ufVxjz827qVyvHHdfUsh1PuQBjDK99/xsbDqbxbBt//v7gHZw6dYqffvqJNm3a2I6nLkBH4F7Cx8eHN998kylTpvDrr8uJ/PBZmlU4y9C5MWxLyLAdT7mASasP8k10It0rJjPyqfsREdatW0ePHj1sR1N/QAvcywwYMIDFixdz+HAs2z4eTNCpBJ6cEcWh1Czb0ZRF321O5N0le7nq+EamjRhIkyZNiIiIoHXr1rajqYvQAvdC3bt3Z/369fj5+nBg6jAy9m6i7+eRJGXowsjeaMmOZIbPi6H81jmsnPomPXr0YM2aNdStqzercnVa4F6qVatWREZG8qc/NeXw7BHErV9I388jSc3MtR1NlaE1+1J55qtI8paOZ9cvMxk4cCALFy6kYsWKtqOpYtCTmF6sTp06rFmzhgcffJDFP31I7olk+vrA3IGdqFIuwHY8VcqiDp/gic9+5cSCN8mI3cF7773HsGHD9IIdN6IjcC9XoUIFFi5cyKBBg0jb+C0bpvyLxz5bR1Zuvu1oqhTtSDrJo2O/58iXw8lJPsC8efMYPny4lreb0QJX+Pn58fHHHzN27Fiy9qxj2fvP8NePl+l9xD3UgWOZ3DdiGrHThhJsclmxYgV9+vSxHUuVgBa4AgpX8hk2bBjz5s3DHI9l0RtP8ti4BZzN10vuPUl8WjZ3DnmHA9Nfon6dWmyKjKBjx462Y6kS0gJX/6VPnz6sXrWKYDnL/JFP0OtfU3Qk7iH2p2TS+ZFnOTD3Ta4Lb0dU5EYaN25sO5a6Alrg6n906NCBmKhIQmuGsOS9wdwyaLQeE3dz2+LS6PjnB0lcOo27ej/A+tUrqF69uu1Y6gppgasLaty4MTtiomjR+noipo6k/QPPkHH6rO1YqgTW7oyjy23dSdv8M4Off5EfvptLYGCg7VjKCbTA1R+qVq0aWzau5pY/92b3osm0uv0+jqbrFZvuZOG67XS79WZOx27jvQ8+4cP339GZJh7kigpcRA6LyG8islVEop0VSrmOwMBAfl30HY8Nep7EiJ9o2eEW9iek2I6limHKgl+5/85byD+VytzvFjL82UG2Iyknc8YI/BZjTFtjTLgTfpZyQSLCl5+8z6tvT+TEgRjatu/Ihm17bMdSFzHqk5k8/eBd+Pv5sXLVah6498+2I6lSoIdQVLGNfmkIk776lpz0FG7q2pkvF620HUmdxxjDQ/8YxcjB/ahcqyEx0ZvoesP1tmOpUnKlBW6ApSKyWUSeckYg5doGPnIvi5etxNfXj/739+SVCdNtR1JFcs7m075Xf+ZOGElY207s3xZJ88a6ELEnu9IC72KMuQ64E3hGRG48/wUi8pSIRItIdGpq6hW+nXIFPbq2Z1t0FFVqN+St55/kvsEjcDh0yS2bjhw/SdMO3Yhe9CWd73mUfZErqFG1su1YqpRdUYEbY5KKPh4DFgDtL/CaycaYcGNMeEhIyJW8nXIhzRo34MC2TTS+rgsLPv431931GNm5utq9DZt2xXL19Z1IjFnN48+PYO33X+Hv7287lioDJS5wESkvIhV//xzoDuxwVjDl+qpVqcSeiF+5ufdjbPt5Fk1u6MbhlBO2Y3mV2csi6NqlM5lHDvDOpC+Y9v5InSboRa5kBB4KrBORbcAm4CdjzBLnxFLuws/PjxXfzeDpl0aRvG0N14R3ZlHEbtuxPF6BwzB0wtf0vacbJu8MC35ayosD+9mOpcpYiQvcGHPIGNOm6NHSGPOmM4Mp9yEifPr263w4dSZnUmK5745beOnzn8kv0BthlYZjmTnc+PQbTBzej0pVaxC9KYJe3W+yHUtZoNMIldMMfuIRVq5cQQB5jB38AN1e+JTkk7pMmzOt3ZfKtb0HseHzEVzdJpwDv0XT+upmtmMpS7TAlVPd2LkTO2KiqV27NqsmPkeHJ0ayYo9euXml8gscvLt4Bz37PMKRX6dzz/0PErNhld6QystpgSuna9SoEb9t2cQNHTsSP/9d7vvbMEYs3MFpvaNhicQeP02fD5Yz4u+PkfXbcv756mt8P2+23pBK6ZqYqnRUrVqVNSuW8/gTT/L1rJm8n5HC0t+GM+aBa7mlWU3b8dxCXoGDyWsOMXb+OpLnjiQ/PYkvvviC/v37246mXIQWuCo1AQEBzPzqS5o2acyoUaPYm5NGv+Mvce8Nf+L1u1tQo4KOIP9ITHw6/5z/G9u2xnDq+zcIJJ+flizhtttusx1NuRAtcFWqRISRI0fSqFEjBgwYQMD3r/FD/qus2Z/Kqz2vps/19XTe8jmycvMZ+8teZmw8TEBiDOnfjSE0JITFixfTsmVL2/GUi9Fj4KpM9OvXj19++YWcjFROf/MS1bMTeOHb7Tw8JYLtiRm241lX4DAsiEmk+7jVzNh4mOZp6znw9QhatWxJZGSklre6IC1wVWZuvfVWNmzYQPlywUR8OIT7qiezLyWLez5az99nbebAMe9bLMIYw/JdKfScuJZ/zN1GlWBfOqX+xJIpb3HXXXexatUqatWqZTumclFa4KpMtWjRgoiICFq0aMGEl57ioXI7eO62pqzem0r38at56dvtHMnwjrnjm2JP0OfTjQz4MpqzBQ7G9m5O4OoP+HrqJIYMGcL8+fMpX7687ZjKhekxcFXmatWqxapVq3j00Ud5afjzPPdcHCtHjeHTNYeZGRHHgq1JPNahIU90aUTdKsG24zqVMYbNcel8vPIAK/emElopkDG9W3FjfX/u730vUVFRTJgwgeeee852VOUGxJiyuw1oeHi4iY7WlddUoYKCAoYNG8bEiRO59957mTVrFidyYcLy/czfkghAtxah9OsURserqrv1yc6cvAJ+2HqEGRsPs/PIKSoH+/P3mxvTr1MYsQf20bNnT1JSUpg9eza9evWyHVe5GBHZfKFVz7TAlXUffPABQ4cOJTw8nEWLFhEaGkpiejYzI+KZExVPRnYefwqtwF87hnHfdXUpF+A+vzj+vh9zo+JJz86jWWhF/tqpIb2vLdyPVatW0bt3bwICAvjxxx9p166d7cjKBWmBK5e2cOFCHn74Ybp06cLSpUv/sz0nr4Afth1hxobCkWvFID+6t6jFHdfUomvTGgT5+1pMfWHHMnNYtiuFX3amsG5/4SIm3VvUol+nMDpcVe0/v0kcP36csLAwGjRowOLFiwkLC7OYWrkyLXDl8qKioqhatSpNmjT5n68ZY9gSn86syHiW70rhVE4+5QJ8ublZCD1a1uKW5jWpFGRvEYP4tGx+2XmUX3YeZXN8OsZAWPVy/Ll1bR65oeEfHstfsmQJN9xwA1WrVi3jxMqdaIErj5FX4CDiUBpLdhxl6a4UUjNz8fcV2tSrQpv6hY9r61ehXtXgUjlunlfgYO/RTLYmZLAtIYOYhIz/TIFsUbsSPVrWosc1oTQLrejWx+2V69ACVx7J4TDEJKSzdGcK0XHp7Eg6SW5+4X3Iq5UPoE29yjSrVYlalQKpVTmI0EpB1K4cTI0KAfj5/vEs2py8Ao6ezOHoqRxSTuVw9GQORzLOsOPIqQu+R6fGNejRshYNqpcrk/1W3uWPCtx9zgYpVWTo0KFs3br1gl+rZSD7bD5Zufmk5ubzfW4+OXkOzh+oCODr68OFxsfGGPIvsEizr49QLsCPCoF+VAz0o0KQH4F+PsQCscCs817ftm1bJkyYcPk7qFQxaYErjyIC5QP9KB/oR+g52/MKHJzNd3C26GNegYO8ggv/9ikC/r4+BPj6EOBX9PD1wddHD4co16IFrtyOjmqVKqSX0iullJvSAldKKTelBa6UUm5KC1wppdyUFrhSSrkpLXCllHJTWuBKKeWmtMCVUspNlem9UEQkFYgr4bfXAI47MY5Nui+ux1P2A3RfXNWV7EtDY0zI+RvLtMCvhIhEX+hmLu5I98X1eMp+gO6LqyqNfdFDKEop5aa0wJVSyk25U4FPth3AiXRfXI+n7Afovrgqp++L2xwDV0op9d/caQSulFLqHG5V4CLyhohsF5GtIrJUROrYzlRSIvKeiOwp2p8FIlLFdqaSEJEHRGSniDhExC1nC4jIHSKyV0QOiMjLtvOUlIhME5FjIrLDdpYrISL1RWSliOwq+rv1nO1MJSUiQSKySUS2Fe3LKKf+fHc6hCIilYwxp4o+HwK0MMYMtByrRESkO7DCGJMvIu8AGGNeshzrsonI1YAD+AwYboxxq0VPRcQX2Ad0AxKBKOBhY8wuq8FKQERuBLKAL40x19jOU1IiUhuobYzZIiIVgc3AvW76ZyJAeWNMloj4A+uA54wxEc74+W41Av+9vIuUB9znX5/zGGOWGmPyi55GAPVs5ikpY8xuY8xe2zmuQHvggDHmkDHmLDAH6GU5U4kYY9YAJ2znuFLGmGRjzJaizzOB3UBdu6lKxhTKKnrqX/RwWm+5VYEDiMibIpIAPAq8bjuPkzwB/Gw7hJeqCySc8zwRNy0LTyQiYcC1QKTlKCUmIr4ishU4BiwzxjhtX1yuwEVkuYjsuMCjF4Ax5lVjTH0KFwEfbDftxV1qX4pe8yqQz/8uau4yirMfSjmbiFQAvgOGnvfbt1sxxhQYY9pS+Ft2exFx2uEtl1vU2BhzezFfOgtYDIwoxThX5FL7IiL9gbuA24wLn4y4jD8Td5QE1D/neb2ibcqiouPF3wGzjDHzbedxBmNMhoisBO4AnHKi2eVG4BcjIk3PedoL2GMry5USkTuAF4F7jDHZtvN4sSigqYg0EpEA4CHgB8uZvFrRib+pwG5jzDjbea6EiIT8PsNMRIIpPFnutN5yt1ko3wHNKJz1EAcMNMa45WhJRA4AgUBa0aYId5xRIyK9gQ+BECAD2GqM6WE11GUSkZ7ABMAXmGaMedNuopIRkdnAzRTe9S4FGGGMmWo1VAmISBdgLfAbhf+vA7xijFlsL1XJiEhrYAaFf7d8gG+MMf922s93pwJXSin1/9zqEIpSSqn/pwWulFJuSgtcKaXclBa4Ukq5KS1wpZRyU1rgSinlprTAlVLKTWmBK6WUm/o/wlv7MX2/AH0AAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "x = np.linspace(-3, 3)\n", "plt.plot(x, f(x))\n", "for xi in np.linspace(-2, 2, 5):\n", " dx = 0.4\n", " mi = fprime(xi)\n", " yi = f(xi)\n", " print(f\"x={xi:4} y={yi:4} m={mi:5}\")\n", " plt.plot([xi-dx, xi, xi+dx], [yi - mi*dx, yi, yi+mi*dx], \"-k\")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray(12., dtype=float64)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fprime(2.0)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 5., 14.])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = np.array((1.0, 2.0))\n", "f(x)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "Gradient only defined for scalar-output functions. Output had shape: (2,).", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfprime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 411\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdocstr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdocstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margnums\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margnums\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 412\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 413\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 414\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 472\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 473\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 474\u001b[0;31m \u001b[0m_check_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 475\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdtypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 476\u001b[0m \u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_check_output_dtype_grad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mholomorphic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36m_check_scalar\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 493\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mShapedArray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 494\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 495\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"had shape: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 496\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 497\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"had abstract value {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mTypeError\u001b[0m: Gradient only defined for scalar-output functions. Output had shape: (2,)." ] } ], "source": [ "fprime(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Ok, we cannot compute a gradient if the function is not $\\mathcal{R}^n \\to \\mathcal{R}$. We need to compute the Jacobian." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "fjac = jax.jacfwd(f)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[ 6., 0.],\n", " [ 0., 12.]], dtype=float64)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fjac(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There is also `jacrev`." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[ 6., 0.],\n", " [ 0., 12.]], dtype=float64)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.jacrev(f)(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It always gives the same result (within floating point precision). For $\\mathcal{R}^n \\to \\mathcal{R}^m$, use `jax.jacfwd` if $n \\ge m$ and `jax.jacrev` if $m < n$." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def g(x):\n", " A = ((1., 0.1), (0.1, 2.))\n", " return np.dot(x, np.dot(A, x))\n", "\n", "gprime = jax.grad(g)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "9.399999999999999" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "g(x)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "ename": "Exception", "evalue": "The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Tracedwith\n with primal = Traced\n tangent = Traced.\n\nThis error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mException\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgprime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 411\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdocstr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdocstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margnums\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margnums\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 412\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 413\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 414\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 469\u001b[0m \u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_check_input_dtype_grad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mholomorphic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdyn_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 470\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 471\u001b[0;31m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 472\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 473\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36m_vjp\u001b[0;34m(fun, *primals, **kwargs)\u001b[0m\n\u001b[1;32m 1553\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1554\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun_nokwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1555\u001b[0;31m \u001b[0mout_primal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_vjp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimals_flat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1556\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mvjp\u001b[0;34m(traceable, primals, has_aux)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 109\u001b[0;31m \u001b[0mout_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinearize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 110\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0mout_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinearize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mlinearize\u001b[0;34m(traceable, *primals, **kwargs)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0mout_primals_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tangents_pvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_primal_pval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_known\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mout_primal_pval\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mout_primals_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr\u001b[0;34m(fun, pvals, instantiate, stage_out, bottom, trace_type)\u001b[0m\n\u001b[1;32m 427\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mnew_master\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbottom\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbottom\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmaster\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 428\u001b[0m \u001b[0mfun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaster\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minstantiate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 429\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 430\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 431\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mmaster\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0mgen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 150\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 151\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mstack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mg\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mA\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mgprime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36mdot\u001b[0;34m(*args, **kwargs)\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/core.py\u001b[0m in \u001b[0;36m__array__\u001b[0;34m(self, *args, **kw)\u001b[0m\n\u001b[1;32m 448\u001b[0m \u001b[0;34m\"JAX Tracer instance; in that case, you can instead write \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 449\u001b[0m \"`jax.device_put(x)[idx]`.\")\n\u001b[0;32m--> 450\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 451\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTrace\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mException\u001b[0m: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Tracedwith\n with primal = Traced\n tangent = Traced.\n\nThis error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`." ] } ], "source": [ "gprime(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "JAX cannot use `numpy` for dynamical code, you need to replace the calls with `jax.numpy`." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "dot requires ndarray or scalar arguments, got at position 0.", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mgprime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mgprime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 411\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdocstr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdocstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margnums\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margnums\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 412\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 413\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 414\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 469\u001b[0m \u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_check_input_dtype_grad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mholomorphic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdyn_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 470\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 471\u001b[0;31m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 472\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 473\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36m_vjp\u001b[0;34m(fun, *primals, **kwargs)\u001b[0m\n\u001b[1;32m 1553\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1554\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun_nokwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1555\u001b[0;31m \u001b[0mout_primal\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_vjp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mad\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimals_flat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1556\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1557\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mvjp\u001b[0;34m(traceable, primals, has_aux)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mvjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 109\u001b[0;31m \u001b[0mout_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinearize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 110\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 111\u001b[0m \u001b[0mout_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maux\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinearize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mlinearize\u001b[0;34m(traceable, *primals, **kwargs)\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 96\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 97\u001b[0m \u001b[0mout_primals_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tangents_pvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_primal_pval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_known\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mout_primal_pval\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mout_primals_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr\u001b[0;34m(fun, pvals, instantiate, stage_out, bottom, trace_type)\u001b[0m\n\u001b[1;32m 427\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mnew_master\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbottom\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbottom\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mmaster\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 428\u001b[0m \u001b[0mfun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmaster\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minstantiate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 429\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 430\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 431\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mmaster\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 148\u001b[0m \u001b[0mgen\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 150\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 151\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 152\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mstack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mg\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mA\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0.1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mA\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mgprime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/numpy/lax_numpy.py\u001b[0m in \u001b[0;36mdot\u001b[0;34m(a, b, precision)\u001b[0m\n\u001b[1;32m 2757\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0m_wraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlax_description\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_PRECISION_DOC\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2758\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprecision\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# pylint: disable=missing-docstring\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2759\u001b[0;31m \u001b[0m_check_arraylike\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"dot\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2760\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_promote_dtypes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2761\u001b[0m \u001b[0ma_ndim\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mb_ndim\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mndim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mndim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/numpy/lax_numpy.py\u001b[0m in \u001b[0;36m_check_arraylike\u001b[0;34m(fun_name, *args)\u001b[0m\n\u001b[1;32m 273\u001b[0m if not _arraylike(arg))\n\u001b[1;32m 274\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"{} requires ndarray or scalar arguments, got {} at position {}.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 275\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpos\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 276\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 277\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mTypeError\u001b[0m: dot requires ndarray or scalar arguments, got at position 0." ] } ], "source": [ "def g(x):\n", " A = ((1., 0.1), (0.1, 2.))\n", " return jnp.dot(x, jnp.dot(A, x))\n", "\n", "gprime = jax.grad(g)\n", "gprime(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "... and it does not do implicit conversions like `numpy` for performance reasons." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([2.4, 8.2], dtype=float64)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def g(x):\n", " A = np.array(((1., 0.1), (0.1, 2.)))\n", " return jnp.dot(x, jnp.dot(A, x))\n", "\n", "gprime = jax.grad(g)\n", "gprime(x)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[2. , 0.2],\n", " [0.2, 4. ]], dtype=float64)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ghessian = jax.hessian(g)\n", "ghessian(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Rule: Use `np` for constant arguments. Use `jnp` for variable arguments. That is the official recommendation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What happens if a function has several parameters?" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray(18., dtype=float64)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def f(x, a):\n", " return 3 * x ** 2 + a\n", "\n", "fprime = jax.grad(f)\n", "fprime(3.0, 1.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Jax computes the derivative with respect to the first argument by default." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Second steps" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from numba_stats import norm_cdf\n", "\n", "norm_pdf = jax.grad(norm_cdf)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.15865525393145707" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "norm_cdf(0, 1, 1)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating), but got int64. ", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnorm_pdf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 411\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mwraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdocstr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdocstr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margnums\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0margnums\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 412\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 413\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 414\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 467\u001b[0m \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrap_init\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 468\u001b[0m \u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdyn_args\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0margnums_partial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margnums\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 469\u001b[0;31m \u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_check_input_dtype_grad\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mholomorphic\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdyn_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 470\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 471\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/tree_util.py\u001b[0m in \u001b[0;36mtree_map\u001b[0;34m(f, tree)\u001b[0m\n\u001b[1;32m 162\u001b[0m \"\"\"\n\u001b[1;32m 163\u001b[0m \u001b[0mleaves\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtreedef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpytree\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mleaves\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtree_multimap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtree\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mrest\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.8/site-packages/jax/api.py\u001b[0m in \u001b[0;36m_check_input_dtype_revderiv\u001b[0;34m(name, holomorphic, x)\u001b[0m\n\u001b[1;32m 510\u001b[0m \u001b[0;34m\"is a sub-dtype of np.floating or np.complexfloating), \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 511\u001b[0m f\"but got {aval.dtype.name}. \")\n\u001b[0;32m--> 512\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 513\u001b[0m \u001b[0m_check_input_dtype_grad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpartial\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_check_input_dtype_revderiv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"grad\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 514\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mTypeError\u001b[0m: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.floating or np.complexfloating), but got int64. " ] } ], "source": [ "norm_pdf(0, 1, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "JAX cannot derive code that is not expressible in JAX primitives. You have to express your code using `jax.numpy` and `jax.scipy` (see next example).\n", "\n", "It is possible to teach JAX about new primitives. This is well explained in the JAX docs with a nice tutorial, but that is too involved to cover here.\n", "\n", "There is a comparabily simple way to help JAX with some derivatives, where a human knowns a faster/more accurate way to calculate the derivate. I cannot present this here, but it is fairly easy to set up. Please see the excellent JAX docs/tutorials on this topic." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Practical example A: Fit of a gaussian model to a histogram\n", "\n", "We fit a gaussian to a histogram using a maximum-likelihood approach based on Poisson statistics. This example is used to investigate how automatic differentiation can accelerate a typical fit in a counting experiment.\n", "\n", "To compare fits with and without passing an analytic gradient fairly, we use `Minuit.strategy = 0`, which prevents Minuit from automatically computing the Hesse matrix after the fit." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:37.436843Z", "start_time": "2020-02-21T10:26:37.432080Z" } }, "outputs": [], "source": [ "from jax.scipy.special import erf # replacement for scipy.special.erf\n", "from iminuit import Minuit" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We generate some toy data and write the negative log-likelihood (nll) for a fit to binned data, assuming Poisson-distributed counts.\n", "\n", "**Note:** We write all statistical functions in pure Python code, to demonstrate Jax's ability to automatically differentiate and JIT compile this code. In practice, one should import JIT-able statistical distributions from jax.scipy.stats. The library versions can be expected to have fewer bugs and to be faster and more accurate than hand-written code." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:37.594856Z", "start_time": "2020-02-21T10:26:37.585943Z" } }, "outputs": [], "source": [ "# generate some toy data\n", "rng = np.random.default_rng(seed=1)\n", "n, xe = np.histogram(rng.normal(size=10000), bins=1000)\n", "\n", "\n", "def cdf(x, mu, sigma):\n", " # cdf of a normal distribution, needed to compute the expected counts per bin\n", " # better alternative for real code: from jax.scipy.stats.norm import cdf\n", " z = (x - mu) / sigma\n", " return 0.5 * (1 + erf(z / np.sqrt(2)))\n", "\n", "\n", "def nll(par): # negative log-likelihood with constants stripped\n", " amp = par[0]\n", " mu, sigma = par[1:]\n", " p = cdf(xe, mu, sigma)\n", " mu = amp * jnp.diff(p)\n", " result = jnp.sum(mu - n + n * jnp.log(n / (mu + 1e-100) + 1e-100))\n", " return result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's check results from all combinations of using JIT and gradient and then compare the execution times.\n", "\n", "| | | |\n", "|:----:|:---:|:---:|\n", "| |~JIT~| JIT |\n", "|~grad~| m1 | m3 |\n", "| grad | m2 | m4 |" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:37.890967Z", "start_time": "2020-02-21T10:26:37.886224Z" } }, "outputs": [], "source": [ "start_values = (1.5 * np.sum(n), 1.0, 2.0)\n", "limits = ((0, None), (None, None), (0, None))\n", "\n", "\n", "def make_and_run_minuit(fcn, grad=None):\n", " m = Minuit(fcn, start_values, grad=grad, name=(\"amp\", \"mu\", \"sigma\"))\n", " m.errordef = Minuit.LIKELIHOOD\n", " m.limits = limits\n", " m.strategy = 0 # do not explicitly compute hessian after minimisation\n", " m.migrad()\n", " return m" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:38.532308Z", "start_time": "2020-02-21T10:26:38.368563Z" }, "scrolled": true }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
FCN = 496.2 Nfcn = 66
EDM = 1.84e-08 (Goal: 0.0001)
Valid Minimum Valid Parameters No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok APPROXIMATE Pos. def. Not forced
" ], "text/plain": [ "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”\n", "β”‚ FCN = 496.2 β”‚ Nfcn = 66 β”‚\n", "β”‚ EDM = 1.84e-08 (Goal: 0.0001) β”‚ β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Valid Minimum β”‚ Valid Parameters β”‚ No Parameters at limit β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Below EDM threshold (goal x 10) β”‚ Below call limit β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Covariance β”‚ Hesse ok β”‚APPROXIMATEβ”‚ Pos. def. β”‚ Not forced β”‚\n", "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m1 = make_and_run_minuit(nll)\n", "m1.fmin" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:39.371830Z", "start_time": "2020-02-21T10:26:38.797460Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
FCN = 496.2 Nfcn = 26
EDM = 1.84e-08 (Goal: 0.0001) Ngrad = 6
Valid Minimum Valid Parameters No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok APPROXIMATE Pos. def. Not forced
" ], "text/plain": [ "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”\n", "β”‚ FCN = 496.2 β”‚ Nfcn = 26 β”‚\n", "β”‚ EDM = 1.84e-08 (Goal: 0.0001) β”‚ Ngrad = 6 β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Valid Minimum β”‚ Valid Parameters β”‚ No Parameters at limit β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Below EDM threshold (goal x 10) β”‚ Below call limit β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Covariance β”‚ Hesse ok β”‚APPROXIMATEβ”‚ Pos. def. β”‚ Not forced β”‚\n", "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m2 = make_and_run_minuit(nll, grad=jax.grad(nll))\n", "m2.fmin" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:39.510553Z", "start_time": "2020-02-21T10:26:39.373728Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
FCN = 496.2 Nfcn = 26
EDM = 1.88e-08 (Goal: 0.0001) Ngrad = 6
Valid Minimum Valid Parameters No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok APPROXIMATE Pos. def. Not forced
" ], "text/plain": [ "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”\n", "β”‚ FCN = 496.2 β”‚ Nfcn = 26 β”‚\n", "β”‚ EDM = 1.88e-08 (Goal: 0.0001) β”‚ Ngrad = 6 β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Valid Minimum β”‚ Valid Parameters β”‚ No Parameters at limit β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Below EDM threshold (goal x 10) β”‚ Below call limit β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Covariance β”‚ Hesse ok β”‚APPROXIMATEβ”‚ Pos. def. β”‚ Not forced β”‚\n", "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m3 = make_and_run_minuit(jax.jit(nll), grad=jax.grad(nll))\n", "m3.fmin" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:40.573574Z", "start_time": "2020-02-21T10:26:40.229476Z" }, "scrolled": true }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
FCN = 496.2 Nfcn = 26
EDM = 1.88e-08 (Goal: 0.0001) Ngrad = 6
Valid Minimum Valid Parameters No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok APPROXIMATE Pos. def. Not forced
" ], "text/plain": [ "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”\n", "β”‚ FCN = 496.2 β”‚ Nfcn = 26 β”‚\n", "β”‚ EDM = 1.88e-08 (Goal: 0.0001) β”‚ Ngrad = 6 β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Valid Minimum β”‚ Valid Parameters β”‚ No Parameters at limit β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Below EDM threshold (goal x 10) β”‚ Below call limit β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Covariance β”‚ Hesse ok β”‚APPROXIMATEβ”‚ Pos. def. β”‚ Not forced β”‚\n", "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m4 = make_and_run_minuit(jax.jit(nll), grad=jax.jit(jax.grad(nll)))\n", "m4.fmin" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
FCN = 496.2 Nfcn = 82
EDM = 5.31e-05 (Goal: 0.0001)
Valid Minimum Valid Parameters No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok APPROXIMATE Pos. def. Not forced
" ], "text/plain": [ "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”\n", "β”‚ FCN = 496.2 β”‚ Nfcn = 82 β”‚\n", "β”‚ EDM = 5.31e-05 (Goal: 0.0001) β”‚ β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Valid Minimum β”‚ Valid Parameters β”‚ No Parameters at limit β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Below EDM threshold (goal x 10) β”‚ Below call limit β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Covariance β”‚ Hesse ok β”‚APPROXIMATEβ”‚ Pos. def. β”‚ Not forced β”‚\n", "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from numba_stats import norm_cdf # numba jit-able version of norm.cdf\n", "\n", "@nb.njit\n", "def nb_nll(par): # negative log-likelihood with constants stripped\n", " amp = par[0]\n", " mu, sigma = par[1:]\n", " p = norm_cdf(xe, mu, sigma)\n", " mu = amp * np.diff(p)\n", " result = np.sum(mu - n + n * np.log(n / (mu + 1e-323) + 1e-323))\n", " return result\n", "\n", "m5 = make_and_run_minuit(nb_nll)\n", "m5.fmin" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:45.031931Z", "start_time": "2020-02-21T10:26:40.674388Z" } }, "outputs": [], "source": [ "from timeit import timeit\n", "\n", "times = {\n", " \"no JIT, no grad\": \"m1\",\n", " \"no JIT, grad\": \"m2\",\n", " \"jax JIT, no grad\": \"m3\",\n", " \"jax JIT, grad\": \"m4\",\n", " \"numba JIT, no grad\": \"m5\",\n", "}\n", "for k, v in times.items():\n", " t = timeit(\n", " f\"{v}.values = start_values; {v}.migrad()\",\n", " f\"from __main__ import {v}, start_values\",\n", " number=1,\n", " )\n", " times[k] = t" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:26:45.142272Z", "start_time": "2020-02-21T10:26:45.033451Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from matplotlib import pyplot as plt\n", "\n", "x = np.fromiter(times.values(), dtype=float)\n", "xmin = np.min(x)\n", "\n", "y = -np.arange(len(times))\n", "plt.barh(y, x)\n", "for yi, k, v in zip(y, times, x):\n", " plt.text(v, yi, f\"{v/xmin:.1f}x\")\n", "plt.yticks(y, times.keys())\n", "for loc in (\"top\", \"right\"):\n", " plt.gca().spines[loc].set_visible(False)\n", "plt.xlabel(\"execution time / s\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusions:\n", "\n", "1. As expected, the best results are obtained by JIT compiling the function and the gradient.\n", "\n", "2. Without using the gradient, JIT compiling the cost function with Jax is relatively a minor performance improvement. Numba is able to do much better, because it directly generates optimized machine code, using the powerful optimizer from LLVM.\n", "\n", "3. JIT compiling the gradient is very important. Using a Python-computed gradient is only a minor performance improvement in this example. This may change if the model has a large number of parameters (>> 10).\n", "\n", "The gain from using a gradient is larger for functions with hundreds of parameters, as is common in machine learning. Human-made models often have less than 10 parameters, and then the gain is not so dramatic. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Computing covariance matrices with JAX\n", "\n", "Automatic differentiation gives us another way to compute uncertainties of fitted parameters. MINUIT compute the uncertainties with the HESSE algorithm by default, which computes the matrix of second derivates approximately using finite differences and inverts this.\n", "\n", "Let's compare the output of HESSE with the exact (within floating point precision) computation using automatic differentiation." ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:27:38.715871Z", "start_time": "2020-02-21T10:27:37.907690Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "sigma[amp] : HESSE = 100.0, JAX = 100.0\n", "sigma[mu] : HESSE = 0.0100, JAX = 0.0100\n", "sigma[sigma]: HESSE = 0.0071, JAX = 0.0071\n" ] } ], "source": [ "m4.hesse()\n", "cov_hesse = m4.covariance\n", "\n", "\n", "def jax_covariance(par):\n", " return jnp.linalg.inv(jax.hessian(nll)(par))\n", "\n", "\n", "par = np.array(m4.values)\n", "cov_jax = jax_covariance(par)\n", "\n", "print(\n", " f\"sigma[amp] : HESSE = {cov_hesse[0, 0] ** 0.5:6.1f}, JAX = {cov_jax[0, 0] ** 0.5:6.1f}\"\n", ")\n", "print(\n", " f\"sigma[mu] : HESSE = {cov_hesse[1, 1] ** 0.5:6.4f}, JAX = {cov_jax[1, 1] ** 0.5:6.4f}\"\n", ")\n", "print(\n", " f\"sigma[sigma]: HESSE = {cov_hesse[2, 2] ** 0.5:6.4f}, JAX = {cov_jax[2, 2] ** 0.5:6.4f}\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Success, HESSE and JAX give the same answer within the relevant precision.\n", "\n", "**Note:** If you compute the covariance matrix in this way from a least-squares cost function, you must multiply it by 2.\n", "\n", "Let us compare the performance of HESSE with Jax." ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "90.3 ms Β± 3.37 ms per loop (mean Β± std. dev. of 3 runs, 1 loop each)\n" ] } ], "source": [ "%%timeit -n 1 -r 3\n", "m = Minuit(nll, par)\n", "m.errordef = Minuit.LIKELIHOOD\n", "m.hesse()" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "53.2 ms Β± 4.73 ms per loop (mean Β± std. dev. of 3 runs, 1 loop each)\n" ] } ], "source": [ "%%timeit -n 1 -r 3\n", "jax_covariance(par)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The computation with Jax is faster, but not by much. It is also more accurate (although the added precision is not relevant).\n", "\n", "Altogether, Minuit's HESSE algorithm still makes sense today. It has the advantage that it can process any function, while Jax cannot. Jax cannot differentiate a function that calls into C/C++ code or Cython code, for example.\n", "\n", "Final note: If we JIT compile `jax_covariance`, it greatly outperforms Minuit's HESSE algorithm, but that only makes sense if you need to compute the hessian at different parameter values, so that the extra time spend to compile is balanced by the time saved over many invokations. This is not what happens here, the Hessian in only needed at the best fit point." ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "253 Β΅s Β± 19.1 Β΅s per loop (mean Β± std. dev. of 3 runs, 1 loop each)\n" ] } ], "source": [ "%%timeit -n 1 -r 3 jit_jax_covariance = jax.jit(jax_covariance); jit_jax_covariance(par)\n", "jit_jax_covariance(par)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is much faster... but only because the compilation cost is excluded here." ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "565 ms Β± 0 ns per loop (mean Β± std. dev. of 1 run, 1 loop each)\n" ] } ], "source": [ "%%timeit -n 1 -r 1\n", "# if we include the JIT compilation cost, the performance drops dramatically\n", "@jax.jit\n", "def jax_covariance(par):\n", " return jnp.linalg.inv(jax.hessian(nll)(par))\n", "\n", "\n", "jax_covariance(par)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With compilation cost included, it is much slower.\n", "\n", "Conclusion: Using the JIT compiler makes a lot of sense if the covariance matrix has to be computed repeatedly for the same cost function but different parameters, but this is not the case when we use it to compute parameter errors." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Practical example B: Fit data points with uncertainties in x and y\n", "\n", "Let's say we have some data points $(x_i \\pm \\sigma_{x,i}, y_i \\pm \\sigma_{y,i})$ and we have a model $y=f(x)$ that we want to adapt to this data. If $\\sigma_{x,i}$ was zero, we could use the usual least-squares method, minimizing the sum of squared residuals $r^2_i = (y_i - f(x_i))^2 / \\sigma^2_{y,i}$. Here, we don't know where to evaluate $f(x)$, since the exact $x$-location is only known up to $\\sigma_{x,i}$.\n", "\n", "We can approximately extend the standard least-squares method to handle this case. We use that the uncertainty along the $x$-axis can be converted into an additional uncertainty along the $y$-axis with error propagation,\n", "$$\n", "f(x_i \\pm \\sigma_{x,i}) \\simeq f(x_i) \\pm f'(x_i)\\,\\sigma_{x,i}.\n", "$$\n", "Using this, we obtain modified squared residuals\n", "$$\n", "r^2_i = \\frac{(y_i - f(x_i))^2}{\\sigma^2_{y,i} + (f'(x_i) \\,\\sigma_{x,i})^2}.\n", "$$\n", "\n", "We demonstrate this with a fit of a polynomial." ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:25:43.510168Z", "start_time": "2020-02-21T10:25:43.371319Z" } }, "outputs": [], "source": [ "# polynomial model\n", "def f(x, par):\n", " return jnp.polyval(par, x)\n", "\n", "\n", "# true polynomial f(x) = x^2 + 2 x + 3\n", "par_true = (1, 2, 3)\n", "\n", "\n", "# grad computes derivative with respect to the first argument\n", "f_prime = jax.jit(jax.grad(f))\n", "\n", "\n", "# checking first derivative f'(x) = 2 x + 2\n", "assert f_prime(0.0, par_true) == 2\n", "assert f_prime(1.0, par_true) == 4\n", "assert f_prime(2.0, par_true) == 6\n", "# ok!\n", "\n", "# generate toy data\n", "n = 30\n", "data_x = np.linspace(-4, 7, n)\n", "data_y = f(data_x, par_true)\n", "\n", "rng = np.random.default_rng(seed=1)\n", "sigma_x = 0.5\n", "sigma_y = 5\n", "data_x += rng.normal(0, sigma_x, n)\n", "data_y += rng.normal(0, sigma_y, n)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:25:43.646212Z", "start_time": "2020-02-21T10:25:43.512384Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.errorbar(data_x, data_y, sigma_y, sigma_x, fmt=\"o\");" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:25:44.032210Z", "start_time": "2020-02-21T10:25:43.648365Z" } }, "outputs": [ { "data": { "text/plain": [ "DeviceArray(876.49545695, dtype=float64)" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# define the cost function\n", "@jax.jit\n", "def cost(par):\n", " result = 0.0\n", " for xi, yi in zip(data_x, data_y):\n", " y_var = sigma_y ** 2 + (f_prime(xi, par) * sigma_x) ** 2\n", " result += (yi - f(xi, par)) ** 2 / y_var\n", " return result\n", "\n", "cost.errordef = Minuit.LEAST_SQUARES\n", "\n", "# test the jit-ed function\n", "cost(np.zeros(3))" ] }, { "cell_type": "code", "execution_count": 40, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:25:44.059729Z", "start_time": "2020-02-21T10:25:44.034029Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
FCN = 23.14 Nfcn = 91
EDM = 3.12e-05 (Goal: 0.0002)
Valid Minimum Valid Parameters No Parameters at limit
Below EDM threshold (goal x 10) Below call limit
Covariance Hesse ok Accurate Pos. def. Not forced
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Name Value Hesse Error Minos Error- Minos Error+ Limit- Limit+ Fixed
0 x0 1.25 0.15
1 x1 1.5 0.5
2 x2 1.6 1.5
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
x0 x1 x2
x0 0.0223 -0.0388 (-0.530) -0.15 (-0.657)
x1 -0.0388 (-0.530) 0.24 0.172 (0.230)
x2 -0.15 (-0.657) 0.172 (0.230) 2.32
" ], "text/plain": [ "β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”\n", "β”‚ FCN = 23.14 β”‚ Nfcn = 91 β”‚\n", "β”‚ EDM = 3.12e-05 (Goal: 0.0002) β”‚ β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Valid Minimum β”‚ Valid Parameters β”‚ No Parameters at limit β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Below EDM threshold (goal x 10) β”‚ Below call limit β”‚\n", "β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ Covariance β”‚ Hesse ok β”‚ Accurate β”‚ Pos. def. β”‚ Not forced β”‚\n", "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n", "β”Œβ”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”\n", "β”‚ β”‚ Name β”‚ Value β”‚ Hesse Err β”‚ Minos Err- β”‚ Minos Err+ β”‚ Limit- β”‚ Limit+ β”‚ Fixed β”‚\n", "β”œβ”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ 0 β”‚ x0 β”‚ 1.25 β”‚ 0.15 β”‚ β”‚ β”‚ β”‚ β”‚ β”‚\n", "β”‚ 1 β”‚ x1 β”‚ 1.5 β”‚ 0.5 β”‚ β”‚ β”‚ β”‚ β”‚ β”‚\n", "β”‚ 2 β”‚ x2 β”‚ 1.6 β”‚ 1.5 β”‚ β”‚ β”‚ β”‚ β”‚ β”‚\n", "β””β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”˜\n", "β”Œβ”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”\n", "β”‚ β”‚ x0 x1 x2 β”‚\n", "β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€\n", "β”‚ x0 β”‚ 0.0223 -0.0388 -0.15 β”‚\n", "β”‚ x1 β”‚ -0.0388 0.24 0.172 β”‚\n", "β”‚ x2 β”‚ -0.15 0.172 2.32 β”‚\n", "β””β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m = Minuit(cost, np.zeros(3))\n", "m.migrad()" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "ExecuteTime": { "end_time": "2020-02-21T10:25:44.566228Z", "start_time": "2020-02-21T10:25:44.065443Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAEMCAYAAADd+e2FAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAsp0lEQVR4nO3deXhU5dnH8e+djUQW2RUCSFRUUMQg4oLiLmoRERewVnFp3W21FgXltVZrodVWrbhW61IRUMSIC+BeUdACAqLggmwmgiyyhAAJkOf940xiCDPJZDKTM8vvc11zkTnb3GdI7nnmWc05h4iIJJ40vwMQEZHIKIGLiCQoJXARkQSlBC4ikqCUwEVEEpQSuIhIglICFxFJUErgKc7MepvZTDP70MzGmVmm3zGJSHiUwOV74CTnXF9gGXC2v+GISLgy/A5A/OWcW1nlaRlQ7lcsIlI3KoELAGa2D3Aa8FqVbe3NrNC/qEIzs0Zm9pSZLTezYjObZ2ZnVNn/vJmtNLNNZvaNmf26lutdb2azzazUzJ4JcUwXM9tmZs8H2dfezArDiGtztcdOM3uoHm9F1RhamtkrZlYSeP1f1nBsZzN708zWm9kqMxtjZhlV9g8xs0WBa31nZsdFI0aJLiVwwcyaAf8BLnXOba+y60xgqj9R1SoDr/rneGBPYCTwopl1DuwfBXR2zjUDBgB/NrPDa7jeD8CfgX/XcMzDwKwQ+yreqxrjcs41qXgAewNbgZdqvNPwPYz3LWov4CLgUTM7OMSxjwCrgXbAYYF4rwUws1OBvwKXAU2BvsCSKMUoUaQEngLM7CszW2BmeweeH2JmG82sW6DUNR74k3Pu62qnngm8GTjnqkCJ7WEzW2tmPwT+0H3hnCtxzt3pnFvmnCt3zr0OLAUOD+z/0jlXWnF44LFfDdeb5JwrANYF229mQ4ANwLshLnEm8GZtcVVzLl4SnV7L7dbKzBoHrvd/zrnNzrmPgMnAxSFOyQNedM5tc86twvvwqUj2fwLucs59EriHIudcUX1jlOhTAk8N+cBm4KxAL5PngL845xYCFwJHAv9nZh+Y2WCAwHF9gbcD1+gBHIWXFNoCjwO3RitAM3vdzDaEeLwexvl7AQcAX1bZ9oiZbQG+AlYS+DCKILZmwF3A70Psr/5e1RhXFUOB51yQKUEjeD8OAHY4576psm0+Pyfl6h4AhpjZHmaWC5wBTDWzdKAX0MbMFgeqhcaYWU6I64iP1IiZApxzW81sGtAduB3YDtwX2PcfvOqT6voC851zxYHnhwKjnXPTAMxsIbBbvaiZHQL8wTl3abBYzMyAAqAZcIFzbk0gjv6R3l8ggY4FnnXOfVWx3Tl3rZndABwNnACUBr9Cre4GnnLOFXrh76b6e1VjXIF9++BVW1wR7IIRvB9NgE3Vtm3EqwIJ5kPgysA56cCzeP8v7YBM4Dy8/9/twKt4VUG31zEmiTGVwFPHF0B/4Ga8uu6dtRxftfrE8JL/a1X2HwIsjCCOvQGccydWJO/6MLM0vA+gMuD66vudczsD1QkdgGsiuP5hwCnA/TUcVvlehRsXXtXGR865pXWNKYTNeB+KVTUDiqsfGIhtKjAJaAy0Blrg1XtvDRz2kHNupXNuLfAPvHuUOKMEnjoW4tV73u2cWxTG8VWTUme8b2tV68jzgXkAZpZhZi+a2TvATYFt6YGeIP81szfMrEXgvAeBY8xsUtUXM7MpQXpoVDymBAsw8MHyFF6j3bnVGmCry6CGOvAanIB3/yvMbBXwB+BcM/usyjG7JPAw47oEr9QbVATvxzdAhpl1qbKtB8GrbloCnYAxzrlS59w64GngTOfceqAQr82gglZ9iVfOOT1S4AGMxvtDbBPGsXnAkirPzwY+qXbM98DhgZ/Pw6tTB7gaeCawbXRg28XAHYGfOwMTo3RPjwGfAE2qbW8LDMGrVkgH+gElwIAarpUBZOP1XvlP4OcMYA+8bw0Vj/uAiRXvY/X3qqa4quw/JhBP0yj/H48HxuGVqvvgVaEcHOLYJcDwwD02B14BXgjsuwuvt01bvJL5dLwPft9/j/XY9aESeAows954XcQKga5hnPILdq0SOJRAaTtwvdZ4yeyLwKb9gTmBn2dV2TaryraqJcN6C9QhX4XXBW5VldLpRXgfVNfg3e96vKR7o3NucpXzp5jZbVUuORKv+mA48KvAzyOdc1ucc6sqHnhVFdvcz9U/u7xXtcRVYSgwyVWrM4+Ca4EcvJ4t44BrnHNfhrjfQcDpwBpgMV5d902BfXfj/Z99AywC5gL3RDlWiQILfOJKkjKzRsBneKXC3ngl6YdrOedNvK/XYfXaMLPzgMOccyPN7Cq8RsPXgSOcc7ea2cVAnnPurkB/6Pucc+dFflfxo67vlUg0qQSe/P4E/AiMAT4HfmFmWbWc8wHwfh1eowA4yMzexSt5VmzraGYf4nVVHFOH6yWSD6jbeyUSNSqBJ7FA1ck0vNLxcjPLCzxf55w72t/oRKS+lMBFRBKUqlBERBKUEriISIJq0KH0rVu3dp07d27IlxQRSXhz5sxZ65xrU317gybwzp07M3v27IZ8SRGRhGdmy4NtVxWKiEiCUgIXEUlQSuAiIgnK9/nAt2/fTmFhIdu2bfM7lJjJzs6mQ4cOZGZm+h2KiCQR3xN4YWEhTZs2pXPnzoSYLD+hOedYt24dhYWF5OXl+R2OiCQR36tQtm3bRqtWrZIyeQOYGa1atUrqbxgi4g/fEziQtMm7QrLfn4j4Iy4SuIhIQigvh2m3w5pvaj+2ASiBA//85z/p2rUrLVq0YPTo0QAUFBSwcGEkSz6KSNJa8CLMHAOFs2o/tgH43ogZDx555BHeeecdOnToULmtoKCA/v37061bNx8jE5G4UVoMb98BuYdDjwv9jgZQCZyrr76aJUuWcMYZZ3D//fdz/fXXM2PGDCZPnsywYcM47LDD+O677/wOU0T89uG9sPlHOONeSIuP1BlfJfApw2HVguhec+/ucMbokLsfe+wxpk6dyvvvv8/rr78OwDHHHMOAAQPo378/552XFCt/iUh9rF0MMx+Bw34FHQ73O5pK8fExIiISz6aNgMwcOOWPfkeyi/gqgddQUhYR8cU30+Dbt+C0e6BJW7+j2YVK4CE0bdqU4uJiv8MQET/tKIWpw6H1AdD7Sr+j2Y0SeAhDhgzh3nvvJT8/X42YIqlq5hj4aQmcPgoysvyOZjcNuqhxr169XPUFHRYtWkTXrl0bLAa/pMp9iiSNDStgTG/ocgoMft7XUMxsjnOuV/XtKoGLiAQzdYT3b79R/sZRAyVwEZHqvn0bvnodjh8GzTv6HU1ISuAiIlVt3wZvDoNWXeDoG/yOpkbx1Y1QRMRvM/4J65fCxQVx2XBZlUrgIiIV1i+D6X+Hg8+B/U70O5paJWQCH/z4TAY/PtPvMEQk2UwZDpbuDdoJIZ7yT0Im8Fi68847ue+++0Lu1zSzIknq6ynwzRQ4YTjsmet3NGFJuAReMLeIuSs28OnSn+gz+j0K5hY17OsrgYskn+1bYcqt0OYgOOoav6MJW0Il8IK5RYyYtICyneUAFG3YyohJC+qdxO+55x4OOOAAjj32WL7++msA/vWvf3HEEUfQo0cPzj33XLZs2RJ0mtlgx4lIgvnoftiwHM68D9Iz/Y4mbAmRwCvqnG6Z+Dlbt+/cZd/W7Tu5ZeLnEddJzZkzh/HjxzNv3jzefPNNZs3yVtoYNGgQs2bNYv78+XTt2pWnnnqqcprZe++9l3nz5rHffvsFPU5EEsi67+CjB6D7+ZB3nN/R1ElCdSOsKHmHuz0c06dP55xzzmGPPfYAYMCAAQB88cUXjBw5kg0bNrB582b69esX9PxwjxOROOQcvPF7yGgEp97tdzR1lhAJfMJVRwPQZ/R7FG3Yutv+3OY5lcdEy6WXXkpBQQE9evTgmWee4YMPPqjXcSIShxa8BEs+8KpOmrXzO5o6S4gqlArD+h1ITmb6LttyMtMZ1u/AiK/Zt29fCgoK2Lp1K8XFxbz22msAFBcX065dO7Zv387YsWMrj68+zWyo40Qkzm35yZvvJLcX9Lrc72giklAJfGB+LqMGdScr3Qs7t3kOowZ1Z2B+5F1+evbsyeDBg+nRowdnnHEGRxxxBAB33303Rx55JH369OGggw6qPL76NLOhjhOROPf2HbB1PZz1IKSl1348/veCqy4hp5OtaLCMdrVJLGk6WZE4suxjeOZM6PM7OPWusE6p6AVXtSNFTmZ6vQuR4Qg1nWxC1IFXl0iJW0TizI5SeP1GaN4Jjh8edg+2uSs27NZhoqIX3Lj/rajx3FjlrIRM4CIiEfv4QVj7DVz0MmTtEfZpsegFV19xkcCdc5iZ32HETENWU4lIDdYuhg/vg0PO9VbaIfzScUP2gguX742Y2dnZrFu3LmmTnHOOdevWkZ2d7XcoIqnNOa/qJCM7olV2YtELrr58L4F36NCBwsJC1qxZ43coMZOdnU2HDh38DkMktc0fB8umQ//7oeledT69oqHylomfU7aznNzmOQzrd2DMGzBr4nsCz8zMJC8vz+8wRCSZlayDabdDxyOh56URX2Zgfm5lg2U8dKbwvQpFRCTmpg6H0mLo/wCkJU/aS547EREJ5ptpsOBFOO5m2Kub39FEVVgJ3Myam9lEM/vKzBaZ2dFm1tLM3jazbwP/toh1sCIidbJtE7x2I7Tt5iXwJBNuCfxBYKpz7iCgB7AIGA6865zrArwbeC4iEj/evgM2r4IBY+J+geJI1JrAzWxPoC/wFIBzrsw5twE4G3g2cNizwMDYhCgiEoGl02HO03DUtdDhcL+jiYlwSuB5wBrgaTOba2ZPmlljYC/n3MrAMauAuvfLERGJhbItMPkGaLkvnHi739HETDgJPAPoCTzqnMsHSqhWXeK8UThBR+KY2ZVmNtvMZidzX28RiSPv3wPrl8JZ/6zTcPlEE04/8EKg0Dn3aeD5RLwE/qOZtXPOrTSzdsDqYCc7554AngBvNsIoxCwiUmm32UkL58Anj8Dhl8VkibR46P9dodYSuHNuFfC9mVWMFz0ZWAhMBoYGtg0FXo1JhCIi4dpRCq9eB03bhT1NbIWKtXcTSbgjMW8AxppZFrAEuAwv+b9oZlcAy4ELYhOiiEiY6wBM/zusWQS/fAmymzVQZP4JK4E75+YBu00mjlcaFxHx36ovvAR+6GA44DS/o2kQGokpIgkv3W2HV66GnBYRzTSYqHyfzEpEpL4GbR4HmxfAkHHQuJXf4TQYlcBFJKHtV/Y152weDz1+CQed6Xc4DUoJXEQS1/atXLvx76xPawmnp07VSQUlcBFJWAUvPMbgzTfTa8uD9HnwMwrmFvkdUoNSHbiIJKSCt99nxKJ92EojAIo2bGXEpAUAvq6S05CsIdei7NWrl5s9e3aDvZ6IJJ5gg2nWFpeyZG0JDshKTyOveTobf/qRVa7lbsdmpaeR36l5ra9TtT95wdyiuFoqrTozm+Oc260rt0rgIhLX1haXsnRdSeVkS2U7y1m6rowygi9BULazvE7XL5hbxIhJCyrPS6SSvErgIuKLcIetz12xIWhSTqOc8iDNeBUl8IUrN9GtXe2jMUNdP5KSfKyEKoGrEVNE4lqoEnU5aaTZrtvSDDq2yInK9etakveDqlBExBfhllz7jH6Pog1bd9teUVdd37rrmq4fTzMPBqMSuIjEtWH9DiQnM32XbTmZ6ZXJOr9Tc47Ma8nHw0+KqM66puvHOyVwEYlrA/NzGXV6O9rZWgxHbvNsRg3qHrUGxoH5uYwa1J2sdC8d5jbPier1Y0lVKCIS38p3MvDr4ZyWPY9bWz/MQ9f1j/pLDMzPZdz/VgDxtWBDbVQCF5H49vEDsGIGTze7htUZ7fyOJq4ogYtI/Cr6DN7/Cxx8Dh/mnOJ3NHFHCVxE4lNZCUz6DTTZC/rfD2a1n5NiVAcuIvFp2u2w7jsYOtlbqEF2oxK4iMSfr96AOU/DMTdAXl+/o4lbKoGLSHwpXgWTb4C9D4WTRtZ6eCL1Gok2JXCRJBXWKu7xpnynV+9dtgXOfRIyGlXuivV9JNT7FKAELiLxY/o/YOmHMGAMtIn/kZB+Ux24iMSH5TPgg79A9/Mh/1d+R5MQlMBFxH8l62DiFdCis7oM1oGqUETEX85BwTWwZS38+h1o1NTviBKGEriI+Gvmw/DtNDjjXmjXw+9oEoqqUETEP4Vz4J074aD+0Ps3fkeTcJTARcQf2zbCxMugaTs4e4zqvSOgKhQRaXjOeYN1NhbC5VM1VD5CKoGLJKGCuUXMXbGBT5f+RJ/R71Ewt8jvkHY160lY+CqcfAd07O13NAlLCVwkyRTMLWLEpAWVi/IWbdjKiEkL4ieJfz8Lpo6ALv3gmN/6HU1CUxWKSIKoGBpfm7krNuy2ovrW7Tu5ZeLnlavO1CSmQ8pL1sJLQ6FZexj0OKSpDFkfSuAiSaZ68q5te7jqPbdK+U6YeDlsWQdXvKV67yhQAhdJEOEmzj6j36Now9bdtuc2z/F3wqb374Gl/4WzH1Z/7yjR9xeRJDOs34HkZKbvsi0nM51h/XycHOqrN2H636HnUM1zEkUqgYskmYH5uQDcMvFzynaWk9s8h2H9Dqzc3uDWfQevXA3tDoMz/uZPDElKCVwkCQ3Mz61ssPS12qRsC7x4iddYecFzkJntXyxJKOwqFDNLN7O5ZvZ64HmemX1qZovNbIKZZcUuTBFJOM7BGzfDj1/CoCehxT5+R5R06lIH/jtgUZXnfwXud87tD6wHrohmYCKS4GY/BfNfgBOGQ5dT/I4mKYWVwM2sA/AL4MnAcwNOAiYGDnkWGBiD+EQkES37CKbc6g3W6XuL39EkrXBL4A8AtwAVHUlbARucczsCzwuBoC0kZnalmc02s9lr1qypT6wi4pM6Dc1fv9yr9265L5z7Lw3WiaFa31kz6w+sds7NieQFnHNPOOd6Oed6tWnTJpJLiIiP6jQ0v6wExv8SynfAheMhe88Gjja1hNMLpQ8wwMzOBLKBZsCDQHMzywiUwjsAcTLRgoiEI+pD853jpg330HvbQka3vItxDy0GFrPgzn5RjFqqqjWBO+dGACMAzOwE4A/OuYvM7CXgPGA8MBR4NXZhikhdRav7YLhD8wdtHsdR2z7iP01/w/xGvYBNUXl9Ca0+/cBvBcab2Z+BucBT0QlJRBpCVIfmf/UGjH8ODh3Mxefcy8VmYZfwJXJ1al1wzn3gnOsf+HmJc663c25/59z5zrnS2IQoIn6qdWj+6kUw6Upo3xPOelAr6zQgjcQUkRrVODR/y08w7kLIagxDxkJmTq3Xq/eshlJJCVxEahV0aP6OMq+74KYiuPQNb45vaVBK4CJSd87Ba7+DZdO9YfJaFs0X6mEvInU3/e+BYfIj4NDz/Y4mZSmBi0jdfDEJ3rsbul8Ax98a9JC4X1Q5SSiBi0jYupQt8ub27nQ0nD0maI+TuF9UOYmYc67BXqxXr15u9uzZDfZ6IhI9p9/5H8ZyG9vTGzOy9QMUpwUfJh9s5CZAVnoa+Z2as3ClN8CnW7tmux2jninBmdkc51yv6ttVAheR2m1dzxhGk8lO/tryrpDJG2K3qLLsTr1QRKRm27fBuF/SmVX8peU93H/dBTUeXtvITfUDjx6VwEUktPKd8MqVsGIGGec+zh2/vbrWU2oauanGzehSCVxEgnMOpo6Aha/CafdA9/PCOi3UyE0gaONm1XOkbpTARSS4Gf+E/z0OR18Px1wf8rBQk1Y1ykyjUWYaHVrkMO5/K8KfljYIVbcEpyoUEdnd5y/C23fAwYPg1Lujckk1bkafSuAisqvv3oeCa6HzcXDOY7UuiRaqdFy9sTKsaWmlTlQCF5GfFc6G8RdB6wNg8POQ0Shql651WlqpM5XARcTz40J4/lxo0hYungQ5zaN6+RqnpZWIKIGLCKxfBv85BzKy4ZICaLp3TF4m6LS0EjElcJFUV7wKnjsbdpbCZVOgRWe/I5IwKYGLJImIRjhu+ckreW9eA0Nfg7ZdYxSdxIISuEgcapDh5qWb4YULYN1iuOgl6HB4VC+vKpLYUwIXSUVlW+CFwVD0GVzwLOx7gt8RSQSUwEVSzfatMP5CWDEDBv0Lup7ld0QSIfUDF0klO0phwq9gyX/h7EfCnt9E4pNK4CJJrrI+/YrD4cWhsPgdGPAQHHahL/Gobjx6VAIXSQHpbgdMvAy+mQK/+Af0vMTvkCQKkrYErknjRTzpbgfXb/gbrPoQTv8rHHGF3yFJlKgELpIEQi6UsKOMG9f/hWO2fQin/RmOqn1BBkkcSuAicaauq9aEXAV+9jJ48WJ6l87g6WZXwzE3NED00pCStgpFJBEFS8a/f3EeD737La2bBp8ZMNRCCfe8PIOBjaZyF7/mpU0nMTXEwguqZkxcSuAiMVC9DSbUqjXVBUvG5Q6WrC1h9ebSoOeEWhBhrWvKo3vexEsblaCTlRK4SBwJlYwd0K1ds6D7giV9gPZ7wDU33ckHatBPWkrgIg0g3OQZyao1FdUuW7fvrNyWk+4YdlbPyvr0sp3l9Bn9nubfTjJqxBSJI5GsWjMwP5dRZ+Syl63HcOQ2hlHn5QPBV4GvrVFUEkdSlsBV6pBEFc6qNdXr09vtKOT2n27j3UbF3MCtbGl7VL1XgQdVuSSCpEvgobpUAUrikhDqsmpN5+2Lue2n2wG4nD+yiH3pFtinVeCTX8Ik8Pq04qvUIcmk8nd02ccw7jZo0gwuKcBeXks3tAp8Kkm6OnCVOiQlLHzVW0mn6d5wxTRo3WW3Q7QKfPKrtQRuZh2B54C98HozPeGce9DMWgITgM7AMuAC59z6WAUay1Z8kYQy82GYdjt0OAIuHA+NWwU9TKvAJ79wSuA7gJudc92Ao4DrzKwbMBx41znXBXg38Nx34ZY6Bj8+M+xqGZG6qOtQ+LCV74Qpw2HabdC1PwydHDJ5VxiYn0t+p+YcmdeSj4efpOSdZGotgTvnVgIrAz8Xm9kiIBc4GzghcNizwAfArTGJsg5U6hA/xawRfftWePnX8NXrcNS13sRUaem1nydJrU6NmGbWGcgHPgX2CiR3gFV4VSzBzrkSuBKgU6dOEQdaF3VpxRcJR0M0oof8XS1ZB+MGQ+Fs6DcKjr42rFgk+YWdwM2sCfAycKNzbpOZVe5zzjkzc8HOc849ATwB0KtXr6DHiCSLqDeir17kLT68+Udv8eFuZ4c8VIWV1BNWAjezTLzkPdY5Nymw+Ucza+ecW2lm7YDVsQpSxG8N3Yg+4aqj4eup8OT5kJkDQ1+HjkeEfb6khlobMc0raj8FLHLO/aPKrsnA0MDPQ4FXox+eSGKJStc95+DjB2HcEGi1L1z5vpK3BBVOCbwPcDGwwMzmBbbdBowGXjSzK4DlwAUxiVAkgdS7EX1HKbx2I8x/AboNhIGPQtYeMYtXEls4vVA+AizE7pOjG45I4ou4Eb14FUy4GAr/ByfcBsffAhbqTy98qhtPXgkzlF4kqS37GF66FMo2w/nPwMHn+B2RJICkG0ofjpgNtBCpK+dgxhh49izIbga/flfJW8KWtCXw2ia/12yFEqnqy6WF2lar0mJ49XpYWAAH9ffqu7ODr7ojEkzSJHDNVigJZc3XMOFXsG4xnHoXHPPbqNR3S2pJmgQeLs1WKL5yDuaNhTeHQVZjuORVyOvrd1SSoJImgWu2Qol72zbC67+HLyZC5+Ng0BPQrL3fUUkCS95GzP/+DX6Yu9tmzZEskajr7JUTrjp61wJB4Rx47Dj48hU4aaRX8lbylnpKzgS+5SeY8ww8eQr8917YuaNy18D8XEYN6k5Wunfruc1zGDWouxowJTbKy71Rlf8+DVw5XDYF+g7TTIISFUlThbKLPVrCNR/DG3+A9/8M374Fgx6HlvsCmq1QGsj65fDqdbBsOnQdAAP+CTkt/I5KkkhylsDB+0M57yk49ylY+zU8eizMedZrRBKJJee837VHj4Ef5sGAh+CC55S8JeqSN4FX6H4eXDMDOvSC134L4y70hiyLRKDWQWCbVsILF3i/a+3z4doZ0PMSdRGUmEjOKpTq9uwAFxfAp4/Bu3+Ch3tzQqPL+SDnNL8j201EA0KkQawtLg06CGz28p8C4wt20mfUtwzLKmPgmX+DI34DaclfRhL/mGvAKoVevXq52bNnN9jrBbXuO5j8W1j+EQuyDqP71c9Ayzx/Y6pCCTxykb534fQuWbhyE5u37SD4X4uj6nxv6ebo3KoJrZs2qtym/0+pDzOb45zrVX17apTAq2q1Hwx9jX89cAcXFT/l1VOeNBKOvDruegYomftn4cpNuzzfUhoqeUP1yTp3OmPJ2hJWby6t3BbqQ0L/t1IfqZfAAdLSeKfxL/gsuzePNh/rrfK94CU48+/Q4XC/o5MGFiyJVk24a4tLWbJtx27H1MQB3dppXhOJrdRM4FT5o3VnwRcvw7Tb4cmToefFcPKd0LiVr/GJvyp+PyomPwtV+jYcLsh0+RrZKw0hZRN4JTOvp0qX0+C/f4VPHoWFk+HkO+DwSyEtXVUZSag+k595HI3SjT1zGrG2pIzyKhk+zSA7I43Bj8/U74zElJrIK2Q3g373eAOA9u4Ob/we/nUiLJ/hd2Tio9CTnBmHdWpJXpsm5LVqXFkGz0pPI69V410aMEViRSXw6tp2haGvedUqb42Ep8/gD42O5oVmlwMqTSWLsErGPy2lzwPfU1TWeLddWelp9Z8PXKSeVAIPpqJa5YbP4KSRHFw2n/vWXAWv3wSbV/sdXUKr66RQvti82pvudcwRDLOx5KTtWgpPM+jYIsen4ER+phJ4TbL2gL7D+N0Xh3Du5rGc/tlz8PmLcMwNcNQ1kL1nVF+uYpRf2c5y+ox+j+yMNH0Vb0jbNsGMh2Dmw7BjG/S8hIHH3wrfle+yyrz+XyReqAQeQtWS4qb05jy953Vw7aew34nwwSh44FBvpsNtG6PyesGWelu6roS1xaW1nClQz3VOy0q8xP1gD/jwb3DAaXD9LDjrAWjWjoH5ueR3as6ReS35ePhJSt4SN1QCr4vW+8Pg5715xv/7N2+mw5lj4Ojr4cgrg5bI69PbodzBd2tLOOD2KXRskRMycUSr3jWcetx4rOuNeJ3TbRvhf0/AzEdg60+w74lwyh+9OUxEEoASeCTa58OF43ZP5Ede5c1/0aRNnS9Z05JuZTvLWbquBCClSn+xWue0aflGzix5hX4lr9HYlUCXftD3D9Cxd1TiFmkoSuD1sUsiv9frR/7RA9BjiFcqb3PAbiXVUEkpKz2txiRe7uD79VuDJvDq14yn0nFDCHed0713FHF6yaucuPUtst02Psk+lqOG/gXa9ajT66Xa+yvxSwm8FtUbFof1O3D3r+Xt8+HCF2Dtt14D2Pxx8NmzcMDpXiLvfGyt04l2bJHD0nUluwwIqS7VFl6OyjqnVx4FS96HTx7zFvZIy4Du58KxN3JU264NHqtINCmB1yDU9KEQom61dRev4eukkTDrSa9+9dn+0PpA6HUZ9BhS4x96wdyiyt4OwWh4tqd6PfywfgcyYtICtm7fWXlMTmYaww5YBY8cDWsWQeM2cPwt0OtyaLq3L3GLRFtKJvD6TB9aU93qro5nwk2/gwUTvfU5pw6Hd+6Eg8+Bwy/z6lurlcorlnpbW1zKDxu3VUtIib/wcljfZiJQcQ3vw28nuVlbGWZjGfj5+7D3oTDwUTjkXMhInfYDSQ0pmcDDFao2I+yqjMwcb3KsnhfDys+9RP75i14VS+sDofv53oChavORr95cSvs9s/l+/dbKvsfRSnZ+CaenSMQ9XDYWMXDzBHpnP0n7nUWQ1RQOGQQ934Xcw7UajiStlEzgtSWIiqqMUCKqymh3KPT/B5x6lzdM//MJXu+V9/8MHY7wkvnB51Qe3rppo8oGy3iuNolmT5GKObiDXbP6e9Bi5zqvXvvLV+D7TwDYkNWdV5oM4brrboas3Ye/10c8/x9I6krJBF6T6iXF6updldGoCRw+1Hts+N5L5gsmwpRbYOpw/ph5MG9wGIU7TmBVRsOVuMOp3gh2TLjC7SlSo/XL4Ztp3LnuWQ4s+xKmOmh7MJw4Eg4ZxJ8metMcXBfl5C0Sr1JmSbX6Tx/qdfWL2YCa1Yvgi5dZPmMi++xYCkBRekfmZB/JgPMv90rpGVmRX78GFR9a1evcRw3qXpnEwzmmJjX1FPl4+ElAkCqUHaXebJCL34Fv34a1XwPwfcY+zMzuywVDb4A2P3+IxOMgI5Fo0JJqYaqpRJjfqXnsXrhtVzhpJLd8ezIbVi5mSLMv6bntU84sKYBnJkJGDnQ6EjofB3l9va6L6Zm1Xra2D66aGmtvmjCP/3v1C4Aaj6m9QTdUT5Fdv82ku+3st/1b+OgTWD4Tln0E20sgPQv26eN9a9n/VP4waR0AF7RJ7EZdkfpKmQRe3z7F1acPjaUfaMvUxvsztfHZ5JSX8MwJ22DZdFg6Hd672zsos7G3/Fv7fGjf0/u3eaeIGuxCfQdzdTymJrv2FCknt3k2w45rw8Cc+fDe07B8Js+s+h9ZlME7QKv9ocdgb6GNzsd5VU+V4nw2Q5EGkjIJPFzBSooNOX3ohKuO3qXUvDWtMXQ9Bbr29zaUrPVKpsumQ+Fsbx6P8u3evj1aQbvDoM1BXp/01gcw4aIDoHHrGhN7ONUbNQ6WqemDzTkv5hWfMtCWkNH8LfbZvoTuthzeXu8dY2mw96G83fgXfJV5MDf/eig0aRv0crHqiiiSiJTAq9m9pBhn04c2bg0HD/Qe4NUT//iFN5y/aC6snA/LP/amQ62Q3Rxa7gvN2kPTdtCsHTTL9X7OacGwY1swYlopW7f/XH1UvXoj5GCZvm3gh3lQssabR7tkNWxeA5sK4ael3qOsuPKc08hiRWZn6DrAW/moXQ9o2w0aNeG5ig+uGpJ3TV0RVfctqaZejZhmdjrwIJAOPOmcG13T8X42YtZV1Qaxhm4cq/frlZfDxu9h3bfe8P41X8OG5bBpJRT/EHQK3IIdx3DvjsH8QGvap61n2B5vMDBnnlc6tjQwo6DkEP66+XRWuVa0Zy3DMiYwMCPIknOZjb0PiZb7Qos879+W+0LLPC6c+CPllh50NZuKboShVnMP1cCclZ4WVvuEErwkqqg3YppZOvAwcCpQCMwys8nOuYWRhylRqSJIS4MW+3iP/U/ZfX9ZSSCZr4RtG6C0mIHbNlH60RxyXAkDujYHtze408CVVz4GpmeS9dW7lFo25/TuAln9IfMCaNya/3tnNRvSWvDQb/rV2Ae73NYG3b62uLSyoXTuig1Be/tEpSuiSBKpTxVKb2Cxc24JgJmNB84GlMAjFPG81nWV1dib27z1/rtsnjTfKwkP6B+6pPrsSu+Yc47f9ZhvPpz587Xr6MLenRgxaUFlg2jZznJ+2LiNG07usst9R1wPL5Kk6pPAc4HvqzwvBI6sXzjxqb7JIVbzWlcVjwks2H0HG20Z7n1nZ6SRZuwyY2OakfBzxIhEKuZLqpnZlWY228xmr1mzJtYvl9BStYog3Ptu3bQRea0aU9GfJis9jbxWjdULRVJWfUrgRUDHKs87BLbtwjn3BPAEeI2Y9Xi9hBWVea3jsIRdm2AxB2ugret9a8SliKc+JfBZQBczyzOzLGAIMDk6YaWmYf0OJCczfZdtsZhGtuqCzQ1twlVH75Z4G+q+RZJNxCVw59wOM7semIbXjfDfzrkvoxaZz/wo3QXrg54KA1VS9b5F6qteA3mcc28Cb0YpFuHnRR2g4T9E/KyS8PO+RRKVRmJKnQRLrhreLuKPmPdCkeQWqu96wdzd2rNFJMpUApeQwmnorE/fdVB1iUh9qAQu9ZKqfddF4oFK4BJSOKVjP/quq9Qu4lEJXOpFfbhF/JMya2KKp2BuUdT7W8fimiLyM62JKTGb7VB9uEX8oQSeBFJ1tkORVKc68BSiHiMiyUUl8CSQqrMdiqQ6lcBTiHqMiCQXlcBTiGb9E0kuSuApRj1GRJKHErhEhT4MRBqe6sBFRBKUEriISIJSAhcRSVBK4CIiCUoJXEQkQakXSgpSjxGR5KASuIhIglICFxFJUErgIiIJSglcRCRBKYGLiCQoJXARkQSlBC4ikqCUwEVEEpQSuIhIgjLnXMO9mNkaYHmMLt8aWBujazc03Ut80r3Ep1S4l32cc22qb2zQBB5LZjbbOdfL7ziiQfcSn3Qv8SmV70VVKCIiCUoJXEQkQSVTAn/C7wCiSPcSn3Qv8Sll7yVp6sBFRFJNMpXARURSihK4iEiCSsoEbmY3m5kzs9Z+xxIpM7vXzL4ys8/N7BUza+53THVlZqeb2ddmttjMhvsdT6TMrKOZvW9mC83sSzP7nd8x1ZeZpZvZXDN73e9Y6sPMmpvZxMDfyiIzS9jlpszspsDv1xdmNs7Msms7J+kSuJl1BE4DVvgdSz29DRzinDsU+AYY4XM8dWJm6cDDwBlAN+BCM+vmb1QR2wHc7JzrBhwFXJfA91Lhd8Aiv4OIggeBqc65g4AeJOg9mVku8Fugl3PuECAdGFLbeUmXwIH7gVuAhG6ddc695ZzbEXj6CdDBz3gi0BtY7Jxb4pwrA8YDZ/scU0Sccyudc58Ffi7GSxK5/kYVOTPrAPwCeNLvWOrDzPYE+gJPATjnypxzG3wNqn4ygBwzywD2AH6o7YSkSuBmdjZQ5Jyb73csUXY5MMXvIOooF/i+yvNCEjjpVTCzzkA+8KnPodTHA3iFnHKf46ivPGAN8HSgOuhJM2vsd1CRcM4VAffh1RysBDY6596q7byES+Bm9k6gjqj642zgNuAOv2MMVy33UnHM7Xhf4cf6F6kAmFkT4GXgRufcJr/jiYSZ9QdWO+fm+B1LFGQAPYFHnXP5QAmQkG0tZtYC7xtqHtAeaGxmv6rtvIxYBxZtzrlTgm03s+54Nz/fzMCrcvjMzHo751Y1YIhhC3UvFczsUqA/cLJLvA77RUDHKs87BLYlJDPLxEveY51zk/yOpx76AAPM7EwgG2hmZs8752pNFnGoECh0zlV8G5pIgiZw4BRgqXNuDYCZTQKOAZ6v6aSEK4GH4pxb4Jxr65zr7JzrjPef2zNek3dtzOx0vK+5A5xzW/yOJwKzgC5mlmdmWXgNMpN9jiki5pUIngIWOef+4Xc89eGcG+Gc6xD4GxkCvJegyZvA3/b3ZnZgYNPJwEIfQ6qPFcBRZrZH4PftZMJokE24EngKGQM0At4OfKP4xDl3tb8hhc85t8PMrgem4bWo/9s596XPYUWqD3AxsMDM5gW23eace9O/kCTgBmBsoJCwBLjM53gi4pz71MwmAp/hVZnOJYxh9RpKLyKSoJKmCkVEJNUogYuIJCglcBGRBKUELiKSoJTARUQSlBK4iEiCUgIXEUlQ/w9O4v50V6rPQgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.errorbar(data_x, data_y, sigma_y, sigma_x, fmt=\"o\", label=\"data\")\n", "x = np.linspace(data_x[0], data_x[-1], 200)\n", "plt.plot(x, f(x, m.values), label=\"fit\")\n", "plt.legend()\n", "\n", "# check fit quality\n", "chi2 = m.fval\n", "ndof = len(data_y) - 3\n", "plt.title(f\"$\\\\chi^2 / n_\\\\mathrm{{dof}} = {chi2:.2f} / {ndof} = {chi2/ndof:.2f}$\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We obtained a good fit." ] } ], "metadata": { "kernelspec": { "display_name": "py38", "language": "python", "name": "py38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 2 }