{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ソフトマックス関数を用いた混合ガウスモデルの計算\n", "> \"Gaussian Mixture Model with softmax function\"\n", "\n", "- toc: true\n", "- categories: [MachineLearning, SignalProcessing, Python]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## はじめに\n", "\n", "多峰性のある分布表現する方法として,ガウス分布の線形結合を用いる混合ガウスモデルが広く知られています.\n", "しかしがなら,観測したサンプルから混合ガウスモデルのパラメータを解析的の求めることはできません.\n", "そのため,EMアルゴリズム等を用いて数値計算的にパラメータを推定します.\n", "\n", "EMアルゴリズムを用いた混合ガウスモデルのパラメータ推定は多くの資料で解説されています.\n", "これらの資料を読み解こうと何度か挑戦しましたが,どうもしっくりと来ませんでした.\n", "ですが,参考[1]述べられている,線形結合の重みとしてソフトマックス関数を用いる方法がとても理解しやすかったので,自分なりの理解をまとめました." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## やったこと\n", "\n", "* 線形結合の重みとしてソフトマックス関数を用いた場合における各パラメータの導関数を計算\n", "* 勾配法によって混合ガウスモデルのパラメータを推定\n", "* EMアルゴリズムによって混合ガウスモデルのパラメータを推定" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 混合ガウスモデルについて\n", "平均を$\\mu$,標準偏差を$\\sigma$としたとき,ガウス分布$\\phi\\left(x; \\mu, \\sigma\\right)$は以下の様に表されます.\n", "\n", "$$\n", "\\begin{aligned}\n", "\\phi\\left(x; \\mu, \\sigma\\right) & = \\frac{1}{\\sqrt{2 \\pi \\sigma^{2}}} \\exp{\\left(- \\frac{\\left(x - \\mu\\right)^{2}}{2\\sigma^{2}}\\right)}\n", "\\end{aligned}\n", "$$\n", "\n", "この時,ガウス分布の線形結合からなる混合ガウスモデル$q\\left(x; \\vec{\\theta}\\right)$は以下の様に表すことができます.\n", "\n", "\n", "$$\n", "\\begin{aligned}\n", "q\\left(x; \\vec{\\theta}\\right) &= \\sum_{l=1}^{m} w_{l}\\phi\\left(x; \\mu_{l}, \\sigma_{l}\\right)\n", "\\end{aligned}\n", "$$\n", "\n", "ここで,$\\vec{\\theta} = \\left(w_{1}, \\dots, w_{m},\\mu_{1}, \\dots, \\mu_{m}, \\sigma_{1}, \\dots, \\sigma_{m}\\right)$であり,$w_{l}$は線形結合の重みを,$\\mu_{l}$は各ガウス分布の平均を,$\\sigma_{l}$は各ガウス分布の標準偏差をそれぞれ表します.\n", "\n", "$q\\left(x; \\vec{\\theta}\\right)$ が分布であるためには,$x$に対して\n", "\n", "$$\n", "\\begin{aligned}\n", "q\\left(x; \\vec{\\theta}\\right) &\\geq 0 \\\\\n", "\\int q\\left(x; \\vec{\\theta}\\right) &= 1\n", "\\end{aligned}\n", "$$\n", "\n", "である必要があります.従って,$w_{1}, \\dots, w_{m}$は以下の条件を満たす必要があります.\n", "\n", "$$\n", "\\begin{aligned}\n", "w_{1} \\dots w_{m} &\\geq 0 \\\\\n", "\\sum_{l=1}^{m} w_{l} &= 1\n", "\\end{aligned}\n", "$$\n", "\n", "次に最尤推定によってパラメータ$\\vec{\\theta}$を推定します.ここでは,尤度関数$L\\left(\\vec{\\theta}\\right)$を以下の様に定義します.\n", "\n", "$$\n", "\\begin{aligned}\n", "L\\left(\\vec{\\theta}\\right) &= \\prod _{i=1}^{n} q\\left(x_{i}; \\vec{\\theta}\\right)\n", "\\end{aligned}\n", "$$\n", "\n", "また,$w_{l}$に対する拘束条件を考慮しなくてはなりません.従って,最尤推定量$\\vec{\\theta}$は以下の様に表されます.\n", "\n", "$$\n", "\\begin{aligned}\n", "\\hat{\\vec{\\theta}} &= \\argmax_{\\vec{\\theta}} L\\left(\\vec{\\theta}\\right) & \\text{subject to } \\begin{cases}\n", " w_{1} \\dots w_{m} &\\geq 0\\\\\n", "\\sum_{l=1}^{m} w_{l} &= 1\n", " \\end{cases}\n", "\\end{aligned}\n", "$$\n", "\n", "\n", "拘束条件付きの最適化にはラグランジュの未定乗数法が多く用いられていると思います.今回のケースにおいても同様です.\n", "\n", "しかしながら,ラグランジュの未定乗数法を用いる方法はあまりしっくり来きませんでした.(理解できなかった)\n", "そこで,ここでは参考[1]にある$w_{l}$としてソフトマック関数を用いる方法を採用します.従って$w_{l}$は$\\gamma_{l}$を用いて\n", "\n", "$$\n", "\\begin{aligned}\n", "w_{l} &= \\frac{\\exp{\\left(\\gamma_{l}\\right)}}{\\sum_{l' = 1}^{m}\\exp{\\left(\\gamma_{l’}\\right)}}\n", "\\end{aligned}\n", "$$\n", "\n", "と表現することにします.ソフトマックス関数は上述の拘束条件を満たすため,拘束条件を考えることなく最適化を行うことができます.また,ここでは1次元の混合ガウスモデルについて述べましたが,多次元の場合も同様です." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 勾配法によるパラメータの推定\n", "\n", "\n", "それでは,対数尤度関数$\\log L\\left(\\vec{\\theta}\\right)$の各パラメータに対する導関数を計算していきます.まずは,$\\gamma_{l}$についてです.\n", "\n", "$$\n", "\\begin{aligned}\n", "\\frac{\\partial w_{k}}{\\partial \\gamma_{l}} &= \\frac{\\partial}{\\partial \\gamma_{l}} \\frac{\\exp{\\left(\\gamma_{k}\\right)}}{\\sum_{k'=1}^{m}\\exp{\\left(\\gamma_{k'}\\right)}} \\\\\n", "&= \\frac{\\exp{\\left(\\gamma_{k}\\right)}}{\\sum_{k'=1}^{m}\\exp{\\left(\\gamma_{k'}\\right)}}\\delta_{k,l} - \\frac{\\exp{\\left(\\gamma_{k}\\right)}\\exp{\\left(\\gamma_{l}\\right)}}{\\left(\\sum_{k'=1}^{m}\\exp{\\left(\\gamma_{k'}\\right)}\\right)^{2}}\\\\\n", "&= \\frac{\\exp{\\left(\\gamma_{k}\\right)}}{\\sum_{k'=1}^{m}\\exp{\\left(\\gamma_{k'}\\right)}}\\left(\\delta_{k,l} - \\frac{\\exp{\\left(\\gamma_{l}\\right)}}{\\sum_{k'=1}^{m}\\exp{\\left(\\gamma_{k'}\\right)}}\\right)\\\\\n", "&= w_{k}\\left(\\delta_{k,l} - w_{l}\\right)\\\\\n", "\\frac{\\partial}{\\partial \\gamma_{l}} \\log L\\left(\\vec{\\theta}\\right) &= \\frac{\\partial}{\\partial \\gamma_{l}} \\sum_{i=1}^{n} \\log \\sum_{k}^{m} w_{k}\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right) \\\\\n", " &= \\sum_{i=1}^{n}\\frac{1}{\\sum_{k'=1}^{m} w_{k'}\\phi\\left(x_{i}; \\mu_{k'}, \\sigma_{k'}\\right)} \\sum_{k=1}^{m} \\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right) \\frac{\\partial w_{k}}{\\partial \\gamma_{l}}\\\\\n", " &= \\sum_{i=1}^{n} \\frac{\\sum_{k=1}^{m} w_{k}\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)\\delta_{k,l} -w_{l}\\sum_{k=1}^{m} w_{k}\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)}{\\sum_{k'=1}^{m} w_{k'}\\phi\\left(x_{i}; \\mu_{k'}, \\sigma_{k'}\\right)} \\\\\n", " &= \\sum_{i=1}^{n} \\frac{w_{l}\\phi\\left(x_{i}; \\mu_{l}, \\sigma_{l}\\right)}{\\sum_{k'=1}^{m} w_{k'}\\phi\\left(x_{i}; \\mu_{k'}, \\sigma_{k'}\\right)} - w_{l} \\\\\n", "\\end{aligned}\n", "$$\n", "\n", "次に,$\\mu_{l}$についてです.\n", "\n", "$$\n", "\\begin{aligned}\n", "\\frac{\\partial}{\\partial \\mu_{l}}\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right) &= \\frac{1}{\\sqrt{2 \\pi \\sigma_{k}^{2}}}\\frac{\\partial}{\\partial \\mu_{l}} \\exp{\\left(- \\frac{\\left(x_{i} - \\mu_{k}\\right)^{2}}{2\\sigma_{k}^{2}}\\right)} \\\\\n", "&= \\frac{1}{\\sqrt{2 \\pi \\sigma_{k}^{2}}} \\exp{\\left(- \\frac{\\left(x_{i} - \\mu_{k}\\right)^{2}}{2\\sigma_{k}^{2}}\\right)} \\frac{x_{i} - \\mu_{k}}{\\sigma_{k}^{2}} \\delta_{k,l} \\\\\n", "&= \\frac{x_{i} - \\mu_{k}}{\\sigma_{k}^{2}} \\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right) \\delta_{k,l}\\\\\n", "\\frac{\\partial}{\\partial \\mu_{l}} \\log L\\left(\\vec{\\theta}\\right) &= \\frac{\\partial}{\\partial \\mu_{l}} \\sum_{i=1}^{n} \\log \\sum_{k}^{m} w_{k}\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right) \\\\\n", "&= \\sum_{i=1}^{n}\\frac{1}{\\sum_{k'=1}^{m} w_{k'}\\phi\\left(x_{i}; \\mu_{k'}, \\sigma_{k'}\\right)} \\sum_{k=1}^{m} w_{k} \\frac{\\partial}{\\partial \\mu_{l}} \\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)\\\\\n", "&= \\sum_{i=1}^{n}\\frac{1}{\\sum_{k'=1}^{m} w_{k'}\\phi\\left(x_{i}; \\mu_{k'}, \\sigma_{k'}\\right)} \\sum_{k=1}^{m} w_{k} \\frac{x_{i} - \\mu_{k}}{\\sigma_{k}^{2}} \\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right) \\delta_{k,l}\\\\\n", "&= \\sum_{i=1}^{n}\\frac{1}{\\sum_{k'=1}^{m} w_{k'}\\phi\\left(x_{i}; \\mu_{k'}, \\sigma_{k'}\\right)} w_{l} \\frac{x_{i} - \\mu_{l}}{\\sigma_{l}^{2}} \\phi\\left(x_{i}; \\mu_{l}, \\sigma_{l}\\right)\\\\\n", "&= \\frac{1}{\\sigma_{l}^{2}} \\sum_{i=1}^{n} \\left(x_{i} - \\mu_{l}\\right) \\frac{w_{l} \\phi\\left(x_{i}; \\mu_{l}, \\sigma_{l}\\right) }{\\sum_{k'=1}^{m} w_{k'}\\phi\\left(x_{i}; \\mu_{k'}, \\sigma_{k'}\\right)} \\\\\n", "\\end{aligned}\n", "$$\n", "\n", "最後に,$\\sigma_{l}$についてです.\n", "\n", "$$\n", "\\begin{aligned}\n", "\\frac{\\partial}{\\partial \\sigma_{l}}\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right) &= \\frac{1}{\\sqrt{2 \\pi}} \\frac{\\partial}{\\partial \\sigma_{l}} \\sigma_{k}^{-1} \\exp{\\left(- \\frac{\\left(x_{i} - \\mu_{k}\\right)^{2}}{2\\sigma_{k}^{2}}\\right)} \\\\\n", "&= -\\frac{1}{\\sqrt{2\\pi}\\sigma_{k}^{2}} \\exp{\\left(- \\frac{\\left(x - \\mu_{k}\\right)^{2}}{2\\sigma_{k}^{2}}\\right)} \\delta_{k,l} + \\frac{\\left(x_{i} - \\mu_{k}\\right)^{2}}{\\sqrt{2\\pi}\\sigma_{k}^{4}} \\exp{\\left(- \\frac{\\left(x - \\mu_{k}\\right)^{2}}{2\\sigma_{k}^{2}}\\right)} \\delta_{k,l}\\\\\n", "&= \\frac{\\left(x_{i} - \\mu_{k}\\right)^{2} - \\sigma_{k}^{2}}{\\sigma_{k}^{3}} \\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)\\delta_{k,l}\\\\\n", "\\frac{\\partial}{\\partial \\sigma_{l}} \\log L\\left(\\vec{\\theta}\\right) &= \\frac{\\partial}{\\partial \\sigma_{l}} \\sum_{i=1}^{n} \\log \\sum_{k}^{m} w_{k}\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right) \\\\\n", "&= \\sum_{i=1}^{n}\\frac{1}{\\sum_{k'=1}^{m} w_{k'}\\phi\\left(x_{i}; \\mu_{k'}, \\sigma_{k'}\\right)} \\sum_{k=1}^{m} w_{k} \\frac{\\partial}{\\partial \\sigma_{l}} \\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)\\\\\n", "&= \\sum_{i=1}^{n}\\frac{1}{\\sum_{k'=1}^{m} w_{k'}\\phi\\left(x_{i}; \\mu_{k'}, \\sigma_{k'}\\right)} \\sum_{k=1}^{m} w_{k} \\frac{\\left(x_{i} - \\mu_{k}\\right)^{2} - \\sigma_{k}^{2}}{\\sigma_{k}^{3}} \\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)\\delta_{k,l}\\\\\n", "&= \\frac{1}{\\sigma_{l}^{3}}\\sum_{i=1}^{n} \\left(\\left(x_{i} - \\mu_{l}\\right)^{2} - \\sigma_{l}^{2}\\right) \\frac{w_{l}\\phi\\left(x_{i}; \\mu_{l}, \\sigma_{l}\\right)}{\\sum_{k'=1}^{m} w_{k'}\\phi\\left(x_{i}; \\mu_{k'}, \\sigma_{k'}\\right)} \n", "\\end{aligned}\n", "$$\n", "\n", "ここで,\n", "\n", "$$\n", "\\begin{aligned}\n", "\\eta_{i,l} &= \\frac{w_{l}\\phi\\left(x_{i}; \\mu_{l}, \\sigma_{l}\\right)}{\\sum_{k'=1}^{m} w_{k'}\\phi\\left(x_{i}; \\mu_{k'}, \\sigma_{k'}\\right)}\n", "\\end{aligned}\n", "$$\n", "\n", "と置きます.すると最終的な導関数は以下の様に表されます.\n", "\n", "$$\n", "\\begin{aligned}\n", "\\frac{\\partial}{\\partial \\gamma_{l}} \\log L\\left(\\vec{\\theta}\\right) &= \\sum_{i=1}^{n} \\eta_{i,l} - w_{l} \\\\\n", "\\frac{\\partial}{\\partial \\mu_{l}} \\log L\\left(\\vec{\\theta}\\right) &= \\frac{1}{\\sigma_{l}^{2}} \\sum_{i=1}^{n} \\left(x_{i} - \\mu_{l}\\right) \\eta_{i,l}\\\\\n", "\\frac{\\partial}{\\partial \\sigma_{l}} \\log L\\left(\\vec{\\theta}\\right) &=\\frac{1}{\\sigma_{l}^{3}}\\sum_{i=1}^{n} \\left(\\left(x_{i} - \\mu_{l}\\right)^{2} - \\sigma_{l}^{2}\\right)\\eta_{i,l}\n", "\\end{aligned}\n", "$$\n", "\n", "\n", "次に勾配法を用いてパラメータの推定を行っていきます.具体的な手順は以下の通りです.\n", "\n", "1. $\\hat{\\gamma_{l}}$, $\\hat{\\mu_{l}}$, $\\hat{\\sigma_{l}}$をランダムな値で初期化\n", "2. 以下に示す様にパラメータを更新 \n", "$$\n", "\\begin{aligned}\n", "\\hat{\\gamma_{l}} &= \\hat{\\gamma_{l}} + \\epsilon \\frac{\\partial}{\\partial \\gamma_{l}} \\log L\\left(\\vec{\\theta}\\right) \\\\\n", "\\hat{\\mu_{l}} &= \\hat{\\mu_{l}} + \\epsilon \\frac{\\partial}{\\partial \\mu_{l}} \\log L\\left(\\vec{\\theta}\\right) \\\\\n", "\\hat{\\sigma_{l}} &= \\hat{\\sigma_{l}} + \\epsilon \\frac{\\partial}{\\partial \\sigma_{l}} \\log L\\left(\\vec{\\theta}\\right) \\\\\n", "\\end{aligned}\n", "$$\n", "3. 一定回数または収束するまで2.を繰り返す\n", "\n", "\n", "\n", "実際にパラメータ推定に関するコードを記述する前に,推定を行うデータを作成します.\n", "今回,$\\vec{\\mu} = \\left(3.1, -1.3, 0.5\\right)$, $\\vec{\\sigma} = \\left(0.8, 1.0, 0.5 \\right)$, $\\vec{w} = \\left(0.25, 0.25, 0.5\\right)$として,2048個の標本をサンプリングすることで検証用のデータを作成しました." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "lines_to_next_cell": 0 }, "outputs": [], "source": [ "import numpy as np\n", "np.random.seed(99)\n", "\n", "mus = np.array([3.1, -1.3, 0.5])\n", "sigmas = np.array([0.8, 1.0, 0.5])\n", "ns = np.array([512, 512, 1024])\n", "N = np.sum(ns)\n", "ws = ns / N\n", "\n", "data = [np.random.normal(m, s, n) for n, m, s in zip(ns, mus, sigmas)]\n", "all_data = np.hstack(data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "勾配法によるパラメータ推定を行います.`train_grad` が上述したパラメータ推定の根幹になります.\n", "ここでは,$\\epsilon = 0.0001$とし,パラメータの変化量の絶対値和が一定値以下となる場合を収束と判定しています." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 9.995055674503072e-08, iterations: 16943\n" ] } ], "source": [ "def phi(xs, mus, sigmas):\n", " num = np.exp(-0.5 * ((xs - mus[:,np.newaxis]) / sigmas[:,np.newaxis]) ** 2.0)\n", " den = np.sqrt(2 * np.pi) * sigmas[:,np.newaxis]\n", " return num / den\n", "\n", "def eta(xs, ws, mus, sigmas):\n", " num = ws[:, np.newaxis] * phi(xs, mus, sigmas)\n", " den = np.sum(num, axis=0)\n", " return num / den\n", "\n", "def softmax(gammas):\n", " num = np.exp(gammas)\n", " den = np.sum(num, axis=0)\n", " return num / den\n", "\n", "def grad_gamma(etas, ws):\n", " return np.sum(etas, axis=1) - etas.shape[1] * ws\n", "\n", "def grad_mu(xs, etas, mus, sigmas):\n", " num = np.sum((xs - mus[:, np.newaxis]) * etas, axis=1)\n", " den =sigmas ** 2.0\n", " return num / den\n", "\n", "def grad_sigma(xs, etas, mus, sigmas):\n", " num = np.sum((((xs - mus[:, np.newaxis]) / sigmas[:,np.newaxis]) ** 2.0 - 1.0) * etas, axis=1)\n", " den = sigmas \n", " return num / den\n", "\n", "def train_grad(xs):\n", " mus_hat = 2 * np.max(np.abs(xs)) * (np.random.rand(3) - 0.5)\n", " sigmas_hat = np.random.rand(3)\n", " gammas_hat = np.random.rand(3)\n", "\n", " eps = 0.0001\n", " loss = float('inf')\n", " criteria = 1e-10\n", " iter = 0;\n", " while iter < 30000:\n", " ws_hat = softmax(gammas_hat)\n", " etas_hat = eta(xs, ws_hat, mus_hat, sigmas_hat)\n", "\n", " delta_gammas = eps * grad_gamma(etas_hat, ws_hat)\n", " delta_mus = eps * grad_mu(xs, etas_hat, mus_hat, sigmas_hat)\n", " delta_sigmas = eps * grad_sigma(xs, etas_hat, mus_hat, sigmas_hat)\n", "\n", " gammas_hat += delta_gammas\n", " mus_hat += delta_mus\n", " sigmas_hat += delta_sigmas\n", " loss = np.sum(np.abs(np.hstack([delta_gammas, delta_mus, delta_sigmas])))\n", " if loss < criteria:\n", " break\n", " iter += 1\n", "\n", " ws_hat = softmax(gammas_hat)\n", " return loss, iter, (ws_hat, mus_hat, sigmas_hat)\n", "\n", "loss, iter, (ws_hat_grad, mus_hat_grad, sigmas_hat_grad) = train_grad(all_data)\n", "print(f\"loss: {loss}, iterations: {iter}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "最後に推定したパラメータ`ws_hat_grad`, `mus_hat_grad`, `sigmas_hat_grad` を用いた分布と,訓練データとをプロットして比較します.\n", "勾配法による推定は,パラメータの並びについて曖昧さがあります.これは,$\\hat{w_{l}}$と$w_{l}$とが必ずしも対応するとは限らないことを意味します.(他のパラメータも同様)\n", "\n", "そこで,プロットに先立ち`ws_hat_grad`, `mus_hat_grad`, `sigmas_hat_grad`と`ws`, `mus`, `sigmas`とを比較して,並び順を入れ変えることでこの問題に対応します.\n", "この処理は`permute` にて行われます." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.LayerChart(...)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import altair as alt\n", "import pandas as pd\n", "\n", "def permute(ws, mus, sigmas, ws_hat, mus_hat, sigmas_hat):\n", " preds = np.vstack([ws_hat, mus_hat, sigmas_hat])\n", " acts = np.vstack([ws, mus, sigmas])\n", " corr = preds.T @ acts / (np.linalg.norm(preds, axis=0)[:, np.newaxis] @ np.linalg.norm(acts, axis=0)[np.newaxis,:])\n", " perm = np.argmax(corr, axis=0)\n", " ws_hat = ws_hat[perm]\n", " mus_hat = mus_hat[perm]\n", " sigmas_hat = sigmas_hat[perm]\n", " return (ws_hat, mus_hat, sigmas_hat)\n", "\n", "def plot(data, ws_hat, mus_hat, sigmas_hat):\n", " bins = np.linspace(-6, 6, 128)\n", " hists = {f\"Class {i}\": (128 / (12 * N)) * np.histogram(d, bins=bins)[0] for i, d in enumerate(data)}\n", "\n", " bin_centers = (bins[:-1] + bins[1:]) / 2\n", " bars = alt.Chart(pd.DataFrame({\n", " \"Bin\": bin_centers,\n", " **hists,\n", " })).transform_fold(\n", " fold=[f\"Class {i}\" for i in range(len(data))],\n", " as_=[\"Class\", \"Probability\"]\n", " ).mark_bar(opacity=0.5).encode(\n", " alt.X(\"Bin:Q\"),\n", " alt.Y(\"Probability:Q\"),\n", " alt.Color('Class:N')\n", " )\n", "\n", " envelopes = {f\"Class {i}\": v for i, v in enumerate(ws_hat[:, np.newaxis] * phi(bin_centers, mus_hat, sigmas_hat))}\n", " lines = alt.Chart(\n", " pd.DataFrame({\n", " \"Bin\": bin_centers,\n", " **envelopes,\n", " })).transform_fold(\n", " fold=[f\"Class {i}\" for i in range(len(data))],\n", " as_=[\"Class\", \"Probability\"]\n", " ).mark_line(size=3).encode(\n", " alt.X(\"Bin:Q\"),\n", " alt.Y(\"Probability:Q\"),\n", " alt.Color('Class:N')\n", " )\n", "\n", " return bars + lines\n", "\n", "ws_hat_grad, mus_hat_grad, sigmas_hat_grad = permute(ws, mus, sigmas, ws_hat_grad, mus_hat_grad, sigmas_hat_grad)\n", "plot(data, ws_hat_grad, mus_hat_grad, sigmas_hat_grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以上の結果より,適切に推定できていることが確認できました.しかしながら,勾配法による方法は\n", "\n", "* 初期値によっては推定に失敗する場合がある\n", "* $\\epsilon$の選択によっては推定に失敗する場合がある\n", "* 収束が遅い\n", "\n", "という問題があります.これらの問題に対応するため,EMアルゴリズムによるパラメータ推定を以下に述べます." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## EMアルゴリズムによるパラメータの推定\n", "次にEMアルゴリズムによるパラメータの推定を行います.EMアルゴリズムは混合ガウスモデルのパラメータ推定以外にも利用可能な汎用的なアルゴリズムです.しかしながら,ここではあまり深入りせずに混合ガウスモデルのパラメータ推定方法の一つとして取り扱います.\n", "\n", "EMアルゴリズムはEステップとMステップを交互に繰り返すことで実現します.両ステップについての詳細を以下に述べます.\n", "\n", "\n", "### Eステップ\n", "Eステップでは,対数尤度関数$\\log L\\left(\\vec{\\theta}\\right)$に対して,現在の推定値$\\hat{\\vec{\\theta}}$で接する下界$b\\left(\\vec{\\theta}\\right)$を求めます.すなわち,$b\\left(\\vec{\\theta}\\right)$は以下の関係を満たす必要があります.\n", "\n", "$$\n", "\\begin{aligned}\n", "b\\left(\\vec{\\theta}\\right) &\\leq \\log L\\left(\\vec{\\theta}\\right) \\\\\n", "b\\left(\\hat{\\vec{\\theta}}\\right) &= \\log L\\left(\\hat{\\vec{\\theta}}\\right) \\\\\n", "\\end{aligned}\n", "$$\n", "\n", "下界$b\\left(\\vec{\\theta}\\right)$を求めると聞くと非常に複雑そうに思えます.しかしながら,混合ガウスモデルのパラメータ推定という用途に限定するとそこまで複雑でもありません.まず,対数尤度関数$\\log L\\left(\\vec{\\theta}\\right)$に対して[イェンセンの不等式](https://ja.wikipedia.org/wiki/%E3%82%A4%E3%82%A7%E3%83%B3%E3%82%BB%E3%83%B3%E3%81%AE%E4%B8%8D%E7%AD%89%E5%BC%8F)を適用し,下界$b'\\left(\\vec{\\theta}; \\bf{A}\\right)$を求めていきます.ここで,$\\bf{A} = \\left\\{a_{i,l}\\right\\}$であり,$b\\left(\\vec{\\theta}\\right) = b'\\left(\\vec{\\theta}; \\hat{\\bf{A}}\\right)$であるとします.\n", "\n", "$$\n", "\\begin{aligned}\n", "\\log L\\left(\\vec{\\theta}\\right) &= \\sum_{i=1}^{n} \\log \\sum_{l}^{m} w_{l}\\phi\\left(x_{i}; \\mu_{l}, \\sigma_{l}\\right) \\\\\n", "&= \\sum_{i=1}^{n} \\log \\sum_{l}^{m} a_{i,l}\\frac{w_{l}\\phi\\left(x_{i}; \\mu_{l}, \\sigma_{l}\\right)}{a_{i,l}} \\\\\n", "&\\geq \\sum_{i=1}^{n} \\sum_{l}^{m} a_{i,l} \\log\\frac{w_{l}\\phi\\left(x_{i}; \\mu_{l}, \\sigma_{l}\\right)}{a_{i,l}} \\\\\n", "&= b'\\left(\\vec{\\theta}; \\bf{A}\\right)\n", "\\end{aligned}\n", "$$\n", "\n", "次に,$b'\\left(\\hat{\\vec{\\theta}}; \\hat{\\bf{A}}\\right) = \\log L\\left(\\hat{\\vec{\\theta}}\\right)$となる$\\hat{\\bf{A}}$を求めます.\n", "$$\n", "\\begin{aligned}\n", "b'\\left(\\hat{\\vec{\\theta}}; \\bf{A}\\right) &= \\sum_{i=1}^{n} \\sum_{l}^{m} a_{i,l} \\log\\frac{\\hat{w_{l}}\\phi\\left(x_{i}; \\hat{\\mu_{l}}, \\hat{\\sigma_{l}}\\right)}{a_{i,l}} \\\\\n", "& \\text{ここで$\\hat{\\vec{\\theta}}$に対する$\\eta_{i,l} $である$\\hat{\\eta}_{i,l}$を考えます} \\\\\n", "\\hat{\\eta}_{i,l} &= \\frac{\\hat{w}_{l}\\phi\\left(x_{i}; \\hat{\\mu}_{l}, \\hat{\\sigma}_{l}\\right)}{\\sum_{k'=1}^{m} \\hat{w}_{k'}\\phi\\left(x_{i}; \\hat{\\mu}_{k'}, \\hat{\\sigma}_{k'}\\right)} \\\\\n", "& \\text{そして$a_{i,l} = \\hat{\\eta}_{i,l}$とすると} \\\\\n", "b'\\left(\\hat{\\vec{\\theta}}; \\bf{A}\\right) &= \\sum_{i=1}^{n} \\sum_{l}^{m} \\hat{\\eta}_{i,l} \\log\\frac{\\hat{w_{l}}\\phi\\left(x_{i}; \\hat{\\mu_{l}}, \\hat{\\sigma_{l}}\\right)}{\\hat{\\eta}_{i,l}} \\\\\n", "&= \\sum_{i=1}^{n} \\sum_{l}^{m} \\hat{\\eta}_{i,l} \\log \\sum_{l'=1}^{m} \\hat{w}_{l'}\\phi\\left(x_{i}; \\hat{\\mu}_{l'}, \\hat{\\sigma}_{l'}\\right) \\\\\n", "&= \\left(\\sum_{l}^{m} \\hat{\\eta}_{i,l}\\right)\\sum_{i=1}^{n} \\log \\sum_{l'=1}^{m} \\hat{w}_{l'}\\phi\\left(x_{i}; \\hat{\\mu}_{l'}, \\hat{\\sigma}_{l'}\\right) \\\\\n", "&= \\sum_{i=1}^{n} \\log \\sum_{l'=1}^{m} \\hat{w}_{l'}\\phi\\left(x_{i}; \\hat{\\mu}_{l'}, \\hat{\\sigma}_{l'}\\right) \\\\\n", "&= \\log L\\left(\\vec{\\theta}\\right)\n", "\\end{aligned}\n", "$$\n", "\n", "従って,$\\hat{\\bf{A}} = \\left\\{\\hat{\\eta}_{i,l}\\right\\}$とすると,\n", "\n", "$$\n", "\\begin{aligned}\n", "b\\left(\\vec{\\theta}\\right) &= b'\\left(\\vec{\\theta}; \\hat{\\bf{A}}\\right)\\\\\n", "&= \\sum_{i=1}^{n} \\sum_{l}^{m} \\hat{\\eta}_{i,l} \\log\\frac{w_{l}\\phi\\left(x_{i}; \\mu_{l}, \\sigma_{l}\\right)}{\\hat{\\eta}_{i,l}} \\\\\n", "\\end{aligned}\n", "$$\n", "\n", "以上の様に,下界$b\\left(\\vec{\\theta}\\right)$を求めることは,現在の推定値$\\hat{\\vec{\\theta}}$を用いて$\\hat{\\eta}_{i,l}$を求めることで実現します.\n", "\n", "\n", "### Mステップ\n", "Mステップでは,$b\\left(\\vec{\\theta}\\right)$が最大となる$\\hat{\\theta}'$を求めます.$\\hat{\\theta}'$は$b\\left(\\vec{\\theta}\\right)$の各パラメータに対する導関数が$0$となる$\\theta$を採用します.まずは, $\\gamma_{l}$についてです\n", "\n", "$$\n", "\\begin{aligned}\n", "\\frac{\\partial}{\\partial \\gamma_{l}} b\\left(\\vec{\\theta}\\right) &= \\frac{\\partial}{\\partial \\gamma_{l}} \\sum_{i=1}^{n} \\sum_{k}^{m} \\hat{\\eta}_{i,k} \\log\\frac{w_{k}\\phi\\left(x_{i}; \\mu_{k}, {\\sigma_{k}}\\right)}{\\hat{\\eta}_{i,k}} \\\\\n", "&= \\sum_{i=1}^{n} \\sum_{k}^{m} \\hat{\\eta}_{i,k} \\frac{\\hat{\\eta}_{i,k}}{w_{k}\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)} \\frac{\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)}{\\gamma_{l}} \\frac{\\partial w_{k}}{\\partial \\gamma_{l}} \\\\\n", "&= \\sum_{i=1}^{n} \\sum_{k}^{m} \\frac{\\hat{\\eta}_{i,k}}{w_{k}} w_{k}\\left(\\delta_{k,l} - w_{l}\\right) \\\\\n", "&= \\sum_{i=1}^{n} \\sum_{k}^{m} \\hat{\\eta}_{i,k}\\delta_{k,l} - w_{l}\\hat{\\eta}_{i,k} \\\\\n", "&= \\sum_{i=1}^{n} \\hat{\\eta}_{i,l} - w_{l}\\left(\\sum_{k}^{m} \\hat{\\eta}_{i,k}\\right) \\\\\n", "&= \\sum_{i=1}^{n} \\hat{\\eta}_{i,l} - w_{l} \\\\\n", "& \\text{$\\frac{\\partial}{\\partial \\gamma_{l}} b\\left(\\hat{\\vec{\\theta}}\\right) = 0$より}\\\\\n", " \\sum_{i=1}^{n} \\hat{w}_{l} &= \\sum_{i=1}^{n} \\hat{\\eta}_{i,l} \\\\\n", " n\\hat{w}_{l} &= \\sum_{i=1}^{n} \\hat{\\eta}_{i,l} \\\\\n", " \\hat{w}_{l} &= \\frac{1}{n} \\sum_{i=1}^{n} \\hat{\\eta}_{i,l}\n", "\\end{aligned}\n", "$$\n", "\n", "次に,$\\mu_{l}$についてです.\n", "\n", "$$\n", "\\begin{aligned}\n", "\\frac{\\partial}{\\partial \\mu_{l}} b\\left(\\vec{\\theta}\\right) &= \\frac{\\partial}{\\partial \\mu_{l}} \\sum_{i=1}^{n} \\sum_{k}^{m} \\hat{\\eta}_{i,k} \\log\\frac{w_{k}\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)}{\\hat{\\eta}_{i,k}} \\\\\n", "&= \\sum_{i=1}^{n} \\sum_{k}^{m} \\hat{\\eta}_{i,k} \\frac{\\hat{\\eta}_{i,k}}{w_{k}\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)} \\frac{ w_{k}}{\\hat{\\eta}_{i,k} } \\frac{\\partial }{\\partial \\mu_{l}} \\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)\\\\\n", "&= \\sum_{i=1}^{n} \\sum_{k}^{m} \\frac{\\hat{\\eta}_{i,k}}{\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)} \\frac{x_{i} - \\mu_{k}}{\\sigma_{k}^{2}} \\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right) \\delta_{k,l}\\\\\n", "&= \\frac{1}{\\sigma_{l}^{2}} \\sum_{i=1}^{n} \\hat{\\eta}_{i,l}x_{i} - \\hat{\\eta}_{i,l}\\mu_{l}\\\\\n", "& \\text{$\\frac{\\partial}{\\partial \\mu_{l}} b\\left(\\hat{\\vec{\\theta}}\\right) = 0$より}\\\\\n", "\\hat{\\mu}_{l} \\sum_{i=1}^{n}\\hat{\\eta}_{i,l}&= \\sum_{i=1}^{n}\\hat{\\eta}_{i,l}x_{i} \\\\\n", "\\hat{\\mu}_{l} &= \\frac{\\sum_{i=1}^{n}\\hat{\\eta}_{i,l}x_{i} }{\\sum_{i=1}^{n}\\hat{\\eta}_{i,l}}\\\\\n", "\\end{aligned}\n", "$$\n", "\n", "最後に,$\\sigma_{l}$についてです.\n", "\n", "$$\n", "\\begin{aligned}\n", "\\frac{\\partial}{\\partial \\sigma_{l}} b\\left(\\vec{\\theta}\\right) &= \\frac{\\partial}{\\partial \\sigma_{l}} \\sum_{i=1}^{n} \\sum_{k}^{m} \\hat{\\eta}_{i,k} \\log\\frac{w_{k}\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)}{\\hat{\\eta}_{i,k}} \\\\\n", "&= \\sum_{i=1}^{n} \\sum_{k}^{m} \\hat{\\eta}_{i,k} \\frac{\\hat{\\eta}_{i,k}}{w_{k}\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)} \\frac{ w_{k}}{\\hat{\\eta}_{i,k} } \\frac{\\partial }{\\partial \\sigma_{l}} \\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)\\\\\n", "&= \\sum_{i=1}^{n} \\sum_{k}^{m} \\frac{\\hat{\\eta}_{i,k}}{\\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)}\\frac{\\left(x_{i} - \\mu_{k}\\right)^{2} - \\sigma_{k}^{2}}{\\sigma_{k}^{3}} \\phi\\left(x_{i}; \\mu_{k}, \\sigma_{k}\\right)\\delta_{k,l}\\\\\n", "&= \\frac{1}{\\sigma_{l}^{3}} \\sum_{i=1}^{n} \\hat{\\eta}_{i,l}\\left(x_{i} - \\mu_{l}\\right)^{2} - \\hat{\\eta}_{i,l}\\sigma_{l}^{2} \\\\\n", "& \\text{$\\frac{\\partial}{\\partial \\mu_{l}} b\\left(\\hat{\\vec{\\theta}}\\right) = 0$より}\\\\\n", "\\hat{\\sigma}_{l}^{2} \\sum_{i=1}^{n}\\hat{\\eta}_{i,l}&= \\sum_{i=1}^{n} \\hat{\\eta}_{i,l}\\left(x_{i} - \\mu_{l}\\right)^{2} \\\\\n", "\\hat{\\sigma}_{l} &= \\sqrt{\\frac{\\sum_{i=1}^{n} \\hat{\\eta}_{i,l}\\left(x_{i} - \\mu_{l}\\right)^{2}}{\\sum_{i=1}^{n}\\hat{\\eta}_{i,l}}}\\\\\n", "\\end{aligned}\n", "$$\n", "\n", "まとめると,Mステップでは以下の様にパラメータを更新します.\n", "\n", "$$\n", "\\begin{aligned}\n", " \\hat{w}_{l} &= \\frac{1}{n} \\sum_{i=1}^{n} \\hat{\\eta}_{i,l}\\\\\n", "\\hat{\\mu}_{l} &= \\frac{\\sum_{i=1}^{n}\\hat{\\eta}_{i,l}x_{i} }{\\sum_{i=1}^{n}\\hat{\\eta}_{i,l}}\\\\\n", "\\hat{\\sigma}_{l} &= \\sqrt{\\frac{\\sum_{i=1}^{n} \\hat{\\eta}_{i,l}\\left(x_{i} - \\mu_{l}\\right)^{2}}{\\sum_{i=1}^{n}\\hat{\\eta}_{i,l}}}\\\\\n", "\\end{aligned}\n", "$$\n", "\n", "次にEMアルゴリズムを用いてパラメータの推定を行っていきます.具体的な手順は以下の通りです.\n", "\n", "1. $\\hat{\\gamma_{l}}$, $\\hat{\\mu_{l}}$, $\\hat{\\sigma_{l}}$をランダムな値で初期化\n", "2. $\\hat{\\eta_{i,l}}$を更新し,下界$b\\left(\\vec{\\theta}\\right)$を求める(Eステップ)\n", "$$\n", "\\begin{aligned}\n", "\\hat{\\eta_{i,l}} &= \\frac{\\hat{w}_{l}\\phi\\left(x_{i}; \\hat{\\mu}_{l}, \\hat{\\sigma}_{l}\\right)}{\\sum_{k'=1}^{m} \\hat{w}_{k'}\\phi\\left(x_{i}; \\hat{\\mu}_{k'}, \\hat{\\sigma}_{k'}\\right)} \n", "\\end{aligned}\n", "$$\n", "3. 下界$b\\left(\\vec{\\theta}\\right)$を最大化する$\\hat{\\vec{\\theta}}'$を求める(Mステップ)\n", "$$\n", "\\begin{aligned}\n", " \\hat{w}_{l} &= \\frac{1}{n} \\sum_{i=1}^{n} \\hat{\\eta}_{i,l}\\\\\n", "\\hat{\\mu}_{l} &= \\frac{\\sum_{i=1}^{n}\\hat{\\eta}_{i,l}x_{i} }{\\sum_{i=1}^{n}\\hat{\\eta}_{i,l}}\\\\\n", "\\hat{\\sigma}_{l} &= \\sqrt{\\frac{\\sum_{i=1}^{n} \\hat{\\eta}_{i,l}\\left(x_{i} - \\mu_{l}\\right)^{2}}{\\sum_{i=1}^{n}\\hat{\\eta}_{i,l}}}\\\\\n", "\\end{aligned}\n", "$$\n", "4. 一定回数または収束するまで2〜3を繰り返す\n", "\n", "EMアルゴリズムによるパラメータ推定を行うコードは以下の通りです.勾配法の時と同様に`train_em` が上述したパラメータ推定の根幹になります.\n", "そして,全パラメータの変化量の絶対値和が一定値以下となる場合を収束と判定しています.また,入力については勾配法で用いたものを再度使用します." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 9.965148151103165e-11, iterations: 1254\n", "[0.47878854 0.27353509 0.24767637] [ 0.51716133 -1.10900049 3.16175044] [0.51084106 1.06776561 0.76372732]\n" ] } ], "source": [ "def opt_w(etas):\n", " return np.average(etas, axis=1)\n", "\n", "def opt_mu(xs, etas):\n", " num = np.sum(etas * xs, axis=1)\n", " den = np.sum(etas, axis=1)\n", " return num / den\n", "\n", "def opt_sigma(xs, etas, mus):\n", " num = np.sum((xs - mus[:, np.newaxis]) ** 2 * etas, axis=1)\n", " den = np.sum(etas, axis=1)\n", " return np.sqrt(num / den)\n", "\n", "def train_em(xs):\n", " mus_hat = 2 * np.max(np.abs(xs)) * (np.random.rand(3) - 0.5)\n", " sigmas_hat = np.random.rand(3)\n", " ws_hat = np.random.rand(3)\n", "\n", " loss = float('inf')\n", " criteria = 1e-10\n", " iter = 0;\n", " while iter < 10000:\n", " # E step\n", " etas_hat = eta(xs, ws_hat, mus_hat, sigmas_hat)\n", "\n", " # M Step\n", " delta_ws = ws_hat\n", " delta_mus = mus_hat\n", " delta_sigmas = sigmas_hat\n", " ws_hat = opt_w(etas_hat)\n", " mus_hat = opt_mu(xs, etas_hat)\n", " sigmas_hat = opt_sigma(xs, etas_hat, mus_hat)\n", " delta_ws -= ws_hat\n", " delta_mus -= mus_hat\n", " delta_sigmas -= sigmas_hat\n", "\n", " loss = np.sum(np.abs(np.hstack([delta_ws, delta_mus, delta_sigmas])))\n", " if loss < criteria:\n", " break\n", " iter += 1\n", "\n", " return loss, iter, (ws_hat, mus_hat, sigmas_hat)\n", "\n", "loss, iter, (ws_hat_em, mus_hat_em, sigmas_hat_em) = train_em(all_data)\n", "print(f\"loss: {loss}, iterations: {iter}\")\n", "print(ws_hat_em, mus_hat_em, sigmas_hat_em)" ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 0 }, "source": [ "勾配法と比較して,イテレーション回数が10分の1程度であることが確認できます.\n", "勾配法の時と同様に訓練データと推定したパラメータによる分布とをプロットして確認します." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "" ], "text/plain": [ "alt.LayerChart(...)" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ws_hat_em, mus_hat_em, sigmas_hat_em = permute(ws, mus, sigmas, ws_hat_em, mus_hat_em, sigmas_hat_em)\n", "plot(data, ws_hat_em, mus_hat_em, sigmas_hat_em)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "以上の結果より,適切に推定できていることが確認できました.EMアルゴリズムの勾配法に対する利点としては,\n", "\n", "* $\\epsilon$を設定する必要がない\n", "* 勾配法よりも高速に収束する\n", "\n", "という点があげられます.しかしながら,\n", "\n", "* 初期値によっては推定に失敗する場合がある\n", "\n", "という問題は継続して存在します.したがって,実用を考えると複数の初期値でパラメータを推定するなどの工夫が必要になるかもしれません." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 参考文献\n", "1. 杉山将,統計的機械学習 : 生成モデルに基づくパターン認識,オーム社,2009" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "encoding": "# -*- coding: utf-8 -*-", "formats": "ipynb,py:hydrogen" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }