{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Juliaの計算は関数にしようって話\n", "\n", "項目反応理論 (IRT) モデルのパラメタ推定には潜在変数 (latent variable, hidden valiable) が含まれるため,それを周辺化して構造母数だけを推定するのがお作法。\n", "\n", "潜在変数を周辺化しながら最大化を図るための効率よいアルゴリズムにEMアルゴリズムと呼ばれるものがある。ところが困ったことにIRTモデルの尤度関数の潜在変数についての積分は解析的に解くことができない。\n", "\n", "そこで区分求積法に頼って近似計算をするわけだが,これが結構計算コストがかかる。R単体でfor loopを回すと,めっちゃおそい。\n", "\n", "普通はここでC++などの力を借りて高速化計算するのだが,今回はJuliaでお手軽高速化を図ってみたい。まだRとJuliaのIntegrationはよく分からないので,こんかいはJulia単体でいく。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "プロットに使うパラメタと関数の定義" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "using Distributions, Random, StatsFuns, Plots" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "icc2pl (generic function with 1 method)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "function icc2pl(θ::Float64, α::Float64, β::Float64)::Float64\n", " x::Float64 = α * (θ - β)\n", " p::Float64 = StatsFuns.logistic(x)\n", " p\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "もっと賢い書き方もある気がするけど,とりあえず愚直にforを回す。" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "-4\n", "\n", "\n", "-2\n", "\n", "\n", "0\n", "\n", "\n", "2\n", "\n", "\n", "4\n", "\n", "\n", "0.00\n", "\n", "\n", "0.25\n", "\n", "\n", "0.50\n", "\n", "\n", "0.75\n", "\n", "\n", "1.00\n", "\n", "\n", "\n", "\n", "\n", "\n", "y1\n", "\n", "\n" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "theta = collect(-4:0.1:4)\n", "icc = zeros(length(theta))\n", "for i in 1:length(theta)\n", " icc[i] = icc2pl(theta[i], 1.0, 0.0)\n", "end\n", "plot(theta, icc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "x軸が右に行けば行くほど,ある問題への正答確率が高くなる様なモデルを考える。このときαは項目が持つ能力の識別性能の高さを,βは項目自体の難しさを表していると考える。ここら辺は項目反応理論でググってみてください。今回は割愛します。" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "resgen_bin (generic function with 1 method)" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "struct simgenClass2\n", " resp::Array{Int64,2}\n", " θ::Array{Float64,1}\n", " α::Array{Float64,1}\n", " β::Array{Float64,1}\n", "end\n", "\n", "function resgen_bin(N::Int, J::Int)\n", " θ = rand(Normal(), N)\n", " α = rand(LogNormal(), J)\n", " β = rand(Normal(-1, 1), J)\n", " resp = zeros(Int64, length(θ), length(α))\n", " for i in 1:length(θ)\n", " for j in 1:length(α)\n", " resp[i, j] = ifelse(rand() > icc2pl(θ[i], α[j], β[j]), 0, 1)\n", " end\n", " end\n", " simgenClass2(resp, θ, α, β)\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "擬似的な項目反応データを生成する。50000人分のデータを生成するのだが,ここの時点ですでに高速すぎてちびった。" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 0.058579 seconds (2.76 k allocations: 19.498 MiB)\n" ] }, { "data": { "text/plain": [ "simgenClass2([1 0 … 1 1; 1 1 … 1 0; … ; 1 0 … 1 1; 1 1 … 1 1], [0.16099, -0.678808, 0.457395, 0.453073, 0.310874, -0.822951, 0.271002, -0.147871, -0.999116, -0.363523 … -0.105797, -0.830859, -1.21541, 0.845631, 1.39459, 1.44449, -0.388785, -0.618865, -0.943425, -1.17293], [0.54494, 1.71906, 1.12542, 0.595063, 1.03782, 0.489053, 2.05296, 0.37631, 2.78794, 1.79084 … 2.1822, 0.694498, 1.18709, 0.946057, 6.65905, 3.11907, 0.286872, 1.02368, 1.53182, 0.488699], [-3.1403, -0.49488, 0.362954, -1.3816, -0.371558, -2.09085, -0.268813, -0.417073, -0.937248, 0.857298 … -1.16759, -2.57849, -0.454556, -1.75791, -1.05828, 0.40001, -0.331622, -1.17879, -1.81634, -1.33322])" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "N = 50_000\n", "J = 50\n", "@time resgen_bin(N, J)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "尤度関数の周辺化を数値的に近似するための関数。" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Estep (generic function with 1 method)" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "struct EstepClass2\n", " resp::Array{Int64,2}\n", " θ::Array{Float64,1}\n", " α::Array{Float64,1}\n", " β::Array{Float64,1}\n", " Gim::Array{Float64,2}\n", " Lim::Array{Float64,2}\n", "end\n", "\n", "function Estep(N::Int, J::Int, M::Int)\n", " by = (4-(-4))/M\n", " xq = collect(-4:by:4)\n", " aq = pdf.(Normal(), xq) ./ sum(pdf.(Normal(), xq))\n", " L = zeros(N, length(xq))\n", " Gim = zeros(N, length(xq))\n", " # sim data gen\n", " resp = resgen_bin(N, J)\n", " x = resp.resp\n", " α = resp.α\n", " β = resp.β\n", " for m in 1:M\n", " println(\"NOW...\", m)\n", " for i in 1:N\n", " Li = zeros(J)\n", " for j in 1:J\n", " Li[j] = ifelse(x[i, j] == 1, icc2pl(xq[m], α[j], β[j]), 1 - icc2pl(xq[m], α[j], β[j]))\n", " end\n", " L[i, m] = prod(Li)\n", " Gim[i,m] = L[i,m] * aq[m]\n", " end\n", " if(m == length(xq))\n", " for i in 1:N\n", " Gim[i,m] = Gim[i,m]/sum(Gim[i,:])\n", " end\n", " end\n", " end\n", " EstepClass2(x, θ, α, β, Gim, L)\n", "end" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NOW...1\n", "NOW...2\n", "NOW...3\n", "NOW...4\n", "NOW...5\n", "NOW...6\n", "NOW...7\n", "NOW...8\n", "NOW...9\n", "NOW...10\n", "NOW...11\n", "NOW...12\n", "NOW...13\n", "NOW...14\n", "NOW...15\n", "NOW...16\n", "NOW...17\n", "NOW...18\n", "NOW...19\n", "NOW...20\n", "NOW...21\n", "NOW...22\n", "NOW...23\n", "NOW...24\n", "NOW...25\n", "NOW...26\n", "NOW...27\n", "NOW...28\n", "NOW...29\n", "NOW...30\n", " 2.762779 seconds (1.63 M allocations: 758.630 MiB, 5.04% gc time)\n" ] }, { "data": { "text/plain": [ "EstepClass2([1 1 … 1 1; 1 1 … 1 1; … ; 1 1 … 1 1; 1 1 … 1 1], [-0.37216, 0.394744, 1.0254, 0.0571423, 0.80871, 1.02205, -0.383376, 0.114626, -1.59855, 0.249076 … -0.107049, -0.35597, -0.0211458, 0.996935, -0.308854, 0.345422, 0.31834, 0.709949, -2.01223, 0.0424107], [1.48159, 0.932297, 0.811575, 1.59753, 0.855244, 4.2433, 4.43223, 0.439847, 0.461096, 0.905005 … 0.844219, 0.930126, 1.04635, 0.751366, 2.10459, 13.2816, 0.122729, 0.491111, 0.493516, 1.50265], [-1.54179, -1.3377, -1.79193, -1.00404, -1.69861, -0.991398, -1.41818, -1.01052, -0.331266, -1.77651 … -1.16776, -3.63493, -1.86877, -0.332499, -0.855854, -0.140566, -2.58158, -2.59735, -1.43972, -2.08783], [2.99803e-119 4.0408e-110 … 6.50739e-14 0.0; 1.32238e-122 3.53826e-113 … 1.24229e-8 0.0; … ; 2.26314e-115 1.33017e-106 … 1.73261e-20 0.0; 2.79725e-57 3.96453e-52 … 0.0 0.0], [8.40037e-115 4.0376e-106 … 6.50224e-10 0.0; 3.70526e-118 3.53546e-109 … 0.00012413 0.0; … ; 6.34123e-111 1.32912e-102 … 1.73124e-16 0.0; 7.8378e-53 3.96139e-48 … 0.0 0.0])" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@time Estep(50000, 50, 30)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "今度は関数化しないバージョン。最初にこちらのバージョンを走らせていて,「Juliaたいしたことねーな。」とか思っていたことは秘密。" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NOW...1\n", "NOW...2\n", "NOW...3\n", "NOW...4\n", "NOW...5\n", "NOW...6\n", "NOW...7\n", "NOW...8\n", "NOW...9\n", "NOW...10\n", "NOW...11\n", "NOW...12\n", "NOW...13\n", "NOW...14\n", "NOW...15\n", "NOW...16\n", "NOW...17\n", "NOW...18\n", "NOW...19\n", "NOW...20\n", "NOW...21\n", "NOW...22\n", "NOW...23\n", "NOW...24\n", "NOW...25\n", "NOW...26\n", "NOW...27\n", "NOW...28\n", "NOW...29\n", "NOW...30\n", " 23.034723 seconds (613.47 M allocations: 10.975 GiB, 4.70% gc time)\n" ] } ], "source": [ "N = 50_000\n", "J = 50\n", "M = 30\n", "by = (4-(-4))/M\n", "xq = collect(-4:by:4)\n", "aq = pdf.(Normal(), xq) ./ sum(pdf.(Normal(), xq))\n", "L = zeros(N, length(xq))\n", "Gim = zeros(N, length(xq))\n", "# sim data gen\n", "resp = resgen_bin(N, J)\n", "x = resp.resp\n", "α = resp.α\n", "β = resp.β\n", "@time for m in 1:M\n", " println(\"NOW...\", m)\n", " for i in 1:N\n", " Li = zeros(J)\n", " for j in 1:J\n", " Li[j] = ifelse(x[i, j] == 1, icc2pl(xq[m], α[j], β[j]), 1 - icc2pl(xq[m], α[j], β[j]))\n", " end\n", " L[i, m] = prod(Li)\n", " Gim[i,m] = L[i,m] * aq[m]\n", " end\n", " if(m == length(xq))\n", " for i in 1:N\n", " Gim[i,m] = Gim[i,m]/sum(Gim[i,:])\n", " end\n", " end\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以上,Juliaでfor loopをいっぱい回すときは,関数にすることを心がけましょう。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Julia 1.1.0", "language": "julia", "name": "julia-1.1" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.1.0" } }, "nbformat": 4, "nbformat_minor": 2 }