{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# julia言語でポアソン混合モデルを実装する(n番煎じ) \n", "\n", "~~ネットに頼ればJulia完全理解者がより速いコードを教えてくれるらしいので,~~Juliaでポアソン混合モデルを実装してみました。\n", "\n", "須山さんの「ベイズ推論による機械学習」の第4章です。\n", "\n", "変分推論の勉強中だったのでGibbs SamplingではなくVariational Bayesです。というか個人的にDeterministicな方法が好きなので,変分推論です。\n", "\n", "一部だけ須山さんの[サポートページ](https://github.com/sammy-suyama/BayesBook/blob/master/src/PoissonMixtureModel.jl)のコードを参考にしています。\n", "\n", "ポアソン混合モデルの活用については詳しくないのですが,非負離散量の分布なので,たとえばテスト得点で群分けするのに使えたりするんじゃないでしょうか。" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "0.00\n", "\n", "\n", "0.02\n", "\n", "\n", "0.04\n", "\n", "\n", "0.06\n", "\n", "\n", "0.08\n", "\n", "\n", "0.10\n", "\n", "\n", "0.12\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "y1\n", "\n", "\n", "\n", "y2\n", "\n", "\n", "\n", "y3\n", "\n", "\n" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "using Distributions, StatsBase, StatsFuns, Plots, Random, SpecialFunctions\n", "\n", "N = 10000 # sample n\n", "Sn = rand(Categorical([0.3, 0.3, 0.4 ]), N)# latent cluster\n", "\n", "Khyper = [2 3; 8 4; 15 1]\n", "plot([0:0.01:40;], pdf.(Gamma(Khyper[1,1], Khyper[1,2]), [0:0.01:40;]))\n", "plot!([0:0.01:40;], pdf.(Gamma(Khyper[2,1], Khyper[2,2]), [0:0.01:40;]))\n", "plot!([0:0.01:40;], pdf.(Gamma(Khyper[3,1], Khyper[3,2]), [0:0.01:40;]))" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1×10000 LinearAlgebra.Adjoint{Int64,Array{Int64,1}}:\n", " 2 37 5 37 1 15 16 29 3 16 … 16 13 22 8 3 16 1 38 4 37" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Random.seed!(204)\n", "λ = [rand(Gamma(Khyper[i,1], Khyper[i,2])) for i in 1:3 ]\n", "X = [ rand(Poisson(λ[Sn[i]]), 1)[1] for i in 1:length(Sn)]\n", "X'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "10\n", "\n", "\n", "20\n", "\n", "\n", "30\n", "\n", "\n", "40\n", "\n", "\n", "50\n", "\n", "\n", "60\n", "\n", "\n", "0\n", "\n", "\n", "200\n", "\n", "\n", "400\n", "\n", "\n", "600\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "y1\n", "\n", "\n" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "plot(fit(Histogram, X, nbins = 100))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "今回はこんな感じのデータを用意します。一次元の非負離散変数を混合ポアソン分布から発生させました。混合比率は`Categorical([0.3, 0.3, 0.4 ])`で指定して,ポアソン分布の超パラは`λ`,その`λ`自体はガンマ分布`Gamma()`から生成しました。\n", "\n", "次に変分推論を実行するための関数を書き下します。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "vb (generic function with 5 methods)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "struct MixturePoissonVB\n", " η\n", " λ\n", " shape\n", " scale\n", " α\n", "end\n", "\n", "function vb(X, nK, MAXITER = 10,\n", " a = sample([1:1:40;], nK), # Gamma hyper param(shape)\n", " b = sample([1:1:40;], nK), # Gamma hyper param(scale)\n", " α = sample([1:1:40;], nK) # Dirichrret hyper param\n", " )\n", " N = length(X)\n", " # initialize distribution (Expectation)\n", " Sn = rand(Categorical(ones(nK)/nK), N)# cluster\n", " S = zeros(Int64, N, nK)\n", " println(\"Initialize latent matrix\")\n", " for k in 1:nK\n", " S[findall(x->x==k, Sn), k] .= 1\n", " end\n", " SnX = S' * X\n", " sumS = sum(S', dims = 2)\n", " println(\"Initialize parameter vectors\")\n", " # empty vectors\n", " a1 = zeros(nK);b1 = zeros(nK);α1 = zeros(nK)\n", " lnλ1 = zeros(nK);lnπ1 = zeros(nK);λ1 = zeros(nK)\n", " η1 = zeros(N, nK)\n", " for k in 1:nK\n", " a1[k] = SnX[k] + a[k]\n", " b1[k] = sumS[k] + b[k]\n", " α1[k] = sumS[k] + α[k]\n", " λ1[k] = a1[k] / b1[k]\n", " lnλ1[k] = digamma(a1[k]) - log(b1[k])\n", " lnπ1[k] = digamma(α1[k]) - digamma(sum(α1))\n", " end\n", " # VB ITERATION\n", " ITER = 0\n", " while ITER < (MAXITER + 1)\n", " print(\"Itaration\", ITER, \"... λ is \")\n", " println(λ1)\n", " ITER += 1\n", " # Expectation of Sn\n", " for i in 1:N\n", " η1[i,:] = exp.(X[i] * lnλ1 - λ1 + lnπ1)\n", " η1[i,:] .= η1[i,:] / sum(η1[i,:]) # shoud use logsumexp ?\n", " end\n", " # Expectation of λ and π\n", " ηX = η1' * X\n", " sumη = sum(η1', dims = 2) # total probability of each cluster\n", " for k in 1:nK\n", " a1[k] = ηX[k] + a[k]\n", " b1[k] = sumη[k] + b[k]\n", " α1[k] = sumη[k] + α[k]\n", " λ1[k] = a1[k] / b1[k]\n", " lnλ1[k] = digamma(a1[k]) - log(b1[k])\n", " lnπ1[k] = digamma(α1[k]) - digamma(sum(α1))\n", " end\n", " end # of while\n", " MixturePoissonVB(η1, λ1, a1, b1, α1)\n", "end # of function" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ちょっとまだ関数が未熟でELBOによる収束判断を実装できていません。推定すべきパラメタは潜在変数`Sn`とポアソン分布,ガンマ分布,ディリクレ分布のパラメタです。" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3-element Array{Float64,1}:\n", " 3.1272608113496094\n", " 41.148909106061154 \n", " 16.084052566969888 " ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "λ # 真値" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "まずは事前分布であるガンマ分布とディリクレ分布の超パラをランダムに設定します。これだけのデータ数があれば,概ね大抵の事前分布であっても,データに見合った推定値が得られるはずです。" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initialize latent matrix\n", "Initialize parameter vectors\n", "Itaration0... λ is [19.4747, 20.128, 20.019]\n", "Itaration1... λ is [16.5266, 24.5148, 22.9295]\n", "Itaration2... λ is [9.06952, 37.9468, 29.3673]\n", "Itaration3... λ is [8.63289, 41.2114, 23.8154]\n", "Itaration4... λ is [7.41936, 41.6757, 20.1111]\n", "Itaration5... λ is [5.69185, 41.4263, 18.1008]\n", "Itaration6... λ is [4.1456, 41.2581, 16.9235]\n", "Itaration7... λ is [3.41642, 41.1582, 16.328]\n", "Itaration8... λ is [3.22317, 41.1039, 16.098]\n", "Itaration9... λ is [3.17612, 41.0804, 16.023]\n", "Itaration10... λ is [3.16385, 41.0717, 16.0007]\n", "Itaration11... λ is [3.16052, 41.0688, 15.9942]\n", "Itaration12... λ is [3.15959, 41.0679, 15.9924]\n", "Itaration13... λ is [3.15934, 41.0677, 15.9918]\n", "Itaration14... λ is [3.15926, 41.0676, 15.9917]\n", "Itaration15... λ is [3.15924, 41.0676, 15.9916]\n", "Itaration16... λ is [3.15924, 41.0676, 15.9916]\n", "Itaration17... λ is [3.15924, 41.0676, 15.9916]\n", "Itaration18... λ is [3.15924, 41.0676, 15.9916]\n", "Itaration19... λ is [3.15924, 41.0676, 15.9916]\n", "Itaration20... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration21... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration22... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration23... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration24... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration25... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration26... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration27... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration28... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration29... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration30... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration31... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration32... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration33... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration34... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration35... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration36... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration37... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration38... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration39... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration40... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration41... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration42... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration43... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration44... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration45... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration46... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration47... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration48... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration49... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration50... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration51... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration52... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration53... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration54... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration55... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration56... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration57... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration58... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration59... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration60... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration61... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration62... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration63... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration64... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration65... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration66... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration67... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration68... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration69... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration70... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration71... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration72... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration73... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration74... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration75... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration76... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration77... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration78... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration79... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration80... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration81... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration82... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration83... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration84... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration85... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration86... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration87... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration88... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration89... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration90... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration91... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration92... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration93... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration94... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration95... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration96... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration97... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration98... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration99... λ is [3.15923, 41.0676, 15.9916]\n", "Itaration100... λ is [3.15923, 41.0676, 15.9916]\n", " 3.155149 seconds (11.59 M allocations: 981.732 MiB, 11.90% gc time)\n" ] }, { "data": { "text/plain": [ "MixturePoissonVB([0.99991 5.98437e-15 9.00214e-5; 1.71818e-25 0.999931 6.93452e-5; … ; 0.997698 1.0091e-12 0.00230168; 1.71818e-25 0.999931 6.93452e-5], [3.15923, 41.0676, 15.9916], [9470.59, 1.26779e5, 63329.7], [2997.75, 3087.08, 3960.18], [3009.75, 3098.08, 3956.18])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@time res1 = vb(X, 3, 100)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "おつぎは事前分布を生成モデルと一致させてみます。生成モデルにディリクレ分布は使わなかったので,適当に1を設定しました。" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Initialize latent matrix\n", "Initialize parameter vectors\n", "Itaration0... λ is [20.1854, 19.6025, 20.0288]\n", "Itaration1... λ is [22.1964, 15.4098, 20.2071]\n", "Itaration2... λ is [33.8609, 7.35802, 20.6835]\n", "Itaration3... λ is [39.4515, 6.06035, 17.7006]\n", "Itaration4... λ is [41.1095, 4.39839, 16.9881]\n", "Itaration5... λ is [41.2965, 3.49929, 16.4746]\n", "Itaration6... λ is [41.2815, 3.24416, 16.2233]\n", "Itaration7... λ is [41.257, 3.1829, 16.1316]\n", "Itaration8... λ is [41.2452, 3.1671, 16.1026]\n", "Itaration9... λ is [41.241, 3.16281, 16.0939]\n", "Itaration10... λ is [41.2396, 3.16163, 16.0914]\n", "Itaration11... λ is [41.2392, 3.16129, 16.0907]\n", "Itaration12... λ is [41.2391, 3.1612, 16.0905]\n", "Itaration13... λ is [41.2391, 3.16117, 16.0904]\n", "Itaration14... λ is [41.239, 3.16117, 16.0904]\n", "Itaration15... λ is [41.239, 3.16116, 16.0904]\n", "Itaration16... λ is [41.239, 3.16116, 16.0904]\n", "Itaration17... λ is [41.239, 3.16116, 16.0904]\n", "Itaration18... λ is [41.239, 3.16116, 16.0904]\n", "Itaration19... λ is [41.239, 3.16116, 16.0904]\n", "Itaration20... λ is [41.239, 3.16116, 16.0904]\n", "Itaration21... λ is [41.239, 3.16116, 16.0904]\n", "Itaration22... λ is [41.239, 3.16116, 16.0904]\n", "Itaration23... λ is [41.239, 3.16116, 16.0904]\n", "Itaration24... λ is [41.239, 3.16116, 16.0904]\n", "Itaration25... λ is [41.239, 3.16116, 16.0904]\n", "Itaration26... λ is [41.239, 3.16116, 16.0904]\n", "Itaration27... λ is [41.239, 3.16116, 16.0904]\n", "Itaration28... λ is [41.239, 3.16116, 16.0904]\n", "Itaration29... λ is [41.239, 3.16116, 16.0904]\n", "Itaration30... λ is [41.239, 3.16116, 16.0904]\n", "Itaration31... λ is [41.239, 3.16116, 16.0904]\n", "Itaration32... λ is [41.239, 3.16116, 16.0904]\n", "Itaration33... λ is [41.239, 3.16116, 16.0904]\n", "Itaration34... λ is [41.239, 3.16116, 16.0904]\n", "Itaration35... λ is [41.239, 3.16116, 16.0904]\n", "Itaration36... λ is [41.239, 3.16116, 16.0904]\n", "Itaration37... λ is [41.239, 3.16116, 16.0904]\n", "Itaration38... λ is [41.239, 3.16116, 16.0904]\n", "Itaration39... λ is [41.239, 3.16116, 16.0904]\n", "Itaration40... λ is [41.239, 3.16116, 16.0904]\n", "Itaration41... λ is [41.239, 3.16116, 16.0904]\n", "Itaration42... λ is [41.239, 3.16116, 16.0904]\n", "Itaration43... λ is [41.239, 3.16116, 16.0904]\n", "Itaration44... λ is [41.239, 3.16116, 16.0904]\n", "Itaration45... λ is [41.239, 3.16116, 16.0904]\n", "Itaration46... λ is [41.239, 3.16116, 16.0904]\n", "Itaration47... λ is [41.239, 3.16116, 16.0904]\n", "Itaration48... λ is [41.239, 3.16116, 16.0904]\n", "Itaration49... λ is [41.239, 3.16116, 16.0904]\n", "Itaration50... λ is [41.239, 3.16116, 16.0904]\n", "Itaration51... λ is [41.239, 3.16116, 16.0904]\n", "Itaration52... λ is [41.239, 3.16116, 16.0904]\n", "Itaration53... λ is [41.239, 3.16116, 16.0904]\n", "Itaration54... λ is [41.239, 3.16116, 16.0904]\n", "Itaration55... λ is [41.239, 3.16116, 16.0904]\n", "Itaration56... λ is [41.239, 3.16116, 16.0904]\n", "Itaration57... λ is [41.239, 3.16116, 16.0904]\n", "Itaration58... λ is [41.239, 3.16116, 16.0904]\n", "Itaration59... λ is [41.239, 3.16116, 16.0904]\n", "Itaration60... λ is [41.239, 3.16116, 16.0904]\n", "Itaration61... λ is [41.239, 3.16116, 16.0904]\n", "Itaration62... λ is [41.239, 3.16116, 16.0904]\n", "Itaration63... λ is [41.239, 3.16116, 16.0904]\n", "Itaration64... λ is [41.239, 3.16116, 16.0904]\n", "Itaration65... λ is [41.239, 3.16116, 16.0904]\n", "Itaration66... λ is [41.239, 3.16116, 16.0904]\n", "Itaration67... λ is [41.239, 3.16116, 16.0904]\n", "Itaration68... λ is [41.239, 3.16116, 16.0904]\n", "Itaration69... λ is [41.239, 3.16116, 16.0904]\n", "Itaration70... λ is [41.239, 3.16116, 16.0904]\n", "Itaration71... λ is [41.239, 3.16116, 16.0904]\n", "Itaration72... λ is [41.239, 3.16116, 16.0904]\n", "Itaration73... λ is [41.239, 3.16116, 16.0904]\n", "Itaration74... λ is [41.239, 3.16116, 16.0904]\n", "Itaration75... λ is [41.239, 3.16116, 16.0904]\n", "Itaration76... λ is [41.239, 3.16116, 16.0904]\n", "Itaration77... λ is [41.239, 3.16116, 16.0904]\n", "Itaration78... λ is [41.239, 3.16116, 16.0904]\n", "Itaration79... λ is [41.239, 3.16116, 16.0904]\n", "Itaration80... λ is [41.239, 3.16116, 16.0904]\n", "Itaration81... λ is [41.239, 3.16116, 16.0904]\n", "Itaration82... λ is [41.239, 3.16116, 16.0904]\n", "Itaration83... λ is [41.239, 3.16116, 16.0904]\n", "Itaration84... λ is [41.239, 3.16116, 16.0904]\n", "Itaration85... λ is [41.239, 3.16116, 16.0904]\n", "Itaration86... λ is [41.239, 3.16116, 16.0904]\n", "Itaration87... λ is [41.239, 3.16116, 16.0904]\n", "Itaration88... λ is [41.239, 3.16116, 16.0904]\n", "Itaration89... λ is [41.239, 3.16116, 16.0904]\n", "Itaration90... λ is [41.239, 3.16116, 16.0904]\n", "Itaration91... λ is [41.239, 3.16116, 16.0904]\n", "Itaration92... λ is [41.239, 3.16116, 16.0904]\n", "Itaration93... λ is [41.239, 3.16116, 16.0904]\n", "Itaration94... λ is [41.239, 3.16116, 16.0904]\n", "Itaration95... λ is [41.239, 3.16116, 16.0904]\n", "Itaration96... λ is [41.239, 3.16116, 16.0904]\n", "Itaration97... λ is [41.239, 3.16116, 16.0904]\n", "Itaration98... λ is [41.239, 3.16116, 16.0904]\n", "Itaration99... λ is [41.239, 3.16116, 16.0904]\n", "Itaration100... λ is [41.239, 3.16116, 16.0904]\n", " 1.945381 seconds (8.38 M allocations: 823.399 MiB, 9.55% gc time)\n" ] }, { "data": { "text/plain": [ "MixturePoissonVB([5.07029e-15 0.999917 8.28694e-5; 0.999919 1.7906e-25 8.07822e-5; … ; 8.61198e-13 0.997857 0.00214278; 0.999919 1.7906e-25 8.07822e-5], [41.239, 3.16116, 16.0904], [1.26629e5, 9464.78, 63449.6], [3070.6, 2994.08, 3943.32], [3068.6, 2991.08, 3943.32])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@time res2 = vb(X, 3, 100, Khyper[:,1], Khyper[:,2], [1.0,1.0,1.0])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "10000×5 Array{Float64,2}:\n", " 5.07029e-15 0.999917 8.28694e-5 2.0 1.0\n", " 0.999919 1.7906e-25 8.07822e-5 37.0 2.0\n", " 1.11377e-11 0.989187 0.0108125 5.0 1.0\n", " 0.999919 1.7906e-25 8.07822e-5 37.0 2.0\n", " 3.88668e-16 0.999984 1.62811e-5 1.0 1.0\n", " 1.25977e-5 7.83332e-6 0.99998 15.0 3.0\n", " 3.2287e-5 1.53886e-6 0.999966 16.0 3.0\n", " 0.86925 1.30629e-16 0.13075 29.0 2.0\n", " 6.61254e-14 0.999578 0.000421684 3.0 1.0\n", " 3.2287e-5 1.53886e-6 0.999966 16.0 3.0\n", " 0.00139127 2.28903e-9 0.998609 20.0 3.0\n", " 0.999968 1.37258e-26 3.15206e-5 38.0 2.0\n", " 7.47518e-7 0.00103211 0.998967 12.0 3.0\n", " ⋮ \n", " 0.022918 1.69812e-11 0.977082 23.0 3.0\n", " 8.61198e-13 0.997857 0.00214278 4.0 1.0\n", " 3.2287e-5 1.53886e-6 0.999966 16.0 3.0\n", " 1.91745e-6 0.000202929 0.999795 13.0 3.0\n", " 0.00906871 8.76639e-11 0.990931 22.0 3.0\n", " 1.02395e-8 0.409555 0.590445 8.0 3.0\n", " 6.61254e-14 0.999578 0.000421684 3.0 1.0\n", " 3.2287e-5 1.53886e-6 0.999966 16.0 3.0\n", " 3.88668e-16 0.999984 1.62811e-5 1.0 1.0\n", " 0.999968 1.37258e-26 3.15206e-5 38.0 2.0\n", " 8.61198e-13 0.997857 0.00214278 4.0 1.0\n", " 0.999919 1.7906e-25 8.07822e-5 37.0 2.0" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "[res2.η X Sn]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "

10,000 rows × 5 columns

x1x2x3x4x5
Float64Float64Float64Float64Float64
10.9999175.07029e-158.28694e-52.01.0
21.7906e-250.9999198.07822e-537.02.0
30.9891871.11377e-110.01081255.01.0
41.7906e-250.9999198.07822e-537.02.0
50.9999843.88668e-161.62811e-51.01.0
67.83332e-61.25977e-50.9999815.03.0
71.53886e-63.2287e-50.99996616.03.0
81.30629e-160.869250.1307529.02.0
90.9995786.61254e-140.0004216843.01.0
101.53886e-63.2287e-50.99996616.03.0
112.28903e-90.001391270.99860920.03.0
121.37258e-260.9999683.15206e-538.02.0
130.001032117.47518e-70.99896712.03.0
144.48713e-100.003558040.99644221.03.0
154.73826e-310.9999997.30524e-742.02.0
164.73826e-310.9999997.30524e-742.02.0
170.9995786.61254e-140.0004216843.01.0
181.25373e-361.06.60571e-947.02.0
191.53886e-63.2287e-50.99996616.03.0
200.9999175.07029e-158.28694e-52.01.0
210.9891871.11377e-110.01081255.01.0
220.9999175.07029e-158.28694e-52.01.0
230.0002029291.91745e-60.99979513.03.0
240.9978578.61198e-130.002142784.01.0
251.53886e-63.2287e-50.99996616.03.0
260.1199263.91165e-80.8800749.03.0
270.001032117.47518e-70.99896712.03.0
285.16956e-210.9965260.0034738633.02.0
291.53886e-63.2287e-50.99996616.03.0
300.0002029291.91745e-60.99979513.03.0
" ], "text/latex": [ "\\begin{tabular}{r|ccccc}\n", "\t& x1 & x2 & x3 & x4 & x5\\\\\n", "\t\\hline\n", "\t& Float64 & Float64 & Float64 & Float64 & Float64\\\\\n", "\t\\hline\n", "\t1 & 0.999917 & 5.07029e-15 & 8.28694e-5 & 2.0 & 1.0 \\\\\n", "\t2 & 1.7906e-25 & 0.999919 & 8.07822e-5 & 37.0 & 2.0 \\\\\n", "\t3 & 0.989187 & 1.11377e-11 & 0.0108125 & 5.0 & 1.0 \\\\\n", "\t4 & 1.7906e-25 & 0.999919 & 8.07822e-5 & 37.0 & 2.0 \\\\\n", "\t5 & 0.999984 & 3.88668e-16 & 1.62811e-5 & 1.0 & 1.0 \\\\\n", "\t6 & 7.83332e-6 & 1.25977e-5 & 0.99998 & 15.0 & 3.0 \\\\\n", "\t7 & 1.53886e-6 & 3.2287e-5 & 0.999966 & 16.0 & 3.0 \\\\\n", "\t8 & 1.30629e-16 & 0.86925 & 0.13075 & 29.0 & 2.0 \\\\\n", "\t9 & 0.999578 & 6.61254e-14 & 0.000421684 & 3.0 & 1.0 \\\\\n", "\t10 & 1.53886e-6 & 3.2287e-5 & 0.999966 & 16.0 & 3.0 \\\\\n", "\t11 & 2.28903e-9 & 0.00139127 & 0.998609 & 20.0 & 3.0 \\\\\n", "\t12 & 1.37258e-26 & 0.999968 & 3.15206e-5 & 38.0 & 2.0 \\\\\n", "\t13 & 0.00103211 & 7.47518e-7 & 0.998967 & 12.0 & 3.0 \\\\\n", "\t14 & 4.48713e-10 & 0.00355804 & 0.996442 & 21.0 & 3.0 \\\\\n", "\t15 & 4.73826e-31 & 0.999999 & 7.30524e-7 & 42.0 & 2.0 \\\\\n", "\t16 & 4.73826e-31 & 0.999999 & 7.30524e-7 & 42.0 & 2.0 \\\\\n", "\t17 & 0.999578 & 6.61254e-14 & 0.000421684 & 3.0 & 1.0 \\\\\n", "\t18 & 1.25373e-36 & 1.0 & 6.60571e-9 & 47.0 & 2.0 \\\\\n", "\t19 & 1.53886e-6 & 3.2287e-5 & 0.999966 & 16.0 & 3.0 \\\\\n", "\t20 & 0.999917 & 5.07029e-15 & 8.28694e-5 & 2.0 & 1.0 \\\\\n", "\t21 & 0.989187 & 1.11377e-11 & 0.0108125 & 5.0 & 1.0 \\\\\n", "\t22 & 0.999917 & 5.07029e-15 & 8.28694e-5 & 2.0 & 1.0 \\\\\n", "\t23 & 0.000202929 & 1.91745e-6 & 0.999795 & 13.0 & 3.0 \\\\\n", "\t24 & 0.997857 & 8.61198e-13 & 0.00214278 & 4.0 & 1.0 \\\\\n", "\t25 & 1.53886e-6 & 3.2287e-5 & 0.999966 & 16.0 & 3.0 \\\\\n", "\t26 & 0.119926 & 3.91165e-8 & 0.880074 & 9.0 & 3.0 \\\\\n", "\t27 & 0.00103211 & 7.47518e-7 & 0.998967 & 12.0 & 3.0 \\\\\n", "\t28 & 5.16956e-21 & 0.996526 & 0.00347386 & 33.0 & 2.0 \\\\\n", "\t29 & 1.53886e-6 & 3.2287e-5 & 0.999966 & 16.0 & 3.0 \\\\\n", "\t30 & 0.000202929 & 1.91745e-6 & 0.999795 & 13.0 & 3.0 \\\\\n", "\t$\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ & $\\dots$ \\\\\n", "\\end{tabular}\n" ], "text/plain": [ "10000×5 DataFrame\n", "│ Row │ x1 │ x2 │ x3 │ x4 │ x5 │\n", "│ │ \u001b[90mFloat64\u001b[39m │ \u001b[90mFloat64\u001b[39m │ \u001b[90mFloat64\u001b[39m │ \u001b[90mFloat64\u001b[39m │ \u001b[90mFloat64\u001b[39m │\n", "├───────┼─────────────┼─────────────┼─────────────┼─────────┼─────────┤\n", "│ 1 │ 0.999917 │ 5.07029e-15 │ 8.28694e-5 │ 2.0 │ 1.0 │\n", "│ 2 │ 1.7906e-25 │ 0.999919 │ 8.07822e-5 │ 37.0 │ 2.0 │\n", "│ 3 │ 0.989187 │ 1.11377e-11 │ 0.0108125 │ 5.0 │ 1.0 │\n", "│ 4 │ 1.7906e-25 │ 0.999919 │ 8.07822e-5 │ 37.0 │ 2.0 │\n", "│ 5 │ 0.999984 │ 3.88668e-16 │ 1.62811e-5 │ 1.0 │ 1.0 │\n", "│ 6 │ 7.83332e-6 │ 1.25977e-5 │ 0.99998 │ 15.0 │ 3.0 │\n", "│ 7 │ 1.53886e-6 │ 3.2287e-5 │ 0.999966 │ 16.0 │ 3.0 │\n", "│ 8 │ 1.30629e-16 │ 0.86925 │ 0.13075 │ 29.0 │ 2.0 │\n", "│ 9 │ 0.999578 │ 6.61254e-14 │ 0.000421684 │ 3.0 │ 1.0 │\n", "│ 10 │ 1.53886e-6 │ 3.2287e-5 │ 0.999966 │ 16.0 │ 3.0 │\n", "⋮\n", "│ 9990 │ 0.997857 │ 8.61198e-13 │ 0.00214278 │ 4.0 │ 1.0 │\n", "│ 9991 │ 1.53886e-6 │ 3.2287e-5 │ 0.999966 │ 16.0 │ 3.0 │\n", "│ 9992 │ 0.000202929 │ 1.91745e-6 │ 0.999795 │ 13.0 │ 3.0 │\n", "│ 9993 │ 8.76639e-11 │ 0.00906871 │ 0.990931 │ 22.0 │ 3.0 │\n", "│ 9994 │ 0.409555 │ 1.02395e-8 │ 0.590445 │ 8.0 │ 3.0 │\n", "│ 9995 │ 0.999578 │ 6.61254e-14 │ 0.000421684 │ 3.0 │ 1.0 │\n", "│ 9996 │ 1.53886e-6 │ 3.2287e-5 │ 0.999966 │ 16.0 │ 3.0 │\n", "│ 9997 │ 0.999984 │ 3.88668e-16 │ 1.62811e-5 │ 1.0 │ 1.0 │\n", "│ 9998 │ 1.37258e-26 │ 0.999968 │ 3.15206e-5 │ 38.0 │ 2.0 │\n", "│ 9999 │ 0.997857 │ 8.61198e-13 │ 0.00214278 │ 4.0 │ 1.0 │\n", "│ 10000 │ 1.7906e-25 │ 0.999919 │ 8.07822e-5 │ 37.0 │ 2.0 │" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "using DataFrames\n", "df_tmp = [res2.η[:,2] res2.η[:,1] res2.η[:,3]]\n", "df = convert(DataFrames.DataFrame, [df_tmp X Sn])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "誤分類について可視化してみましょう。" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1×10000 LinearAlgebra.Adjoint{Int64,Array{Int64,1}}:\n", " 1 2 1 2 1 3 3 2 1 3 3 2 3 … 3 1 3 3 3 3 1 3 1 2 1 2" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "which_max = zeros(Int64, N)\n", "for i in 1:N\n", " which_max[i] = findall(x->x==maximum(df_tmp[i,:]), df_tmp[i,:])[]\n", "end\n", "which_max'" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dict{Int64,Int64} with 3 entries:\n", " 2 => 25\n", " 3 => 67\n", " 1 => 40" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "error_count = Sn[Sn .!== which_max]\n", "countmap(error_count)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Error rate is 1.32%\n" ] } ], "source": [ "println(\"Error rate is \", count(!iszero, which_max - Sn) / N * 100 ,\"%\")" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "ename": "UndefVarError", "evalue": "UndefVarError: fit not defined", "output_type": "error", "traceback": [ "UndefVarError: fit not defined", "", "Stacktrace:", " [1] top-level scope at In[1]:1" ] } ], "source": [ "plot(fit(Histogram, error_count), xlab = \"Latent Cluster\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "x軸の体裁があまりよくないですが...." ] } ], "metadata": { "kernelspec": { "display_name": "Julia 1.1.1", "language": "julia", "name": "julia-1.1" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.1.1" } }, "nbformat": 4, "nbformat_minor": 2 }