{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "c630076d-430b-493d-ac80-27b2f408c143",
"metadata": {},
"outputs": [],
"source": [
"using Distributions\n",
"using StatsPlots\n",
"using Memoization\n",
"\n",
"default(size=(500, 400))\n",
"plot(sin);"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a3372dea-760f-4b07-bd13-da598d0597ee",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"abs_error (generic function with 1 method)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@memoize function E_bin(f, n, p)\n",
" bin = Binomial(n, p)\n",
" sum(f(k) * pdf(bin, k) for k in support(bin))\n",
"end\n",
"\n",
"@memoize function E_negbin(g, k, p)\n",
" negbin = LocationScale(k, 1, NegativeBinomial(k, p))\n",
" m, s = mean(negbin), std(negbin)\n",
" nmax = round(Int, m + 5s)\n",
" sum(g(n) * pdf(negbin, n) for n in k:nmax)\n",
"end\n",
"\n",
"function kl(n, k, p; a = 0.5, b = 0.5)\n",
" p̂ = (k + a)/(n + a + b)\n",
" -(p*log(p̂) + (1 - p)*log(1 - p̂)) - entropy(Bernoulli(p)) \n",
"end\n",
"\n",
"function squared_error(n, k, p; a = 0.5, b = 0.5)\n",
" p̂ = (k + a)/(n + a + b)\n",
" (p̂ - p)^2\n",
"end\n",
"\n",
"function abs_error(n, k, p; a = 0.5, b = 0.5)\n",
" p̂ = (k + a)/(n + a + b)\n",
" abs(p̂ - p)\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7319b0a0-c626-4f05-aff0-ca7489bb8f71",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 3.539977 seconds (5.40 M allocations: 254.922 MiB, 6.02% gc time, 9.23% compilation time)\n",
"(val, idx) = findmin(z) = (0.047898347587918166, CartesianIndex(51, 51))\n",
"(a[idx[1]], b[idx[2]]) = (0.51, 0.51)\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n = 10\n",
"@memoize f(n, a, b) = maximum(p -> E_bin(k -> kl(n, k, p; a, b), n, p), 0.0:0.01:1)\n",
"a = b = 0.01:0.01:1\n",
"@time z = f.(n, a, b')\n",
"@show val, idx = findmin(z)\n",
"@show a[idx[1]], b[idx[2]]\n",
"P = contourf(a, b, z'; label=\"\")\n",
"Q = surface(a, b, z'; label=\"\")\n",
"plot(P, Q; size=(800, 400), colorbar=false, camera=(60, 60))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "53734fc1-1b45-496c-94a5-21d2d7a4072d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 16.230313 seconds (8.41 M allocations: 393.980 MiB, 2.81% gc time, 1.49% compilation time)\n",
"(val, idx) = findmin(z) = (0.09218128401137077, CartesianIndex(69, 66))\n",
"(a[idx[1]], b[idx[2]]) = (0.336, 0.43)\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"k = 3\n",
"@memoize g(k, a, b) = maximum(p -> E_negbin(n -> kl(n, k, p; a, b), k, p), 0.01:0.01:0.99)\n",
"a = 0.2:0.002:0.5\n",
"b = 0.3:0.002:0.55\n",
"@time z = g.(k, a, b')\n",
"@show val, idx = findmin(z)\n",
"@show a[idx[1]], b[idx[2]]\n",
"P = contourf(a, b, z'; label=\"\")\n",
"Q = surface(a, b, z'; label=\"\")\n",
"plot(P, Q; size=(800, 400), colorbar=false, camera=(60, 60))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3a4cdb61-1cf5-4e1e-b443-9a4213b0dff5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 9.104402 seconds (10.01 M allocations: 488.196 MiB, 8.56% gc time, 2.48% compilation time)\n",
"(val, idx) = findmin(z) = (0.014435380308755486, CartesianIndex(109, 109))\n",
"(a[idx[1]], b[idx[2]]) = (1.58, 1.58)\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n = 10\n",
"@memoize F(n, a, b) = maximum(p -> E_bin(k -> squared_error(n, k, p; a, b), n, p), 0:0.01:1)\n",
"a = b = 0.5:0.01:2\n",
"@time z = F.(n, a, b')\n",
"@show val, idx = findmin(z)\n",
"@show a[idx[1]], b[idx[2]]\n",
"P = contourf(a, b, z'; label=\"\")\n",
"Q = surface(a, b, z'; label=\"\")\n",
"plot(P, Q; size=(800, 400), colorbar=false, camera=(60, 60))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "348db599-48d3-4009-8fe3-2b59142bd824",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 12.786361 seconds (8.01 M allocations: 380.006 MiB, 1.69% compilation time)\n",
"(val, idx) = findmin(z) = (0.029013989479268445, CartesianIndex(61, 78))\n",
"(a[idx[1]], b[idx[2]]) = (0.13, 0.654)\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"k = 3\n",
"@memoize G(k, a, b) = maximum(p -> E_negbin(n -> squared_error(n, k, p; a, b), k, p), 0.01:0.01:0.99)\n",
"a = 0.01:0.002:0.25\n",
"b = 0.5:0.002:0.8\n",
"@time z = G.(k, a, b')\n",
"@show val, idx = findmin(z)\n",
"@show a[idx[1]], b[idx[2]]\n",
"P = contourf(a, b, z'; label=\"\")\n",
"Q = surface(a, b, z'; label=\"\")\n",
"plot(P, Q; size=(800, 400), colorbar=false, camera=(60, 60))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "571a86f1-0247-42a3-a864-2b9930673bf5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 11.382526 seconds (10.01 M allocations: 520.117 MiB, 12.49% gc time, 1.92% compilation time)\n",
"(val, idx) = findmin(z) = (0.10265274157994743, CartesianIndex(80, 80))\n",
"(a[idx[1]], b[idx[2]]) = (1.29, 1.29)\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n = 10\n",
"@memoize f1(n, a, b) = maximum(p -> E_bin(k -> abs_error(n, k, p; a, b), n, p), 0:0.01:1)\n",
"a = b = 0.5:0.01:2\n",
"@time z = f1.(n, a, b')\n",
"@show val, idx = findmin(z)\n",
"@show a[idx[1]], b[idx[2]]\n",
"P = contourf(a, b, z'; label=\"\")\n",
"Q = surface(a, b, z'; label=\"\")\n",
"plot(P, Q; size=(800, 400), colorbar=false, camera=(60, 60))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "632e8c56-e74d-4a2c-ae91-6b790e77123c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" 19.878701 seconds (7.22 M allocations: 413.755 MiB, 10.58% gc time, 1.09% compilation time)\n",
"(val, idx) = findmin(z) = (0.1411059886538703, CartesianIndex(43, 122))\n",
"(a[idx[1]], b[idx[2]]) = (0.62, 0.621)\n"
]
},
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"k = 3\n",
"@memoize g1(k, a, b) = maximum(p -> E_negbin(n -> abs_error(n, k, p; a, b), k, p), 0.01:0.01:0.99)\n",
"a = 0.2:0.01:1\n",
"b = 0.5:0.001:0.7\n",
"@time z = g1.(k, a, b')\n",
"@show val, idx = findmin(z)\n",
"@show a[idx[1]], b[idx[2]]\n",
"P = contourf(a, b, z'; label=\"\")\n",
"Q = surface(a, b, z'; label=\"\")\n",
"plot(P, Q; size=(800, 400), colorbar=false, camera=(60, 60))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f96fcc52-2170-43bf-8c46-7fa1b3b8f1af",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,md"
},
"kernelspec": {
"display_name": "Julia 1.7.1",
"language": "julia",
"name": "julia-1.7"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}