{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Expectation Maximization Algorithm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Expectation maximization (EM, Dempster et al. 1977) uses iterative optimization along with a latent variable model to obtain maximum likelihood estimates for models whose parameters are difficult to estimate directly. The algorithm was motivated by missing data imputation. However, the missing values may be deliberately introduced to the problem, as a conceptual ploy that simplifies the obtaining of a solution.\n", "\n", "It may not be intuitive how introducing latent (missing) elements to a problem will facilitate its solution, but it works essentially by breaking the optimization into two steps:\n", "\n", "1. generating an **expectation** over the missing variable(s) based on current estimates of parameters\n", "2. **maximizing** the log-likelihood from the expectation step, thereby generating updated estimates of parameters\n", "\n", "EM is particularly suited to estimating the parameters of *mixture models*, where we do not know from which component each observation is derived." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In general, suppose we have observed quantities $x = x_1,\\ldots,x_n$ and unobserved (latent) quantities $z = z_1,\\ldots,z_m$ that are derived from some joint model:\n", "\n", "$$y = (x,z) \\sim P(x,z|\\theta)$$\n", "\n", "We are interested in obtaining the MLE for the marginal distribution of $X$:\n", "\n", "$$x \\sim P(x|\\theta)$$\n", "\n", "However, it is difficult to marginalize over $Z$ and maximize. EM gets around this by iteratively improving an initial estimate $\\theta^{(0)}$." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example: Mixture of normals\n", "\n", "Consider a set of observations, each of which has been drawn from one of two populations:\n", "\n", "$$x^{(a)} \\sim N(\\mu_a, \\sigma^2_a)$$\n", "$$x^{(b)} \\sim N(\\mu_b, \\sigma^2_b)$$\n", "\n", "except we only observe the values for $x = [x^{(a)}, x^{(b)}]$, not the labels which identify which population they are derived from." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAC/JJREFUeJzt3W+IZfV9x/H3J64h0SimOA1WnU4CQQh5UO2QNhWkaCymiuZBHygY0lCYPmitaQvBFIr0QcFCCcmDUljU1BKrtGpoSEKqJLFpoLXdXS3+WUNSuxs3muyGUBJLwab99sGelu1Gd+fec5wz9+v7BcPce/fsPV92l/ee+d1z70lVIUlafW+YewBJ0jQMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJvbs5M7OP//82tjY2MldStLK279///eqau102+1o0Dc2Nti3b99O7lKSVl6Sw9vZziUXSWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJamJH3yn6erRx2+eX/r2H7rh2wkm0W/lvRFPxCF2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTZw26EnuTnI0yVMnPPYTSR5J8o3h+1tf2zElSaeznSP0PwOuOemx24AvVdU7gS8N9yVJMzpt0Kvqq8D3T3r4BuCe4fY9wAcmnkuStKBl19DfVlUvAgzff3K6kSRJy3jNXxRNspVkX5J9x44de613J0mvW8sG/btJLgAYvh99tQ2ram9VbVbV5tra2pK7kySdzrJB/yzwoeH2h4C/nmYcSdKytnPa4n3A3wOXJDmS5NeAO4Crk3wDuHq4L0ma0Z7TbVBVN73KL1018SySpBF8p6gkNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITp/20RUmntnHb5+ceQQI8QpekNgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpoYFfQkv53k6SRPJbkvyZumGkyStJilg57kQuC3gM2qejdwBnDjVINJkhYzdsllD/DmJHuAs4AXxo8kSVrG0pegq6pvJ/lj4FvAfwAPV9XDJ2+XZAvYAlhfX192d7Px8mKrY8zf1aE7rp1wEmkeY5Zc3grcALwd+Cng7CQ3n7xdVe2tqs2q2lxbW1t+UknSKY1Zcnkf8K9Vdayq/hN4CPiFacaSJC1qTNC/Bfx8krOSBLgKODjNWJKkRS0d9Kp6DHgAOAA8OTzX3onmkiQtaOkXRQGq6nbg9olmkSSN4DtFJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2Smhj1aYvavbwc22K81KA68Ahdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU2MCnqS85I8kOTZJAeTvHeqwSRJixn7eeifBL5YVb+S5I3AWRPMJElawtJBT3IucAXwqwBV9TLw8jRjSZIWNWbJ5R3AMeBTSR5PcmeSsyeaS5K0oDFLLnuAy4BbquqxJJ8EbgN+/8SNkmwBWwDr6+sjdidpSl6msJ8xR+hHgCNV9dhw/wGOB/7/qaq9VbVZVZtra2sjdidJOpWlg15V3wGeT3LJ8NBVwDOTTCVJWtjYs1xuAe4dznB5Dvjw+JEkScsYFfSqegLYnGgWSdIIvlNUkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU2M/Tx0Sa9DYy5fB17C7rXiEbokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2Smhgd9CRnJHk8yeemGEiStJwpjtBvBQ5O8DySpBFGBT3JRcC1wJ3TjCNJWtbYKxZ9AvgocM6rbZBkC9gCWF9fX3pHY6+QotXg37O0vKWP0JNcBxytqv2n2q6q9lbVZlVtrq2tLbs7SdJpjFlyuRy4Pskh4H7gyiSfnmQqSdLClg56VX2sqi6qqg3gRuDLVXXzZJNJkhbieeiS1MTYF0UBqKpHgUeneC5J0nI8QpekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWpikk9blE7kZeR2zqr+WY+Z+9Ad1044SS8eoUtSEwZdkpow6JLUhEGXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpiaWDnuTiJF9JcjDJ00lunXIwSdJixnwe+o+A362qA0nOAfYneaSqnploNknSApY+Qq+qF6vqwHD7h8BB4MKpBpMkLWaSNfQkG8ClwGNTPJ8kaXGpqnFPkLwF+FvgD6vqoVf49S1gC2B9ff1nDx8+vNR+VvVSW5J2j1W9fF2S/VW1ebrtRh2hJzkTeBC495ViDlBVe6tqs6o219bWxuxOknQKY85yCXAXcLCqPj7dSJKkZYw5Qr8c+CBwZZInhq9fnmguSdKClj5tsaq+BmTCWSRJI/hOUUlqwqBLUhMGXZKaMOiS1IRBl6QmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUhEGXpCYMuiQ1sfTH50rS68mYy2Du1KXvPEKXpCYMuiQ1YdAlqQmDLklNGHRJasKgS1ITBl2SmjDoktSEQZekJgy6JDVh0CWpCYMuSU0YdElqwqBLUhOjgp7kmiRfT/LNJLdNNZQkaXFLBz3JGcCfAO8H3gXclORdUw0mSVrMmCP09wDfrKrnqupl4H7ghmnGkiQtakzQLwSeP+H+keExSdIMxlyCLq/wWP3YRskWsDXcfSnJ17fx3OcD3xsx25ycfeet6tzg7Dsqf/R/N3d09hP2u6yf3s5GY4J+BLj4hPsXAS+cvFFV7QX2LvLESfZV1eaI2Wbj7DtvVecGZ5/LKs9+KmOWXP4JeGeStyd5I3Aj8NlpxpIkLWrpI/Sq+lGS3wT+BjgDuLuqnp5sMknSQsYsuVBVXwC+MNEsJ1poiWaXcfadt6pzg7PPZZVnf1Wp+rHXMSVJK8i3/ktSE7sq6EnuTnI0yVNzz7KIJBcn+UqSg0meTnLr3DNtV5I3JfnHJP88zP4Hc8+0qCRnJHk8yefmnmURSQ4leTLJE0n2zT3PdiU5L8kDSZ4d/s2/d+6ZtiPJJcOf9f9+/SDJR+aea0q7asklyRXAS8CfV9W7555nu5JcAFxQVQeSnAPsBz5QVc/MPNppJQlwdlW9lORM4GvArVX1DzOPtm1JfgfYBM6tquvmnme7khwCNqtqtc7lTu4B/q6q7hzOcDurqv5t7rkWMXx0ybeBn6uqw3PPM5VddYReVV8Fvj/3HIuqqher6sBw+4fAQVbkXbN13EvD3TOHr93zv/xpJLkIuBa4c+5ZXg+SnAtcAdwFUFUvr1rMB1cB/9Ip5rDLgt5Bkg3gUuCxeSfZvmHJ4gngKPBIVa3M7MAngI8C/z33IEso4OEk+4d3VK+CdwDHgE8Ny1x3Jjl77qGWcCNw39xDTM2gTyjJW4AHgY9U1Q/mnme7quq/qupnOP5u3/ckWYnlriTXAUerav/csyzp8qq6jOOfWPobw5LjbrcHuAz406q6FPh3YKU+OntYJroe+Ku5Z5maQZ/IsP78IHBvVT009zzLGH50fhS4ZuZRtuty4PphLfp+4Mokn553pO2rqheG70eBz3D8E0x3uyPAkRN+inuA44FfJe8HDlTVd+ceZGoGfQLDC4t3AQer6uNzz7OIJGtJzhtuvxl4H/DsvFNtT1V9rKouqqoNjv8I/eWqunnmsbYlydnDC+gMSxa/BOz6s7uq6jvA80kuGR66Ctj1L/6f5CYaLrfAyHeKTi3JfcAvAucnOQLcXlV3zTvVtlwOfBB4cliLBvi94Z20u90FwD3Dq/5vAP6yqlbq9L8V9TbgM8ePBdgD/EVVfXHekbbtFuDeYeniOeDDM8+zbUnOAq4Gfn3uWV4Lu+q0RUnS8lxykaQmDLokNWHQJakJgy5JTRh0SWrCoEtSEwZdkpow6JLUxP8ALpuRi13HW8YAAAAASUVORK5CYII=\n", "text/plain": [ "