{
"cells": [
{
"cell_type": "markdown",
"id": "4159be0e",
"metadata": {
"hideCode": false,
"hidePrompt": false,
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# 2023-04-03 Nonlinear Regression\n",
"\n",
"## Last time\n",
"\n",
"* Discuss projects\n",
"* Gradient-based optimization for linear models\n",
"* Effect of conditioning on convergence rate\n",
"\n",
"## Today\n",
"* Nonlinear models\n",
"* Computing derivatives\n",
" * numeric\n",
" * analytic by hand\n",
" * algorithmic (automatic) differentiation"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "eb781a7a",
"metadata": {
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [
{
"data": {
"text/plain": [
"vcond (generic function with 1 method)"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using LinearAlgebra\n",
"using Plots\n",
"default(linewidth=4, legendfontsize=12)\n",
"\n",
"function vander(x, k=nothing)\n",
" if isnothing(k)\n",
" k = length(x)\n",
" end\n",
" m = length(x)\n",
" V = ones(m, k)\n",
" for j in 2:k\n",
" V[:, j] = V[:, j-1] .* x\n",
" end\n",
" V\n",
"end\n",
"\n",
"function vander_chebyshev(x, n=nothing)\n",
" if isnothing(n)\n",
" n = length(x) # Square by default\n",
" end\n",
" m = length(x)\n",
" T = ones(m, n)\n",
" if n > 1\n",
" T[:, 2] = x\n",
" end\n",
" for k in 3:n\n",
" #T[:, k] = x .* T[:, k-1]\n",
" T[:, k] = 2 * x .* T[:,k-1] - T[:, k-2]\n",
" end\n",
" T\n",
"end\n",
"\n",
"function chebyshev_regress_eval(x, xx, n)\n",
" V = vander_chebyshev(x, n)\n",
" vander_chebyshev(xx, n) / V\n",
"end\n",
"\n",
"runge(x) = 1 / (1 + 10*x^2)\n",
"runge_noisy(x, sigma) = runge.(x) + randn(size(x)) * sigma\n",
"\n",
"CosRange(a, b, n) = (a + b)/2 .+ (b - a)/2 * cos.(LinRange(-pi, 0, n))\n",
"\n",
"vcond(mat, points, nmax) = [cond(mat(points(-1, 1, n))) for n in 2:nmax]"
]
},
{
"cell_type": "markdown",
"id": "ab251374",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Gradient descent\n",
"\n",
"Instead of solving the least squares problem using linear algebra (QR factorization), we could solve it using gradient descent. That is, on each iteration, we'll take a step in the direction of the negative gradient."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "04841eea",
"metadata": {
"cell_style": "center"
},
"outputs": [
{
"data": {
"text/plain": [
"grad_descent (generic function with 1 method)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function grad_descent(loss, grad, c0; gamma=1e-3, tol=1e-5)\n",
" \"\"\"Minimize loss(c) via gradient descent with initial guess c0\n",
" using learning rate gamma. Declares convergence when gradient\n",
" is less than tol or after 500 steps.\n",
" \"\"\"\n",
" c = copy(c0)\n",
" chist = [copy(c)]\n",
" lhist = [loss(c)]\n",
" for it in 1:500\n",
" g = grad(c)\n",
" c -= gamma * g\n",
" push!(chist, copy(c))\n",
" push!(lhist, loss(c))\n",
" if norm(g) < tol\n",
" break\n",
" end\n",
" end\n",
" (c, hcat(chist...), lhist)\n",
"end"
]
},
{
"cell_type": "markdown",
"id": "ca23b5fe",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Quadratic model"
]
},
{
"cell_type": "code",
"execution_count": 112,
"id": "ab3a31f0",
"metadata": {
"cell_style": "split",
"hideCode": false,
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cond(A) = 9.46578492882319\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 112,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"A = [1 1; 1 8]\n",
"@show cond(A)\n",
"loss(c) = .5 * c' * A * c\n",
"grad(c) = A * c\n",
"\n",
"c, chist, lhist = grad_descent(loss, grad, [.9, .9],\n",
" gamma=.22)\n",
"plot(lhist, yscale=:log10, xlims=(0, 80))"
]
},