{ "cells": [ { "cell_type": "markdown", "id": "3cf2db77", "metadata": {}, "source": [ "Title: EM on 1D Mixture Densities\n", "Author: Thomas Breuel\n", "Institution: UniKL" ] }, { "cell_type": "code", "execution_count": 2, "id": "5f103f05", "metadata": { "collapsed": true }, "outputs": [], "source": [ "\n", "import pylab\n", "from pylab import *\n", "from scipy import random\n", "from scipy.stats import distributions" ] }, { "cell_type": "markdown", "id": "ddd01513", "metadata": {}, "source": [ "# Expectation Maximization for Mixture Densities" ] }, { "cell_type": "markdown", "id": "50242a18", "metadata": {}, "source": [ "The EM algorithm is a common algorithm used for estimating mixture\n", "densities.\n", "More generally, it is used in which estimating something requires\n", "some \"hidden\" variables that we can't observe.\n", "The idea is that we start with a guess for the hidden variables,\n", "perform our estimate, and then use the estimate itself to update\n", "the hidden variables." ] }, { "cell_type": "markdown", "id": "1b403f67", "metadata": {}, "source": [ "Let's start by generating a mixture distribution. \n", "\n", "We have two component mixtures, one with a mean of 0 and variance of 1, the other with a mean of 5 and variance of 4. The first mixture component is chosen with probability 0.3, the second with probability 0.7.\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "a7611520", "metadata": { "collapsed": true }, "outputs": [], "source": [ "data1 = random.normal(loc=0.0,scale=1.0,size=3000)\n", "data2 = random.normal(loc=5.0,scale=2.0,size=7000)\n", "data = concatenate((data1,data2))\n", "random.shuffle(data)\n", "n = len(data)" ] }, { "cell_type": "markdown", "id": "5c4f001d", "metadata": {}, "source": [ "The mixture density has the form:\n", "\n", "$$ p(x) = p_0 {\\cal N}(x;\\mu_0,\\sigma^2_0) + p_1 {\\cal N}(x;\\mu_1,\\sigma^2_1) $$" ] }, { "cell_type": "code", "execution_count": 4, "id": "d01eef5c", "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD9CAYAAAC2l2x5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAGYNJREFUeJzt3W9Mlff9//HXceCNDqm4jmPHsTtuQOAgCtOhWeaCUXB2\nldhqWLFTonY3NEvm9FvX9pe2sGSC25pN25k0BjeaNUXvFNiiBLfsNF2bBjdx60pXWAMb/1OLqJQp\nCtfvBvVwkH/nHM451znXeT4SksM513Wut8h58Tmf87nel80wDEMAAMtZYHYBAIDQIOABwKIIeACw\nKAIeACyKgAcAiyLgAcCiZg34mzdvau3atcrJyZHL5dIzzzwjSRoYGFBBQYHS09NVWFiowcFBzz4V\nFRVKS0tTRkaGGhsbQ1s9AGBGtrnWwQ8PD+u+++7TnTt39M1vflO/+MUvVF9frwceeEBHjhzRsWPH\ndPXqVVVWVqqlpUU7d+7UxYsX1d3drU2bNqm1tVULFvBGAQDCbc7kve+++yRJIyMjGh0dVVJSkurr\n61VaWipJKi0tVW1trSSprq5OJSUlio+Pl9PpVGpqqpqamkJYPgBgJnMG/NjYmHJycmS327VhwwZl\nZWWpv79fdrtdkmS329Xf3y9J6unpkcPh8OzrcDjU3d0dotIBALOJm2uDBQsW6PLly7p27Zo2b96s\nP//5z5Met9lsstlsM+4/3WOzbQ8AmJk/3WV8nhy///779Z3vfEd/+9vfZLfb1dfXJ0nq7e1VcnKy\nJCklJUWdnZ2efbq6upSSkjJjkZH09cILL5heAzVZqy5qoqZgf/lr1oC/cuWKZ4XM//73P124cEG5\nubkqKipSdXW1JKm6ulrbtm2TJBUVFammpkYjIyNqb29XW1ub8vLy/C4KADB/s07R9Pb2qrS0VGNj\nYxobG9OuXbu0ceNG5ebmqri4WFVVVXI6nTp79qwkyeVyqbi4WC6XS3FxcTp58iTTMQBgklkDPjs7\nW5cuXZpy/5IlS/THP/5x2n2effZZPfvss8GpLozy8/PNLmEKavJdJNZFTb6hptCZcx18SA5qswU0\nnwQAsczf7OQMJACwKAIeACyKgAcAiyLgAcCiCHgAsCgCHgAsioAHAIsi4AHAogh4ALAoAh4ALIqA\nBwCLIuABwKIIeACwKAIeACyKgAcAiyLgAcCiCHgAsCgCHgAsioAHAIsi4AHAogh4ALAoAh4ALIqA\nBwCLIuABwKIIeACwKAIeACyKgAcAi5o14Ds7O7VhwwZlZWVpxYoVOnHihCSprKxMDodDubm5ys3N\n1fnz5z37VFRUKC0tTRkZGWpsbAxt9UAIJSYukc1mk81mU2LiErPLAfxmMwzDmOnBvr4+9fX1KScn\nR0NDQ1q9erVqa2t19uxZLVq0SIcOHZq0fUtLi3bu3KmLFy+qu7tbmzZtUmtrqxYsmPx3xGazaZbD\nAhHBZrNJuvt7yu8szOdvds46gl+6dKlycnIkSQkJCcrMzFR3d7ckTXuQuro6lZSUKD4+Xk6nU6mp\nqWpqavKnfgBAkPg8B9/R0aHm5matW7dOkvTSSy9p1apV2rdvnwYHByVJPT09cjgcnn0cDofnDwIA\nILzifNloaGhIO3bs0PHjx5WQkKD9+/fr+eeflyQ999xzOnz4sKqqqqbdd/xt7lRlZWWe2/n5+crP\nz/evcgCwOLfbLbfbHfD+s87BS9Lt27f1yCOPaMuWLTp48OCUxzs6OrR161a99957qqyslCQ9/fTT\nkqRvf/vbKi8v19q1aycflDl4RAHm4BFpgjoHbxiG9u3bJ5fLNSnce3t7PbffeOMNZWdnS5KKiopU\nU1OjkZERtbe3q62tTXl5ef7+GwAAQTDrFM3bb7+t3/3ud1q5cqVyc3MlSUePHtXrr7+uy5cvy2az\nafny5XrllVckSS6XS8XFxXK5XIqLi9PJkydnnKIBwikxcYlu3Ljq+X7RoiRdvz5gYkVA6M05RROS\ngzJFgzCbPN0i+TLlwhQNIk1Qp2gAANGLgAcAiyLgAcCiCHgAsCgCHgAsioAHAIsi4AHAogh4xKg4\ner3D8nxqNgZYzx3dPYnpxg3OtoY1MYIHfMKIH9GHgAe8eF+mb7K7I35jUk8bIJLRiwYxYbpeNNP1\nmbm3/8xMt/n9hRnoRQMAkETAA4BlEfCA3/jAFdGBZZLAZ4HtO5ZYIjowgodlzbwi5l4TK2QAK2EV\nDSxr5hUx934/v9v8LiNcWEUDAJBEwEcE76kEPrQDECxM0UQALu4cGkzRwGqYogFMwjsxRBpG8BGA\nEXzwJCYuuadXTPhG8Pw/ItQYwSOmjYc7Sx4BiYAHAMsi4AHAogh4ALAoAh4ALIpmY8C8+NuoDAif\nWUfwnZ2d2rBhg7KysrRixQqdOHFCkjQwMKCCggKlp6ersLBQg4ODnn0qKiqUlpamjIwMNTY2hrZ6\nwHQ0KkPkmnUdfF9fn/r6+pSTk6OhoSGtXr1atbW1+s1vfqMHHnhAR44c0bFjx3T16lVVVlaqpaVF\nO3fu1MWLF9Xd3a1NmzaptbVVCxZM/jvCOvjJWD8dPL5dcm+2xzjDFZErqOvgly5dqpycHElSQkKC\nMjMz1d3drfr6epWWlkqSSktLVVtbK0mqq6tTSUmJ4uPj5XQ6lZqaqqampkD/LQCAefB5Dr6jo0PN\nzc1au3at+vv7ZbfbJUl2u139/f2SpJ6eHq1bt86zj8PhUHd397TPV1ZW5rmdn5+v/Pz8AMoHAOty\nu91yu90B7+9TwA8NDWn79u06fvy4Fi1aNOmxuS6oMNNj3gEfK7xPo1+0KEnXrw9Ms9XEh3YzbwMg\nFtw7+C0vL/dr/zmXSd6+fVvbt2/Xrl27tG3bNknjo/a+vj5JUm9vr5KTkyVJKSkp6uzs9Ozb1dWl\nlJQUvwqyMu/T6Cf3S/F2Z9ptaGQVbbhuK8w3a8AbhqF9+/bJ5XLp4MGDnvuLiopUXV0tSaqurvYE\nf1FRkWpqajQyMqL29na1tbUpLy8vhOXHDt/+OCByTP+HGginWVfR/OUvf9G3vvUtrVy50jNtUFFR\noby8PBUXF+u///2vnE6nzp49q8WLF0uSjh49qtOnTysuLk7Hjx/X5s2bpx40RlfRzLRaZraVH3Qq\nnN5M012RtIqG/y8Em7/ZSbvgMApFwPs2rx+9fA1yX36WBDyiHQEfwUIR8FYf2fv77ybgYWX0gwcA\nSKIXjYnoYQIgtBjBm4YeJgBCixE8LIB3Q8B0GMHDAng3BEyHEXxEY2QKIHCM4CPaTCPTidPgEX28\n207YbAtpaYCQYQQfle4GvzS+3hrRZKLthOS9Xv7GDf4vEVyM4AHAogh4IGLQgRLBxRQNEDEmpt6Y\nrkEwMIIHAIsi4AHAogh4ALAo5uARRTjxC/AHI3hEEVoSAP4g4BERuKg4EHxM0SAieJ/dyRJBIDgY\nwQOARRHwAGBRTNEAIcfqH5iDEXwITW4LG94XeHR/aGm1dsjeq39YAYTwYQQfQpPbwkqhb+1770gx\nWj+0pB0yEAyM4C2FdeIAJhDwAGBRBDwAWBQBDwAWNWfA7927V3a7XdnZ2Z77ysrK5HA4lJubq9zc\nXJ0/f97zWEVFhdLS0pSRkaHGxsbQVA0AmNOcAb9nzx41NDRMus9ms+nQoUNqbm5Wc3OztmzZIklq\naWnRmTNn1NLSooaGBh04cEBjY2OhqRx+iJu0XDP6lk0CCMScAb9+/XolJSVNud8wpq7UqKurU0lJ\nieLj4+V0OpWamqqmpqbgVIp5mLwOe3z5JgCrC3gd/EsvvaRXX31Va9as0YsvvqjFixerp6dH69at\n82zjcDjU3d097f5lZWWe2/n5+crPzw+0FACwJLfbLbfbHfD+AQX8/v379fzzz0uSnnvuOR0+fFhV\nVVXTbjvT2YjeAQ8AmOrewW95eblf+we0iiY5Odkzn/vkk096pmFSUlLU2dnp2a6rq0spKSmBHAIA\nME8BBXxvb6/n9htvvOFZYVNUVKSamhqNjIyovb1dbW1tysvLC06lAAC/zDlFU1JSojfffFNXrlzR\nsmXLVF5eLrfbrcuXL8tms2n58uV65ZVXJEkul0vFxcVyuVyKi4vTyZMnLdQwCgCii82YbjlMqA9q\ns027Csdqxv+43dtszDD59vj3kfbzn/yzCs/PwMyf/9zbxWt89ZO0aFGSrl8fEOBvdnImK0zj3dIY\n95pY2sqyVgSKgIdpJtopR9a7CcAqCPiYFMdZrUAM4IIfMWnighrRdzEQAL5iBA8AFkXAI+ii+3qw\ngHUwRYOg874WLVNAgHkYwSPEaFUMmIURPEJs4gNdiRE9EE6M4IEowucb8AcjeCCK8PkG/EHAI8zi\naE0AhAkBjzDznpMn6H3DH0UEhjl4IOJ5X1MX8B0Bj4DxgZ/Z6CmE2TFFg4DxgZ/Z6CmE2TGCDzJ6\nnAOIFAR8kNHjHECkIODhwZw6YC0EfBBYZVrG+90Hl4kDoh8BHwRMywCIRAQ8AFgUAQ8AFkXAA4BF\nEfAAYFEEfMyLm2EFEKfBRxf+vzAVrQpi3kzdHTkNPrrw/4WpCHgECS1tgUgz5xTN3r17ZbfblZ2d\n7blvYGBABQUFSk9PV2FhoQYHBz2PVVRUKC0tTRkZGWpsbAxN1QgzX97+09IWiDRzBvyePXvU0NAw\n6b7KykoVFBSotbVVGzduVGVlpSSppaVFZ86cUUtLixoaGnTgwAGNjY2FpnKE0UR4c4YrED3mDPj1\n69crKSlp0n319fUqLS2VJJWWlqq2tlaSVFdXp5KSEsXHx8vpdCo1NVVNTU0hKBsAMJeA5uD7+/tl\nt9slSXa7Xf39/ZKknp4erVu3zrOdw+FQd3f3tM9RVlbmuZ2fn6/8/PxASgEAy3K73XK73QHvP+8P\nWedqsjXTY94BDwCY6t7Bb3l5uV/7B7QO3m63q6+vT5LU29ur5ORkSVJKSoo6Ozs923V1dSklJSWQ\nQwAA5imggC8qKlJ1dbUkqbq6Wtu2bfPcX1NTo5GREbW3t6utrU15eXnBqxYA4LM5p2hKSkr05ptv\n6sqVK1q2bJl+8pOf6Omnn1ZxcbGqqqrkdDp19uxZSZLL5VJxcbFcLpfi4uJ08uRJ1kZbDuvdgWhh\nMwwj7AuXbTabTDhsyIwHnvfZoNPdnu2xcN6O1WNHSh3hOZ6VXl+Y4G920osGACyKgAcAiyLgAcCi\nCHgAsCgCHgAsioAHAIsi4AHAogh4ALAoAh6wsMTEJVyrNYYR8AHyfuEAkWr8Ai1crCVWEfB+8A51\n7xcOEB0mLr3IiD42cNFtP0yEujTe+wOIJncvvTjuxg1+h62OETwAWBQBDwAWRcADgEUxBw9YDhdl\nwThG8IDl3P0wlRVesY6ABwCLIuABwKIIeACwKAIeiFlxnNVqcayiAWLWxJmtnNVqTYzgAcCiCPg5\n0DUSQLQi4OdA10gA0YqABwCLIuABwKIIeACwqHktk3Q6nUpMTNTnPvc5xcfHq6mpSQMDA/rud7+r\n//znP3I6nTp79qwWL14crHoBAD6a1wjeZrPJ7XarublZTU1NkqTKykoVFBSotbVVGzduVGVlZVAK\nBQD4Z95TNIYxeXVJfX29SktLJUmlpaWqra2d7yEAAAGY9wh+06ZNWrNmjU6dOiVJ6u/vl91ulyTZ\n7Xb19/fPv0oAgN/mNQf/9ttv68EHH9THH3+sgoICZWRkTHp8thOEysrKPLfz8/OVn58/n1KCKjFx\nyWfr34FYMXGRkEWLknT9+oDJ9UCS3G633G53wPvbjHvnWAJUXl6uhIQEnTp1Sm63W0uXLlVvb682\nbNigf/3rX5MParNNmdqJJOO/6HfrC9btYD5XtNZh5rEjpQ4zj+1rHfEa71ND2Ecaf7Mz4Cma4eFh\n3bhxQ5L06aefqrGxUdnZ2SoqKlJ1dbUkqbq6Wtu2bQv0EABMMXFFKN7JRreAp2j6+/v16KOPSpLu\n3LmjJ554QoWFhVqzZo2Ki4tVVVXlWSYJAAi/oE3R+HVQpmhMvB2rx46UOsw8dmB1RPJrNdaEbYoG\nABDZCHgAfvNuo83VoCIXV3QC4LeJNtpcDSqSMYIHMAuu2xrNYjrgeZsJzIUlk9EspqdoeJsJwMpi\negQPAFZGwAOARRHwAGBRMT0HD8AfcTN2h0VkYgTvETdre2MAEytqJmMpZaRiBO9x95dXGu/FAcA3\nE68dVqNFFkbwAGBRBDwAWBQBDwAWFXMB792eAACsLOYCfqI9ARcxAGBtMRfwAMKDZn7mY5kkgCC6\n92Qolk+aiRE8gCCa6WQomIGABwCLIuABwKIIeACwKMsHvPcn+ax9BxBLLL+KxvuyfOMIeQCxwfIj\neACRIG7SO2nWxYeH5UfwACKBdztu1sWHiyVH8PSbASIdFwkJh5AEfENDgzIyMpSWlqZjx46F4hBT\neIc6/WZikdvsAqbhNruAabjNLuAzEydE3bhx3Wv6ZmHAwR/M1ghut3te+0eKoAf86OiofvCDH6ih\noUEtLS16/fXX9cEHH/j1HIODg7p8+bLn69NPP51zH0I91rnNLmAabrMLmIbb7AKmMaqJ1+5tTQT/\nVb+exTsD/N33XlYJ+KDPwTc1NSk1NVVOp1OS9Pjjj6uurk6ZmZk+P8f//d//02uv1Wnhwgd061av\npE9169bdkI/X+C/BvbcBWIt3X5uJ1/qiRUm6fn1A0viofb5hbmVBH8F3d3dr2bJlnu8dDoe6u7v9\neo47d0YlLZHNtkxS4mfhPvUv/OTbAKzFu6+N98j+hg/TsXFBmfaRfJv6mWkbsztqBn0E7+sHm75s\nd/Pme957BOl2MJ8rmo4dKXWE6nnLP/syuw5fagrHsX2pycw6grG/r9tMvMu/ceOqzxlVXl4+5T5f\n9p9pG3+OHSxBD/iUlBR1dnZ6vu/s7JTD4Zi0jWEw4gaAUAv6FM2aNWvU1tamjo4OjYyM6MyZMyoq\nKgr2YQAAcwj6CD4uLk4vv/yyNm/erNHRUe3bt8+vD1gBAMERknXwW7Zs0Ycffqh///vfeuaZZ2bc\n7sUXX9SCBQs0MDAQijL89tRTTykzM1OrVq3SY489pmvXrplWixnnEsyms7NTGzZsUFZWllasWKET\nJ06YXZLH6OiocnNztXXrVrNLkTS+zHfHjh3KzMyUy+XSu+++a3ZJkqSKigplZWUpOztbO3fu1K1b\nt8Jew969e2W325Wdne25b2BgQAUFBUpPT1dhYaEGBwdNr8nsLJiuprv8yU3TzmTt7OzUhQsX9OUv\nf9msEqYoLCzU+++/r7///e9KT09XRUWFKXUE41yCYIuPj9cvf/lLvf/++3r33Xf161//2vSa7jp+\n/LhcLlfEnLn8wx/+UA8//LA++OAD/eMf/4iId7AdHR06deqULl26pPfee0+jo6OqqakJex179uxR\nQ0PDpPsqKytVUFCg1tZWbdy4UZWVlabXZHYWTFeT5H9umhbwhw4d0s9+9jOzDj+tgoICLVgw/iNZ\nu3aturq6TKnD+1yC+Ph4z7kEZlq6dKlycnIkSQkJCcrMzFRPT4+pNUlSV1eXzp07pyeffDIiPry/\ndu2a3nrrLe3du1fS+JTl/fffb3JVUmJiouLj4zU8PKw7d+5oeHhYKSkpYa9j/fr1SkpKmnRffX29\nSktLJUmlpaWqra01vSazs2C6miT/c9OUgK+rq5PD4dDKlSvNOLxPTp8+rYcfftiUYwfjXIJQ6ujo\nUHNzs9auXWt2KfrRj36kn//8554Xo9na29v1xS9+UXv27NHXvvY1ff/739fw8LDZZWnJkiU6fPiw\nHnroIX3pS1/S4sWLtWnTJrPLkiT19/fLbrdLkux2u/r7+02uaDIzs8BbILkZsldFQUGBsrOzp3zV\n19eroqJi0hrTcI68Zqrr97//vWebn/70p1q4cKF27twZtrq8RcpUw3SGhoa0Y8cOHT9+XAkJCabW\n8oc//EHJycnKzc2NiNG7JN25c0eXLl3SgQMHdOnSJX3+858P+5TDdD766CP96le/UkdHh3p6ejQ0\nNKTXXnvN7LKmiLQmgWZnwV3Dw8M6evSo37kZsnbBFy5cmPb+f/7zn2pvb9eqVaskjb/FXr16tZqa\nmpScnByqcuas667f/va3OnfunP70pz+FvJaZ+HIugRlu376t7du363vf+562bdtmdjl65513VF9f\nr3PnzunmzZu6fv26du/erVdffdW0mhwOhxwOh77+9a9Lknbs2BERAf/Xv/5V3/jGN/SFL3xBkvTY\nY4/pnXfe0RNPPGFyZeOj9r6+Pi1dulS9vb1hyQFfREIW3PXRRx+po6PD/9w0TOZ0Oo1PPvnE7DIM\nwzCM8+fPGy6Xy/j4449NreP27dvGV77yFaO9vd24deuWsWrVKqOlpcXUmsbGxoxdu3YZBw8eNLWO\nmbjdbuORRx4xuwzDMAxj/fr1xocffmgYhmG88MILxpEjR0yuyDAuX75sZGVlGcPDw8bY2Jixe/du\n4+WXXzallvb2dmPFihWe75966imjsrLSMAzDqKioMH784x+bXlMkZMG9NXnzNTdND/jly5dHTMCn\npqYaDz30kJGTk2Pk5OQY+/fvN62Wc+fOGenp6cZXv/pV4+jRo6bVcddbb71l2Gw2Y9WqVZ6fz/nz\n580uy8Ptdhtbt241uwzDMMbDdM2aNcbKlSuNRx991BgcHDS7JMMwDOPYsWOGy+UyVqxYYezevdsY\nGRkJew2PP/648eCDDxrx8fGGw+EwTp8+bXzyySfGxo0bjbS0NKOgoMC4evWqqTVVVVWZngV3a1q4\ncKHn5+TN19y0GUaETF4CAIIqMpYeAACCjoAHAIsi4AHAogh4ALAoAh4ALIqABwCL+v+COyIQD+45\ntQAAAABJRU5ErkJggg==\n" }, "metadata": {}, "output_type": "display_data" } ], "source": [ "_=hist(data,bins=100)" ] }, { "cell_type": "markdown", "id": "021b5e81", "metadata": {}, "source": [ "We need the formula for the normal density.\n", "\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "ee7915fc", "metadata": { "collapsed": true }, "outputs": [], "source": [ "def normal(x,params):\n", " mu = params[0]\n", " var = params[1]\n", " return exp(-(x-mu)**2/(2*var))/sqrt(2*pi*var)" ] }, { "cell_type": "markdown", "id": "a652418b", "metadata": {}, "source": [ "Let's pretend we already know the solution parameters $\\theta$ approximately.\n", "These parameters are the means and variances of each mixture component.\n", "We represent them as a matrix $\\theta = {{\\mu_0 \\sigma^2_0}\\choose{\\mu_1 \\sigma^2_1}}$.\n", "Let's pretend we also know the probabilities of the mixture components $a = (p_0,p_1)$.\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "7c719248", "metadata": { "collapsed": true }, "outputs": [], "source": [ "thetas = array([[0.0,1.0],[1.0,1.0]])\n", "mixes = array([0.5,0.5])" ] }, { "cell_type": "markdown", "id": "781ba24f", "metadata": {}, "source": [ "Assuming these parameters are correct, we can compute the probabilities that each of our data samples comes from mixture component 0 or mixture component 1. This is a simple application of Bayes formula:\n", "\n", "$$p(c|x) = \\frac{p(x|c) p(c)}{p(x)}$$\n", "\n", "Of course, we don't actually know the precise distributions, so we just estimate these:\n", "\n", "$$ \\hat{y}_{i,j} = p(c=i|x_j) $$\n", "\n", "This is called the expectation step of the EM algorithm. \n", "\n", "The reason is that the true membership variables $y_{i,j}$ are unknown to us and unobservable, and we compute their expectation based on the current parameter estimates." ] }, { "cell_type": "code", "execution_count": 7, "id": "7019392d", "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([[ 0.01558392, 0.00131541, 0.04342912, 0.00378512, 0.07296195],\n", " [ 0.98441608, 0.99868459, 0.95657088, 0.99621488, 0.92703805]])" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ys = zeros((2,n))\n", "for i in range(2):\n", " for j in range(n):\n", " ys[i,j] = mixes[i]*normal(data[j],thetas[i])\n", "ys /= sum(ys,axis=0)[newaxis,:]\n", "ys[:,:5]" ] }, { "cell_type": "markdown", "id": "e430b065", "metadata": {}, "source": [ "Based on this, we can compute new estimates of the prior probability of each mixture component. \n", "\n", "If we knew the precise densities, this would be:\n", "\n", "$$ p(c) = \\int p(c|x) p(x) dx $$\n", "\n", "However, since we don't, we just estimate is as:\n", "\n", "$$ \\hat{p}_i = \\frac{1}{n} \\sum p(c=i|x_j) $$" ] }, { "cell_type": "code", "execution_count": 8, "id": "7d0a6791", "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "array([ 0.21702, 0.78298])" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mixes = sum(ys,axis=1)/n\n", "mixes" ] }, { "cell_type": "markdown", "id": "d947263a", "metadata": {}, "source": [ "Now we compute estimates of the means and covariances.\n", "Here, we use an apportionment variable $y_{i,j}$ to weight the mixture components.\n", "Generally, $y_{i,j} = \\hat{p}(c=i|x_j)$, but other choices are possible.\n", "\n", "$$\\hat{\\mu}_i = \\frac{\\sum_j \\hat{y}_{i,j} x_j}{\\sum_j \\hat{y}_{i,j}}$$\n", "\n", "$$\\hat{\\sigma}^2_i = \\frac{\\sum_j \\hat{y}_{i,j} x_j^2}{\\sum_j \\hat{y}_{i,j}} - \\mu^2_i$$\n", "\n", "This step is called the maximization step of the EM algorithm. \n", "\n", "It may not look like much of a \"maximization\", but what we are doing is computing maximum likelihood estimates of $\\mu_i$ and $\\sigma^2_i$. We just happen to have a closed-form expression for the maximization step." ] }, { "cell_type": "code", "execution_count": 9, "id": "de640e0c", "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-0.0068905665568 4.42414254987\n" ] } ], "source": [ "mu0 = sum(data*ys[0])/sum(ys[0])\n", "mu1 = sum(data*ys[1])/sum(ys[1])\n", "print mu0,mu1" ] }, { "cell_type": "markdown", "id": "e14a3819", "metadata": {}, "source": [ "As we can see, the estimates for the mean have already improved significantly (from 0 and 1 to about 0 and 4)." ] }, { "cell_type": "code", "execution_count": 10, "id": "8e11f081", "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.89860516759 5.89413266158\n" ] } ], "source": [ "var0 = sum(data**2 * ys[0])/sum(ys[0]) - mu0**2\n", "var1 = sum(data**2 * ys[1])/sum(ys[1]) - mu1**2\n", "print var0,var1" ] }, { "cell_type": "markdown", "id": "0a02eacf", "metadata": {}, "source": [ "The variance estimates have also improved significantly." ] }, { "cell_type": "code", "execution_count": 11, "id": "3c7f251d", "metadata": { "collapsed": true }, "outputs": [], "source": [ "thetas = array([[mu0,var0],[mu1,var1]])" ] }, { "cell_type": "markdown", "id": "f04982f1", "metadata": {}, "source": [ "Let's just wrap up the entire procedure and iterate it multiple times." ] }, { "cell_type": "code", "execution_count": 12, "id": "b36a6748", "metadata": { "collapsed": true }, "outputs": [], "source": [ "def expectation_maximization(data):\n", " thetas = array([[0.0,1.0],[1.0,1.0]])\n", " mixes = array([0.5,0.5])\n", " for iter in range(100):\n", " ys = zeros((2,n))\n", " for i in range(2):\n", " for j in range(n):\n", " ys[i,j] = mixes[i]*normal(data[j],thetas[i])\n", " ys /= sum(ys,axis=0)[newaxis,:]\n", " mixes = sum(ys,axis=1)/n\n", " mu0 = sum(data*ys[0])/sum(ys[0])\n", " mu1 = sum(data*ys[1])/sum(ys[1])\n", " var0 = sum(data**2 * ys[0])/sum(ys[0]) - mu0**2\n", " var1 = sum(data**2 * ys[1])/sum(ys[1]) - mu1**2\n", " thetas = array([[mu0,var0],[mu1,var1]])\n", " if iter%10==0:\n", " print \"===\",iter,\"===\"\n", " print mixes\n", " print thetas\n", " return mixes,thetas" ] }, { "cell_type": "code", "execution_count": 18, "id": "a22602e7", "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "=== 0 ===\n", "[ 0.21702 0.78298]\n", "[[-0.00689057 1.89860517]\n", " [ 4.42414255 5.89413266]]\n", "=== 10 ===\n", "[ 0.28830675 0.71169325]\n", "[[-0.09149176 0.98637995]\n", " [ 4.90224883 4.16203697]]\n", "=== 20 ===\n", "[ 0.30129653 0.69870347]\n", "[[-0.03876513 1.05305054]\n", " [ 4.97235189 3.94967901]]\n", "=== 30 ===\n", "[ 0.30435459 0.69564541]\n", "[[-0.02544894 1.07067247]\n", " [ 4.98855477 3.90236296]]\n", "=== 40 ===\n", "[ 0.30507611 0.69492389]\n", "[[-0.02225813 1.07494779]\n", " [ 4.99235993 3.89135167]]\n", "=== 50 ===\n", "[ 0.30524673 0.69475327]\n", "[[-0.02150089 1.07596548]\n", " [ 4.99325871 3.88875649]]\n", "=== 60 ===\n", "[ 0.3052871 0.6947129]\n", "[[-0.02132156 1.07620666]\n", " [ 4.99347131 3.88814292]]\n", "=== 70 ===\n", "[ 0.30529665 0.69470335]\n", "[[-0.02127911 1.07626376]\n", " [ 4.99352162 3.88799774]]\n", "=== 80 ===\n", "[ 0.30529892 0.69470108]\n", "[[-0.02126907 1.07627727]\n", " [ 4.99353353 3.88796339]]\n", "=== 90 ===\n", "[ 0.30529945 0.69470055]\n", "[[-0.02126669 1.07628047]\n", " [ 4.99353635 3.88795526]]\n" ] } ], "source": [ "# running the EM steps multiple times\n", "(p0,p1),((m0,s0),(m1,s1)) = expectation_maximization(data)" ] }, { "cell_type": "code", "execution_count": 17, "id": "91981e60", "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(0.30529957105536304, 0.6947004289446379)\n", "(-0.021266154871759004, 1.0762811900455098)\n", "(4.9935369813528254, 3.8879534250709256)\n" ] } ], "source": [ "# final estimates\n", "print (p0,p1)\n", "print (m0,s0)\n", "print (m1,s1)\n", "\n" ] }, { "cell_type": "markdown", "id": "5e2f0d33", "metadata": {}, "source": [ "Compare these to the original values:\n", "\n", "probabilities: 0.3 0.7\n", "\n", "parameters 1: $\\mu=0$, $\\sigma^2=1$\n", "\n", "parameters 2: $\\mu=5$, $\\sigma^2=4$\n", "\n", "So the estimates are pretty close to the true values.\n", "Of course, we cannot get exact values because the sample itself was\n", "random and finite." ] }, { "cell_type": "code", "execution_count": 14, "id": "cdbf3c68", "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": {}, "nbformat": 4, "nbformat_minor": 5 }