{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 無限関係モデルによる共クラスタリングのJulia実装\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"無限関係モデル (Infinite Relational Model: IRM) によるクラスタリングを\n",
"崩壊型ギプスサンプリングでやる.\n",
"\n",
"* めちゃ参考になるスライド:\n",
" * https://www.slideshare.net/shuyo/infinite-relational-model\n",
" * 『続・わかりやすいパターン認識』13章の式(13.26), (13.27)の致命的な数式の間違いを指摘してくれてるので目を通しておいた方がいい"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"using Distributions\n",
"using SpecialFunctions\n",
"\n",
"using Plots\n",
"using DelimitedFiles\n",
"using ProgressBars"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# トイデータの読み込みと可視化"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(150, 100)"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"R = readdlm(\"IRM_toydata_R_200402.csv\", ',');\n",
"s1_gt = readdlm(\"IRM_toydata_s1_200402.csv\", ',');\n",
"s2_gt = readdlm(\"IRM_toydata_s2_200402.csv\", ',');\n",
"\n",
"# 行・列数\n",
"K, L = size(R)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"row_idxs_gt = sort(collect(1:K), by=i->s1_gt[i]);\n",
"col_idxs_gt = sort(collect(1:L), by=i->s2_gt[i]);"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plot(\n",
" heatmap(R, yflip=true, title=\"Input\", c=ColorGradient([:white, :black])),\n",
" heatmap(R[row_idxs_gt, col_idxs_gt], yflip=true, title=\"Ground Truth\", c=ColorGradient([:white, :black]))\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# モデルの初期化"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"IRMを表現する構造体"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"struct IRM\n",
" c1::Int\n",
" c2::Int\n",
" K::Int # number of rows\n",
" L::Int # number of cols\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# モデルのハイパーパラメタ\n",
"α = 1.0;\n",
"a = 1.0;\n",
"b = 2.0;\n",
"\n",
"# 初期クラスタ数\n",
"c1 = 2;\n",
"c2 = 2;"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"IRM(2, 2, 150, 100)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"irm = IRM(c1, c2, K, L)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"行・列の各々のオブジェクトのクラスタ割り当て$\\mathbf s^1$と$\\mathbf s^2$を初期化する."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"π1 = rand(Dirichlet(α * ones(irm.c1)));\n",
"π2 = rand(Dirichlet(α * ones(irm.c2)));\n",
"\n",
"s1 = rand(Categorical(π1), irm.K);\n",
"s2 = rand(Categorical(π2), irm.L);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 崩壊型ギプスサンプリングの実装\n",
"\n",
"関数`sample_s1`がテキストの式(13.26),\n",
"関数`sample_s2`がテキストの式(13.27)の条件付き分布からのサンプリングを表す.\n",
"テキストを見てもわかる通り,ふたつの条件付き分布の違いは行か列のどっちをみているかだけなので\n",
"関数の中身もほぼ同じである.\n",
"\n",
"式の中に出てくる関数$n_{(k, +)(i, j)}$とかの計算方法はNotebookの末尾に付録をつけたので参考にされたい."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sample_s1 (generic function with 1 method)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function sample_s1(irm::IRM, R::Array{Float64, 2}, s1::Array{Int64, 1}, s2::Array{Int64, 1})\n",
" \n",
" for k in 1:irm.K\n",
" state_ids = Array{Int64,1}(1:irm.c1)\n",
" m = [count(x -> x == i, s1) for i in state_ids]\n",
" m[s1[k]] -= 1\n",
"\n",
" # If the category s1[k] is empty, the category is removed.\n",
" if m[s1[k]] == 0\n",
" \n",
" # 消されたクラスタより後ろの番号のクラスタについて,番号をつめる\n",
" s1[s1 .> s1[k]] .-= 1\n",
"\n",
" # 消されたクラスタを削除する\n",
" m = m[state_ids .!= s1[k]]\n",
" \n",
" irm = IRM(irm.c1-1, irm.c2, irm.K, irm.L)\n",
" end\n",
"\n",
" new_mixing_coeff = [m; α]\n",
" new_mixing_coeff = new_mixing_coeff ./ (irm.K - 1 + α)\n",
"\n",
" # Compute likelihoods\n",
" likelihood = ones(irm.c1+1)\n",
" for i in 1:(irm.c1)\n",
" log_lik = 0.0\n",
" for j in 1:(irm.c2)\n",
" a_post_n = sum(R[s1 .== i, s2 .== j]) + a\n",
" b_post_n = sum(1.0 .- R[s1 .== i, s2 .== j]) + b\n",
" a_post_d = sum(R[s1 .== i, s2 .== j]) - sum(R[k, s2 .== j]) + a\n",
" b_post_d = sum(1.0 .- R[s1 .== i, s2 .== j]) - sum(1.0 .- R[k, s2 .== j]) + b\n",
" log_lik += lbeta(a_post_n, b_post_n) - lbeta(a_post_d, b_post_d)\n",
" end\n",
" likelihood[i] = exp(log_lik)\n",
" end\n",
" \n",
" # For new class\n",
" i = irm.c1+1\n",
" log_lik = 0.0\n",
" for j in 1:(irm.c2)\n",
" a_post_n = sum(R[k, s2 .== j]) + a\n",
" b_post_n = sum(1.0 .- R[k, s2 .== j]) + b\n",
" log_lik += lbeta(a_post_n, b_post_n) - lbeta(a, b)\n",
" end\n",
" likelihood[i] = exp(log_lik)\n",
" \n",
" likelihood = likelihood ./ sum(likelihood)\n",
" new_mixing_coeff = new_mixing_coeff .* likelihood\n",
" new_mixing_coeff = new_mixing_coeff ./ sum(new_mixing_coeff)\n",
" \n",
" # Resample hidden state\n",
" s1[k] = rand(Categorical(new_mixing_coeff))\n",
"\n",
" if s1[k] == irm.c1+1\n",
" irm = IRM(irm.c1+1, irm.c2, irm.K, irm.L)\n",
" else\n",
" irm = IRM(irm.c1, irm.c2, irm.K, irm.L)\n",
" end\n",
" end\n",
" \n",
" return s1, irm\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"sample_s2 (generic function with 1 method)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function sample_s2(irm::IRM, R::Array{Float64, 2}, s1::Array{Int64, 1}, s2::Array{Int64, 1})\n",
" \n",
" for l in 1:irm.L\n",
" state_ids = Array{Int64,1}(1:irm.c2)\n",
" m = [count(x -> x == j, s2) for j in state_ids]\n",
" m[s2[l]] -= 1\n",
"\n",
" # If the category s2[k] is empty, the category is removed.\n",
" if m[s2[l]] == 0\n",
" \n",
" # 消されたクラスタより後ろの番号のクラスタについて,番号をつめる\n",
" s2[s2 .> s2[l]] .-= 1\n",
"\n",
" # 消されたクラスタを削除する\n",
" m = m[state_ids .!= s2[l]]\n",
" \n",
" irm = IRM(irm.c1, irm.c2-1, irm.K, irm.L)\n",
" end\n",
"\n",
" new_mixing_coeff = [m; α]\n",
" new_mixing_coeff = new_mixing_coeff ./ (irm.L - 1 + α)\n",
"\n",
" # Compute likelihoods\n",
" likelihood = ones(irm.c2+1)\n",
" for j in 1:(irm.c2)\n",
" log_lik = 0.0\n",
" for i in 1:(irm.c1)\n",
" a_post_n = sum(R[s1 .== i, s2 .== j]) + a\n",
" b_post_n = sum(1.0 .- R[s1 .== i, s2 .== j]) + b\n",
" a_post_d = sum(R[s1 .== i, s2 .== j]) - sum(R[s1 .== i, l]) + a\n",
" b_post_d = sum(1.0 .- R[s1 .== i, s2 .== j]) - sum(1.0 .- R[s1 .== i, l]) + b\n",
" log_lik += lbeta(a_post_n, b_post_n) - lbeta(a_post_d, b_post_d)\n",
" end\n",
" likelihood[j] = exp(log_lik)\n",
" end\n",
" \n",
" # For new class\n",
" j = irm.c2+1\n",
" log_lik = 0.0\n",
" for i in 1:(irm.c1)\n",
" a_post_n = sum(R[s1 .== i, l]) + a\n",
" b_post_n = sum(1.0 .- R[s1 .== i, l]) + b\n",
" log_lik += lbeta(a_post_n, b_post_n) - lbeta(a, b)\n",
" end\n",
" likelihood[j] = exp(log_lik)\n",
" \n",
" likelihood = likelihood ./ sum(likelihood)\n",
" new_mixing_coeff = new_mixing_coeff .* likelihood\n",
" new_mixing_coeff = new_mixing_coeff ./ sum(new_mixing_coeff)\n",
" \n",
" # Resample hidden state\n",
" s2[l] = rand(Categorical(new_mixing_coeff))\n",
"\n",
" if s2[l] == irm.c2+1\n",
" irm = IRM(irm.c1, irm.c2+1, irm.K, irm.L)\n",
" else\n",
" irm = IRM(irm.c1, irm.c2, irm.K, irm.L)\n",
" end\n",
" end\n",
" \n",
" return s2, irm\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"崩壊型ギプスサンプリングの実行"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": []
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"100.00%┣███████████████████████████████████████████████████████████▉┫ 200/200 00:44<00:00, 4.53 it/s]5/200 00:03<02:45, 1.18 it/s]7/200 00:04<02:02, 1.59 it/s]11/200 00:05<01:26, 2.19 it/s]20/200 00:06<00:59, 3.04 it/s]21/200 00:06<00:57, 3.12 it/s]22/200 00:07<00:56, 3.18 it/s]27/200 00:08<00:51, 3.41 it/s]29/200 00:08<00:49, 3.48 it/s]32/200 00:09<00:47, 3.59 it/s]35/200 00:09<00:45, 3.68 it/s]39/200 00:10<00:42, 3.81 it/s]41/200 00:10<00:41, 3.86 it/s]44/200 00:11<00:39, 3.95 it/s]45/200 00:11<00:39, 3.97 it/s]46/200 00:11<00:38, 4.00 it/s]50/200 00:12<00:37, 4.08 it/s]51/200 00:12<00:36, 4.10 it/s]53/200 00:13<00:35, 4.15 it/s]55/200 00:13<00:35, 4.19 it/s]56/200 00:13<00:34, 4.20 it/s]57/200 00:13<00:34, 4.21 it/s]58/200 00:13<00:34, 4.22 it/s]63/200 00:14<00:32, 4.29 it/s]64/200 00:15<00:32, 4.29 it/s]67/200 00:16<00:32, 4.13 it/s]68/200 00:16<00:32, 4.14 it/s]69/200 00:16<00:32, 4.15 it/s]71/200 00:17<00:31, 4.14 it/s]72/200 00:17<00:31, 4.15 it/s]76/200 00:18<00:30, 4.14 it/s]80/200 00:19<00:29, 4.16 it/s]84/200 00:20<00:28, 4.17 it/s]89/200 00:21<00:27, 4.16 it/s]90/200 00:21<00:26, 4.17 it/s]┫ 92/200 00:22<00:26, 4.20 it/s]94/200 00:22<00:25, 4.23 it/s]98/200 00:23<00:24, 4.26 it/s]101/200 00:23<00:23, 4.28 it/s]102/200 00:24<00:23, 4.29 it/s]106/200 00:24<00:22, 4.31 it/s]107/200 00:25<00:22, 4.32 it/s]111/200 00:25<00:20, 4.36 it/s]113/200 00:26<00:20, 4.38 it/s]┫ 115/200 00:26<00:19, 4.40 it/s]116/200 00:26<00:19, 4.41 it/s]122/200 00:27<00:18, 4.44 it/s]126/200 00:28<00:17, 4.45 it/s]128/200 00:28<00:16, 4.46 it/s]┫ 130/200 00:29<00:16, 4.47 it/s]132/200 00:29<00:15, 4.47 it/s]134/200 00:30<00:15, 4.48 it/s]135/200 00:30<00:14, 4.49 it/s]┫ 138/200 00:31<00:14, 4.48 it/s]139/200 00:31<00:14, 4.47 it/s]140/200 00:31<00:13, 4.47 it/s]143/200 00:32<00:13, 4.46 it/s]┫ 144/200 00:32<00:13, 4.46 it/s]145/200 00:32<00:12, 4.46 it/s]147/200 00:33<00:12, 4.46 it/s]150/200 00:34<00:11, 4.44 it/s]152/200 00:34<00:11, 4.44 it/s]153/200 00:34<00:11, 4.43 it/s]154/200 00:35<00:10, 4.42 it/s]156/200 00:35<00:10, 4.43 it/s]157/200 00:35<00:10, 4.44 it/s]158/200 00:35<00:09, 4.44 it/s]160/200 00:36<00:09, 4.45 it/s]161/200 00:36<00:09, 4.46 it/s]164/200 00:36<00:08, 4.48 it/s]165/200 00:37<00:08, 4.48 it/s]167/200 00:37<00:07, 4.49 it/s]┫ 174/200 00:38<00:06, 4.53 it/s]176/200 00:39<00:05, 4.53 it/s]177/200 00:39<00:05, 4.54 it/s]183/200 00:41<00:04, 4.49 it/s]┫ 184/200 00:41<00:04, 4.49 it/s]185/200 00:41<00:03, 4.49 it/s]186/200 00:41<00:03, 4.49 it/s]187/200 00:41<00:03, 4.49 it/s]189/200 00:42<00:02, 4.48 it/s]190/200 00:42<00:02, 4.49 it/s]195/200 00:43<00:01, 4.50 it/s]196/200 00:43<00:01, 4.50 it/s]"
]
}
],
"source": [
"for t in tqdm(1:200)\n",
" s1, irm = sample_s1(irm, R, s1, s2);\n",
" s2, irm = sample_s2(irm, R, s1, s2);\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"クラスタリング結果のプロット.\n",
"一見,Ground Truthと異なるようにみえる(可能性がかなり高い)が,\n",
"クラスタの順序が入れ替わってるだけなので並べ替えるとGround Truthと一致するはず."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"row_idxs = sort(collect(1:K), by=i->s1[i]);\n",
"col_idxs = sort(collect(1:L), by=i->s2[i]);"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plot(\n",
" heatmap(R[row_idxs, col_idxs], yflip=true, title=\"Result\", c=ColorGradient([:white, :black])),\n",
" heatmap(R[row_idxs_gt, col_idxs_gt], yflip=true, title=\"Ground Truth\", c=ColorGradient([:white, :black]))\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"plot(\n",
" heatmap(R, yflip=true, title=\"Input\", c=ColorGradient([:white, :black])),\n",
" heatmap(R[row_idxs, col_idxs], yflip=true, title=\"Result\", c=ColorGradient([:white, :black]))\n",
")\n",
"savefig(\"IRM_result_200402.png\")\n",
"savefig(\"IRM_result_200402.pdf\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 付録\n",
"\n",
"## 関数$n_{(k, +)(i, j)}$とかの計算方法について\n",
"\n",
"参考スライド( https://www.slideshare.net/shuyo/infinite-relational-model )の\n",
"P.17にある図のような関係行列について,\n",
"$n_{(+, +)(i, j)}$, $\\bar n_{(+, +)(i, j)}$, \n",
"$n_{(-k,+)(i,j)}$, $\\bar n_{(-k,+)(i,j)}$, \n",
"$n_{(k,+)(+,j)}$, $\\bar n_{(k,+)(+,j)}$\n",
"の値の計算方法を以下に示した.\n",
"$n_{(k,+)(+,j)}$, $\\bar n_{(k,+)(+,j)}$はテキスト中に記号の定義が無い上に\n",
"これらが出てくる式(13.26), (13.27)が間違ってるので注意."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"i, j = 1, 1\n",
"k = 4"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"s1 = [1, 1, 1, 1, 1, 1];\n",
"s2 = [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0];"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"6×15 Array{Int64,2}:\n",
" 0 0 1 0 0 1 1 0 1 1 0 0 0 0 1\n",
" 0 1 0 0 0 1 1 1 1 1 0 0 0 0 0\n",
" 1 0 1 0 0 1 1 1 1 1 0 0 0 0 0\n",
" 0 0 0 0 0 1 0 0 1 1 0 0 0 1 1\n",
" 0 0 0 0 0 1 0 1 1 1 0 0 0 0 0\n",
" 0 1 0 1 0 0 1 1 1 1 1 0 0 0 0"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"R = [0 0 1 0 0 1 1 0 1 1 0 0 0 0 1;\n",
" 0 1 0 0 0 1 1 1 1 1 0 0 0 0 0;\n",
" 1 0 1 0 0 1 1 1 1 1 0 0 0 0 0;\n",
" 0 0 0 0 0 1 0 0 1 1 0 0 0 1 1;\n",
" 0 0 0 0 0 1 0 1 1 1 0 0 0 0 0;\n",
" 0 1 0 1 0 0 1 1 1 1 1 0 0 0 0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$n_{(+, +)(i, j)}$"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"13"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sum(R[s1 .== i, s2 .== j])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$\\bar n_{(+, +)(i, j)}$"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sum(1 .- R[s1 .== i, s2 .== j])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$n_{(k,+)(+,j)}$"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sum(R[k, s2 .== j])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$\\bar n_{(k,+)(+,j)}$"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sum(1 .- R[k, s2 .== j])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$n_{(-k,+)(i,j)} = n_{(+,+)(i,j)} - n_{(k,+)(+,j)}$"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"12"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sum(R[s1 .== i, s2 .== j]) - sum(R[k, s2 .== j])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$\\bar n_{(-k,+)(i,j)} = \\bar n_{(+,+)(i,j)} - \\bar n_{(k,+)(+,j)}$"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sum(1 .- R[s1 .== i, s2 .== j]) - sum(1 .- R[k, s2 .== j])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.2.0",
"language": "julia",
"name": "julia-1.2"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.2.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}