
{
"cell_type": "code",
"execution_count": 113,
"id": "7909e35d",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plot(chist[1, :], chist[2, :], marker=:circle)\n",
"x = LinRange(-1, 1, 30)\n",
"contour!(x, x, (x,y) -> loss([x, y]))"
]
},
{
"cell_type": "markdown",
"id": "ddd68837",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Chebyshev regression via optimization\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 121,
"id": "5520e580",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"8-element Vector{Float64}:\n",
" 0.7805671076161028\n",
" 0.07125885267568355\n",
" -1.8025038787243435\n",
" 0.03382484774772874\n",
" 0.8464406682785461\n",
" -0.17304293350946157\n",
" 0.6962824216982956\n",
" 0.0711130108623033"
]
},
"execution_count": 121,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x = LinRange(-1, 1, 200)\n",
"sigma = 0.5; n = 8\n",
"y = runge_noisy(x, sigma)\n",
"V = vander(x, n) # or vander_chebyshev\n",
"function loss(c)\n",
" r = V * c - y\n",
" .5 * r' * r\n",
"end\n",
"function grad(c)\n",
" r = V * c - y\n",
" V' * r\n",
"end\n",
"c, _, lhist = grad_descent(loss, grad, ones(n),\n",
" gamma=0.008)\n",
"c"
]
},
{
"cell_type": "code",
"execution_count": 122,
"id": "77957be4",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cond(V) = 230.00549982014527\n",
"cond(V' * V) = 52902.52994792632\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 122,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c0 = V \\ y\n",
"l0 = 0.5 * norm(V * c0 - y)^2\n",
"@show cond(V)\n",
"@show cond(V' * V)\n",
"plot(lhist, yscale=:log10, ylim=(15, 50))\n",
"plot!(i -> l0, color=:black)"
]
},
{
"cell_type": "markdown",
"id": "8c8e60c7",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Why use QR vs gradient-based optimization?"
]
},
{
"cell_type": "markdown",
"id": "7a95e2e9",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Nonlinear models\n",
"\n",
"Instead of the linear model\n",
"$$ f(x,c) = V(x) c = c_0 + c_1 \\underbrace{x}_{T_1(x)} + c_2 T_2(x) + \\dotsb $$\n",
"let's consider a rational model with only three parameters\n",
"$$ f(x,c) = \\frac{1}{c_1 + c_2 x + c_3 x^2} = (c_1 + c_2 x + c_3 x^2)^{-1} . $$\n",
"We'll use the same loss function\n",
"$$ L(c; x,y) = \\frac 1 2 \\lVert f(x,c) - y \\rVert^2 . $$\n",
"\n",
"We will also need the gradient\n",
"$$ \\nabla_c L(c; x,y) = \\big( f(x,c) - y \\big)^T \\nabla_c f(x,c) $$\n",
"where\n",
"\\begin{align}\n",
"\\frac{\\partial f(x,c)}{\\partial c_1} &= -(c_1 + c_2 x + c_3 x^2)^{-2} = - f(x,c)^2 \\\\\n",
"\\frac{\\partial f(x,c)}{\\partial c_2} &= -(c_1 + c_2 x + c_3 x^2)^{-2} x = - f(x,c)^2 x \\\\\n",
"\\frac{\\partial f(x,c)}{\\partial c_3} &= -(c_1 + c_2 x + c_3 x^2)^{-2} x^2 = - f(x,c)^2 x^2 .\n",
"\\end{align}"
]
},
{
"cell_type": "markdown",
"id": "247e6403",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Fitting a rational function"
]
},