{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 無限関係モデルによる共クラスタリングの実験用のトイデータ生成\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "クラスタ数は『続・わかりやすいパターン認識』の13章P.279のデータと揃えてある.\n", "ただし,関係行列の行数・列数は比率はそのままで数を10倍に増やしてる." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "using Distributions\n", "using Plots\n", "\n", "using DelimitedFiles" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "行(顧客)と列(商品)のクラスタ数$c_1$, $c_2$" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "c1 = 4;\n", "c2 = 3;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "行(顧客)と列(商品)のデータ数$K$, $L$" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "K = 15 * 10;\n", "L = 10 * 10;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "行(顧客)と列(商品)の混合比率$\\pi^1$, $\\pi^2$" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "π1 = [4/15; 3/15; 5/15; 3/15];\n", "π2 = [4/10; 3/10; 3/10];" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "行列$\\Theta = \\{\\theta_ij\\}$の各成分は,\n", "顧客クラスタ$i$に属する顧客が商品クラスタ$j$に属する商品を購入する確率を表す." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "Θ = [0.2 0.9 0.1;\n", " 1.0 0.8 0.0;\n", " 0.1 0.1 0.9;\n", " 0.2 0.7 0.1];" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "顧客と商品の所属クラスを表す潜在変数$\\mathbf s^1$, $\\mathbf s^2$" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "s1 = rand(Categorical(π1), K);\n", "s2 = rand(Categorical(π2), L);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "顧客$k$が商品$l$を購入してたら1,そうでなければ0を成分にもつ\n", "関係行列$\\mathbf R = \\{R_{kl}\\}$" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "R = zeros(K, L);\n", "\n", "for k in 1:K\n", " for l in 1:L\n", " R[k, l] = rand(Bernoulli(Θ[s1[k], s2[l]]))\n", " end\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "顧客と商品の関係行列$\\mathbf R$をプロット" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "20\n", "\n", "\n", "40\n", "\n", "\n", "60\n", "\n", "\n", "80\n", "\n", "\n", "100\n", "\n", "\n", "20\n", "\n", "\n", "40\n", "\n", "\n", "60\n", "\n", "\n", "80\n", "\n", "\n", "100\n", "\n", "\n", "120\n", "\n", "\n", "140\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "0.1\n", "\n", "\n", "0.2\n", "\n", "\n", "0.3\n", "\n", "\n", "0.4\n", "\n", "\n", "0.5\n", "\n", "\n", "0.6\n", "\n", "\n", "0.7\n", "\n", "\n", "0.8\n", "\n", "\n", "0.9\n", "\n", "\n", "1.0\n", "\n", "\n", "\n" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "heatmap(R, yflip=true, c=ColorGradient([:white, :black]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "顧客と商品の関係行列$\\mathbf R$を真のラベルを用いて並べ替えて表示する" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "20\n", "\n", "\n", "40\n", "\n", "\n", "60\n", "\n", "\n", "80\n", "\n", "\n", "100\n", "\n", "\n", "20\n", "\n", "\n", "40\n", "\n", "\n", "60\n", "\n", "\n", "80\n", "\n", "\n", "100\n", "\n", "\n", "120\n", "\n", "\n", "140\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "0.1\n", "\n", "\n", "0.2\n", "\n", "\n", "0.3\n", "\n", "\n", "0.4\n", "\n", "\n", "0.5\n", "\n", "\n", "0.6\n", "\n", "\n", "0.7\n", "\n", "\n", "0.8\n", "\n", "\n", "0.9\n", "\n", "\n", "1.0\n", "\n", "\n", "\n" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "row_idxs = sort(collect(1:K), by=i->s1[i]);\n", "col_idxs = sort(collect(1:L), by=i->s2[i]);\n", "\n", "heatmap(R[row_idxs, col_idxs], yflip=true, c=ColorGradient([:white, :black]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# データの出力" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "https://stackoverflow.com/questions/52900232/export-an-array-to-csv-file-in-julia" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "writedlm(\"IRM_toydata_R_200402.csv\", R, ',')\n", "writedlm(\"IRM_toydata_s1_200402.csv\", s1, ',')\n", "writedlm(\"IRM_toydata_s2_200402.csv\", s2, ',')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以下,出力したデータを読み込むテスト" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "R_gt = readdlm(\"IRM_toydata_R_200402.csv\", ',');\n", "s1_gt = readdlm(\"IRM_toydata_s1_200402.csv\", ',');\n", "s2_gt = readdlm(\"IRM_toydata_s2_200402.csv\", ',');" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "row_idxs = sort(collect(1:K), by=i->s1_gt[i]);\n", "col_idxs = sort(collect(1:L), by=i->s2_gt[i]);" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "20\n", "\n", "\n", "40\n", "\n", "\n", "60\n", "\n", "\n", "80\n", "\n", "\n", "100\n", "\n", "\n", "20\n", "\n", "\n", "40\n", "\n", "\n", "60\n", "\n", "\n", "80\n", "\n", "\n", "100\n", "\n", "\n", "120\n", "\n", "\n", "140\n", "\n", "\n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "0\n", "\n", "\n", "0.1\n", "\n", "\n", "0.2\n", "\n", "\n", "0.3\n", "\n", "\n", "0.4\n", "\n", "\n", "0.5\n", "\n", "\n", "0.6\n", "\n", "\n", "0.7\n", "\n", "\n", "0.8\n", "\n", "\n", "0.9\n", "\n", "\n", "1.0\n", "\n", "\n", "\n" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "heatmap(R_gt[row_idxs, col_idxs], yflip=true, c=ColorGradient([:white, :black]))" ] } ], "metadata": { "kernelspec": { "display_name": "Julia 1.2.0", "language": "julia", "name": "julia-1.2" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.2.0" } }, "nbformat": 4, "nbformat_minor": 4 }