{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "c630076d-430b-493d-ac80-27b2f408c143", "metadata": {}, "outputs": [], "source": [ "using Distributions\n", "using StatsPlots\n", "using Memoization\n", "\n", "default(size=(500, 400))\n", "plot(sin);" ] }, { "cell_type": "code", "execution_count": 2, "id": "a3372dea-760f-4b07-bd13-da598d0597ee", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "abs_error (generic function with 1 method)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@memoize function E_bin(f, n, p)\n", " bin = Binomial(n, p)\n", " sum(f(k) * pdf(bin, k) for k in support(bin))\n", "end\n", "\n", "@memoize function E_negbin(g, k, p)\n", " negbin = LocationScale(k, 1, NegativeBinomial(k, p))\n", " m, s = mean(negbin), std(negbin)\n", " nmax = round(Int, m + 5s)\n", " sum(g(n) * pdf(negbin, n) for n in k:nmax)\n", "end\n", "\n", "function kl(n, k, p; a = 0.5, b = 0.5)\n", " p̂ = (k + a)/(n + a + b)\n", " -(p*log(p̂) + (1 - p)*log(1 - p̂)) - entropy(Bernoulli(p)) \n", "end\n", "\n", "function squared_error(n, k, p; a = 0.5, b = 0.5)\n", " p̂ = (k + a)/(n + a + b)\n", " (p̂ - p)^2\n", "end\n", "\n", "function abs_error(n, k, p; a = 0.5, b = 0.5)\n", " p̂ = (k + a)/(n + a + b)\n", " abs(p̂ - p)\n", "end" ] }, { "cell_type": "code", "execution_count": 3, "id": "7319b0a0-c626-4f05-aff0-ca7489bb8f71", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 3.539977 seconds (5.40 M allocations: 254.922 MiB, 6.02% gc time, 9.23% compilation time)\n", "(val, idx) = findmin(z) = (0.047898347587918166, CartesianIndex(51, 51))\n", "(a[idx[1]], b[idx[2]]) = (0.51, 0.51)\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n = 10\n", "@memoize f(n, a, b) = maximum(p -> E_bin(k -> kl(n, k, p; a, b), n, p), 0.0:0.01:1)\n", "a = b = 0.01:0.01:1\n", "@time z = f.(n, a, b')\n", "@show val, idx = findmin(z)\n", "@show a[idx[1]], b[idx[2]]\n", "P = contourf(a, b, z'; label=\"\")\n", "Q = surface(a, b, z'; label=\"\")\n", "plot(P, Q; size=(800, 400), colorbar=false, camera=(60, 60))" ] }, { "cell_type": "code", "execution_count": 4, "id": "53734fc1-1b45-496c-94a5-21d2d7a4072d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 16.230313 seconds (8.41 M allocations: 393.980 MiB, 2.81% gc time, 1.49% compilation time)\n", "(val, idx) = findmin(z) = (0.09218128401137077, CartesianIndex(69, 66))\n", "(a[idx[1]], b[idx[2]]) = (0.336, 0.43)\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "k = 3\n", "@memoize g(k, a, b) = maximum(p -> E_negbin(n -> kl(n, k, p; a, b), k, p), 0.01:0.01:0.99)\n", "a = 0.2:0.002:0.5\n", "b = 0.3:0.002:0.55\n", "@time z = g.(k, a, b')\n", "@show val, idx = findmin(z)\n", "@show a[idx[1]], b[idx[2]]\n", "P = contourf(a, b, z'; label=\"\")\n", "Q = surface(a, b, z'; label=\"\")\n", "plot(P, Q; size=(800, 400), colorbar=false, camera=(60, 60))" ] }, { "cell_type": "code", "execution_count": 5, "id": "3a4cdb61-1cf5-4e1e-b443-9a4213b0dff5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 9.104402 seconds (10.01 M allocations: 488.196 MiB, 8.56% gc time, 2.48% compilation time)\n", "(val, idx) = findmin(z) = (0.014435380308755486, CartesianIndex(109, 109))\n", "(a[idx[1]], b[idx[2]]) = (1.58, 1.58)\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n = 10\n", "@memoize F(n, a, b) = maximum(p -> E_bin(k -> squared_error(n, k, p; a, b), n, p), 0:0.01:1)\n", "a = b = 0.5:0.01:2\n", "@time z = F.(n, a, b')\n", "@show val, idx = findmin(z)\n", "@show a[idx[1]], b[idx[2]]\n", "P = contourf(a, b, z'; label=\"\")\n", "Q = surface(a, b, z'; label=\"\")\n", "plot(P, Q; size=(800, 400), colorbar=false, camera=(60, 60))" ] }, { "cell_type": "code", "execution_count": 6, "id": "348db599-48d3-4009-8fe3-2b59142bd824", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 12.786361 seconds (8.01 M allocations: 380.006 MiB, 1.69% compilation time)\n", "(val, idx) = findmin(z) = (0.029013989479268445, CartesianIndex(61, 78))\n", "(a[idx[1]], b[idx[2]]) = (0.13, 0.654)\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "k = 3\n", "@memoize G(k, a, b) = maximum(p -> E_negbin(n -> squared_error(n, k, p; a, b), k, p), 0.01:0.01:0.99)\n", "a = 0.01:0.002:0.25\n", "b = 0.5:0.002:0.8\n", "@time z = G.(k, a, b')\n", "@show val, idx = findmin(z)\n", "@show a[idx[1]], b[idx[2]]\n", "P = contourf(a, b, z'; label=\"\")\n", "Q = surface(a, b, z'; label=\"\")\n", "plot(P, Q; size=(800, 400), colorbar=false, camera=(60, 60))" ] }, { "cell_type": "code", "execution_count": 7, "id": "571a86f1-0247-42a3-a864-2b9930673bf5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 11.382526 seconds (10.01 M allocations: 520.117 MiB, 12.49% gc time, 1.92% compilation time)\n", "(val, idx) = findmin(z) = (0.10265274157994743, CartesianIndex(80, 80))\n", "(a[idx[1]], b[idx[2]]) = (1.29, 1.29)\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n = 10\n", "@memoize f1(n, a, b) = maximum(p -> E_bin(k -> abs_error(n, k, p; a, b), n, p), 0:0.01:1)\n", "a = b = 0.5:0.01:2\n", "@time z = f1.(n, a, b')\n", "@show val, idx = findmin(z)\n", "@show a[idx[1]], b[idx[2]]\n", "P = contourf(a, b, z'; label=\"\")\n", "Q = surface(a, b, z'; label=\"\")\n", "plot(P, Q; size=(800, 400), colorbar=false, camera=(60, 60))" ] }, { "cell_type": "code", "execution_count": 8, "id": "632e8c56-e74d-4a2c-ae91-6b790e77123c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 19.878701 seconds (7.22 M allocations: 413.755 MiB, 10.58% gc time, 1.09% compilation time)\n", "(val, idx) = findmin(z) = (0.1411059886538703, CartesianIndex(43, 122))\n", "(a[idx[1]], b[idx[2]]) = (0.62, 0.621)\n" ] }, { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "k = 3\n", "@memoize g1(k, a, b) = maximum(p -> E_negbin(n -> abs_error(n, k, p; a, b), k, p), 0.01:0.01:0.99)\n", "a = 0.2:0.01:1\n", "b = 0.5:0.001:0.7\n", "@time z = g1.(k, a, b')\n", "@show val, idx = findmin(z)\n", "@show a[idx[1]], b[idx[2]]\n", "P = contourf(a, b, z'; label=\"\")\n", "Q = surface(a, b, z'; label=\"\")\n", "plot(P, Q; size=(800, 400), colorbar=false, camera=(60, 60))" ] }, { "cell_type": "code", "execution_count": null, "id": "f96fcc52-2170-43bf-8c46-7fa1b3b8f1af", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "formats": "ipynb,md" }, "kernelspec": { "display_name": "Julia 1.7.1", "language": "julia", "name": "julia-1.7" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.7.1" } }, "nbformat": 4, "nbformat_minor": 5 }