
{
"cell_type": "code",
"execution_count": 123,
"id": "d6c610e7",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"gradient (generic function with 1 method)"
]
},
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f(x, c) = 1 ./ (c[1] .+ c[2].*x + c[3].*x.^2)\n",
"function gradf(x, c)\n",
" f2 = f(x, c).^2\n",
" [-f2 -f2.*x -f2.*x.^2]\n",
"end\n",
"function loss(c)\n",
" r = f(x, c) - y\n",
" 0.5 * r' * r\n",
"end\n",
"function gradient(c)\n",
" r = f(x, c) - y\n",
" vec(r' * gradf(x, c))\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 138,
"id": "ffd37ce6",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"┌ Warning: scale linear is unsupported with Plots.GRBackend(). Choose from: [:identity, :log10]\n",
"└ @ Plots /home/jed/.julia/packages/Plots/4UTBj/src/args.jl:1662\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 138,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c, _, lhist = grad_descent(loss, gradient, 1ones(3),\n",
" gamma=5e-2)\n",
"plot(lhist, yscale=:linear, ylim=(20, 60), title=\"Loss $(lhist[end])\")"
]
},
{
"cell_type": "markdown",
"id": "7f4be7cd",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Compare fits on noisy data"
]
},
{
"cell_type": "code",
"execution_count": 141,
"id": "11ab1f9e",
"metadata": {
"slideshow": {
"slide_type": ""
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 141,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"scatter(x, y)\n",
"V = vander_chebyshev(x, 20)\n",
"plot!(x -> runge(x), color=:black, label=\"Runge\")\n",
"plot!(x, V * (V \\ y), label=\"Chebyshev fit\")\n",
"plot!(x -> f(x, c), label=\"Rational fit\")"
]
},
{
"cell_type": "markdown",
"id": "627277ab",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# How do we compute those derivatives as the model gets complicated?"
]
},
{
"cell_type": "markdown",
"id": "a8f69462",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"$$\\lim_{h\\to 0} \\frac{f(x+h) - f(x)}{h}$$"
]
},
{
"cell_type": "markdown",
"id": "73b1b145",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* How should we choose $h$?"
]
},
{
"cell_type": "markdown",
"id": "6a6d0812",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Taylor series\n",
"\n",
"Classical accuracy analysis assumes that functions are sufficiently smooth, meaning that derivatives exist and Taylor expansions are valid within a neighborhood. In particular,\n",
"$$ f(x+h) = f(x) + f'(x) h + f''(x) \\frac{h^2}{2!} + \\underbrace{f'''(x) \\frac{h^3}{3!} + \\dotsb}_{O(h^3)} . $$\n",
"\n",
"The big-$O$ notation is meant in the limit $h\\to 0$. Specifically, a function $g(h) \\in O(h^p)$ (sometimes written $g(h) = O(h^p)$) when\n",
"there exists a constant $C$ such that\n",
"$$ g(h) \\le C h^p $$\n",
"for all sufficiently *small* $h$."
]
},
{
"cell_type": "markdown",
"id": "9c564912",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Rounding error\n",
"\n",
"We have an additional source of error, *rounding error*, which comes from not being able to compute $f(x)$ or $f(x+h)$ exactly, nor subtract them exactly. Suppose that we can, however, compute these functions with a relative error on the order of $\\epsilon_{\\text{machine}}$. This leads to\n",
"$$ \\begin{split}\n",
"\\tilde f(x) &= f(x)(1 + \\epsilon_1) \\\\\n",
"\\tilde f(x \\oplus h) &= \\tilde f((x+h)(1 + \\epsilon_2)) \\\\\n",
"&= f((x + h)(1 + \\epsilon_2))(1 + \\epsilon_3) \\\\\n",
"&= [f(x+h) + f'(x+h)(x+h)\\epsilon_2 + O(\\epsilon_2^2)](1 + \\epsilon_3) \\\\\n",
"&= f(x+h)(1 + \\epsilon_3) + f'(x+h)x\\epsilon_2 + O(\\epsilon_{\\text{machine}}^2 + \\epsilon_{\\text{machine}} h)\n",
"\\end{split}\n",
"$$"
]
},