{ "cells": [ { "cell_type": "markdown", "id": "4159be0e", "metadata": { "hideCode": false, "hidePrompt": false, "slideshow": { "slide_type": "slide" } }, "source": [ "# 2023-04-03 Nonlinear Regression\n", "\n", "## Last time\n", "\n", "* Discuss projects\n", "* Gradient-based optimization for linear models\n", "* Effect of conditioning on convergence rate\n", "\n", "## Today\n", "* Nonlinear models\n", "* Computing derivatives\n", " * numeric\n", " * analytic by hand\n", " * algorithmic (automatic) differentiation" ] }, { "cell_type": "code", "execution_count": 1, "id": "eb781a7a", "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [ { "data": { "text/plain": [ "vcond (generic function with 1 method)" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "using LinearAlgebra\n", "using Plots\n", "default(linewidth=4, legendfontsize=12)\n", "\n", "function vander(x, k=nothing)\n", " if isnothing(k)\n", " k = length(x)\n", " end\n", " m = length(x)\n", " V = ones(m, k)\n", " for j in 2:k\n", " V[:, j] = V[:, j-1] .* x\n", " end\n", " V\n", "end\n", "\n", "function vander_chebyshev(x, n=nothing)\n", " if isnothing(n)\n", " n = length(x) # Square by default\n", " end\n", " m = length(x)\n", " T = ones(m, n)\n", " if n > 1\n", " T[:, 2] = x\n", " end\n", " for k in 3:n\n", " #T[:, k] = x .* T[:, k-1]\n", " T[:, k] = 2 * x .* T[:,k-1] - T[:, k-2]\n", " end\n", " T\n", "end\n", "\n", "function chebyshev_regress_eval(x, xx, n)\n", " V = vander_chebyshev(x, n)\n", " vander_chebyshev(xx, n) / V\n", "end\n", "\n", "runge(x) = 1 / (1 + 10*x^2)\n", "runge_noisy(x, sigma) = runge.(x) + randn(size(x)) * sigma\n", "\n", "CosRange(a, b, n) = (a + b)/2 .+ (b - a)/2 * cos.(LinRange(-pi, 0, n))\n", "\n", "vcond(mat, points, nmax) = [cond(mat(points(-1, 1, n))) for n in 2:nmax]" ] }, { "cell_type": "markdown", "id": "ab251374", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Gradient descent\n", "\n", "Instead of solving the least squares problem using linear algebra (QR factorization), we could solve it using gradient descent. That is, on each iteration, we'll take a step in the direction of the negative gradient." ] }, { "cell_type": "code", "execution_count": 2, "id": "04841eea", "metadata": { "cell_style": "center" }, "outputs": [ { "data": { "text/plain": [ "grad_descent (generic function with 1 method)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function grad_descent(loss, grad, c0; gamma=1e-3, tol=1e-5)\n", " \"\"\"Minimize loss(c) via gradient descent with initial guess c0\n", " using learning rate gamma. Declares convergence when gradient\n", " is less than tol or after 500 steps.\n", " \"\"\"\n", " c = copy(c0)\n", " chist = [copy(c)]\n", " lhist = [loss(c)]\n", " for it in 1:500\n", " g = grad(c)\n", " c -= gamma * g\n", " push!(chist, copy(c))\n", " push!(lhist, loss(c))\n", " if norm(g) < tol\n", " break\n", " end\n", " end\n", " (c, hcat(chist...), lhist)\n", "end" ] }, { "cell_type": "markdown", "id": "ca23b5fe", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Quadratic model" ] }, { "cell_type": "code", "execution_count": 112, "id": "ab3a31f0", "metadata": { "cell_style": "split", "hideCode": false, "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cond(A) = 9.46578492882319\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 112, "metadata": {}, "output_type": "execute_result" } ], "source": [ "A = [1 1; 1 8]\n", "@show cond(A)\n", "loss(c) = .5 * c' * A * c\n", "grad(c) = A * c\n", "\n", "c, chist, lhist = grad_descent(loss, grad, [.9, .9],\n", " gamma=.22)\n", "plot(lhist, yscale=:log10, xlims=(0, 80))" ] }, { "cell_type": "code", "execution_count": 113, "id": "7909e35d", "metadata": { "cell_style": "split" }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 113, "metadata": {}, "output_type": "execute_result" } ], "source": [ "plot(chist[1, :], chist[2, :], marker=:circle)\n", "x = LinRange(-1, 1, 30)\n", "contour!(x, x, (x,y) -> loss([x, y]))" ] }, { "cell_type": "markdown", "id": "ddd68837", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Chebyshev regression via optimization\n", "\n" ] }, { "cell_type": "code", "execution_count": 121, "id": "5520e580", "metadata": { "cell_style": "split" }, "outputs": [ { "data": { "text/plain": [ "8-element Vector{Float64}:\n", " 0.7805671076161028\n", " 0.07125885267568355\n", " -1.8025038787243435\n", " 0.03382484774772874\n", " 0.8464406682785461\n", " -0.17304293350946157\n", " 0.6962824216982956\n", " 0.0711130108623033" ] }, "execution_count": 121, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x = LinRange(-1, 1, 200)\n", "sigma = 0.5; n = 8\n", "y = runge_noisy(x, sigma)\n", "V = vander(x, n) # or vander_chebyshev\n", "function loss(c)\n", " r = V * c - y\n", " .5 * r' * r\n", "end\n", "function grad(c)\n", " r = V * c - y\n", " V' * r\n", "end\n", "c, _, lhist = grad_descent(loss, grad, ones(n),\n", " gamma=0.008)\n", "c" ] }, { "cell_type": "code", "execution_count": 122, "id": "77957be4", "metadata": { "cell_style": "split" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cond(V) = 230.00549982014527\n", "cond(V' * V) = 52902.52994792632\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 122, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c0 = V \\ y\n", "l0 = 0.5 * norm(V * c0 - y)^2\n", "@show cond(V)\n", "@show cond(V' * V)\n", "plot(lhist, yscale=:log10, ylim=(15, 50))\n", "plot!(i -> l0, color=:black)" ] }, { "cell_type": "markdown", "id": "8c8e60c7", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Why use QR vs gradient-based optimization?" ] }, { "cell_type": "markdown", "id": "7a95e2e9", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Nonlinear models\n", "\n", "Instead of the linear model\n", "$$ f(x,c) = V(x) c = c_0 + c_1 \\underbrace{x}_{T_1(x)} + c_2 T_2(x) + \\dotsb $$\n", "let's consider a rational model with only three parameters\n", "$$ f(x,c) = \\frac{1}{c_1 + c_2 x + c_3 x^2} = (c_1 + c_2 x + c_3 x^2)^{-1} . $$\n", "We'll use the same loss function\n", "$$ L(c; x,y) = \\frac 1 2 \\lVert f(x,c) - y \\rVert^2 . $$\n", "\n", "We will also need the gradient\n", "$$ \\nabla_c L(c; x,y) = \\big( f(x,c) - y \\big)^T \\nabla_c f(x,c) $$\n", "where\n", "\\begin{align}\n", "\\frac{\\partial f(x,c)}{\\partial c_1} &= -(c_1 + c_2 x + c_3 x^2)^{-2} = - f(x,c)^2 \\\\\n", "\\frac{\\partial f(x,c)}{\\partial c_2} &= -(c_1 + c_2 x + c_3 x^2)^{-2} x = - f(x,c)^2 x \\\\\n", "\\frac{\\partial f(x,c)}{\\partial c_3} &= -(c_1 + c_2 x + c_3 x^2)^{-2} x^2 = - f(x,c)^2 x^2 .\n", "\\end{align}" ] }, { "cell_type": "markdown", "id": "247e6403", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Fitting a rational function" ] }, { "cell_type": "code", "execution_count": 123, "id": "d6c610e7", "metadata": { "cell_style": "split" }, "outputs": [ { "data": { "text/plain": [ "gradient (generic function with 1 method)" ] }, "execution_count": 123, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f(x, c) = 1 ./ (c[1] .+ c[2].*x + c[3].*x.^2)\n", "function gradf(x, c)\n", " f2 = f(x, c).^2\n", " [-f2 -f2.*x -f2.*x.^2]\n", "end\n", "function loss(c)\n", " r = f(x, c) - y\n", " 0.5 * r' * r\n", "end\n", "function gradient(c)\n", " r = f(x, c) - y\n", " vec(r' * gradf(x, c))\n", "end" ] }, { "cell_type": "code", "execution_count": 138, "id": "ffd37ce6", "metadata": { "cell_style": "split" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "┌ Warning: scale linear is unsupported with Plots.GRBackend(). Choose from: [:identity, :log10]\n", "└ @ Plots /home/jed/.julia/packages/Plots/4UTBj/src/args.jl:1662\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 138, "metadata": {}, "output_type": "execute_result" } ], "source": [ "c, _, lhist = grad_descent(loss, gradient, 1ones(3),\n", " gamma=5e-2)\n", "plot(lhist, yscale=:linear, ylim=(20, 60), title=\"Loss $(lhist[end])\")" ] }, { "cell_type": "markdown", "id": "7f4be7cd", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Compare fits on noisy data" ] }, { "cell_type": "code", "execution_count": 141, "id": "11ab1f9e", "metadata": { "slideshow": { "slide_type": "" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 141, "metadata": {}, "output_type": "execute_result" } ], "source": [ "scatter(x, y)\n", "V = vander_chebyshev(x, 20)\n", "plot!(x -> runge(x), color=:black, label=\"Runge\")\n", "plot!(x, V * (V \\ y), label=\"Chebyshev fit\")\n", "plot!(x -> f(x, c), label=\"Rational fit\")" ] }, { "cell_type": "markdown", "id": "627277ab", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# How do we compute those derivatives as the model gets complicated?" ] }, { "cell_type": "markdown", "id": "a8f69462", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$\\lim_{h\\to 0} \\frac{f(x+h) - f(x)}{h}$$" ] }, { "cell_type": "markdown", "id": "73b1b145", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "* How should we choose $h$?" ] }, { "cell_type": "markdown", "id": "6a6d0812", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Taylor series\n", "\n", "Classical accuracy analysis assumes that functions are sufficiently smooth, meaning that derivatives exist and Taylor expansions are valid within a neighborhood. In particular,\n", "$$ f(x+h) = f(x) + f'(x) h + f''(x) \\frac{h^2}{2!} + \\underbrace{f'''(x) \\frac{h^3}{3!} + \\dotsb}_{O(h^3)} . $$\n", "\n", "The big-$O$ notation is meant in the limit $h\\to 0$. Specifically, a function $g(h) \\in O(h^p)$ (sometimes written $g(h) = O(h^p)$) when\n", "there exists a constant $C$ such that\n", "$$ g(h) \\le C h^p $$\n", "for all sufficiently *small* $h$." ] }, { "cell_type": "markdown", "id": "9c564912", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Rounding error\n", "\n", "We have an additional source of error, *rounding error*, which comes from not being able to compute $f(x)$ or $f(x+h)$ exactly, nor subtract them exactly. Suppose that we can, however, compute these functions with a relative error on the order of $\\epsilon_{\\text{machine}}$. This leads to\n", "$$ \\begin{split}\n", "\\tilde f(x) &= f(x)(1 + \\epsilon_1) \\\\\n", "\\tilde f(x \\oplus h) &= \\tilde f((x+h)(1 + \\epsilon_2)) \\\\\n", "&= f((x + h)(1 + \\epsilon_2))(1 + \\epsilon_3) \\\\\n", "&= [f(x+h) + f'(x+h)(x+h)\\epsilon_2 + O(\\epsilon_2^2)](1 + \\epsilon_3) \\\\\n", "&= f(x+h)(1 + \\epsilon_3) + f'(x+h)x\\epsilon_2 + O(\\epsilon_{\\text{machine}}^2 + \\epsilon_{\\text{machine}} h)\n", "\\end{split}\n", "$$" ] }, { "cell_type": "markdown", "id": "a3b39cf8", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Tedious error propagation\n", "$$ \\begin{split}\n", "\\left\\lvert \\frac{\\tilde f(x+h) \\ominus \\tilde f(x)}{h} - \\frac{f(x+h) - f(x)}{h} \\right\\rvert &=\n", " \\left\\lvert \\frac{f(x+h)(1 + \\epsilon_3) + f'(x+h)x\\epsilon_2 + O(\\epsilon_{\\text{machine}}^2 + \\epsilon_{\\text{machine}} h) - f(x)(1 + \\epsilon_1) - f(x+h) + f(x)}{h} \\right\\rvert \\\\\n", " &\\le \\frac{|f(x+h)\\epsilon_3| + |f'(x+h)x\\epsilon_2| + |f(x)\\epsilon_1| + O(\\epsilon_{\\text{machine}}^2 + \\epsilon_{\\text{machine}}h)}{h} \\\\\n", " &\\le \\frac{(2 \\max_{[x,x+h]} |f| + \\max_{[x,x+h]} |f' x| \\epsilon_{\\text{machine}} + O(\\epsilon_{\\text{machine}}^2 + \\epsilon_{\\text{machine}} h)}{h} \\\\\n", " &= (2\\max|f| + \\max|f'x|) \\frac{\\epsilon_{\\text{machine}}}{h} + O(\\epsilon_{\\text{machine}}) \\\\\n", "\\end{split} $$\n", "where we have assumed that $h \\ge \\epsilon_{\\text{machine}}$.\n", "This error becomes large (relative to $f'$ -- we are concerned with relative error after all)\n", "* $f$ is large compared to $f'$\n", "* $x$ is large\n", "* $h$ is too small" ] }, { "cell_type": "markdown", "id": "6d442032", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Automatic step size selection\n", "\n", "* Walker and Pernice\n", "* Dennis and Schnabel" ] }, { "cell_type": "code", "execution_count": 143, "id": "2508cef2", "metadata": { "cell_style": "center", "slideshow": { "slide_type": "" } }, "outputs": [ { "data": { "text/plain": [ "-4.139506408429305e-6" ] }, "execution_count": 143, "metadata": {}, "output_type": "execute_result" } ], "source": [ "diff(f, x; h=1e-8) = (f(x+h) - f(x)) / h\n", "\n", "function diff_wp(f, x; h=1e-8)\n", " \"\"\"Diff using Walker and Pernice (1998) choice of step\"\"\"\n", " h *= (1 + abs(x))\n", " (f(x+h) - f(x)) / h\n", "end\n", "\n", "x = 1000\n", "diff_wp(sin, x) - cos(x)" ] }, { "cell_type": "markdown", "id": "13598631", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Symbolic differentiation" ] }, { "cell_type": "code", "execution_count": 144, "id": "fa185815", "metadata": { "cell_style": "split" }, "outputs": [ { "data": { "text/latex": [ "$$ \\begin{equation}\n", "\\frac{\\mathrm{d} \\sin\\left( x \\right)}{\\mathrm{d}x}\n", "\\end{equation}\n", " $$" ], "text/plain": [ "Differential(x)(sin(x))" ] }, "execution_count": 144, "metadata": {}, "output_type": "execute_result" } ], "source": [ "using Symbolics\n", "\n", "@variables x\n", "Dx = Differential(x)\n", "\n", "y = sin(x)\n", "Dx(y)" ] }, { "cell_type": "code", "execution_count": 145, "id": "537e4153", "metadata": { "cell_style": "split" }, "outputs": [ { "data": { "text/latex": [ "$$ \\begin{equation}\n", "\\cos\\left( x \\right)\n", "\\end{equation}\n", " $$" ], "text/plain": [ "cos(x)" ] }, "execution_count": 145, "metadata": {}, "output_type": "execute_result" } ], "source": [ "expand_derivatives(Dx(y))" ] }, { "cell_type": "markdown", "id": "55f8ae35", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Awesome, what about products?" ] }, { "cell_type": "code", "execution_count": 149, "id": "0a4243eb", "metadata": {}, "outputs": [ { "data": { "text/latex": [ "$$ \\begin{equation}\n", "\\frac{\\left( \\frac{\\cos\\left( x^{\\pi} \\right)}{x} - 3.141592653589793 x^{2.141592653589793} \\log\\left( x \\right) \\sin\\left( x^{\\pi} \\right) \\right) \\cos\\left( \\cos^{3.141592653589793}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{3.141592653589793} \\right)}{\\cos\\left( x^{\\pi} \\right) \\log\\left( x \\right)} - \\left( \\frac{3.141592653589793 \\cos^{3.141592653589793}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{2.141592653589793}}{x} - 9.869604401089358 \\cos^{2.141592653589793}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{3.141592653589793} x^{2.141592653589793} \\sin\\left( x^{\\pi} \\right) \\right) \\sin\\left( \\cos^{3.141592653589793}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{3.141592653589793} \\right) \\log\\left( \\log\\left( x \\right) \\cos\\left( x^{\\pi} \\right) \\right)\n", "\\end{equation}\n", " $$" ], "text/plain": [ "((x^-1)*cos(x^π) - 3.141592653589793(x^2.141592653589793)*log(x)*sin(x^π))*(log(x)^-1)*(cos(x^π)^-1)*cos((log(x)^3.141592653589793)*(cos(x^π)^3.141592653589793)) - (3.141592653589793(x^-1)*(log(x)^2.141592653589793)*(cos(x^π)^3.141592653589793) - 9.869604401089358(x^2.141592653589793)*(log(x)^3.141592653589793)*(cos(x^π)^2.141592653589793)*sin(x^π))*sin((log(x)^3.141592653589793)*(cos(x^π)^3.141592653589793))*log(log(x)*cos(x^π))" ] }, "execution_count": 149, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y = x\n", "for _ in 1:2\n", " y = cos(y^pi) * log(y)\n", "end\n", "expand_derivatives(Dx(y))" ] }, { "cell_type": "markdown", "id": "34b799be", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "* The size of these expressions can grow **exponentially**" ] }, { "cell_type": "markdown", "id": "d0a680f1", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Hand-coding derivatives\n", "\n", "$$ df = f'(x) dx $$" ] }, { "cell_type": "code", "execution_count": 91, "id": "c24fb506", "metadata": { "cell_style": "split" }, "outputs": [ { "data": { "text/plain": [ "(-1.5346823414986814, -34.032439961925064)" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function f(x)\n", " y = x\n", " for _ in 1:2\n", " a = y^pi\n", " b = cos(a)\n", " c = log(y)\n", " y = b * c\n", " end\n", " y\n", "end\n", "\n", "f(1.9), diff_wp(f, 1.9)" ] }, { "cell_type": "code", "execution_count": 153, "id": "d8dec50b", "metadata": { "cell_style": "split", "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/plain": [ "(-1.5346823414986814, -34.032419599140475)" ] }, "execution_count": 153, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function df(x, dx)\n", " y = x\n", " dy = dx\n", " for _ in 1:2\n", " a = y ^ pi\n", " da = pi * y^(pi - 1) * dy\n", " b = cos(a)\n", " db = -sin(a) * da\n", " c = log(y)\n", " dc = dy / y\n", " y = b * c\n", " dy = db * c + b * dc\n", " end\n", " y, dy\n", "end\n", "\n", "df(1.9, 1)" ] }, { "cell_type": "markdown", "id": "3fb72ae9", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# We can go the other way\n", "\n", "We can differentiate a composition $h(g(f(x)))$ as\n", "\n", "\\begin{align}\n", " \\operatorname{d} h &= h' \\operatorname{d} g \\\\\n", " \\operatorname{d} g &= g' \\operatorname{d} f \\\\\n", " \\operatorname{d} f &= f' \\operatorname{d} x.\n", "\\end{align}\n", "\n", "What we've done above is called \"forward mode\", and amounts to placing the parentheses in the chain rule like\n", "\n", "$$ \\operatorname d h = \\frac{dh}{dg} \\left(\\frac{dg}{df} \\left(\\frac{df}{dx} \\operatorname d x \\right) \\right) .$$\n", "\n", "The expression means the same thing if we rearrange the parentheses,\n", "\n", "$$ \\operatorname d h = \\left( \\left( \\left( \\frac{dh}{dg} \\right) \\frac{dg}{df} \\right) \\frac{df}{dx} \\right) \\operatorname d x .$$" ] }, { "cell_type": "markdown", "id": "12e01ae1", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Automatic differentiation" ] }, { "cell_type": "code", "execution_count": 18, "id": "7052e15a", "metadata": { "cell_style": "split" }, "outputs": [], "source": [ "import Zygote" ] }, { "cell_type": "code", "execution_count": 17, "id": "2aabcb46", "metadata": { "cell_style": "split" }, "outputs": [ { "data": { "text/plain": [ "(-34.03241959914049,)" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Zygote.gradient(f, 1.9)" ] }, { "cell_type": "code", "execution_count": 21, "id": "9cee1ed6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[90m; @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:95 within `gradient`\u001b[39m\n", "\u001b[95mdefine\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[93m@julia_gradient_12050\u001b[39m\u001b[33m(\u001b[39m\u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[33m)\u001b[39m \u001b[0m#0 \u001b[33m{\u001b[39m\n", "\u001b[91mtop:\u001b[39m\n", "\u001b[90m; @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:96 within `gradient`\u001b[39m\n", "\u001b[90m; ┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:42 within `pullback` @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:44\u001b[39m\n", "\u001b[90m; │┌ @ In[21]:1 within `_pullback` @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:9\u001b[39m\n", "\u001b[90m; ││┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl within `macro expansion`\u001b[39m\n", "\u001b[90m; │││┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:218 within `chain_rrule`\u001b[39m\n", "\u001b[90m; ││││┌ @ /home/jed/.julia/packages/ChainRulesCore/C73ay/src/rules.jl:134 within `rrule` @ /home/jed/.julia/packages/ChainRules/hVHC4/src/rulesets/Base/fastmath_able.jl:56\u001b[39m\n", " \u001b[0m%1 \u001b[0m= \u001b[96m\u001b[1mcall\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[93m@j_exp_12052\u001b[39m\u001b[33m(\u001b[39m\u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[33m)\u001b[39m \u001b[0m#0\n", "\u001b[90m; └└└└└\u001b[39m\n", "\u001b[90m; @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:97 within `gradient`\u001b[39m\n", "\u001b[90m; ┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:45 within `#60`\u001b[39m\n", "\u001b[90m; │┌ @ In[21]:1 within `Pullback`\u001b[39m\n", "\u001b[90m; ││┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:206 within `ZBack`\u001b[39m\n", "\u001b[90m; │││┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/lib/number.jl:12 within `literal_pow_pullback`\u001b[39m\n", "\u001b[90m; ││││┌ @ promotion.jl:389 within `*` @ float.jl:385\u001b[39m\n", " \u001b[0m%2 \u001b[0m= \u001b[96m\u001b[1mfmul\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[0m, \u001b[33m2.000000e+00\u001b[39m\n", "\u001b[90m; │└└└└\u001b[39m\n", "\u001b[90m; │┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/lib/lib.jl:17 within `accum`\u001b[39m\n", "\u001b[90m; ││┌ @ float.jl:383 within `+`\u001b[39m\n", " \u001b[0m%3 \u001b[0m= \u001b[96m\u001b[1mfadd\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%2\u001b[0m, \u001b[0m%1\n", "\u001b[90m; └└└\u001b[39m\n", "\u001b[90m; @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:98 within `gradient`\u001b[39m\n", " \u001b[0m%.fca.0.insert \u001b[0m= \u001b[96m\u001b[1minsertvalue\u001b[22m\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[95mzeroinitializer\u001b[39m\u001b[0m, \u001b[36mdouble\u001b[39m \u001b[0m%3\u001b[0m, \u001b[33m0\u001b[39m\n", " \u001b[96m\u001b[1mret\u001b[22m\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[0m%.fca.0.insert\n", "\u001b[33m}\u001b[39m\n" ] } ], "source": [ "g(x) = exp(x) + x^2\n", "@code_llvm Zygote.gradient(g, 1.9)" ] } ], "metadata": { "@webio": { "lastCommId": null, "lastKernelId": null }, "celltoolbar": "Slideshow", "hide_code_all_hidden": false, "kernelspec": { "display_name": "Julia 1.7.2", "language": "julia", "name": "julia-1.7" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.8.5" }, "rise": { "enable_chalkboard": true } }, "nbformat": 4, "nbformat_minor": 5 }