{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "866ff385-2e86-405f-b74d-0602389f2807", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Main.O" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "module O\n", "\n", "using Distributions\n", "using Random\n", "\n", "struct NormalGamma <: ContinuousUnivariateDistribution\n", " μ\n", " λ\n", " a\n", " b\n", "end\n", "\n", "function Base.rand(rng::AbstractRNG, d::NormalGamma)\n", " (; μ, λ, a, b) = d\n", " rand(rng, Normal(μ, inv(sqrt(λ * rand(Gamma(a, b))))))\n", "end\n", "\n", "end" ] }, { "cell_type": "code", "execution_count": 2, "id": "1cd9bf6e-9089-48c8-9815-49d1f88d4fc8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0.713671 seconds (8.77 M allocations: 194.933 MiB, 15.64% gc time, 31.25% compilation time)\n", " 0.401586 seconds (8.00 M allocations: 152.580 MiB, 2.92% gc time)\n", " 0.413190 seconds (8.00 M allocations: 152.580 MiB, 3.15% gc time)\n" ] } ], "source": [ "tmp = Vector{Float64}(undef, 10^6)\n", "@time O.rand!(O.NormalGamma(1, 2, 2.1, 0.5), tmp)\n", "@time O.rand!(O.NormalGamma(1, 2, 2.1, 0.5), tmp)\n", "@time O.rand!(O.NormalGamma(1, 2, 2.1, 0.5), tmp);" ] }, { "cell_type": "code", "execution_count": 3, "id": "41556b14-bd5b-46cd-99b2-2478f77ed40b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MethodInstance for rand(::Random.TaskLocalRNG, ::Main.O.NormalGamma)\n", " from rand(rng::Random.AbstractRNG, d::Main.O.NormalGamma) in Main.O at In[1]:13\n", "Arguments\n", " #self#\u001b[36m::Core.Const(rand)\u001b[39m\n", " rng\u001b[36m::Core.Const(Random.TaskLocalRNG())\u001b[39m\n", " d\u001b[36m::Main.O.NormalGamma\u001b[39m\n", "Locals\n", " b\u001b[91m\u001b[1m::Any\u001b[22m\u001b[39m\n", " a\u001b[91m\u001b[1m::Any\u001b[22m\u001b[39m\n", " λ\u001b[91m\u001b[1m::Any\u001b[22m\u001b[39m\n", " μ\u001b[91m\u001b[1m::Any\u001b[22m\u001b[39m\n", "Body\u001b[91m\u001b[1m::Any\u001b[22m\u001b[39m\n", "\u001b[90m1 ─\u001b[39m (μ = Base.getproperty(d, :μ))\n", "\u001b[90m│ \u001b[39m (λ = Base.getproperty(d, :λ))\n", "\u001b[90m│ \u001b[39m (a = Base.getproperty(d, :a))\n", "\u001b[90m│ \u001b[39m (b = Base.getproperty(d, :b))\n", "\u001b[90m│ \u001b[39m %5 = μ\u001b[91m\u001b[1m::Any\u001b[22m\u001b[39m\n", "\u001b[90m│ \u001b[39m %6 = λ\u001b[91m\u001b[1m::Any\u001b[22m\u001b[39m\n", "\u001b[90m│ \u001b[39m %7 = Main.O.Gamma(a, b)\u001b[91m\u001b[1m::Distributions.Gamma\u001b[22m\u001b[39m\n", "\u001b[90m│ \u001b[39m %8 = Main.O.rand(%7)\u001b[91m\u001b[1m::Any\u001b[22m\u001b[39m\n", "\u001b[90m│ \u001b[39m %9 = (%6 * %8)\u001b[91m\u001b[1m::Any\u001b[22m\u001b[39m\n", "\u001b[90m│ \u001b[39m %10 = Main.O.sqrt(%9)\u001b[91m\u001b[1m::Any\u001b[22m\u001b[39m\n", "\u001b[90m│ \u001b[39m %11 = Main.O.inv(%10)\u001b[91m\u001b[1m::Any\u001b[22m\u001b[39m\n", "\u001b[90m│ \u001b[39m %12 = Main.O.Normal(%5, %11)\u001b[91m\u001b[1m::Distributions.Normal\u001b[22m\u001b[39m\n", "\u001b[90m│ \u001b[39m %13 = Main.O.rand(rng, %12)\u001b[91m\u001b[1m::Any\u001b[22m\u001b[39m\n", "\u001b[90m└──\u001b[39m return %13\n", "\n" ] } ], "source": [ "@code_warntype rand(O.Random.default_rng(), O.NormalGamma(1, 2, 3, 4))" ] }, { "cell_type": "code", "execution_count": 4, "id": "f993d6ad-60d5-431f-880e-6186b0a44fc3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Main.Q" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "module Q\n", "\n", "using Distributions\n", "using Random\n", "\n", "struct NormalGamma{Tμ, Tλ, Ta, Tb} <: ContinuousUnivariateDistribution\n", " μ::Tμ\n", " λ::Tλ\n", " a::Ta\n", " b::Tb\n", "end\n", "\n", "function Base.rand(rng::AbstractRNG, d::NormalGamma)\n", " (; μ, λ, a, b) = d\n", " rand(rng, Normal(μ, inv(sqrt(λ * rand(Gamma(a, b))))))\n", "end\n", "\n", "end" ] }, { "cell_type": "code", "execution_count": 5, "id": "fce4f1db-6d2d-49c1-8974-121225506e8e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0.404534 seconds (8.02 M allocations: 153.928 MiB, 7.86% gc time, 2.03% compilation time)\n", " 0.381746 seconds (8.00 M allocations: 152.580 MiB, 4.92% gc time)\n", " 0.378809 seconds (8.00 M allocations: 152.580 MiB, 3.18% gc time)\n", " 0.074955 seconds (37.05 k allocations: 2.076 MiB, 21.77% compilation time)\n", " 0.059850 seconds\n", " 0.059976 seconds\n", " 0.091595 seconds (128.33 k allocations: 6.742 MiB, 8.84% gc time, 36.64% compilation time)\n", " 0.058394 seconds\n", " 0.057663 seconds\n" ] } ], "source": [ "using Random\n", "using Distributions\n", "using StatsPlots\n", "default(fmt = :png)\n", "\n", "res(xlim, A) = A[first(xlim) .< A .< last(xlim)]\n", "\n", "μ, λ, a, b = 1, 2, 1.5, 4\n", "ng_org = O.NormalGamma(μ, λ, a, b)\n", "ng_rev = Q.NormalGamma(μ, λ, a, b)\n", "normal = Normal(μ, 1/√(λ*a*b))\n", "tdist = μ + TDist(2a)/√(λ*a*b)\n", "\n", "A = Vector{Float64}(undef, 10^6)\n", "B = similar(A)\n", "C = similar(A)\n", "A = @time rand!(ng_org, A)\n", "A = @time rand!(ng_org, A)\n", "A = @time rand!(ng_org, A)\n", "B = @time rand!(ng_rev, B)\n", "B = @time rand!(ng_rev, B)\n", "B = @time rand!(ng_rev, B)\n", "C = @time rand!(tdist, C)\n", "C = @time rand!(tdist, C)\n", "C = @time rand!(tdist, C)\n", ";" ] }, { "cell_type": "code", "execution_count": 6, "id": "7ea7e28e-e388-40fd-a4fb-7f601c1f084b", "metadata": {}, "outputs": [ { "data": { "image/png": "" }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xlim = quantile.(Ref(A), (0.001, 0.999))\n", "bin = round(Int, abs(-(xlim...))*50)\n", "stephist(res(xlim, A); norm=true, label=\"original NormalGamma\", bin)\n", "stephist!(res(xlim, B); norm=true, label=\"revised NormalGamma\", bin, ls=:dash)\n", "stephist!(res(xlim, C); norm=true, label=\"μ+TDist(2a)/√(λ*a*b)\", bin, ls=:dashdot)\n", "plot!(normal, xlim...; label=\"normal approx.\", c=:red, ls=:dot)\n", "plot!(; xlim)" ] }, { "cell_type": "code", "execution_count": null, "id": "b948530f-ab0b-44a9-b995-1e7f941cfccd", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "@webio": { "lastCommId": null, "lastKernelId": null }, "jupytext": { "formats": "ipynb,md" }, "kernelspec": { "display_name": "Julia 1.7.2", "language": "julia", "name": "julia-1.7" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.7.2" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }