{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Gaussian Mixture Models\n", "by Marc Deisenroth" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we will look at density modeling with Gaussian mixture models (GMMs).\n", "In Gaussian mixture models, we describe the density of the data as\n", "$$\n", "p(\\boldsymbol x) = \\sum_{k=1}^K \\pi_k \\mathcal{N}(\\boldsymbol x|\\boldsymbol \\mu_k, \\boldsymbol \\Sigma_k)\\,,\\quad \\pi_k \\geq 0\\,,\\quad \\sum_{k=1}^K\\pi_k = 1\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The goal of this notebook is to get a better understanding of GMMs and to write some code for training GMMs using the EM algorithm. We provide a code skeleton and mark the bits and pieces that you need to implement yourself." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# imports\n", "import numpy as np\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "from scipy.stats import multivariate_normal\n", "import scipy.linalg as la\n", "import matplotlib.cm as cm\n", "from matplotlib import rc\n", "import time\n", "from IPython import display\n", "\n", "%matplotlib inline\n", "\n", "np.random.seed(42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define a GMM from which we generate data\n", "Set up the true GMM from which we will generate data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Choose a GMM with 3 components\n", "\n", "# means\n", "m = np.zeros((3,2))\n", "m[0] = np.array([1.2, 0.4])\n", "m[1] = np.array([-4.4, 1.0])\n", "m[2] = np.array([4.1, -0.3])\n", "\n", "# covariances\n", "S = np.zeros((3,2,2))\n", "S[0] = np.array([[0.8, -0.4], [-0.4, 1.0]])\n", "S[1] = np.array([[1.2, -0.8], [-0.8, 1.0]])\n", "S[2] = np.array([[1.2, 0.6], [0.6, 3.0]])\n", "\n", "# mixture weights\n", "w = np.array([0.3, 0.2, 0.5])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generate some data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "N_split = 200 # number of data points per mixture component\n", "N = N_split*3 # total number of data points\n", "x = []\n", "y = []\n", "for k in range(3):\n", " x_tmp, y_tmp = np.random.multivariate_normal(m[k], S[k], N_split).T \n", " x = np.hstack([x, x_tmp])\n", " y = np.hstack([y, y_tmp])\n", "\n", "data = np.vstack([x, y])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Visualization of the dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X, Y = np.meshgrid(np.linspace(-10,10,100), np.linspace(-10,10,100))\n", "pos = np.dstack((X, Y))\n", "\n", "mvn = multivariate_normal(m[0,:].ravel(), S[0,:,:])\n", "xx = mvn.pdf(pos)\n", "\n", "# plot the dataset\n", "plt.figure()\n", "plt.title(\"Mixture components\")\n", "plt.plot(x, y, 'ko', alpha=0.3)\n", "plt.xlabel(\"$x_1$\")\n", "plt.ylabel(\"$x_2$\")\n", "\n", "# plot the individual components of the GMM\n", "plt.plot(m[:,0], m[:,1], 'or')\n", "\n", "for k in range(3):\n", " mvn = multivariate_normal(m[k,:].ravel(), S[k,:,:])\n", " xx = mvn.pdf(pos)\n", " plt.contour(X, Y, xx, alpha = 1.0, zorder=10)\n", " \n", "# plot the GMM \n", "plt.figure()\n", "plt.title(\"GMM\")\n", "plt.plot(x, y, 'ko', alpha=0.3)\n", "plt.xlabel(\"$x_1$\")\n", "plt.ylabel(\"$x_2$\")\n", "\n", "# build the GMM\n", "gmm = 0\n", "for k in range(3):\n", " mix_comp = multivariate_normal(m[k,:].ravel(), S[k,:,:])\n", " gmm += w[k]*mix_comp.pdf(pos)\n", " \n", "plt.contour(X, Y, gmm, alpha = 1.0, zorder=10);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the GMM via EM\n", "### Initialize the parameters for EM" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "K = 3 # number of clusters\n", "\n", "means = np.zeros((K,2))\n", "covs = np.zeros((K,2,2))\n", "for k in range(K):\n", " means[k] = np.random.normal(size=(2,))\n", " covs[k] = np.eye(2)\n", "\n", "weights = np.ones((K,1))/K\n", "print(\"Initial mean vectors (one per row):\\n\" + str(means))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#EDIT THIS FUNCTION\n", "NLL = [] # log-likelihood of the GMM\n", "gmm_nll = 0\n", "NLL += [gmm_nll] #<-- REPLACE THIS LINE\n", "\n", "plt.figure()\n", "plt.plot(x, y, 'ko', alpha=0.3)\n", "plt.plot(means[:,0], means[:,1], 'oy', markersize=25)\n", "\n", "for k in range(K):\n", " rv = multivariate_normal(means[k,:], covs[k,:,:])\n", " plt.contour(X, Y, rv.pdf(pos), alpha = 1.0, zorder=10)\n", " \n", "plt.xlabel(\"$x_1$\");\n", "plt.ylabel(\"$x_2$\");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we define the responsibilities (which are updated in the E-step), given the model parameters $\\pi_k, \\boldsymbol\\mu_k, \\boldsymbol\\Sigma_k$ as\n", "$$\n", "r_{nk} := \\frac{\\pi_k\\mathcal N(\\boldsymbol\n", " x_n|\\boldsymbol\\mu_k,\\boldsymbol\\Sigma_k)}{\\sum_{j=1}^K\\pi_j\\mathcal N(\\boldsymbol\n", " x_n|\\boldsymbol \\mu_j,\\boldsymbol\\Sigma_j)} \n", " $$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given the responsibilities we just defined, we can update the model parameters in the M-step as follows:\n", "\\begin{align*}\n", "\\boldsymbol\\mu_k^\\text{new} &= \\frac{1}{N_k}\\sum_{n = 1}^Nr_{nk}\\boldsymbol x_n\\,,\\\\\n", " \\boldsymbol\\Sigma_k^\\text{new}&= \\frac{1}{N_k}\\sum_{n=1}^Nr_{nk}(\\boldsymbol x_n-\\boldsymbol\\mu_k)(\\boldsymbol x_n-\\boldsymbol\\mu_k)^\\top\\,,\\\\\n", " \\pi_k^\\text{new} &= \\frac{N_k}{N}\n", "\\end{align*}\n", "where $$\n", "N_k := \\sum_{n=1}^N r_{nk}\n", "$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### EM Algorithm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#EDIT THIS FUNCTION\n", "r = np.zeros((K,N)) # will store the responsibilities\n", "\n", "for em_iter in range(100): \n", " means_old = means.copy()\n", " \n", " # E-step: update responsibilities\n", " r = np.zeros((K,N)) #<-- REPLACE THIS LINE\n", " \n", " # M-step\n", " N_k = np.sum(r, axis=1)\n", "\n", " for k in range(K): \n", " # update the means\n", " means[k] = 0 #<-- REPLACE THIS LINE\n", " \n", " # update the covariances\n", " covs[k] = 0 #<-- REPLACE THIS LINE\n", " \n", " # weights\n", " weights = np.zeros((K,)) #<-- REPLACE THIS LINE\n", " \n", " # log-likelihood\n", " NLL += [10] #<-- REPLACE THIS LINE\n", " \n", " plt.figure() \n", " plt.plot(x, y, 'ko', alpha=0.3)\n", " plt.plot(means[:,0], means[:,1], 'oy', markersize=25)\n", " for k in range(K):\n", " rv = multivariate_normal(means[k,:], covs[k])\n", " plt.contour(X, Y, rv.pdf(pos), alpha = 1.0, zorder=10)\n", " \n", " plt.xlabel(\"$x_1$\")\n", " plt.ylabel(\"$x_2$\")\n", " plt.text(x=3.5, y=8, s=\"EM iteration \"+str(em_iter+1))\n", " \n", " if la.norm(NLL[em_iter+1]-NLL[em_iter]) < 1e-6:\n", " print(\"Converged after iteration \", em_iter+1)\n", " break\n", " \n", "# plot final the mixture model\n", "plt.figure() \n", "gmm = 0\n", "for k in range(3):\n", " mix_comp = multivariate_normal(means[k,:].ravel(), covs[k,:,:])\n", " gmm += weights[k]*mix_comp.pdf(pos)\n", "\n", "plt.plot(x, y, 'ko', alpha=0.3)\n", "plt.contour(X, Y, gmm, alpha = 1.0, zorder=10) \n", "plt.xlim([-8,8]);\n", "plt.ylim([-6,6]);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.figure()\n", "plt.semilogy(np.linspace(1,len(NLL), len(NLL)), NLL)\n", "plt.xlabel(\"EM iteration\");\n", "plt.ylabel(\"Negative log-likelihood\");\n", "\n", "idx = [0, 1, 9, em_iter+1]\n", "\n", "for i in idx:\n", " plt.plot(i+1, NLL[i], 'or')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 2 }