
{
"cell_type": "markdown",
"id": "a3b39cf8",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Tedious error propagation\n",
"$$ \\begin{split}\n",
"\\left\\lvert \\frac{\\tilde f(x+h) \\ominus \\tilde f(x)}{h} - \\frac{f(x+h) - f(x)}{h} \\right\\rvert &=\n",
" \\left\\lvert \\frac{f(x+h)(1 + \\epsilon_3) + f'(x+h)x\\epsilon_2 + O(\\epsilon_{\\text{machine}}^2 + \\epsilon_{\\text{machine}} h) - f(x)(1 + \\epsilon_1) - f(x+h) + f(x)}{h} \\right\\rvert \\\\\n",
" &\\le \\frac{|f(x+h)\\epsilon_3| + |f'(x+h)x\\epsilon_2| + |f(x)\\epsilon_1| + O(\\epsilon_{\\text{machine}}^2 + \\epsilon_{\\text{machine}}h)}{h} \\\\\n",
" &\\le \\frac{(2 \\max_{[x,x+h]} |f| + \\max_{[x,x+h]} |f' x| \\epsilon_{\\text{machine}} + O(\\epsilon_{\\text{machine}}^2 + \\epsilon_{\\text{machine}} h)}{h} \\\\\n",
" &= (2\\max|f| + \\max|f'x|) \\frac{\\epsilon_{\\text{machine}}}{h} + O(\\epsilon_{\\text{machine}}) \\\\\n",
"\\end{split} $$\n",
"where we have assumed that $h \\ge \\epsilon_{\\text{machine}}$.\n",
"This error becomes large (relative to $f'$ -- we are concerned with relative error after all)\n",
"* $f$ is large compared to $f'$\n",
"* $x$ is large\n",
"* $h$ is too small"
]
},
{
"cell_type": "markdown",
"id": "6d442032",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Automatic step size selection\n",
"\n",
"* Walker and Pernice\n",
"* Dennis and Schnabel"
]
},
{
"cell_type": "code",
"execution_count": 143,
"id": "2508cef2",
"metadata": {
"cell_style": "center",
"slideshow": {
"slide_type": ""
}
},
"outputs": [
{
"data": {
"text/plain": [
"-4.139506408429305e-6"
]
},
"execution_count": 143,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"diff(f, x; h=1e-8) = (f(x+h) - f(x)) / h\n",
"\n",
"function diff_wp(f, x; h=1e-8)\n",
" \"\"\"Diff using Walker and Pernice (1998) choice of step\"\"\"\n",
" h *= (1 + abs(x))\n",
" (f(x+h) - f(x)) / h\n",
"end\n",
"\n",
"x = 1000\n",
"diff_wp(sin, x) - cos(x)"
]
},
{
"cell_type": "markdown",
"id": "13598631",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Symbolic differentiation"
]
},
{
"cell_type": "code",
"execution_count": 144,
"id": "fa185815",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/latex": [
"$$ \\begin{equation}\n",
"\\frac{\\mathrm{d} \\sin\\left( x \\right)}{\\mathrm{d}x}\n",
"\\end{equation}\n",
" $$"
],
"text/plain": [
"Differential(x)(sin(x))"
]
},
"execution_count": 144,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using Symbolics\n",
"\n",
"@variables x\n",
"Dx = Differential(x)\n",
"\n",
"y = sin(x)\n",
"Dx(y)"
]
},
{
"cell_type": "code",
"execution_count": 145,
"id": "537e4153",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/latex": [
"$$ \\begin{equation}\n",
"\\cos\\left( x \\right)\n",
"\\end{equation}\n",
" $$"
],
"text/plain": [
"cos(x)"
]
},
"execution_count": 145,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"expand_derivatives(Dx(y))"
]
},
{
"cell_type": "markdown",
"id": "55f8ae35",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Awesome, what about products?"
]
},