{ "cells": [ { "cell_type": "markdown", "source": [ "# Line search MM\n", "\n", "Examples illustrating the\n", "line-search method\n", "based on majorize-minimize (MM) principles\n", "in the Julia package\n", "[`MIRT`](https://github.com/JeffFessler/MIRT.jl).\n", "This method is probably most useful\n", "for algorithm developers." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "### Setup" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Packages needed here." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "using Plots; default(markerstrokecolor = :auto, label=\"\")\n", "using MIRTjim: prompt\n", "using MIRT: line_search_mm, LineSearchMMWork\n", "using LineSearches: BackTracking, HagerZhang, MoreThuente\n", "using LinearAlgebra: norm, dot\n", "using Random: seed!; seed!(0)\n", "using BenchmarkTools: @btime, @benchmark\n", "using InteractiveUtils: versioninfo" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "The following line is helpful when running this file as a script;\n", "this way it will prompt user to hit a key after each figure is displayed." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "isinteractive() && prompt(:prompt);" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "# Theory\n", "\n", "Many methods for solving inverse problems\n", "involve optimization problems\n", "of the form\n", "$$\n", "\\hat{x} = \\arg\\min_{x ∈ \\mathbb{F}^N} f(x)\n", ",\\qquad\n", "f(x) = \\sum_{j=1}^J f_j(B_j x)\n", "$$\n", "where $\\mathbb{F}$ denotes the field of real or complex numbers,\n", "matrix $B_j$ has size $M_j × N$,\n", "and $f_j : \\mathbb{F}^{M_j} ↦ \\mathbb{R}$.\n", "\n", "One could apply general-purpose optimization methods here,\n", "like those in\n", "[Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl),\n", "but often we can obtain faster results\n", "by exploiting the specific\n", "(yet still fairly general) structure,\n", "particularly when the problem dimension $N$ is large.\n", "\n", "Many algorithms for solving such problems\n", "require an inner 1D optimization problem\n", "called a\n", "[line search](https://en.wikipedia.org/wiki/Line_search)\n", "of the form\n", "$$\n", "α_* = \\arg\\min_{α ∈ \\mathbb{R}} h(α)\n", ",\\qquad\n", "h(α) = f(x + α d),\n", "$$\n", "for some search direction $d$.\n", "There are general purpose line search algorithms in\n", "[LineSearches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl),\n", "but here we focus on the specific form of $f$ given above.\n", "\n", "For that form we see that we have the special structure\n", "$$\n", "h(α) = \\sum_{j=1}^J f_j(u_j + α v_j)\n", ",\\qquad\n", "u_j = B_j x\n", ",\\quad\n", "v_j = B_j d.\n", "$$\n", "\n", "Here we focus further\n", "on the case where each function $f_j(⋅)$\n", "has a quadratic majorizer\n", "of the form\n", "$$\n", "f_j(x) ≤ q_j(x,z) = f_j(z) + \\text{real}(⟨ ∇f_j(z), x - z ⟩)\n", "+ \\frac{1}{2} (x - z)' D_j(z) (x - z),\n", "$$\n", "where $D_j(z)$ is a positive semidefinite matrix\n", "that typically is diagonal.\n", "Often it is a constant times the identity matrix,\n", "e.g.,\n", "a Lipschitz constant for $∇f_j$,\n", "but often there are\n", "sharper majorizers.\n", "\n", "Such quadratic majorizers\n", "induce a quadratic majorizer\n", "for $h(α)$ as well:\n", "$$\n", "h(α) ≤ q(α; α_t) =\n", "\\sum_{j=1}^J q_j(u_j + α v_j; u_j + α_t v_j)\n", "= h(α_t) + c_1(α_t) (α - α_t)\n", "+ \\frac{1}{2} c_2(α_t) (α - α_t)^2\n", "$$\n", "where\n", "$$\n", "c_1(α_t) = \\sum_{j=1}^J \\text{real}(⟨ ∇f_j(u_j + α_t v_j), v_j ⟩)\n", ",\\qquad\n", "c_2(α_t) = \\sum_{j=1}^J v_j' D_j(u_j + α_t v_j) v_j.\n", "$$\n", "\n", "The `line_search_mm` function\n", "in this package\n", "uses this quadratic majorizer\n", "to update $α$\n", "using the iteration\n", "$$\n", "α_{t+1}\n", "= \\arg\\min_{α} q(α;α_t)\n", "= α_t - c_1(α_t) / c_2(α_t).\n", "$$\n", "Being an MM algorithm,\n", "it is guaranteed to decrease\n", "$h(α)$ every update.\n", "For an early exposition of this approach,\n", "see\n", "[Fessler & Booth, 1999](http://doi.org/10.1109/83.760336).\n", "\n", "From the above derivation,\n", "the main ingredients needed are\n", "functions for computing\n", "the dot products\n", "$⟨ ∇f_j(u_j + α_t v_j), v_j ⟩$\n", "and\n", "$v_j' D_j(u_j + α_t v_j) v_j$.\n", "\n", "The `line_search_mm` function\n", "can construct such functions\n", "given input gradient functions\n", "$[∇f_1,…,∇f_J]$\n", "and curvature functions\n", "$[ω_1,…,ω_J]$\n", "where\n", "$D_j(z) = \\text{Diag}(ω_j(z))$.\n", "\n", "Alternatively,\n", "the user can provide functions\n", "for computing the dot products.\n", "\n", "All of this is best illustrated\n", "by an example.\n", "\n", "# Smooth LASSO problem\n", "\n", "The usual LASSO optimization problem uses the cost function\n", "$$\n", "f(x) = \\frac{1}{2} \\| A x - y \\|_2^2 + β R(x)\n", ",\\qquad\n", "R(x) = \\| x \\|_1 = \\sum_{n=1}^N |x_n| = 1' \\text{abs.}(x).\n", "$$\n", "\n", "The 1-norm is just a relaxation of the 0-norm\n", "so here we further \"relax\" it\n", "by considering the \"corner rounded\" version\n", "using the Fair potential function\n", "$$\n", "R(x) = \\sum_{n=1}^N ψ(x_n) = 1' ψ.(x),\n", "\\qquad\n", "ψ(z) = δ^2 |z/δ| - \\log(1 + |z/δ|)\n", "$$\n", "for a small value of $δ$.\n", "\n", "The derivative of this potential function\n", "is\n", "$\\dot{ψ}(z) = z / (1 + |z / δ|)$\n", "and Huber's curvature\n", "$ω_{ψ}(z) = 1 / (1 + |z / δ|)$\n", "provides a suitable majorizer.\n", "\n", "This smooth LASSO cost function has the general form above\n", "with $J=2$,\n", "$B_1 = A$,\n", "$B_2 = I$,\n", "$f_1(u) = \\frac{1}{2} \\| u - y \\|_2^2,$\n", "$f_2(u) = β 1' ψ.(u),$\n", "for which\n", "$∇f_1(u) = u - y,$\n", "$∇f_2(u) = β ψ.(u),$\n", "and\n", "$∇^2 f_1(u) = I,$\n", "$∇^2 f_2(u) \\succeq β \\, \\text{diag}(ω_{ψ}(u)).$\n", "\n", "Set up an example and plot $h(α)$." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "Fair potential, its derivative and Huber weighting function:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "function fair_pot()\n", " fpot(z,δ) = δ^2 * (abs(z/δ) - log(1 + abs(z/δ)))\n", " dpot(z,δ) = z / (1 + abs(z/δ))\n", " wpot(z,δ) = 1 / (1 + abs(z/δ))\n", " return fpot, dpot, wpot\n", "end;" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Data, cost function and gradients for smooth LASSO problem:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "M, N = 1000, 2000\n", "A = randn(M,N)\n", "x0 = randn(N) .* (rand(N) .< 0.4) # sparse vector\n", "y = A * x0 + 0.001 * randn(M)\n", "β = 95\n", "δ = 0.1\n", "fpot, dpot, wpot = Base.Fix2.(fair_pot(), δ)\n", "\n", "f(x) = 0.5 * norm(A * x - y)^2 + β * sum(fpot, x)\n", "∇f(x) = A' * (A * x - y) + β * dpot.(x)\n", "x = randn(N) # random point\n", "d = -∇f(x)/M # some search direction\n", "h(α) = f(x + α * d)\n", "dh(α) = d' * ∇f(x + α * d)\n", "pa = plot(h, xlabel=\"α\", ylabel=\"h(α)\", xlims=(-1, 2))" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Apply MM-based line search: simple version.\n", "The key inputs are the gradient and curvature functions:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "gradf = [\n", " u -> u - y, # ∇f₁ for data-fit term\n", " u -> β * dpot.(u), # ∇f₂ for regularizer\n", "]\n", "curvf = [\n", " 1, # curvature for data-fit term\n", " u -> β * wpot.(u), # Huber curvature for regularizer\n", "]\n", "\n", "uu = [A * x, x] # [u₁ u₂]\n", "vv = [A * d, d] # [v₁ v₂]\n", "fun(state) = state.α # log this\n", "ninner = 7\n", "out = Vector{Any}(undef, ninner+1)\n", "α0 = 0\n", "αstar = line_search_mm(gradf, curvf, uu, vv; ninner, out, fun, α0)\n", "hmin = h(αstar)\n", "scatter!([αstar], [hmin], marker=:star, color=:red)\n", "scatter!([α0], [h(α0)], marker=:circle, color=:green)\n", "ps = plot(0:ninner, out, marker=:circle, xlabel=\"iteration\", ylabel=\"α\",\n", " color = :green)\n", "pd = plot(0:ninner, abs.(dh.(out)), marker=:diamond,\n", " yaxis = :log, color=:red,\n", " xlabel=\"iteration\", ylabel=\"|dh(α)|\")\n", "pu = plot(1:ninner, log10.(max.(abs.(diff(out)), 1e-16)), marker=:square,\n", " color=:blue, xlabel=\"iteration\", ylabel=\"log10(|α_k - α_{k-1}|)\")\n", "plot(pa, ps, pd, pu)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Thanks to Huber's curvatures,\n", "the $α_t$ sequence converges very quickly.\n", "\n", "Now explore a fancier version\n", "that needs less heap memory." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "work = LineSearchMMWork(uu, vv, α0) # pre-allocate\n", "function lsmm1(gradf, curvf)\n", " return line_search_mm(gradf, curvf, uu, vv;\n", " ninner, out, fun, α0, work)\n", "end\n", "function lsmm2(dot_gradf, dot_curvf)\n", " gradn = [() -> nothing, () -> nothing]\n", " return line_search_mm(uu, vv, dot_gradf, dot_curvf;\n", " ninner, out, fun, α0, work)\n", "end;" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "The `let` statements below are a performance trick from the\n", "[Julia manual](https://docs.julialang.org/en/v1/manual/performance-tips/#man-performance-captured-1).\n", "Using `Iterators.map` avoids allocating arrays like `z - y`\n", "and does not even require any work space." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "gradz = [\n", " let y=y; z -> Iterators.map(-, z, y); end, # z - y\n", " let β=β, dpot=dpot; z -> Iterators.map(z -> β * dpot(z), z); end, # β * dψ.(z)\n", "]\n", "curvz = [\n", " 1,\n", " let β=β, wpot=wpot; z -> Iterators.map(z -> β * wpot(z), z); end, # β * ωψ.(z)\n", "]\n", "\n", "function make_grad1c()\n", " w = similar(uu[1]) # work-space\n", " let w=w, y=y\n", " function grad1c(z)\n", " @. w = z - y\n", " return w\n", " end\n", " end\n", "end\n", "\n", "function make_grad2c()\n", " w = similar(uu[2]) # work-space\n", " let w=w, β=β, dpot=dpot\n", " function grad2c(z)\n", " @. w = β * dpot(z)\n", " return w\n", " end\n", " end\n", "end\n", "\n", "function make_curv2c()\n", " w = similar(uu[2]) # work-space\n", " let w=w, β=β, wpot=wpot\n", " function curv2c(z)\n", " @. w = β * wpot(z) # β * ωψ.(z)\n", " return w\n", " end\n", " end\n", "end\n", "\n", "gradc = [ # capture version\n", " make_grad1c(), # z - y\n", " make_grad2c(), # β * dψ.(z)\n", "]\n", "curvc = [\n", " 1,\n", " make_curv2c(), # β * ωψ.(z)\n", "]\n", "\n", "sum_map(f::Function, args...) = sum(Iterators.map(f, args...))\n", "dot_gradz = [\n", " let y=y; (v,z) -> sum_map((v,z,y) -> dot(v, z - y), v, z, y); end, # v'(z - y)\n", " let β=β, dpot=dpot; (v,z) -> β * sum_map((v,z) -> dot(v, dpot(z)), v, z); end, # β * (v'dψ.(z))\n", "]\n", "dot_curvz = [\n", " (v,z) -> norm(v)^2,\n", " let β=β, wpot=wpot; (v,z) -> β * sum_map((v,z) -> abs2(v) * wpot(z), v, z); end, # β * (abs2.(v)'ωψ.(z))\n", "]\n", "\n", "\n", "a1 = lsmm1(gradf, curvf)\n", "a1c = lsmm1(gradc, curvc)\n", "a2 = lsmm1(gradz, curvz)\n", "a3 = lsmm2(dot_gradz, dot_curvz)\n", "@assert a1 ≈ a2 ≈ a3 ≈ a1c\n", "\n", "b1 = @benchmark a1 = lsmm1($gradf, $curvf)" ], "metadata": {}, "execution_count": null }, { "outputs": [], "cell_type": "code", "source": [ "bc = @benchmark a1c = lsmm1($gradc, $curvc)" ], "metadata": {}, "execution_count": null }, { "outputs": [], "cell_type": "code", "source": [ "b2 = @benchmark a2 = lsmm1($gradz, $curvz)" ], "metadata": {}, "execution_count": null }, { "outputs": [], "cell_type": "code", "source": [ "b3 = @benchmark a3 = lsmm2($dot_gradz, $dot_curvz)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Timing results on my Mac:\n", "- 95 μs\n", "- 65 μs # 1c after using `make_`\n", "- 80 μs\n", "- 69 μs (and lowest memory)\n", "\n", "The versions using `gradc` and `dot_gradz`\n", "with their \"properly captured\" variables\n", "are the fastest.\n", "But all the versions here are pretty similar\n", "so even using the simplest version\n", "seems likely to be fine." ], "metadata": {} }, { "cell_type": "markdown", "source": [ "## Compare with LineSearches.jl\n", "\n", "Was all this specialized effort useful?\n", "Let's compare to the general line search methods in\n", "[LineSearches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl).\n", "\n", "It seems that some of those methods do not allow $α₀ = 0$\n", "so we use 1.0 instead.\n", "We use the default arguments for all the solvers,\n", "which means some of them might terminate\n", "before `ninner` iterations,\n", "giving them a potential speed advantage." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "a0 = 1.0 # α0\n", "hdh(α) = h(α), dh(α)\n", "h0 = h(0)\n", "dh0 = dh(0);\n", "function ls_ls(linesearch)\n", " a1, fx = linesearch(h, dh, hdh, a0, h0, dh0)\n", " return a1\n", "end;\n", "\n", "solvers = [\n", " BackTracking( ; iterations = ninner),\n", " HagerZhang( ; linesearchmax = ninner),\n", " MoreThuente( ; maxfev = ninner),\n", "]\n", "for ls in solvers # check that they work properly\n", " als = ls_ls(ls)\n", " @assert isapprox(als, αstar; atol=1e-3)\n", "end;" ], "metadata": {}, "execution_count": null }, { "outputs": [], "cell_type": "code", "source": [ "bbt = @benchmark ls_ls($(solvers[1]))" ], "metadata": {}, "execution_count": null }, { "outputs": [], "cell_type": "code", "source": [ "bhz = @benchmark ls_ls($(solvers[2]))" ], "metadata": {}, "execution_count": null }, { "outputs": [], "cell_type": "code", "source": [ "bmt = @benchmark ls_ls($(solvers[3]))" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "On my Mac the timings are all much longer\n", "compared to `line_search_mm`:\n", "- 840 μs `BackTracking`\n", "- 2.6 ms `HagerZhang`\n", "- 3.9 ms `MoreThuente`\n", "\n", "This comparison illustrates\n", "the benefit of the \"special purpose\" line search.\n", "\n", "\n", "The fastest version seems to be `BackTracking`,\n", "so plot its iterates:" ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "alpha_bt = zeros(ninner + 1)\n", "alpha_bt[1] = a0\n", "for iter in 1:ninner\n", " tmp = BackTracking( ; iterations = iter)\n", " alpha_bt[iter+1] = ls_ls(tmp)\n", "end\n", "plot(0:ninner, alpha_bt, marker=:square, color=:blue,\n", " xlabel=\"Iteration\", ylabel=\"BackTracking α\")\n", "plot!([0, ninner], [1,1] * αstar, color=:red)" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "Unexpectedly,\n", "`BackTracking` seems to terminate at the first iteration.\n", "But even just that single iteration is slower than 7 iterations\n", "of `line_search_mm`." ], "metadata": {} }, { "outputs": [], "cell_type": "code", "source": [ "prompt()" ], "metadata": {}, "execution_count": null }, { "cell_type": "markdown", "source": [ "---\n", "\n", "*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*" ], "metadata": {} } ], "nbformat_minor": 3, "metadata": { "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.10.5" }, "kernelspec": { "name": "julia-1.10", "display_name": "Julia 1.10.5", "language": "julia" } }, "nbformat": 4 }