{ "cells": [ { "cell_type": "markdown", "id": "703c2864-e0a8-4640-ba92-2bbdb63a1b5b", "metadata": {}, "source": [ "# JAX\n", "\n", "## Grad\n", "\n", "Jax's tracer can compute gradients! Let's try:\n", "\n", "$$\n", "y = x^3 + x^2 + x \\\\\n", "y' = 3x^2 + 2x + 1 \\\\\n", "y'' = 6x + 2 \\\\\n", "$$" ] }, { "cell_type": "code", "execution_count": null, "id": "187ebafe-43c5-4192-b4a9-a1719eb1ace2", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "id": "d383b64a-6b86-4f63-938e-273f39901418", "metadata": {}, "outputs": [], "source": [ "def f(x):\n", " return x**3 + x**2 + x" ] }, { "cell_type": "code", "execution_count": null, "id": "7442ccd4-a79b-4e87-b498-80593dca4f37", "metadata": {}, "outputs": [], "source": [ "f(1.0)" ] }, { "cell_type": "code", "execution_count": null, "id": "6b445caf-50ba-4c8e-ab5c-f5e19967a7e2", "metadata": {}, "outputs": [], "source": [ "fp = jax.grad(f)" ] }, { "cell_type": "code", "execution_count": null, "id": "bab3d9f4-c21e-4311-8aa2-45d8e59e4879", "metadata": {}, "outputs": [], "source": [ "fp(1.0)" ] }, { "cell_type": "code", "execution_count": null, "id": "53f44bb6-9633-450d-8cea-bfd0aea6b34c", "metadata": {}, "outputs": [], "source": [ "fpp = jax.grad(fp)" ] }, { "cell_type": "code", "execution_count": null, "id": "cab68f79-75c1-4797-b85b-8093b914592b", "metadata": {}, "outputs": [], "source": [ "fpp(1.0)" ] }, { "cell_type": "markdown", "id": "4cd5a826-f100-47cb-8720-1ee591970371", "metadata": {}, "source": [ "## Tracer limitations\n" ] }, { "cell_type": "markdown", "id": "9c01cce6-5611-4fee-ade9-9768734aeee0", "metadata": {}, "source": [ "Let's watch the tracer:" ] }, { "cell_type": "code", "execution_count": null, "id": "9d77aaa0-e3ce-4a77-94f4-4b56ec67b9a9", "metadata": {}, "outputs": [], "source": [ "def f(x):\n", " print(f\"{x = }\")\n", " y = x**2\n", " print(f\"{y = }\")\n", " return y" ] }, { "cell_type": "code", "execution_count": null, "id": "e18d72f0-2977-4439-9671-e01687014216", "metadata": {}, "outputs": [], "source": [ "f_jit = jax.jit(f)" ] }, { "cell_type": "code", "execution_count": null, "id": "c093d74b-9fc3-43d3-9d4e-8ddd99d8455b", "metadata": {}, "outputs": [], "source": [ "f_jit(2)" ] }, { "cell_type": "code", "execution_count": null, "id": "f5f4759a-5c7c-4e32-bcef-06a65b727338", "metadata": {}, "outputs": [], "source": [ "f_jit(2)" ] }, { "cell_type": "markdown", "id": "4a795676-90f9-4028-aac6-f515fd893513", "metadata": {}, "source": [ "Notice that the Python code runs once, and something that is not an integer at all is being passed in. From then on, the function doesn't run the Python code anymore. Well, as long as you use the same input types / shapes:" ] }, { "cell_type": "code", "execution_count": null, "id": "43479bc8-ba81-4829-9d83-606ee03027ea", "metadata": {}, "outputs": [], "source": [ "f_jit(1.0)" ] }, { "cell_type": "code", "execution_count": null, "id": "d45aa5d9-6649-483e-a6aa-b1be38990e5a", "metadata": {}, "outputs": [], "source": [ "f_jit(1.0)" ] }, { "cell_type": "code", "execution_count": null, "id": "7daef99c-533c-4bde-be1d-1229f4d941f7", "metadata": {}, "outputs": [], "source": [ "f_jit(1)" ] }, { "cell_type": "markdown", "id": "9c921c20-f61f-4d5e-a933-a7becaaab176", "metadata": {}, "source": [ "You can't trace through flow control that depends on the tracers, or dynamically change the shape of the array:" ] }, { "cell_type": "code", "execution_count": null, "id": "47a0d168-eab8-45c2-820b-bbda150799f8", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def broken(x):\n", " if x == 3:\n", " return x**3\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "id": "ab207814-a2bb-497e-8e64-2a4086f19114", "metadata": {}, "outputs": [], "source": [ "broken(2)" ] }, { "cell_type": "markdown", "id": "7176568d-0651-453e-ae46-d63aa4785b1f", "metadata": {}, "source": [ "## Jax is functional\n", "\n", "Unlike NumPY, Jax arrays are immutable. You also should write pure functions (ones without side effects / state)." ] }, { "cell_type": "markdown", "id": "271e7b3c-9e10-442c-af5b-00eb2ea8df49", "metadata": {}, "source": [ "For example, you can't do an in-place set:" ] }, { "cell_type": "code", "execution_count": null, "id": "8b2e0d2d-ec0c-4c62-810c-f58b2ae57ad2", "metadata": {}, "outputs": [], "source": [ "jarr = jnp.zeros((3, 3))\n", "jarr[np.diag(np.ones(3, dtype=bool))] = 1\n", "jarr" ] }, { "cell_type": "markdown", "id": "c8d47edc-6cae-400e-8063-56edc29fac34", "metadata": {}, "source": [ "Jax provides a trick to make this easy to do while avoiding an in-place mutation:" ] }, { "cell_type": "code", "execution_count": null, "id": "32967641-18ff-4b83-8a96-5cbc3afb7cbe", "metadata": {}, "outputs": [], "source": [ "j1 = jnp.zeros((3, 3))\n", "j2 = j1.at[np.diag(np.ones(3, dtype=bool))].set(1)\n", "j2" ] }, { "cell_type": "markdown", "id": "3e0f8b99-e725-4211-a7cc-cc9890393faa", "metadata": {}, "source": [ "## Further reading\n", "\n", "See the Jax docs!\n", "\n", "https://jax.readthedocs.io/en/latest/notebooks/quickstart.html" ] } ], "metadata": { "kernelspec": { "display_name": "performance-minicourse", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "vscode": { "interpreter": { "hash": "493d32115352708ee205b0d097e176d1d360fe34639a9405e4a1e16a5d39b607" } } }, "nbformat": 4, "nbformat_minor": 5 }