
{
"cell_type": "code",
"execution_count": 149,
"id": "0a4243eb",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$$ \\begin{equation}\n",
"\\frac{\\left( \\frac{\\cos\\left( x^{\\pi} \\right)}{x} - 3.141592653589793 x^{2.141592653589793} \\log\\left( x \\right) \\sin\\left( x^{\\pi} \\right) \\right) \\cos\\left( \\cos^{3.141592653589793}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{3.141592653589793} \\right)}{\\cos\\left( x^{\\pi} \\right) \\log\\left( x \\right)} - \\left( \\frac{3.141592653589793 \\cos^{3.141592653589793}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{2.141592653589793}}{x} - 9.869604401089358 \\cos^{2.141592653589793}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{3.141592653589793} x^{2.141592653589793} \\sin\\left( x^{\\pi} \\right) \\right) \\sin\\left( \\cos^{3.141592653589793}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{3.141592653589793} \\right) \\log\\left( \\log\\left( x \\right) \\cos\\left( x^{\\pi} \\right) \\right)\n",
"\\end{equation}\n",
" $$"
],
"text/plain": [
"((x^-1)*cos(x^π) - 3.141592653589793(x^2.141592653589793)*log(x)*sin(x^π))*(log(x)^-1)*(cos(x^π)^-1)*cos((log(x)^3.141592653589793)*(cos(x^π)^3.141592653589793)) - (3.141592653589793(x^-1)*(log(x)^2.141592653589793)*(cos(x^π)^3.141592653589793) - 9.869604401089358(x^2.141592653589793)*(log(x)^3.141592653589793)*(cos(x^π)^2.141592653589793)*sin(x^π))*sin((log(x)^3.141592653589793)*(cos(x^π)^3.141592653589793))*log(log(x)*cos(x^π))"
]
},
"execution_count": 149,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = x\n",
"for _ in 1:2\n",
" y = cos(y^pi) * log(y)\n",
"end\n",
"expand_derivatives(Dx(y))"
]
},
{
"cell_type": "markdown",
"id": "34b799be",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* The size of these expressions can grow **exponentially**"
]
},
{
"cell_type": "markdown",
"id": "d0a680f1",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Hand-coding derivatives\n",
"\n",
"$$ df = f'(x) dx $$"
]
},
{
"cell_type": "code",
"execution_count": 91,
"id": "c24fb506",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"(-1.5346823414986814, -34.032439961925064)"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function f(x)\n",
" y = x\n",
" for _ in 1:2\n",
" a = y^pi\n",
" b = cos(a)\n",
" c = log(y)\n",
" y = b * c\n",
" end\n",
" y\n",
"end\n",
"\n",
"f(1.9), diff_wp(f, 1.9)"
]
},
{
"cell_type": "code",
"execution_count": 153,
"id": "d8dec50b",
"metadata": {
"cell_style": "split",
"slideshow": {
"slide_type": "fragment"
}
},
"outputs": [
{
"data": {
"text/plain": [
"(-1.5346823414986814, -34.032419599140475)"
]
},
"execution_count": 153,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function df(x, dx)\n",
" y = x\n",
" dy = dx\n",
" for _ in 1:2\n",
" a = y ^ pi\n",
" da = pi * y^(pi - 1) * dy\n",
" b = cos(a)\n",
" db = -sin(a) * da\n",
" c = log(y)\n",
" dc = dy / y\n",
" y = b * c\n",
" dy = db * c + b * dc\n",
" end\n",
" y, dy\n",
"end\n",
"\n",
"df(1.9, 1)"
]
},
{
"cell_type": "markdown",
"id": "3fb72ae9",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# We can go the other way\n",
"\n",
"We can differentiate a composition $h(g(f(x)))$ as\n",
"\n",
"\\begin{align}\n",
" \\operatorname{d} h &= h' \\operatorname{d} g \\\\\n",
" \\operatorname{d} g &= g' \\operatorname{d} f \\\\\n",
" \\operatorname{d} f &= f' \\operatorname{d} x.\n",
"\\end{align}\n",
"\n",
"What we've done above is called \"forward mode\", and amounts to placing the parentheses in the chain rule like\n",
"\n",
"$$ \\operatorname d h = \\frac{dh}{dg} \\left(\\frac{dg}{df} \\left(\\frac{df}{dx} \\operatorname d x \\right) \\right) .$$\n",
"\n",
"The expression means the same thing if we rearrange the parentheses,\n",
"\n",
"$$ \\operatorname d h = \\left( \\left( \\left( \\frac{dh}{dg} \\right) \\frac{dg}{df} \\right) \\frac{df}{dx} \\right) \\operatorname d x .$$"
]
},