{ "cells": [ { "cell_type": "markdown", "id": "a1132dd9-6bcc-44c8-a1f3-cbe82ae20992", "metadata": {}, "source": [ "# Hamiltonian Monte Carlo with leapfrog\n", "\n", "* 黒木玄\n", "* 2021-12-06, 2022-09-03\n", "\n", "Scalar version: https://github.com/genkuroki/public/blob/main/0018/HMC%20leapfrog.ipynb\n", "\n", "2022-09-03: カーネルをJulia v1.8.0に変更して, ConcreteStructs.jl と Parameters.jl への依存を無くした." ] }, { "cell_type": "code", "execution_count": 1, "id": "ba312428-480f-4965-a914-9cc2b255c32e", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "Main.My" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Sample code of Hamiltonian Monte Carlo with leapfrog\n", "\n", "module My\n", "\n", "using LinearAlgebra: dot\n", "using ForwardDiff: gradient\n", "using Random: default_rng, randn!\n", "using StaticArrays: SVector, MVector\n", "\n", "# Type of Leapfrog Problem\n", "struct LFProblem{dim, F1, F2, F3, T, I}\n", " ϕ::F1\n", " H::F2\n", " F::F3\n", " dt::T\n", " nsteps::I\n", "end\n", "\n", "\"\"\"\n", " LFProblem(dim, ϕ, H, F, dt, nsteps)\n", "\n", "Assume that `ϕ(x, param)` is a potential function of a `dim`-dimensional vector `x` and a parameter `param`, `H(x, v, param)` and `F(x, param)` the Hamiltonian and force functions corresponding to `ϕ`, `dt` a step of discretized time, and `nsteps` the number of steps. Then it returns the Problem object of solving the Hamiltonian equation with leapfrog method.\n", "\"\"\"\n", "function LFProblem(dim, ϕ, H, F, dt, nsteps)\n", " F1, F2, F3, T, I = typeof(ϕ), typeof(H), typeof(F), typeof(dt), typeof(nsteps)\n", " LFProblem{dim, F1, F2, F3, T, I}(ϕ, H, F, dt, nsteps)\n", "end\n", "\n", "\"\"\"\n", " LFProblem(dim, ϕ; dt = 1.0, nsteps = 40)\n", "\n", "Assume that `ϕ(x, param)` is a potential function of a `dim`-dimensional vector `x` and a parameter `param`. Then it defines the Hamiltonian function `H(x, v, param)` and the force function `F(x, param)` corresponding to `ϕ` and returns `LFProblem(dim, ϕ, H, F, dt, nsteps)`.\n", "\"\"\"\n", "function LFProblem(dim, ϕ; dt = 1.0, nsteps = 40)\n", " H(x, v, param) = dot(v, v)/2 + ϕ(x, param)\n", " F(x, param) = -gradient(x -> ϕ(x, param), x)\n", " LFProblem(dim, ϕ, H, F, dt, nsteps)\n", "end\n", "\n", "\"\"\"\n", " LFProblem(dim, ϕ, ∇ϕ; dt = 1.0, nsteps = 40)\n", "\n", "Assume that `ϕ(x, param)` is a potential function of a `dim`-dimensional vector `x` and a parameter `param` and `∇ϕ` its gradient with respect to `x`. Then it deifnes the Hamiltonian function `H(x, v, param)` and the force function `F(x, param)` corresponding to `ϕ` and `∇ϕ` and returns `LFProblem(dim, ϕ, H, F, dt, nsteps)`.\n", "\"\"\"\n", "function LFProblem(dim, ϕ, ∇ϕ; dt = 1.0, nsteps = 40)\n", " H(x, v, param) = dot(v, v)/2 + ϕ(x, param)\n", " F(x, param) = -∇ϕ(x, param)\n", " LFProblem(dim, ϕ, H, F, dt, nsteps)\n", "end\n", "\n", "\"\"\"\n", " solve(lf::LFProblem, x, v, param)\n", "\n", "numerically solve the Hamilton's equation of motion given by `lf` with leapfrog method, where (`x`, `v`) is the initial value and `param` is the parameter of the potential function `lf.ϕ`.\n", "\"\"\"\n", "function solve(lf::LFProblem, x, v, param)\n", " (; F, dt, nsteps) = lf\n", " v = v + F(x, param)*dt/2\n", " x = x + v*dt\n", " for _ in 2:nsteps\n", " v = v + F(x, param)*dt\n", " x = x + v*dt\n", " end\n", " v = v + F(x, param)*dt/2\n", " x, v\n", "end\n", "\n", "@inline function _update!(lf::LFProblem{dim}, x, vtmp, param, rng) where dim\n", " (; H) = lf\n", " v = SVector{dim}(randn!(rng, vtmp))\n", " xnew, vnew = solve(lf, x, v, param)\n", " dH = H(xnew, vnew, param) - H(x, v, param)\n", " rand(rng) ≤ exp(-dH) ? xnew : x\n", "end\n", "\n", "\"\"\"\n", " HMC(lf::LFProblem{dim}, param = nothing;\n", " niters = 10^5, thin = 1, nwarmups = 0, rng = default_rng(),\n", " init = SVector{dim}(randn(rng, dim))) where dim\n", "\n", "generates the sample of the distribution given by the probability density function proportioal to exp(-`lf.ϕ(x, param)`) by Hamiltonian Monte Carlo method.\n", "\"\"\"\n", "function HMC(lf::LFProblem{dim}, param = nothing;\n", " niters = 10^5, thin = 1, nwarmups = 0, rng = default_rng(),\n", " init = SVector{dim}(randn(rng, dim))) where dim\n", " vtmp = MVector{dim}(zeros(eltype(init), dim))\n", " x = init\n", " for _ in 1:nwarmups\n", " x = _update!(lf, x, vtmp, param, rng)\n", " end\n", " sample = Vector{typeof(init)}(undef, niters)\n", " for i in 1:niters\n", " for _ in 1:thin\n", " x = _update!(lf, x, vtmp, param, rng)\n", " end\n", " @inbounds sample[i] = x\n", " end\n", " sample\n", "end\n", "\n", "end" ] }, { "cell_type": "code", "execution_count": 2, "id": "edb62652-9cea-4a8d-ba31-89e25497ae37", "metadata": {}, "outputs": [ { "data": { "text/latex": [ "\\begin{verbatim}\n", "LFProblem(dim, ϕ, H, F, dt, nsteps)\n", "\\end{verbatim}\n", "Assume that \\texttt{ϕ(x, param)} is a potential function of a \\texttt{dim}-dimensional vector \\texttt{x} and a parameter \\texttt{param}, \\texttt{H(x, v, param)} and \\texttt{F(x, param)} the Hamiltonian and force functions corresponding to \\texttt{ϕ}, \\texttt{dt} a step of discretized time, and \\texttt{nsteps} the number of steps. Then it returns the Problem object of solving the Hamiltonian equation with leapfrog method.\n", "\n", "\\rule{\\textwidth}{1pt}\n", "\\begin{verbatim}\n", "LFProblem(dim, ϕ; dt = 1.0, nsteps = 40)\n", "\\end{verbatim}\n", "Assume that \\texttt{ϕ(x, param)} is a potential function of a \\texttt{dim}-dimensional vector \\texttt{x} and a parameter \\texttt{param}. Then it defines the Hamiltonian function \\texttt{H(x, v, param)} and the force function \\texttt{F(x, param)} corresponding to \\texttt{ϕ} and returns \\texttt{LFProblem(dim, ϕ, H, F, dt, nsteps)}.\n", "\n", "\\rule{\\textwidth}{1pt}\n", "\\begin{verbatim}\n", "LFProblem(dim, ϕ, ∇ϕ; dt = 1.0, nsteps = 40)\n", "\\end{verbatim}\n", "Assume that \\texttt{ϕ(x, param)} is a potential function of a \\texttt{dim}-dimensional vector \\texttt{x} and a parameter \\texttt{param} and \\texttt{∇ϕ} its gradient with respect to \\texttt{x}. Then it deifnes the Hamiltonian function \\texttt{H(x, v, param)} and the force function \\texttt{F(x, param)} corresponding to \\texttt{ϕ} and \\texttt{∇ϕ} and returns \\texttt{LFProblem(dim, ϕ, H, F, dt, nsteps)}.\n", "\n" ], "text/markdown": [ "```\n", "LFProblem(dim, ϕ, H, F, dt, nsteps)\n", "```\n", "\n", "Assume that `ϕ(x, param)` is a potential function of a `dim`-dimensional vector `x` and a parameter `param`, `H(x, v, param)` and `F(x, param)` the Hamiltonian and force functions corresponding to `ϕ`, `dt` a step of discretized time, and `nsteps` the number of steps. Then it returns the Problem object of solving the Hamiltonian equation with leapfrog method.\n", "\n", "---\n", "\n", "```\n", "LFProblem(dim, ϕ; dt = 1.0, nsteps = 40)\n", "```\n", "\n", "Assume that `ϕ(x, param)` is a potential function of a `dim`-dimensional vector `x` and a parameter `param`. Then it defines the Hamiltonian function `H(x, v, param)` and the force function `F(x, param)` corresponding to `ϕ` and returns `LFProblem(dim, ϕ, H, F, dt, nsteps)`.\n", "\n", "---\n", "\n", "```\n", "LFProblem(dim, ϕ, ∇ϕ; dt = 1.0, nsteps = 40)\n", "```\n", "\n", "Assume that `ϕ(x, param)` is a potential function of a `dim`-dimensional vector `x` and a parameter `param` and `∇ϕ` its gradient with respect to `x`. Then it deifnes the Hamiltonian function `H(x, v, param)` and the force function `F(x, param)` corresponding to `ϕ` and `∇ϕ` and returns `LFProblem(dim, ϕ, H, F, dt, nsteps)`.\n" ], "text/plain": [ "\u001b[36m LFProblem(dim, ϕ, H, F, dt, nsteps)\u001b[39m\n", "\n", " Assume that \u001b[36mϕ(x, param)\u001b[39m is a potential function of a \u001b[36mdim\u001b[39m-dimensional vector\n", " \u001b[36mx\u001b[39m and a parameter \u001b[36mparam\u001b[39m, \u001b[36mH(x, v, param)\u001b[39m and \u001b[36mF(x, param)\u001b[39m the Hamiltonian and\n", " force functions corresponding to \u001b[36mϕ\u001b[39m, \u001b[36mdt\u001b[39m a step of discretized time, and\n", " \u001b[36mnsteps\u001b[39m the number of steps. Then it returns the Problem object of solving\n", " the Hamiltonian equation with leapfrog method.\n", "\n", " ────────────────────────────────────────────────────────────────────────────\n", "\n", "\u001b[36m LFProblem(dim, ϕ; dt = 1.0, nsteps = 40)\u001b[39m\n", "\n", " Assume that \u001b[36mϕ(x, param)\u001b[39m is a potential function of a \u001b[36mdim\u001b[39m-dimensional vector\n", " \u001b[36mx\u001b[39m and a parameter \u001b[36mparam\u001b[39m. Then it defines the Hamiltonian function \u001b[36mH(x, v,\n", " param)\u001b[39m and the force function \u001b[36mF(x, param)\u001b[39m corresponding to \u001b[36mϕ\u001b[39m and returns\n", " \u001b[36mLFProblem(dim, ϕ, H, F, dt, nsteps)\u001b[39m.\n", "\n", " ────────────────────────────────────────────────────────────────────────────\n", "\n", "\u001b[36m LFProblem(dim, ϕ, ∇ϕ; dt = 1.0, nsteps = 40)\u001b[39m\n", "\n", " Assume that \u001b[36mϕ(x, param)\u001b[39m is a potential function of a \u001b[36mdim\u001b[39m-dimensional vector\n", " \u001b[36mx\u001b[39m and a parameter \u001b[36mparam\u001b[39m and \u001b[36m∇ϕ\u001b[39m its gradient with respect to \u001b[36mx\u001b[39m. Then it\n", " deifnes the Hamiltonian function \u001b[36mH(x, v, param)\u001b[39m and the force function \u001b[36mF(x,\n", " param)\u001b[39m corresponding to \u001b[36mϕ\u001b[39m and \u001b[36m∇ϕ\u001b[39m and returns \u001b[36mLFProblem(dim, ϕ, H, F, dt,\n", " nsteps)\u001b[39m." ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "?My.LFProblem" ] }, { "cell_type": "code", "execution_count": 3, "id": "5893dfe0-dd88-4338-a334-40275427ebbf", "metadata": {}, "outputs": [ { "data": { "text/latex": [ "\\begin{verbatim}\n", "solve(lf::LFProblem, x, v, param)\n", "\\end{verbatim}\n", "numerically solve the Hamilton's equation of motion given by \\texttt{lf} with leapfrog method, where (\\texttt{x}, \\texttt{v}) is the initial value and \\texttt{param} is the parameter of the potential function \\texttt{lf.ϕ}.\n", "\n" ], "text/markdown": [ "```\n", "solve(lf::LFProblem, x, v, param)\n", "```\n", "\n", "numerically solve the Hamilton's equation of motion given by `lf` with leapfrog method, where (`x`, `v`) is the initial value and `param` is the parameter of the potential function `lf.ϕ`.\n" ], "text/plain": [ "\u001b[36m solve(lf::LFProblem, x, v, param)\u001b[39m\n", "\n", " numerically solve the Hamilton's equation of motion given by \u001b[36mlf\u001b[39m with\n", " leapfrog method, where (\u001b[36mx\u001b[39m, \u001b[36mv\u001b[39m) is the initial value and \u001b[36mparam\u001b[39m is the\n", " parameter of the potential function \u001b[36mlf.ϕ\u001b[39m." ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "?My.solve" ] }, { "cell_type": "code", "execution_count": 4, "id": "06bff3ae-8f8a-4257-a280-94487e89976d", "metadata": {}, "outputs": [ { "data": { "text/latex": [ "\\begin{verbatim}\n", "HMC(lf::LFProblem{dim}, param = nothing;\n", " niters = 10^5, thin = 1, nwarmups = 0, rng = default_rng(),\n", " init = SVector{dim}(randn(rng, dim))) where dim\n", "\\end{verbatim}\n", "generates the sample of the distribution given by the probability density function proportioal to exp(-\\texttt{lf.ϕ(x, param)}) by Hamiltonian Monte Carlo method.\n", "\n" ], "text/markdown": [ "```\n", "HMC(lf::LFProblem{dim}, param = nothing;\n", " niters = 10^5, thin = 1, nwarmups = 0, rng = default_rng(),\n", " init = SVector{dim}(randn(rng, dim))) where dim\n", "```\n", "\n", "generates the sample of the distribution given by the probability density function proportioal to exp(-`lf.ϕ(x, param)`) by Hamiltonian Monte Carlo method.\n" ], "text/plain": [ "\u001b[36m HMC(lf::LFProblem{dim}, param = nothing;\u001b[39m\n", "\u001b[36m niters = 10^5, thin = 1, nwarmups = 0, rng = default_rng(),\u001b[39m\n", "\u001b[36m init = SVector{dim}(randn(rng, dim))) where dim\u001b[39m\n", "\n", " generates the sample of the distribution given by the probability density\n", " function proportioal to exp(-\u001b[36mlf.ϕ(x, param)\u001b[39m) by Hamiltonian Monte Carlo\n", " method." ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "?My.HMC" ] }, { "cell_type": "code", "execution_count": 5, "id": "c58be2a5-8c7f-407d-a3e9-1e7ebff2b780", "metadata": {}, "outputs": [], "source": [ "using Plots\n", "using BenchmarkTools\n", "using StaticArrays\n", "using LinearAlgebra\n", "using KernelDensity\n", "using Statistics\n", "using QuadGK\n", "using Distributions\n", "using Symbolics" ] }, { "cell_type": "markdown", "id": "91c44d89-401b-4fd3-9c05-140119047b63", "metadata": {}, "source": [ "## 2-dimensional normal distribution" ] }, { "cell_type": "code", "execution_count": 6, "id": "9387cc53-834f-400f-950d-88fd9130aa74", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Main.My.LFProblem{2, typeof(ϕ), Main.My.var\"#H#3\"{typeof(ϕ)}, Main.My.var\"#F#4\"{typeof(ϕ)}, Float64, Int64}(ϕ, Main.My.var\"#H#3\"{typeof(ϕ)}(ϕ), Main.My.var\"#F#4\"{typeof(ϕ)}(ϕ), 1.0, 40)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "A = @SMatrix [\n", " 1 1/2\n", " 1/2 1\n", "]\n", "param = (; A = A)\n", "ϕ(x, param) = dot(x, param.A, x)/2\n", "lf = My.LFProblem(2, ϕ)" ] }, { "cell_type": "code", "execution_count": 7, "id": "982aae48-8471-45e4-bf2f-c9aaa88d8280", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0.967557 seconds (2.27 M allocations: 126.243 MiB, 4.24% gc time, 94.16% compilation time)\n", " 0.056512 seconds (4 allocations: 1.526 MiB)\n", " 0.057569 seconds (4 allocations: 1.526 MiB)\n" ] } ], "source": [ "@time sample = My.HMC(lf, param)\n", "@time sample = My.HMC(lf, param)\n", "@time sample = My.HMC(lf, param);" ] }, { "cell_type": "code", "execution_count": 8, "id": "da12d7c5-7a32-4ddb-8c5a-88e107501684", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 52.394 ms (4 allocations: 1.53 MiB)\n" ] } ], "source": [ "@btime My.HMC($lf, $param);" ] }, { "cell_type": "code", "execution_count": 9, "id": "f19c7cee-d109-4896-aa65-f8e25188530e", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X, Y = first.(sample), last.(sample)\n", "d = InterpKDE(kde((X, Y)))\n", "x, y = range(extrema(X)...; length=201), range(extrema(Y)...; length=201)\n", "heatmap(x, y, (x, y) -> pdf(d, x, y); size=(450, 400), right_margin=3Plots.mm)" ] }, { "cell_type": "code", "execution_count": 10, "id": "16ece04e-9ac2-4246-94cb-93d726039a7a", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f(n) = mean(x -> x*x', @view sample[1:n])\n", "\n", "n = 1:1000\n", "S = f.(n)\n", "S11 = (S -> S[1,1]).(S)\n", "S22 = (S -> S[2,2]).(S)\n", "S12 = (S -> S[1,2]).(S)\n", "\n", "ymin = min(-1.5, minimum(S11), minimum(S22), minimum(S12))\n", "ymax = max(2.5, maximum(S11), maximum(S22), maximum(S12))\n", "\n", "plot(ylim = (ymin, ymax))\n", "plot!(S11; label=\"s11\", c=1)\n", "hline!([inv(A)[1,1]]; label=\"\", c=1, ls=:dash)\n", "plot!(S22; label=\"s22\", c=2)\n", "hline!([inv(A)[2,2]]; label=\"\", c=2, ls=:dash)\n", "plot!(S12; label=\"s12\", c=3)\n", "hline!([inv(A)[1,2]]; label=\"\", c=3, ls=:dash)" ] }, { "cell_type": "markdown", "id": "5d9c7c9a-4f6e-42d7-a896-bd350a0dacfd", "metadata": {}, "source": [ "## φ(x) = a(x² - 1)²" ] }, { "cell_type": "code", "execution_count": 11, "id": "6fbedce8-2c9e-43d8-95ba-97566ebc292e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0.398456 seconds (1.07 M allocations: 61.863 MiB, 77.31% compilation time)\n", " 0.082761 seconds (17 allocations: 1.526 MiB)\n", " 0.087808 seconds (17 allocations: 1.526 MiB)\n", " 0.086475 seconds (17 allocations: 1.526 MiB)\n", " 0.085122 seconds (17 allocations: 1.526 MiB)\n", " 0.085609 seconds (17 allocations: 1.526 MiB)\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ϕ4(x, a) = a * (x[1]^2 - 1)^2\n", "a = [3, 4, 5, 6, 7, 8]\n", "XX = Vector{Float64}[]\n", "ZZ = Float64[]\n", "PP = []\n", "for i in eachindex(a)\n", " Z = quadgk(x -> exp(-ϕ4((x,), a[i])), -Inf, Inf)[1]\n", " push!(ZZ, Z)\n", " lf = My.LFProblem(1, ϕ4; dt = 0.05, nsteps = 100)\n", " @time X = first.(My.HMC(lf, a[i]))\n", " flush(stdout)\n", " push!(XX, X)\n", " P = plot()\n", " histogram!(X; norm=true, alpha=0.3, label=\"HMC LF sample\", bin=100, c=i)\n", " plot!(x -> exp(-ϕ4(x, a[i]))/Z, -2, 2; label=\"exp(-ϕ2(x))/Z\", lw=2, c=i)\n", " plot!(; legend=false, xtick=-2:0.5:2)\n", " title!(\"ϕ(x) = a(x² - 1)² for a = $(a[i])\", titlefontsize=9)\n", " push!(PP, P)\n", "end\n", "plot(PP...; size=(800, 450), layout=(3, 2))" ] }, { "cell_type": "code", "execution_count": 12, "id": "6ec41866-7337-4932-9442-e1f086fa52e6", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "QQ = []\n", "for i in eachindex(a)\n", " Q = plot(XX[i][1:10000]; ylim=(-1.5, 1.5), label=\"\", c=i, lw=0.5)\n", " title!(\"ϕ(x) = a(x² - 1)² for a = $(a[i])\", titlefontsize=9)\n", " push!(QQ, Q)\n", "end\n", "plot(QQ...; size=(800, 900), layout=(length(a), 1))" ] }, { "cell_type": "markdown", "id": "51247ef8-b3fd-4de8-a386-e5a7e524b800", "metadata": {}, "source": [ "## Baysian inference for a sample of the standard normal distribution" ] }, { "cell_type": "code", "execution_count": 13, "id": "e2d98daa-3950-4911-9175-e3b7335f28bd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Main.My.LFProblem{2, typeof(negloglik), Main.My.var\"#H#3\"{typeof(negloglik)}, Main.My.var\"#F#4\"{typeof(negloglik)}, Float64, Int64}(negloglik, Main.My.var\"#H#3\"{typeof(negloglik)}(negloglik), Main.My.var\"#F#4\"{typeof(negloglik)}(negloglik), 0.1, 30)" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n = 10\n", "sample_normal = randn(n)\n", "f(y, m, s) = (y - m)^2/(2s^2) + log(s)\n", "negloglik(w, sample) = sum(y -> f(y, w[1], exp(w[2])), sample)\n", "lf = My.LFProblem(2, negloglik; dt = 0.1, nsteps = 30)" ] }, { "cell_type": "code", "execution_count": 14, "id": "8dddd001-d1a6-4e85-8bf5-d9ddf7069764", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0.928805 seconds (523.08 k allocations: 28.726 MiB, 2.16% gc time, 21.13% compilation time)\n", " 0.738722 seconds (3 allocations: 1.526 MiB)\n", " 0.732966 seconds (3 allocations: 1.526 MiB)\n" ] }, { "data": { "text/plain": [ "100000-element Vector{SVector{2, Float64}}:\n", " [0.2808963383016968, -0.0850139100101338]\n", " [0.10362935168951949, -0.12357443484880323]\n", " [0.33381587080106806, -0.3194664595540615]\n", " [0.08865059391238689, -0.18711549359242713]\n", " [0.4640628065183025, -0.24134729302029093]\n", " [0.044636008825220064, -0.14100619029803105]\n", " [0.3959908387069734, -0.2742529336583942]\n", " [0.13348146420984588, -0.34705068152376395]\n", " [0.14669602220119124, -0.32869017250054006]\n", " [0.4121643380676835, -0.1579548585694553]\n", " [0.4250665035491883, -0.30593240759744883]\n", " [0.04166611134427756, -0.2330193376509495]\n", " [0.47965330500629877, -0.10474723006036145]\n", " ⋮\n", " [-0.08764472165726266, 0.09113787329216483]\n", " [0.5987313387896349, 0.20965789447320693]\n", " [0.34081831941965435, 0.6281018061918996]\n", " [-0.18478463857491062, 0.6063828055588716]\n", " [1.0066149007384484, 0.5467211516531318]\n", " [0.03564943258579456, 0.23211738965114237]\n", " [0.519419263816568, 0.22204375458922126]\n", " [-0.1105239217011224, 0.3291869366389478]\n", " [0.5551195562503037, 0.33500436097357156]\n", " [-0.14675822018009363, 0.3552289244143782]\n", " [0.6125848610579413, 0.2507218479211541]\n", " [-0.0586160307394845, 0.3267987552833621]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@time sample = My.HMC(lf, sample_normal; init = SVector(0.0, 0.0))\n", "@time sample = My.HMC(lf, sample_normal; init = SVector(0.0, 0.0))\n", "@time sample = My.HMC(lf, sample_normal; init = SVector(0.0, 0.0))" ] }, { "cell_type": "code", "execution_count": 15, "id": "ccc929fa-cc09-4686-ad23-08c308d442f1", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m, logs = first.(sample), last.(sample)\n", "d = InterpKDE(kde((m, logs)))\n", "x, y = range(extrema(m)...; length=201), range(extrema(logs)...; length=201)\n", "heatmap(x, y, (x, y) -> pdf(d, x, y); size=(450, 400), xlabel=\"μ\", ylabel=\"log(σ)\")" ] }, { "cell_type": "markdown", "id": "4ad4d13d-021b-4f98-b0c3-d9958556a1a8", "metadata": {}, "source": [ "## Symbolics.jl example" ] }, { "cell_type": "code", "execution_count": 16, "id": "5f6cc37e-9545-47e5-8b82-0cf77fc13863", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Num[a[1, 1] a[1, 2]; a[2, 1] a[2, 2]], Num[x[1], x[2]])" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dim = 2\n", "@variables a[1:dim, 1:dim], x[1:dim]\n", "aa, xx = collect.((a, x))" ] }, { "cell_type": "code", "execution_count": 17, "id": "3a6d1ad8-d1e1-4028-b2a3-4f0632fa03fa", "metadata": {}, "outputs": [ { "data": { "text/latex": [ "\\begin{equation}\n", "\\left( \\frac{1}{2} a_{1}ˏ_2 + \\frac{1}{2} a_{2}ˏ_1 \\right) x_1 x_2 + \\frac{1}{2} x_1^{2} a_{1}ˏ_1 + \\frac{1}{2} x_2^{2} a_{2}ˏ_2\n", "\\end{equation}\n" ], "text/plain": [ "((1//2)*a[1, 2] + (1//2)*a[2, 1])*x[1]*x[2] + (1//2)*(x[1]^2)*a[1, 1] + (1//2)*(x[2]^2)*a[2, 2]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "phi_sym = dot(xx, aa, xx)/2 |> expand |> simplify" ] }, { "cell_type": "code", "execution_count": 18, "id": "984de92e-01f9-46f6-8057-3d8243311abf", "metadata": {}, "outputs": [ { "data": { "text/latex": [ "\\begin{equation}\n", "\\left[\n", "\\begin{array}{c}\n", "\\left( \\frac{1}{2} a_{1}ˏ_2 + \\frac{1}{2} a_{2}ˏ_1 \\right) x_2 + x_1 a_{1}ˏ_1 \\\\\n", "\\left( \\frac{1}{2} a_{1}ˏ_2 + \\frac{1}{2} a_{2}ˏ_1 \\right) x_1 + x_2 a_{2}ˏ_2 \\\\\n", "\\end{array}\n", "\\right]\n", "\\end{equation}\n" ], "text/plain": [ "2-element Vector{Num}:\n", " ((1//2)*a[1, 2] + (1//2)*a[2, 1])*x[2] + x[1]*a[1, 1]\n", " ((1//2)*a[1, 2] + (1//2)*a[2, 1])*x[1] + x[2]*a[2, 2]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dphi_sym = Symbolics.gradient(phi_sym, xx)" ] }, { "cell_type": "code", "execution_count": 19, "id": "46a3751c-58c1-44eb-8eaa-17896f5ec796", "metadata": {}, "outputs": [], "source": [ "phi_rgf = build_function(phi_sym, xx, aa; expression = Val(false))\n", "dphi_rgf = build_function(dphi_sym, xx, aa; expression = Val(false))[1]\n", "lf = My.LFProblem(dim, phi_rgf, dphi_rgf);" ] }, { "cell_type": "code", "execution_count": 20, "id": "2038e4ca-749a-4613-b6fc-388408d6692a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2×2 SMatrix{2, 2, Float64, 4} with indices SOneTo(2)×SOneTo(2):\n", " 1.0 0.5\n", " 0.5 1.0" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "param = A = @SMatrix [\n", " 1 1/2\n", " 1/2 1\n", "]" ] }, { "cell_type": "code", "execution_count": 21, "id": "c5afb3a3-23ed-4138-9eb8-64576981aaa4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0.174535 seconds (314.75 k allocations: 18.494 MiB, 84.32% compilation time)\n", " 0.027935 seconds (4 allocations: 1.526 MiB)\n", " 0.026462 seconds (4 allocations: 1.526 MiB)\n" ] } ], "source": [ "@time sample = My.HMC(lf, param)\n", "@time sample = My.HMC(lf, param)\n", "@time sample = My.HMC(lf, param);" ] }, { "cell_type": "code", "execution_count": 22, "id": "2bb039dc-1d0f-48d3-a5ae-6b3f0fd791a4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 25.828 ms (4 allocations: 1.53 MiB)\n" ] } ], "source": [ "@btime My.HMC($lf, $param);" ] }, { "cell_type": "markdown", "id": "18c0e2be-5b4b-4560-bed3-ea432b710cd3", "metadata": {}, "source": [ "自動微分を使うよりも計算が速くなっている." ] }, { "cell_type": "code", "execution_count": 23, "id": "b9a1e202-8935-4ab9-ac36-d245d9fd8f1e", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X, Y = first.(sample), last.(sample)\n", "d = InterpKDE(kde((X, Y)))\n", "x, y = range(extrema(X)...; length=201), range(extrema(Y)...; length=201)\n", "heatmap(x, y, (x, y) -> pdf(d, x, y); size=(450, 400), right_margin=3Plots.mm)" ] }, { "cell_type": "code", "execution_count": 24, "id": "faf87209-d0b0-493f-944a-344c122cf2e5", "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f(n) = mean(x -> x*x', @view sample[1:n])\n", "\n", "n = 1:1000\n", "S = f.(n)\n", "S11 = (S -> S[1,1]).(S)\n", "S22 = (S -> S[2,2]).(S)\n", "S12 = (S -> S[1,2]).(S)\n", "\n", "ymin = min(-1.5, minimum(S11), minimum(S22), minimum(S12))\n", "ymax = max(2.5, maximum(S11), maximum(S22), maximum(S12))\n", "\n", "plot(ylim = (ymin, ymax))\n", "plot!(S11; label=\"s11\", c=1)\n", "hline!([inv(A)[1,1]]; label=\"\", c=1, ls=:dash)\n", "plot!(S22; label=\"s22\", c=2)\n", "hline!([inv(A)[2,2]]; label=\"\", c=2, ls=:dash)\n", "plot!(S12; label=\"s12\", c=3)\n", "hline!([inv(A)[1,2]]; label=\"\", c=3, ls=:dash)" ] }, { "cell_type": "code", "execution_count": null, "id": "56a886c9-dcca-46e1-865a-8f8a8a117e7a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "encoding": "# -*- coding: utf-8 -*-", "formats": "ipynb,jl:hydrogen" }, "kernelspec": { "display_name": "Julia 1.8.0", "language": "julia", "name": "julia-1.8" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.8.0" } }, "nbformat": 4, "nbformat_minor": 5 }