{ "cells": [ { "cell_type": "markdown", "id": "95bce5b5", "metadata": {}, "source": [ "# Differentiation With JAX" ] }, { "cell_type": "markdown", "id": "df5aae26", "metadata": {}, "source": [ "This notebook shows a quick introduction on how to compute derivatives (see the IntroToAutodiff notebook for a much more comprehensive introduction)\n", "\n", "`jax` is designed as a drop in replacement for `numpy` so if you have `numpy` code you can make it differentiable easily by either\n", "\n", "1. `import jax.numpy as np`\n", "\n", "or if you don't like that much black magic and want to know whether you use `jax` or `numpy` you can to \n", "\n", "2. `import jax.numpy as jnp`\n", "\n", "but then you need to change all occurences of `np.foo` to `jnp.foo`\n", "\n", "I would recommend the latter as the aspiration to be a drop-in replacement works 99% of the time but there are always edge cases\n", "\n", "Let's try to define a complicated looking function:" ] }, { "cell_type": "code", "execution_count": 1, "id": "58388491", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/lukasheinrich/Code/iml_tutorial/_venv/lib/python3.9/site-packages/jax/_src/lib/__init__.py:33: UserWarning: JAX on Mac ARM machines is experimental and minimally tested. Please see https://github.com/google/jax/issues/5501 in the event of problems.\n", " warnings.warn(\"JAX on Mac ARM machines is experimental and minimally tested. \"\n" ] } ], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import matplotlib.pyplot as plt\n", "\n", "def func(x):\n", " return (3*x+jnp.sin(3*x))*jnp.exp(-x/2.)" ] }, { "cell_type": "markdown", "id": "d0d392e1", "metadata": {}, "source": [ "We can plot it as usual using `matplotlib`" ] }, { "cell_type": "code", "execution_count": 2, "id": "8bb5ef50", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "data": { "text/plain": [ "[]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xi = jnp.linspace(0,5)\n", "yi = [func(xx) for xx in xi]\n", "plt.plot(xi,yi)" ] }, { "cell_type": "markdown", "id": "38eb1ab9", "metadata": {}, "source": [ "In order to get gradients you simply call `jax.grad()` on the function\n", "\n", "Note: this works only for \"scalar\" functions $\\mathbb{R}^n \\to \\mathbb{R}$. If you want to have more general jacobian matrices you can use `jax.jacobian()`" ] }, { "cell_type": "code", "execution_count": 4, "id": "af3c1f06", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray(2.1044855, dtype=float32, weak_type=True)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "func(2.0)" ] }, { "cell_type": "markdown", "id": "f4ee7f83", "metadata": {}, "source": [ "You can get the gradient at the same $x$ like so" ] }, { "cell_type": "code", "execution_count": 6, "id": "bd76dbb3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray(1.1110764, dtype=float32, weak_type=True)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.grad(func)(2.0)" ] }, { "cell_type": "markdown", "id": "b6c4205e", "metadata": {}, "source": [ "Often you are interested in both the gradient *and* the function value.. so there is a nice API for that as well" ] }, { "cell_type": "code", "execution_count": 7, "id": "d5abd122", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceArray(2.1044855, dtype=float32, weak_type=True),\n", " DeviceArray(1.1110764, dtype=float32, weak_type=True))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.value_and_grad(func)(2.0)" ] }, { "cell_type": "markdown", "id": "080c026e", "metadata": {}, "source": [ "We can now plot the function and its tangens very easily\n", "\n", "Note: the list comprehensions `[ .. for ... in ]` are python and thus slow.. check out `JaxParallelization.ipynb` to see how to `map` over multiple values in a much more efficient way." ] }, { "cell_type": "code", "execution_count": 8, "id": "225dbe10", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "xi = jnp.linspace(0,5)\n", "yi = jnp.array([func(xx) for xx in xi])\n", "gi = jnp.array([jax.grad(func)(xx) for xx in xi])\n", "plt.plot(xi,yi)\n", "plt.quiver(xi,yi,jnp.ones_like(gi),gi,units='xy',angles='xy', color = 'r')" ] }, { "cell_type": "markdown", "id": "ec167db2", "metadata": {}, "source": [ "## Higher Order Gradients\n", "\n", "\n", "`jax` allows you to compute higher order gradients without any issues by just re-wrapping the gradient function multiple types in `jax.grad(...)`\n", "\n", "Tip: try redefining `func` for a few functions and see if the gradients match your expectation" ] }, { "cell_type": "code", "execution_count": 10, "id": "f081fea6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "def func(x):\n", "# return jnp.cos(x)\n", " return x**4\n", "\n", "xi = jnp.linspace(-5,5)\n", "yi = func(xi)\n", "g1 = jax.vmap(jax.grad(func))(xi)\n", "g2 = jax.vmap(jax.grad(jax.grad(func)))(xi)\n", "g3 = jax.vmap(jax.grad(jax.grad(jax.grad(func))))(xi)\n", "g4 = jax.vmap(jax.grad(jax.grad(jax.grad(jax.grad(func)))))(xi)\n", "\n", "\n", "plt.plot(xi,yi)\n", "plt.plot(xi,g1)\n", "plt.plot(xi,g2)\n", "plt.plot(xi,g3)\n", "plt.plot(xi,g4)" ] }, { "cell_type": "code", "execution_count": null, "id": "0b241e50", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.11" } }, "nbformat": 4, "nbformat_minor": 5 }