
{
"cell_type": "markdown",
"id": "12e01ae1",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Automatic differentiation"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "7052e15a",
"metadata": {
"cell_style": "split"
},
"outputs": [],
"source": [
"import Zygote"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "2aabcb46",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"(-34.03241959914049,)"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Zygote.gradient(f, 1.9)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "9cee1ed6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[90m; @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:95 within `gradient`\u001b[39m\n",
"\u001b[95mdefine\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[93m@julia_gradient_12050\u001b[39m\u001b[33m(\u001b[39m\u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[33m)\u001b[39m \u001b[0m#0 \u001b[33m{\u001b[39m\n",
"\u001b[91mtop:\u001b[39m\n",
"\u001b[90m; @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:96 within `gradient`\u001b[39m\n",
"\u001b[90m; ┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:42 within `pullback` @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:44\u001b[39m\n",
"\u001b[90m; │┌ @ In[21]:1 within `_pullback` @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl:9\u001b[39m\n",
"\u001b[90m; ││┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface2.jl within `macro expansion`\u001b[39m\n",
"\u001b[90m; │││┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:218 within `chain_rrule`\u001b[39m\n",
"\u001b[90m; ││││┌ @ /home/jed/.julia/packages/ChainRulesCore/C73ay/src/rules.jl:134 within `rrule` @ /home/jed/.julia/packages/ChainRules/hVHC4/src/rulesets/Base/fastmath_able.jl:56\u001b[39m\n",
" \u001b[0m%1 \u001b[0m= \u001b[96m\u001b[1mcall\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[93m@j_exp_12052\u001b[39m\u001b[33m(\u001b[39m\u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[33m)\u001b[39m \u001b[0m#0\n",
"\u001b[90m; └└└└└\u001b[39m\n",
"\u001b[90m; @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:97 within `gradient`\u001b[39m\n",
"\u001b[90m; ┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:45 within `#60`\u001b[39m\n",
"\u001b[90m; │┌ @ In[21]:1 within `Pullback`\u001b[39m\n",
"\u001b[90m; ││┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/chainrules.jl:206 within `ZBack`\u001b[39m\n",
"\u001b[90m; │││┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/lib/number.jl:12 within `literal_pow_pullback`\u001b[39m\n",
"\u001b[90m; ││││┌ @ promotion.jl:389 within `*` @ float.jl:385\u001b[39m\n",
" \u001b[0m%2 \u001b[0m= \u001b[96m\u001b[1mfmul\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[0m, \u001b[33m2.000000e+00\u001b[39m\n",
"\u001b[90m; │└└└└\u001b[39m\n",
"\u001b[90m; │┌ @ /home/jed/.julia/packages/Zygote/dABKa/src/lib/lib.jl:17 within `accum`\u001b[39m\n",
"\u001b[90m; ││┌ @ float.jl:383 within `+`\u001b[39m\n",
" \u001b[0m%3 \u001b[0m= \u001b[96m\u001b[1mfadd\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%2\u001b[0m, \u001b[0m%1\n",
"\u001b[90m; └└└\u001b[39m\n",
"\u001b[90m; @ /home/jed/.julia/packages/Zygote/dABKa/src/compiler/interface.jl:98 within `gradient`\u001b[39m\n",
" \u001b[0m%.fca.0.insert \u001b[0m= \u001b[96m\u001b[1minsertvalue\u001b[22m\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[95mzeroinitializer\u001b[39m\u001b[0m, \u001b[36mdouble\u001b[39m \u001b[0m%3\u001b[0m, \u001b[33m0\u001b[39m\n",
" \u001b[96m\u001b[1mret\u001b[22m\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[0m%.fca.0.insert\n",
"\u001b[33m}\u001b[39m\n"
]
}
],
"source": [
"g(x) = exp(x) + x^2\n",
"@code_llvm Zygote.gradient(g, 1.9)"
]
}
],
"metadata": {
"@webio": {
"lastCommId": null,
"lastKernelId": null
},
"celltoolbar": "Slideshow",
"hide_code_all_hidden": false,
"kernelspec": {
"display_name": "Julia 1.7.2",
"language": "julia",
"name": "julia-1.7"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.8.5"
},
"rise": {
"enable_chalkboard": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}