{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# EM as a Message Passing Algorithm\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Preliminaries\n", "\n", "- Goals\n", " - Describe Expectation-Maximization (EM) as a message passing algorithm on a Forney-style factor graph\n", "- Materials\n", " - Madatory\n", " - These lecture notes\n", " - Optional\n", " - [Dauwels et al., 2009](./files/Dauwels-2009-Expectation-Maximization-as-Message-Passing.pdf), pp. 1-5 (sections I,II and III)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A Problem for the Multiplier Node\n", "\n", "- Consider the multiplier factor $f(x,y,\\theta) = \\delta(y-\\theta x)$ with incoming Gaussian messages $\\overrightarrow{\\mu}_X(x) = \\mathcal{N}(x|m_x,v_x)$ and $\\overleftarrow{\\mu}_Y(y) = \\mathcal{N}(y|m_y,v_y)$. For simplicity's sake, we assume all variables are scalar. \n", "\n", "\n", "\n", "- In a system identification setting, we are interested in computing the outgoing message $\\overleftarrow{\\mu}_\\Theta(\\theta)$. \n", "\n", "- Let's compute the sum-product message:\n", "\n", "$$\\begin{align*}\n", "\\overleftarrow{\\mu}_\\Theta(\\theta) &= \\int \\overrightarrow{\\mu}_X(x) \\, \\overleftarrow{\\mu}_Y(y) \\, f(x,y,\\theta) \\, \\mathrm{d}x \\mathrm{d}y \\\\\n", " &= \\int \\mathcal{N}(x\\,|\\,m_x,v_x) \\, \\mathcal{N}(y\\,|\\,m_y,v_y) \\, \\delta(y-\\theta x)\\, \\, \\mathrm{d}x \\mathrm{d}y \\\\\n", " &= \\int \\mathcal{N}(x\\,|\\,m_x,v_x) \\,\\mathcal{N}(\\theta x\\,|\\,m_y,v_y) \\, \\mathrm{d}x \\\\\n", " &= \\int \\mathcal{N}(x\\,|\\,m_x,v_x) \\,\\mathcal{N}\\left(x \\,\\bigg|\\, \\frac{m_y}{\\theta},\\frac{v_y}{\\theta^2}\\right) \\, \\mathrm{d}x \\\\\n", " &= \\mathcal{N}\\left(\\frac{m_y}{\\theta} \\,\\bigg|\\, m_x, v_x + \\frac{v_y}{\\theta^2}\\right) \\cdot \\int \\mathcal{N}(x\\,|\\,m_*,v_*)\\, \\mathrm{d}x \\tag{SRG-6} \\\\\n", " &= \\mathcal{N}\\left(\\frac{m_y}{\\theta} \\,\\bigg|\\, m_x, v_x + \\frac{v_y}{\\theta^2}\\right)\n", "\\end{align*}$$\n", "\n", "- This is **not** a Gaussian message for $\\Theta$! Passing this message into the graph leads to very serious problems when trying to compute sum-product messages for other factors in the graph.\n", " - (We have seen before in the lesson on [Working with Gaussians](http://nbviewer.ipython.org/github/bertdv/AIP-5SSB0/blob/master/lessons/04_working_with_Gaussians/Working-with-Gaussians.ipynb) that multiplication of two Gaussian-distributed variables does _not_ produce a Gaussian distributed variable.) \n", "\n", "- The same problem occurs in a forward message passing schedule when we try to compute a message for $Y$ from incoming Gaussian messages for both $X$ and $\\Theta$. \n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Limitations of Sum-Product Messages\n", "\n", "- The foregoing example shows that the sum-product (SP) message update rule will sometimes not do the job. For example:\n", " - On large-dimensional **discrete** domains, the SP update rule maybe computationally intractable.\n", " - On **continuous** domains, the SP update rule may not have a closed-form solution or the rule may lead to a function that is incompatible with Gaussian message passing. \n", "\n", "- There are various ways to cope with 'intractable' SP update rules. In this lesson, we discuss how the EM-algorithm can be written as a message passing algorithm on factor graphs. Then, we will solve the 'multiplier node problem' with EM messages (rather than with SP messages)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### EM as Message Passing\n", "\n", " \n", "- Consider first a general setting with likelihood function $f(x,\\theta)$, hidden variables $x$ and tuning parameters $\\theta$. Assume that we are interested in the maximum likelihood estimate \n", "\n", "$$\\begin{align*}\n", "\\hat{\\theta} &= \\arg\\max_\\theta \\int f(x,\\theta) \\mathrm{d}x\\,.\n", "\\end{align*}$$\n", "\n", "- If $\\int f(x,\\theta) \\mathrm{d}x$ is intractible, we can try to apply the EM-algorithm to estimate $\\hat{\\theta}$, which leads to the following iterations (_cf._ [lesson on the EM algorithm](http://nbviewer.ipython.org/github/bertdv/AIP-5SSB0/blob/master/lessons/10_the_EM_algorithm/The-General-EM-Algorithm.ipynb)):\n", "\n", "$$\n", "\\hat{\\theta}^{(k+1)} = \\underbrace{\\arg\\max_\\theta}_{\\text{M-step}} \\left( \\underbrace{\\int_x f(x,\\hat{\\theta}^{(k)})\\,\\log f(x,\\theta)\\,\\mathrm{d}x}_{\\text{E-step}} \\right)\n", "$$\n", "\n", "- It turns out that _for factorized functions_ $f(x,\\theta)$, the EM-algorihm can be executed as a message passing algorithm on the factor graph. \n", "\n", "- As an simple example, we consider the factorization \n", "\n", "$$\n", "f(x,\\theta) = f_a(\\theta)f_b(x,\\theta)\n", "$$\n", "\n", "\n", "\n", "\n", "- Applying the EM-algorithm to this graph leads to the following forward and backward messages over the $\\theta$ edge\n", "$$\\begin{align*}\n", "\\textbf{E-step}&: \\quad \\eta(\\theta) = \\int p_b(x|\\hat{\\theta}^{(k)}) \\log f_b(x,\\theta) \\,\\mathrm{d}x \\\\\n", "\\textbf{M-step}&: \\quad \\hat{\\theta}^{(k+1)} = \\arg\\max_\\theta \\left( f_a(\\theta)\\, e^{\\eta(\\theta)}\\right) \n", "\\end{align*}$$\n", "where $p_b(x|\\hat{\\theta}^{(k)}) \\triangleq \\frac{f_b(x,\\hat{\\theta}^{(k)})}{\\int f_b(x^\\prime,\\hat{\\theta}^{(k)}) \\,\\mathrm{d}x^\\prime}$. \n", "Proof:\n", "
\n", "$$\\begin{align*}\n", "\\hat{\\theta}^{(k+1)} &= \\arg\\max_\\theta \\, \\int_x f(x,\\hat{\\theta}^{(k)}) \\,\\log f(x,\\theta)\\,\\mathrm{d}x \\\\\n", " &= \\arg\\max_\\theta \\, \\int_x f_a(\\theta)f_b(x,\\hat{\\theta}^{(k)}) \\,\\log \\left( f_a(\\theta)f_b(x,\\theta) \\right) \\,\\mathrm{d}x \\\\\n", " &= \\arg\\max_\\theta \\, \\int_x f_b(x,\\hat{\\theta}^{(k)}) \\cdot \\left( \\log f_a(\\theta) + \\log f_b(x,\\theta) \\right) \\,\\mathrm{d}x \\\\\n", " &= \\arg\\max_\\theta \\left( \\log f_a(\\theta) + \\frac{\\int f_b(x,\\hat{\\theta}^{(k)}) \\log f_b(x,\\theta) \\,\\mathrm{d}x }{\\int f_b(x^\\prime,\\hat{\\theta}^{(k)}) \\,\\mathrm{d}x^\\prime} \\right) \\\\\n", " &= \\arg\\max_\\theta \\left( \\log f_a(\\theta) + \\underbrace{\\int p_b(x|\\hat{\\theta}^{(k)}) \\log f_b(x,\\theta) \\,\\mathrm{d}x}_{\\eta(\\theta)} \\right) \\\\\n", " &= \\underbrace{\\arg\\max_\\theta}_{\\text{M-step}} \\left( f_a(\\theta)\\,\\underbrace{e^{\\eta(\\theta)}}_{\\text{E-step}} \\right) \n", "\\end{align*}$$\n", "
\n", "\n", "- The messages represent the 'E' and 'M' steps, respectively:\n", "\n", "\n", "\n", "- The quantity $\\eta(\\theta)$ (a.k.a. the **E-log message**) may be interpreted as a log-domain summary of $f_b$. The message $e^{\\eta(\\theta)}$ is the corresponding 'probability domain' message that is consistent with the semantics of messages as summaries of factors. In a software implementation, you can use either domain, as long as a consistent method is chosen.\n", "\n", "- Note that the denominator $\\int f_b(x^\\prime,\\hat{\\theta}^{(k)}) \\,\\mathrm{d}x^\\prime$ in $p_b$ is just a scaling factor that can usually be ignored, leading to a simpler E-log message $$\\eta(\\theta) = \\int f_b(x,\\hat{\\theta}^{(k)}) \\log f_b(x,\\theta) \\,\\mathrm{d}x \\,.$$\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### EM vs SP and MP Message Passing\n", "\n", "- Consifer again the likelihood model $f(x,\\theta)$ with $x$ a set of hidden variables. We are interested in the ML estimate\n", "\n", "$$\n", "\\hat{\\theta} = \\arg\\max_\\theta \\int f(x,\\theta) \\mathrm{d}x\\,.\n", "$$\n", "\n", "- Recall that in a 'regular' (_not_ message passing) setting, the EM-algorithm is particularly useful when the _expectation_ (E-step)\n", "$$\n", "\\eta(\\theta) = \\int_x f(x,\\hat{\\theta}^{(k)})\\,\\log f(x,\\theta)\\,\\mathrm{d}x\n", "$$\n", "leads to easier expressions than the _marginalization_ (which is what we really want)\n", "$$\n", "\\bar f(\\theta) = \\int f(x,\\theta) \\mathrm{d}x .\n", "$$\n", "\n", "- Similarly, in a message passing framework with connected nodes $f_a$ and $f_b$, EM messages are particularly useful when the _expectation_ (represented by the **E-log message**)\n", "$$\n", "\\eta(\\theta) = \\int f_b(x|\\hat{\\theta}^{(k)}) \\log f_b(x,\\theta) \\,\\mathrm{d}x\n", "$$\n", "leads to easier expressions than the _marginalization_ (represented by the **sum-product message**, which is also what we really want)\n", "$$\n", "\\mu(\\theta) = \\int f_b(x,\\theta) \\mathrm{d}x .\n", "$$\n", "\n", "- Just as for the sum-product (SP) and max-product (MP) messages, we can work out the outgoing E-log message on the $Y$ edge for a _general_ node $f(x_1,\\ldots,x_M,y)$ with given message inputs $\\overrightarrow{\\mu}_{X_m}(x_m)$ (see also [Dauwels et al. (2009)](./files/Dauwels-2009-Expectation-Maximization-as-Message-Passing.pdf), Table-1, pg.4):\n", "\n", "\n", "$$\\begin{align*}\n", "\\textbf{SP}:&\\;\\;\\overrightarrow{\\mu}(y) = \\int \\overrightarrow{\\mu}_{X_1}(x_1) \\cdots \\overrightarrow{\\mu}_{X_M}(x_M)\\, f(x_1,\\ldots,x_M,y) \\, \\mathrm{d}x_1 \\ldots \\mathrm{d}x_M \\\\\n", "\\textbf{MP}:&\\;\\;\\hat{y} = \\arg\\max_{x_1,\\ldots,x_M} \\overrightarrow{\\mu}_{X_1}(x_1) \\cdots \\overrightarrow{\\mu}_{X_M}(x_M)\\, f(x_1,\\ldots,x_M,y) \\\\\n", "\\textbf{E-log}:&\\;\\;\\overrightarrow{\\eta}(y) = \\int p(x_1,\\ldots,x_M | y^{(k)})\\,\\log f(x_1,\\ldots,x_M,y) \\, \\mathrm{d}x_1 \\ldots \\mathrm{d}x_M \n", "\\end{align*}$$\n", "\n", "where $p(x_1,\\ldots,x_M | y^{(k)}) \\triangleq \\frac{\\overrightarrow{\\mu}_{X_1}(x_1) \\cdots \\overrightarrow{\\mu}_{X_M}(x_M)\\, f(x_1,\\ldots,x_M,\\hat{y}^{(k)})}{\\int \\overrightarrow{\\mu}_{X_1}(x_1) \\cdots \\overrightarrow{\\mu}_{X_M}(x_M)\\, f(x_1,\\ldots,x_M,\\hat{y}^{(k)}) \\, \\mathrm{d}x_1 \\ldots \\mathrm{d}x_M}$.\n", "\n", "- **Exercise**: proof the generic E-log message update rule.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A Snag for EM Message Passing on Deterministic Nodes\n", "\n", "- The factors for deterministic nodes are (Dirac) delta functions, e.g., $\\delta(y-\\theta x)$ for the multiplier.\n", "\n", "- Note that the outgoing E-log message for a deterministic node will also be a delta function, since the expectation of $\\log \\delta(\\cdot)$ is again a delta function. For details, consult [Dauwels et al. (2009)](./files/Dauwels-2009-Expectation-Maximization-as-Message-Passing.pdf) pg.5, section F. \n", "\n", "- This would stall the iterative estimation process at the current estimate since the outgoing E-log message would express complete certainty about the estimate. \n", "\n", "- This issue can be resolved by closing a box around a subgraph that includes (the deterministic node) $f$ and _at least one non-deterministic factor_. EM message passing can now proceed with the newly created node." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### A Solution for the Multiplier Node with Unknown Coefficient\n", "\n", "- We get back no to the original problem in this lesson. Consider again the (scalar) multiplier with unknown coefficient $f(x,y,\\theta) = \\delta(y-\\theta x)$ and incoming messages $\\overrightarrow{\\mu_X}(x) = \\mathcal{N}(x|m_x,v_x)$ and $\\overleftarrow{\\mu_Y}(y) = \\mathcal{N}(y|m_y,v_y)$. \n", "\n", "We will now compute the outgoing E-log message for $\\Theta$.\n", "\n", "- Since $f(x,y,\\theta)$ is deterministic, we will first group $f$ with the (non-deterministic) node $\\overleftarrow{\\mu_Y}(y) = \\mathcal{N}(y|m_y,v_y)$, leading (through sum-product rule) to \n", "$$\\begin{align*}\n", "g(x,\\theta) &\\triangleq \\int \\overleftarrow{\\mu_Y}(y)\\, f(x,y,\\theta) \\,\\mathrm{d}y \\\\\n", " &= \\int \\mathcal{N}(y|m_y,v_y)\\, \\delta(y-\\theta x) \\,\\mathrm{d}y \\\\\n", " &= \\mathcal{N}(\\theta x\\mid m_y,v_y)\\,.\n", "\\end{align*}$$\n", "\n", "\n", "- The problem now is to pass an E-log message out of $g(x,\\theta)$. Assume that $g$ has received an estimate $\\hat{\\theta}$ from the incoming message over the $\\Theta$ edge. The E-log update rule then prescribes\n", "$$\\begin{align*}\n", "\\eta(\\theta) &= \\mathbb{E}\\left[ \\log g(x,\\theta) \\right] \\\\\n", " &= \\mathbb{E}\\left[ \\mathcal{N}(\\theta x|m_y,v_y) \\right] \\\\\n", " &= \\text{const.} - \\frac{1}{2v_y}\\, \\left( \\mathbb{E}[X^2] \\theta^2 - 2 m_y \\mathbb{E}[X] \\theta + m_y^2\\right) \\\\\n", " &\\propto \\mathcal{N}_{\\xi} \\left( \\theta \\,\\bigg|\\, \\frac{m_y \\mathbb{E}\\left[X\\right]}{v_y}, \\frac{\\mathbb{E}\\left[X^2\\right]}{v_y} \\right) \n", "\\end{align*}$$\n", "where we used the 'canonical' parametrization of the Gaussian $\\mathcal{N}_{\\xi}(\\theta \\mid\\xi,w) \\propto \\exp \\left( \\xi \\theta- \\frac{1}{2} w \\theta^2\\right)$. \n", "\n", "- In the E-log message update rule, the expections $\\mathbb{E}\\left[X\\right]$ and $\\mathrm{E}\\left[X^2\\right]$ have to be taken w.r.t. $ p(x|\\hat{\\theta}) = \\overrightarrow{\\mu_X}(x)\\,g(x,\\hat{\\theta})$ (consult the generic E-log update rule). A straightforward (but rather painful) derivation leads to \n", "$$\\begin{align*}\n", "p(x \\mid \\hat{\\theta}) &= \\overrightarrow{\\mu_X}(x)\\,g(x,\\hat{\\theta}) \\\\\n", " &= \\mathcal{N}(x \\mid m_x,v_x)\\cdot \\mathcal{N}(\\hat{\\theta} x \\mid m_y,v_y) \\\\\n", " &= \\mathcal{N}(x \\mid m_x,v_x)\\cdot \\mathcal{N}\\left(x \\,\\bigg| \\,\\frac{m_y}{\\hat{\\theta} },\\frac{v_y}{\\hat{\\theta^2}} \\right) \\\\\n", " &\\propto \\mathcal{N_\\xi}( x \\mid \\xi_g , w_g)\n", "\\end{align*}$$\n", "where $w_g = \\frac{1}{v_x} + \\frac{\\hat{\\theta^2}}{v_y}$ and $\\xi_g \\triangleq w_g m_g = \\frac{m_x}{v_x}+\\frac{\\hat{\\theta}m_y}{v_y}$. It follows that \n", "$$\\begin{align*}\n", "\\mathbb{E}\\left[X\\right] &= m_g \\\\\n", "\\mathbb{E}\\left[X^2\\right] &= m_g^2 + w_g^{-1}\n", "\\end{align*}$$\n", "\n", "- $\\Rightarrow$ The E-log update formula may not be fun to derive, but the result is very pleasing: the **E-log message for the multiplier with unknown coefficient is a Gaussian message** with closed-form expressions for its parameters! See also [Dauwels et al. (2009)](./files/Dauwels-2009-Expectation-Maximization-as-Message-Passing.pdf) Table-2, pg.6. \n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Automating Inference\n", "\n", "- It follows that, for a dynamical system with unknown coefficients, both state estimation and parameter learning can be achieved through Gaussian message passing based on SP and EM message update rules.\n", "\n", "- These (SP and EM) message update rules can be tabularized and implemented in software for a large set of factors that are common in probabilistic models. (See the tables in [Loeliger et al. (2007)](./files/Loeliger-2007-The-factor-graph-approach-to-model-based-signal-processing.pdf) and [Dauwels et al. (2009)](./files/Dauwels-2009-Expectation-Maximization-as-Message-Passing.pdf)).\n", "\n", "- Tabulated SP and EM messages for frequently occuring factors facilitate the automated derivation of nontrivial inference algorithms.\n", "\n", "- This makes it possible to automate inference for state and parameter estimation in very complex probabilistic model. Here (in the SPS group at TU/e), we are developing such a factor graph toolbox in [Julia](http://julialang.org/). \n", "\n", "- There is lots more to say about factor graphs. This is a very exciting area of research that promises both \n", " 1. to consolidate a wide range of signal processing and machine learning algorithms in one elegant framework \n", " 2. to automate inference and learning in new models that have previously been untractable for existing machine learning methods.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Example: Linear Dynamical Systems\n", "\n", "As before let us consider the linear dynamical system (LDS)\n", "\n", "$$\\begin{align*}\n", " z_n &= A z_{n-1} + w_n \\\\\n", " x_n &= C z_n + v_n \\\\\n", " w_n &\\sim \\mathcal{N}(0,\\Sigma_w) \\\\\n", " v_n &\\sim \\mathcal{N}(0,\\Sigma_v)\n", "\\end{align*}$$\n", "\n", "Again, we will consider the case where $x_n$ is observed and $z_n$ is a hidden state. $C$, $\\Sigma_w$ and $\\Sigma_v$ are given parameters but in contrast to the previous section, we will assume that the value of parameter $A$ is unknown. " ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "-----\n", "_The cell below loads the style file_" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "open(\"../../styles/aipstyle.html\") do f\n", " display(\"text/html\", readstring(f))\n", "end" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Julia 0.6.2", "language": "julia", "name": "julia-0.6" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "0.6.2" } }, "nbformat": 4, "nbformat_minor": 1 }