{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 無限関係モデルによる共クラスタリングの実験用のトイデータ生成\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"クラスタ数は『続・わかりやすいパターン認識』の13章P.279のデータと揃えてある.\n",
"ただし,関係行列の行数・列数は比率はそのままで数を10倍に増やしてる."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"using Distributions\n",
"using Plots\n",
"\n",
"using DelimitedFiles"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"行(顧客)と列(商品)のクラスタ数$c_1$, $c_2$"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"c1 = 4;\n",
"c2 = 3;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"行(顧客)と列(商品)のデータ数$K$, $L$"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"K = 15 * 10;\n",
"L = 10 * 10;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"行(顧客)と列(商品)の混合比率$\\pi^1$, $\\pi^2$"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"π1 = [4/15; 3/15; 5/15; 3/15];\n",
"π2 = [4/10; 3/10; 3/10];"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"行列$\\Theta = \\{\\theta_ij\\}$の各成分は,\n",
"顧客クラスタ$i$に属する顧客が商品クラスタ$j$に属する商品を購入する確率を表す."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"Θ = [0.2 0.9 0.1;\n",
" 1.0 0.8 0.0;\n",
" 0.1 0.1 0.9;\n",
" 0.2 0.7 0.1];"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"顧客と商品の所属クラスを表す潜在変数$\\mathbf s^1$, $\\mathbf s^2$"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"s1 = rand(Categorical(π1), K);\n",
"s2 = rand(Categorical(π2), L);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"顧客$k$が商品$l$を購入してたら1,そうでなければ0を成分にもつ\n",
"関係行列$\\mathbf R = \\{R_{kl}\\}$"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"R = zeros(K, L);\n",
"\n",
"for k in 1:K\n",
" for l in 1:L\n",
" R[k, l] = rand(Bernoulli(Θ[s1[k], s2[l]]))\n",
" end\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"顧客と商品の関係行列$\\mathbf R$をプロット"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"heatmap(R, yflip=true, c=ColorGradient([:white, :black]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"顧客と商品の関係行列$\\mathbf R$を真のラベルを用いて並べ替えて表示する"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"row_idxs = sort(collect(1:K), by=i->s1[i]);\n",
"col_idxs = sort(collect(1:L), by=i->s2[i]);\n",
"\n",
"heatmap(R[row_idxs, col_idxs], yflip=true, c=ColorGradient([:white, :black]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# データの出力"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://stackoverflow.com/questions/52900232/export-an-array-to-csv-file-in-julia"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"writedlm(\"IRM_toydata_R_200402.csv\", R, ',')\n",
"writedlm(\"IRM_toydata_s1_200402.csv\", s1, ',')\n",
"writedlm(\"IRM_toydata_s2_200402.csv\", s2, ',')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"以下,出力したデータを読み込むテスト"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"R_gt = readdlm(\"IRM_toydata_R_200402.csv\", ',');\n",
"s1_gt = readdlm(\"IRM_toydata_s1_200402.csv\", ',');\n",
"s2_gt = readdlm(\"IRM_toydata_s2_200402.csv\", ',');"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"row_idxs = sort(collect(1:K), by=i->s1_gt[i]);\n",
"col_idxs = sort(collect(1:L), by=i->s2_gt[i]);"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"heatmap(R_gt[row_idxs, col_idxs], yflip=true, c=ColorGradient([:white, :black]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.2.0",
"language": "julia",
"name": "julia-1.2"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.2.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}