{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "cmap = plt.get_cmap('tab10')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def gaussian_pdf(x, mu, var):\n", " D = len(x)\n", " e = 1e-8\n", " ln_pdf = - 0.5 * ( np.dot(x - mu, x - mu) / (var + e) + D * np.log(np.abs(var)) + D * np.log(2*np.pi) )\n", " return np.exp(ln_pdf)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "N = 100\n", "D = 2\n", "a = 10\n", "b = 100\n", "w = 0.5\n", "#theta = np.random.multivariate_normal(mean=np.zeros(D), cov=np.identity(D))\n", "theta = np.array([1,1])\n", "\n", "X = np.zeros((N,D))\n", "X[:int(w*N)] = np.random.multivariate_normal(mean=theta, cov=np.identity(D), size=int(w*N))\n", "X[int(w*N):] = np.random.multivariate_normal(mean=np.zeros(D), cov=np.identity(D)*a, size=N-int(w*N))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(5,5))\n", "plt.scatter(X[:int(w*N),0], X[:int(w*N),1], s=10, color=cmap(0))\n", "plt.scatter(X[int(w*N):,0], X[int(w*N):,1], s=10, color=cmap(1))\n", "plt.xlim(-6,6)\n", "plt.ylim(-6,6)\n", "plt.tight_layout()\n", "#plt.savefig('data.png', bbox_inches='tight')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1 1]\n", "[ 1.17613383 1.07267182] 0.0327877614257\n" ] } ], "source": [ "max_iter = 100\n", "e = 1e-8\n", "mu_prior = np.array([0.0,0.0])\n", "var_prior = b\n", "mu_new = np.copy(mu_prior)\n", "var_new = np.copy(var_prior)\n", "\n", "mu = np.zeros((N,D))\n", "var = np.ones(N) * 999.0\n", "scale = np.sqrt((2 * np.pi * var)**D)\n", "evidence = np.zeros(max_iter)\n", "\n", "for i in range(max_iter):\n", " for n in range(N):\n", " var_cav = var[n] * var_new / (var[n] - var_new)\n", " mu_cav = mu_new + (mu_new - mu[n]) * var_cav / var[n]\n", " \n", " z_norm = ( (1 - w) * gaussian_pdf(X[n], mu_cav, (1+var_cav)) \n", " + w * gaussian_pdf(X[n], np.zeros(D), a) )\n", " \n", " rho = 1 - w * gaussian_pdf(X[n], np.zeros(D), a) / (z_norm+e)\n", " var_new = ( var_cav - rho * var_cav**2 / (var_cav + 1) \n", " + rho * (1 - rho) * var_cav**2 * np.dot(X[n] - mu_cav, X[n] - mu_cav) / (D * (var_cav + 1)**2) )\n", " mu_new = mu_cav + rho * var_cav * (X[n] - mu_cav) / (var_cav + 1)\n", " \n", " var[n] = var_new * var_cav / (var_cav - var_new)\n", " mu[n] = mu_cav + (var[n] + var_cav) * (mu_new - mu_cav) / var_cav\n", " scale[n] = z_norm / ( np.sqrt((2 * np.pi * var[n])**D) * gaussian_pdf(mu[n], mu_cav, (var[n]+var_cav)) )\n", " \n", "print(theta)\n", "print(mu_new, var_new)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(5,5))\n", "plt.scatter(mu[:,0], mu[:,1], s=5.0, color=cmap(0), alpha=0.8)\n", "plt.scatter(mu_new[0], mu_new[1], color=cmap(3), marker='x')\n", "#plt.scatter(theta[0], theta[1], color=cmap(3))\n", "plt.xlim(-6,6)\n", "plt.ylim(-6,6)\n", "plt.tight_layout()\n", "#plt.savefig('ep_result.png', bbox_inches='tight')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }