{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "frequent-field", "metadata": { "tags": [] }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "%config InlineBackend.figure_format = 'retina'" ] }, { "cell_type": "code", "execution_count": null, "id": "opened-virgin", "metadata": { "tags": [] }, "outputs": [], "source": [ "from IPython.display import YouTubeVideo, display\n", "\n", "YouTubeVideo(\"pepAq_dJIik\")" ] }, { "cell_type": "markdown", "id": "lasting-express", "metadata": {}, "source": [ "# Optimized Learning\n", "\n", "In this notebook, we will take a look at how to transform our numerical programs into their _derivatives_." ] }, { "cell_type": "markdown", "id": "forward-process", "metadata": {}, "source": [ "## Autograd to JAX\n", "\n", "Before they worked on JAX, there was another Python package called `autograd` that some of the JAX developers worked on.\n", "That was where the original idea of building an automatic differentiation system on top of NumPy started." ] }, { "cell_type": "markdown", "id": "correct-cyprus", "metadata": {}, "source": [ "## Example: Transforming a function into its derivative\n", "\n", "Just like `vmap`, `grad` takes in a function and transforms it into another function.\n", "By default, the returned function from `grad`\n", "is the derivative of the function with respect to the first argument.\n", "Let's see an example of it in action using the simple math function:\n", "\n", "$$f(x) = 3x + 1$$" ] }, { "cell_type": "code", "execution_count": null, "id": "demanding-opportunity", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Example 1:\n", "from jax import grad\n", "\n", "\n", "def func(x):\n", " return 3 * x + 1\n", "\n", "\n", "df = grad(func)\n", "\n", "# Pass in any float value of x, you should get back 3.0 as the _gradient_.\n", "df(4.0)" ] }, { "cell_type": "markdown", "id": "forty-lindsay", "metadata": {}, "source": [ "Here's another example using a polynomial function:\n", "\n", "$$f(x) = 3x^2 + 4x -3$$\n", "\n", "Its derivative function is:\n", "\n", "$$f'(x) = 6x + 4$$." ] }, { "cell_type": "code", "execution_count": null, "id": "neutral-neighbor", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Example 2:\n", "\n", "\n", "def polynomial(x):\n", " return 3 * x ** 2 + 4 * x - 3\n", "\n", "\n", "dpolynomial = grad(polynomial)\n", "\n", "# pass in any float value of x\n", "# the result will be evaluated at 6x + 4,\n", "# which is the gradient of the polynomial function.\n", "dpolynomial(3.0)" ] }, { "cell_type": "markdown", "id": "steady-bikini", "metadata": {}, "source": [ "## Using grad to solve minimization problems\n", "\n", "Once we have access to the derivative function that we can evaluate,\n", "we can use it to solve optimization problems.\n", "\n", "Optimization problems are where one wishes to find the maxima or minima of a function.\n", "For example, if we take the polynomial function above, we can calculate its derivative function analytically as:\n", "\n", "$$f'(x) = 6x + 4$$\n", "\n", "At the minima, $f'(x)$ is zero, and solving for the value of $x$, we get $x = -\\frac{2}{3}$." ] }, { "cell_type": "code", "execution_count": null, "id": "opponent-modification", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Example: find the minima of the polynomial function.\n", "\n", "start = 3.0\n", "for i in range(200):\n", " start -= dpolynomial(start) * 0.01\n", "start" ] }, { "cell_type": "markdown", "id": "beautiful-theory", "metadata": {}, "source": [ "We know from calculus that the sign of the second derivative tells us whether we have a minima or maxima at a point.\n", "\n", "Analytically, the second derivative of our polynomial is:\n", "\n", "$$f''(x) = 6$$\n", "\n", "We can verify that the point is a minima by calling grad again on the derivative function." ] }, { "cell_type": "code", "execution_count": null, "id": "former-syracuse", "metadata": {}, "outputs": [], "source": [ "ddpolynomial = grad(dpolynomial)\n", "\n", "ddpolynomial(start)" ] }, { "cell_type": "markdown", "id": "surrounded-plain", "metadata": {}, "source": [ "Grad is composable an arbitrary number of times. You can keep calling grad as many times as you like." ] }, { "cell_type": "markdown", "id": "brazilian-atlas", "metadata": {}, "source": [ "## Maximum likelihood estimation\n", "\n", "In statistics, maximum likelihood estimation is used to estimate \n", "the most likely value of a distribution's parameters.\n", "Usually, analytical solutions can be found;\n", "however, for difficult cases, we can always fall back on `grad`.\n", "\n", "Let's see this in action.\n", "Say we draw 1000 random numbers from a Gaussian with $\\mu=-3$ and $\\sigma=2$.\n", "Our task is to pretend we don't know the actual $\\mu$ and $\\sigma$\n", "and instead estimate it from the observed data." ] }, { "cell_type": "code", "execution_count": null, "id": "confidential-sympathy", "metadata": { "tags": [] }, "outputs": [], "source": [ "from functools import partial\n", "\n", "import jax.numpy as np\n", "from jax import random\n", "\n", "key = random.PRNGKey(44)\n", "real_mu = -3.0\n", "real_log_sigma = np.log(2.0) # the real sigma is 2.0\n", "\n", "data = random.normal(key, shape=(1000,)) * np.exp(real_log_sigma) + real_mu" ] }, { "cell_type": "markdown", "id": "atlantic-excellence", "metadata": {}, "source": [ "Our estimation task will necessitate calculating the total joint log likelihood of our data under a Gaussian model.\n", "What we then need to do is to estimate $\\mu$ and $\\sigma$ that maximizes the log likelihood of observing our data.\n", "\n", "Since we have been operating in a function minimization paradigm, we can instead minimize the negative log likelihood." ] }, { "cell_type": "code", "execution_count": null, "id": "known-terrain", "metadata": { "tags": [] }, "outputs": [], "source": [ "from jax.scipy.stats import norm\n", "\n", "\n", "def negloglike(mu, log_sigma, data):\n", " return -np.sum(norm.logpdf(data, loc=mu, scale=np.exp(log_sigma)))" ] }, { "cell_type": "markdown", "id": "terminal-census", "metadata": {}, "source": [ "If you're wondering why we use `log_sigma` rather than `sigma`, it is a choice made for practical reasons.\n", "When doing optimizations, we can possibly run into negative values,\n", "or more generally, values that are \"out of bounds\" for a parameter.\n", "Operating in log-space for a positive-only value allows us to optimize that value in an unbounded space,\n", "and we can use the log/exp transformations to bring our parameter into the correct space when necessary.\n", "\n", "Whenever doing likelihood calculations,\n", "it's always good practice to ensure that we have no NaN issues first.\n", "Let's check:" ] }, { "cell_type": "code", "execution_count": null, "id": "dominant-delight", "metadata": { "tags": [] }, "outputs": [], "source": [ "mu = -6.0\n", "log_sigma = np.log(2.0)\n", "negloglike(mu, log_sigma, data)" ] }, { "cell_type": "markdown", "id": "equal-brazilian", "metadata": {}, "source": [ "Now, we can create the gradient function of our negative log likelihood.\n", "\n", "But there's a snag! Doesn't grad take the derivative w.r.t. the first argument?\n", "We need it w.r.t. two arguments, `mu` and `log_sigma`.\n", "Well, `grad` has an `argnums` argument that we can use to specify \n", "with respect to which arguments of the function we wish to take the derivative for." ] }, { "cell_type": "code", "execution_count": null, "id": "meaning-scanning", "metadata": { "tags": [] }, "outputs": [], "source": [ "dnegloglike = grad(negloglike, argnums=(0, 1))\n", "\n", "# condition on data\n", "dnegloglike = partial(dnegloglike, data=data)\n", "dnegloglike(mu, log_sigma)" ] }, { "cell_type": "markdown", "id": "hourly-miller", "metadata": {}, "source": [ "Now, we can do the gradient descent step!" ] }, { "cell_type": "code", "execution_count": null, "id": "cosmetic-perception", "metadata": { "tags": [] }, "outputs": [], "source": [ "# gradient descent\n", "for i in range(300):\n", " dmu, dlog_sigma = dnegloglike(mu, log_sigma)\n", " mu -= dmu * 0.0001\n", " log_sigma -= dlog_sigma * 0.0001\n", "mu, np.exp(log_sigma)" ] }, { "cell_type": "markdown", "id": "defensive-family", "metadata": {}, "source": [ "And voila! We have gradient descended our way to the maximum likelihood parameters :)." ] }, { "cell_type": "markdown", "id": "constant-account", "metadata": {}, "source": [ "## Exercise: Where is the gold? It's at the minima!\n", "\n", "We're now going to attempt an exercise.\n", "The task here is to program a robot to find the gold in a field\n", "that is defined by a math function." ] }, { "cell_type": "code", "execution_count": null, "id": "focal-climate", "metadata": { "tags": [] }, "outputs": [], "source": [ "from inspect import getsource\n", "\n", "from dl_workshop.jax_idioms import goldfield\n", "\n", "print(getsource(goldfield))" ] }, { "cell_type": "markdown", "id": "massive-corps", "metadata": {}, "source": [ "It should be evident from here that there are two minima in the function.\n", "Let's find out where they are.\n", "\n", "Firstly, define the gradient function with respect to both x and y.\n", "To see how to make `grad` take a derivative w.r.t. two arguments,\n", "see [the official tutorial](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html?highlight=grad#jax-first-transformation-grad)\n", "for more information." ] }, { "cell_type": "code", "execution_count": null, "id": "opened-beads", "metadata": { "tags": [] }, "outputs": [], "source": [ "from typing import Callable\n", "\n", "\n", "def grad_ex_1():\n", " # your answer here\n", " pass\n", "\n", "\n", "from dl_workshop.jax_idioms import grad_ex_1\n", "\n", "dgoldfield = grad_ex_1()\n", "dgoldfield(3.0, 4.0)" ] }, { "cell_type": "markdown", "id": "brown-violation", "metadata": {}, "source": [ "Now, implement the optimization loop!" ] }, { "cell_type": "code", "execution_count": null, "id": "alternative-wisdom", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Start somewhere\n", "\n", "\n", "def grad_ex_2(x, y, dgoldfield):\n", " # your answer goes here\n", " pass\n", "\n", "\n", "from dl_workshop.jax_idioms import grad_ex_2\n", "\n", "grad_ex_2(x=0.1, y=0.1, dgoldfield=dgoldfield)" ] }, { "cell_type": "markdown", "id": "alternative-iraqi", "metadata": {}, "source": [ "## Exercise: programming a robot that only moves along one axis\n", "\n", "Our robot has had a malfunction, and it now can only flow along one axis.\n", "Can you help it find the minima nonetheless?\n", "\n", "(This is effectively a problem of finding the partial derivative! You can fix either the `x` or `y` to your value of choice.)" ] }, { "cell_type": "code", "execution_count": null, "id": "operational-advantage", "metadata": { "tags": [] }, "outputs": [], "source": [ "def grad_ex_3():\n", " # your answer goes here\n", " pass\n", "\n", "\n", "from dl_workshop.jax_idioms import grad_ex_3\n", "\n", "dgoldfield_dx = grad_ex_3()\n", "\n", "\n", "# Start somewhere and optimize!\n", "x = 0.1\n", "for i in range(300):\n", " dx = dgoldfield_dx(x)\n", " x -= dx * 0.01\n", "x" ] }, { "cell_type": "markdown", "id": "ecological-asian", "metadata": {}, "source": [ "For your reference we have the function plotted below." ] }, { "cell_type": "code", "execution_count": null, "id": "convenient-optics", "metadata": { "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "from matplotlib import cm\n", "\n", "fig, ax = plt.subplots(subplot_kw={\"projection\": \"3d\"})\n", "\n", "# Change the limits of the x and y plane here if you'd like to see a zoomed out view.\n", "X = np.arange(-1.5, 1.5, 0.01)\n", "Y = np.arange(-1.5, 1.5, 0.01)\n", "X, Y = np.meshgrid(X, Y)\n", "Z = goldfield(X, Y)\n", "\n", "# Plot the surface.\n", "surf = ax.plot_surface(\n", " X,\n", " Y,\n", " Z,\n", " cmap=cm.coolwarm,\n", " linewidth=0,\n", " antialiased=False,\n", ")\n", "ax.view_init(elev=20.0, azim=20)" ] }, { "cell_type": "code", "execution_count": null, "id": "loaded-labor", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "dl-workshop", "language": "python", "name": "dl-workshop" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 5 }