{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Factor Graphs" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "\n", "### Preliminaries\n", "\n", "- Goal \n", " - Introduction to Forney-style factor graphs and message passing algorithms\n", "- Materials \n", " - Mandatory\n", " - These lecture notes \n", " - [Loeliger, 2007](./files/Loeliger-2007-The-factor-graph-approach-to-model-based-signal-processing.pdf), pp. 1295-1300 (until section IV)\n", " - Optional\n", " - [Video lecture](https://www.youtube.com/watch?v=Fv2YbVg9Frc&t=31) by Frederico Wadehn (ETH Zurich) (**highly recommended**)\n", " \n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Why Factor Graphs?\n", "\n", "- A probabilistic inference task gets its computational load mainly through the need for marginalization (i.e., computing integrals). E.g., for a generative model $p(x_1,x_2,x_3,x_4,x_5)$, the inference task $p(x_2|x_3)$ is given by \n", "\n", "$$\\begin{align*}\n", "p(x_2|x_3) = \\frac{\\int p(x_1,x_2,x_3,x_4,x_5) \\, \\mathrm{d}x_1 \\mathrm{d}x_4 \\mathrm{d}x_5}{\\int p(x_1,x_2,x_3,x_4,x_5) \\, \\mathrm{d}x_1 \\mathrm{d}x_2 \\mathrm{d}x_4 \\mathrm{d}x_5}\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Since these computations suffer from the \"curse of dimensionality\", we often need to solve a simpler problem in order to get an answer. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Factor graphs provide an computationally efficient approach to solving inference problems **if the generative distribution can be factorized**. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Factorization helps. For instance, if $p(x_1,x_2,x_3,x_4,x_5) = p(x_1)p(x_2,x_3)p(x_4)p(x_5|x_4)$, then\n", "$$\\begin{align*}\n", "p(x_2|x_3) &= \\frac{\\int p(x_1)p(x_2,x_3)p(x_4)p(x_5|x_4) \\, \\mathrm{d}x_1 \\mathrm{d}x_4 \\mathrm{d}x_5}{\\int p(x_1)p(x_2,x_3)p(x_4)p(x_5|x_4) \\, \\mathrm{d}x_1 \\mathrm{d}x_2 \\mathrm{d}x_4 \\mathrm{d}x_5} \n", " = \\frac{p(x_2,x_3)}{\\int p(x_2,x_3) \\mathrm{d}x_2}\n", "\\end{align*}$$\n", "which is computationally much cheaper than the general case above." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- In this lesson, we discuss how computationally efficient inference in factorized probability distributions can be automated." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Factor Graph Construction Rules\n", "\n", "- Consider a function \n", "$$\n", "f(x_1,x_2,x_3,x_4,x_5) = f_a(x_1,x_2,x_3) \\cdot f_b(x_3,x_4,x_5) \\cdot f_c(x_4)\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- The factorization of this function can be graphically represented by a **Forney-style Factor Graph** (FFG):\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- An FFG is an **undirected** graph subject to the following construction rules ([Forney, 2001](http://ieeexplore.ieee.org/xpl/login.jsp?tp=&arnumber=910573&url=http%3A%2F%2Fieeexplore.ieee.org%2Fiel5%2F18%2F19638%2F00910573.pdf%3Farnumber%3D910573))\n", "\n", " 1. A **node** for every factor;\n", " 1. An **edge** (or **half-edge**) for every variable;\n", " 1. Node $g$ is connected to edge $x$ **iff** variable $x$ appears in factor $g$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Some FFG Terminology\n", "\n", "- $f$ is called the **global function** and $f_\\bullet$ are the **factors**. \n", "\n", "- A **configuration** is an assigment of values to all variables.\n", "\n", "- The **configuration space** is the set of all configurations, i.e., the domain of $f$\n", "\n", "- A configuration $\\omega=(x_1,x_2,x_3,x_4,x_5)$ is said to be **valid** iff $f(\\omega) \\neq 0$\n", " " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Equality Nodes for Branching Points\n", "\n", "\n", "- Note that a variable can appear in maximally two factors in an FFG (since an edge has only two end points)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Consider the factorization (where $x_2$ appears in three factors) \n", "\n", "$$\n", " f(x_1,x_2,x_3,x_4) = f_a(x_1,x_2)\\cdot f_b(x_2,x_3) \\cdot f_c(x_2,x_4)\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- For the factor graph representation, we will instead consider the function $g$, defined as\n", "$$\\begin{align*}\n", " g(x_1,x_2&,x_2^\\prime,x_2^{\\prime\\prime},x_3,x_4) \n", " = f_a(x_1,x_2)\\cdot f_b(x_2^\\prime,x_3) \\cdot f_c(x_2^{\\prime\\prime},x_4) \\cdot f_=(x_2,x_2^\\prime,x_2^{\\prime\\prime})\n", "\\end{align*}$$\n", " where \n", "$$\n", "f_=(x_2,x_2^\\prime,x_2^{\\prime\\prime}) \\triangleq \\delta(x_2-x_2^\\prime)\\, \\delta(x_2-x_2^{\\prime\\prime})\n", "$$\n", "\n", " \n", " " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Equality Nodes for Branching Points, cont'd\n", "\n", "- Note that through introduction of auxiliary variables $X_2^\\prime$ and $X_2^{\\prime\\prime}$ each variable in $g$ appears in maximally two factors. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- The constraint $f_=(x,x^\\prime,x^{\\prime\\prime})$ enforces that $X=X^\\prime=X^{\\prime\\prime}$ **for every valid configuration**." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Since $f$ is a marginal of $g$, i.e., \n", "$$\n", "f(x_1,x_2,x_3,x_4) = \\int g(x_1,x_2,x_2^\\prime,x_2^{\\prime\\prime},x_3,x_4)\\, \\mathrm{d}x_2^\\prime \\mathrm{d}x_2^{\\prime\\prime}\n", "$$\n", "it follows that any inference problem on $f$ can be executed by a corresponding inference problem on $g$, e.g.,\n", "$$\\begin{align*}\n", "f(x_1 \\mid x_2) &\\triangleq \\frac{\\int f(x_1,x_2,x_3,x_4) \\,\\mathrm{d}x_3 \\mathrm{d}x_4 }{ \\int f(x_1,x_2,x_3,x_4) \\,\\mathrm{d}x_1 \\mathrm{d}x_3 \\mathrm{d}x_4} \\\\\n", " &= \\frac{\\int g(x_1,x_2,x_2^\\prime,x_2^{\\prime\\prime},x_3,x_4) \\,\\mathrm{d}x_2^\\prime \\mathrm{d}x_2^{\\prime\\prime} \\mathrm{d}x_3 \\mathrm{d}x_4 }{ \\int g(x_1,x_2,x_2^\\prime,x_2^{\\prime\\prime},x_3,x_4) \\,\\mathrm{d}x_1 \\mathrm{d}x_2^\\prime \\mathrm{d}x_2^{\\prime\\prime} \\mathrm{d}x_3 \\mathrm{d}x_4} \\\\\n", " &\\triangleq g(x_1 \\mid x_2)\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- $\\Rightarrow$ **Any factorization of a global function $f$ can be represented by a Forney-style Factor Graph**." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Probabilistic Models as Factor Graphs\n", "\n", "- FFGs can be used to express conditional independence (factorization) in probabilistic models. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- For example, the (previously shown) graph for \n", "$f_a(x_1,x_2,x_3) \\cdot f_b(x_3,x_4,x_5) \\cdot f_c(x_4)$ \n", "could represent the probabilistic model\n", "$$\n", "p(x_1,x_2,x_3,x_4,x_5) = p(x_1,x_2|x_3) \\cdot p(x_3,x_5|x_4) \\cdot p(x_4)\n", "$$\n", "where we identify \n", "$$\\begin{align*}\n", "f_a(x_1,x_2,x_3) &= p(x_1,x_2|x_3) \\\\\n", "f_b(x_3,x_4,x_5) &= p(x_3,x_5|x_4) \\\\\n", "f_c(x_4) &= p(x_4)\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- This is the graph\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Inference by Closing Boxes\n", "\n", "- Factorizations provide opportunities to cut on the amount of needed computations when doing inference. In what follows, we will use FFGs to process these opportunities in an automatic way (i.e., by message passing). " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Assume we wish to compute the marginal\n", "$$\n", "\\bar{f}(x_3) = \\sum_{x_1,x_2,x_4,x_5,x_6,x_7}f(x_1,x_2,\\ldots,x_7)\n", "$$\n", "where $f$ is factorized as given by the following FFG\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- Due to the factorization, we can decompose this sum by the **distributive law** as\n", "$$\\begin{align*}\n", "\\bar{f}(x_3) = & \\underbrace{ \\left( \\sum_{x_1,x_2} f_a(x_1)\\,f_b(x_2)\\,f_c(x_1,x_2,x_3)\\right) }_{\\overrightarrow{\\mu}_{X_3}(x_3)} \\\\\n", " & \\underbrace{ \\cdot\\left( \\sum_{x_4,x_5} f_d(x_4)\\,f_e(x_3,x_4,x_5) \\cdot \\underbrace{ \\left( \\sum_{x_6,x_7} f_f(x_5,x_6,x_7)\\,f_g(x_7)\\right) }_{\\overleftarrow{\\mu}_{X_5}(x_5)} \\right) }_{\\overleftarrow{\\mu}_{X_3}(x_3)}\n", "\\end{align*}$$\n", "which is computationally (much) lighter than executing the full sum $\\sum_{x_1,\\ldots,x_7}f(x_1,x_2,\\ldots,x_7)$\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- Messages may flow in both directions on any edge (here on $X_3$). We often draw _directed edges_ in a FFG in order to distinguish forward messages $\\overrightarrow{\\mu}_\\bullet(\\cdot)$ (in the same direction as the arrow of the edge) from backward messages $\\overleftarrow{\\mu}_\\bullet(\\cdot)$ (in opposite direction). With directed edges, the FFG looks as follows: \n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Crucially, drawing arrows on edges is only meant as a notational convenience. Technically, an FFG is an undirected graph. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Note that $\\overleftarrow{\\mu}_{X_5}(x_5)$ is obtained by multiplying all enclosed factors ($f_f$, $f_g$) by the green dashed box, followed by marginalization over all enclosed variables ($x_6$, $x_7$). " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- This is the **Closing the Box**-rule, which is a general recipe for marginalization of hidden variables and leads to a new factor with outgoing (sum-product) message \n", "$$ \\mu_{\\text{SP}} = \\sum_{ \\stackrel{ \\textrm{enclosed} }{ \\textrm{variables} } } \\;\\prod_{\\stackrel{ \\textrm{enclosed} }{ \\textrm{factors} }}\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Evaluating the closing-the-box rule for individual nodes\n", " \n", "\n", "- First, closing a box around the **terminal nodes** leads to $\\overrightarrow{\\mu}_{X_1}(x_1) \\triangleq f_a(x_1)$, $\\overrightarrow{\\mu}_{X_2}(x_2) \\triangleq f_b(x_2)$ etc. \n", " - So, the message out of a terminal node is the factor itself.\n", " - (Exercise) Derive now that the message coming from the open end of a half-edge always equals $1$. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- The messages from **internal nodes** evaluate to:\n", "$$\\begin{align*}\n", "\\overrightarrow{\\mu}_{X_3}(x_3) &= \\sum_{x_1,x_2} f_a(x_1) \\,f_b(x_2) \\,f_c(x_1,x_2,x_3) \\\\\n", " &= \\sum_{x_1,x_2} \\overrightarrow{\\mu}_{X_1}(x_1) \\overrightarrow{\\mu}_{X_2}(x_2) \\,f_c(x_1,x_2,x_3) \\\\\n", "\\overleftarrow{\\mu}_{X_5}(x_5) &= \\sum_{x_6,x_7} f_f(x_5,x_6,x_7)\\,f_g(x_7) \\\\\n", " &= \\sum_{x_6,x_7} \\overrightarrow{\\mu}_{X_7}(x_7)\\, f_f(x_5,x_6,x_7) \\\\\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Crucially, all message update rules can be computed from information that is **locally available** at each node." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Sum-Product Algorithm\n", "\n", "- (**Sum-Product update rule**). This recursive pattern for computing messages applies generally and is called the **Sum-Product update rule**, which is really just a special case of the closing-the-box rule: For any node, the outgoing message is obtained by taking the product of all incoming messages and the node function, followed by summing out (marginalization) all incoming variables. What is left (the outgoing message) is a function of the outgoing variable only: \n", "\n", "$$ \\boxed{\n", "\\overrightarrow{\\mu}_{Y}(y) = \\sum_{x_1,\\ldots,x_n} \\overrightarrow{\\mu}_{X_1}(x_1)\\cdots \\overrightarrow{\\mu}_{X_n}(x_n) \\,f(y,x_1,\\ldots,x_n) }\n", "$$\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- (**Sum-Product Theorem**). If the factor graph for a function $f$ has **no cycles**, then the marginal $\\bar{f}(x_3) = \\sum_{x_1,x_2,x_4,x_5,x_6,x_7}f(x_1,x_2,\\ldots,x_7)$ is given by the Sum-Product Theorem:\n", "\n", "$$ \\boxed{\n", "\\bar{f}(x_3) = \\overrightarrow{\\mu}_{X_3}(x_3)\\cdot \\overleftarrow{\\mu}_{X_3}(x_3)}\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- (**Sum-Product Algorithm**). If folows that the marginal $\\bar{f}(x_3) = \\sum_{x_1,x_2,x_4,x_5,x_6,x_7}f(x_1,x_2,\\ldots,x_7)$ can be efficiently computed through sum-product messages. Executing inference through SP message passing is called the **Sum-Product Algorithm**." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- (Exercise) Verfiy for yourself that all maginals in a cycle-free graph (a tree) can be computed exactly by starting with messages at the terminals and working towards the root of the tree." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Processing Observations in a Factor Graph\n", "\n", " - Terminal nodes can be used describe **observed variables**, e.g., use a factor $$f_Y(y) = \\delta(y-3)$$ to terminate the edge for variable $Y$ if $y=3$ is observed.\n", " " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ " \n", " ##### Example\n", " \n", " - Consider a generative model \n", "$$p(x,y_1,y_2) = p(x)\\,p(y_1|x)\\,p(y_2|x) .$$ \n", " - This model expresses the assumption that $Y_1$ and $Y_2$ are independent measurements of $X$.\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ " \n", "- Assume that we are interested in the posterior for $X$ after observing $Y_1= \\hat y_1$ and $Y_2= \\hat y_2$. The posterior for $X$ can be inferred by applying the sum-product algorithm to the following graph:\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ " - (Note that) we usually draw terminal nodes for observed variables in the graph by smaller solid-black squares. This is just to help the visualization of the graph, since the computational rules are no different than for other nodes. " ] }, { "cell_type": "markdown", "metadata": { "collapsed": true, "slideshow": { "slide_type": "fragment" } }, "source": [ "- (Exercise) Can you draw the messages that infer $p(x\\,|\\,y_1,y_2)$?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "#### CODE EXAMPLE\n", "\n", "We'll use ForneyLab, a factor graph toolbox for Julia, to build the above graph, and perform sum-product message passing to infer the posterior $p(x|y_1,y_2)$. We assume $p(y_1|x)$ and $p(y_2|x)$ to be Gaussian likelihoods with known variances:\n", "$$\\begin{align*}\n", " p(y_1\\,|\\,x) &= \\mathcal{N}(y_1\\,|\\,x, v_{y1}) \\\\\n", " p(y_2\\,|\\,x) &= \\mathcal{N}(y_2\\,|\\,x, v_{y2})\n", "\\end{align*}$$\n", "Under this model, the posterior is given by:\n", "$$\\begin{align*}\n", " p(x\\,|\\,y_1,y_2) &\\propto \\overbrace{p(y_1\\,|\\,x)\\,p(y_2\\,|\\,x)}^{\\text{likelihood}}\\,\\overbrace{p(x)}^{\\text{prior}} \\\\\n", " &=\\mathcal{N}(x\\,|\\,\\hat{y}_1, v_{y1})\\, \\mathcal{N}(x\\,|\\,\\hat{y}_2, v_{y2}) \\, \\mathcal{N}(x\\,|\\,m_x, v_x) \n", "\\end{align*}$$\n", "so we can validate the answer by solving the Gaussian multiplication manually." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sum-product message passing result: p(x|y1,y2) = 𝒩(m=1.14, v=0.57)\n", "\n", "Manual result: p(x|y1,y2) = N(m=1.14, V=0.57)\n" ] } ], "source": [ "using ForneyLab # version 0.7.1\n", "\n", "# Data\n", "y1_hat = 1.0\n", "y2_hat = 2.0\n", "\n", "# Construct the factor graph\n", "fg = FactorGraph()\n", "@RV x ~ GaussianMeanVariance(constant(0.0), constant(4.0), id=:x) # Node p(x)\n", "@RV y1 ~ GaussianMeanVariance(x, constant(1.0)) # Node p(y1|x)\n", "@RV y2 ~ GaussianMeanVariance(x, constant(2.0)) # Node p(y2|x)\n", "Clamp(y1, y1_hat) # Terminal (clamp) node for y1\n", "Clamp(y2, y2_hat) # Terminal (clamp) node for y2\n", "# draw(fg) # draw the constructed factor graph\n", "\n", "# Perform sum-product message passing\n", "eval(parse(sumProductAlgorithm(x, name=\"_algo1\"))) # Automatically derives a message passing schedule\n", "x_marginal = step_algo1!(Dict())[:x] # Execute algorithm and collect marginal distribution of x\n", "println(\"Sum-product message passing result: p(x|y1,y2) = $(x_marginal)\")\n", "\n", "# Calculate mean and variance of p(x|y1,y2) manually by multiplying 3 Gaussians (see lesson 4 for details)\n", "v = 1 / (1/4 + 1/1 + 1/2)\n", "m = v * (0/4 + y1_hat/1.0 + y2_hat/2.0)\n", "println(\"Manual result: p(x|y1,y2) = N(m=$(round(m,2)), V=$(round(v,2)))\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Example: SP Messages for the Equality Node\n", "\n", "- Let´s compute the SP messages for the **equality node** $f_=(x,y,z) = \\delta(z-x)\\delta(z-y)$: \n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$\\begin{align*}\n", "\\overrightarrow{\\mu}_{Z}(z) &= \\int \\overrightarrow{\\mu}_{X}(x) \\overrightarrow{\\mu}_{Y}(y) \\,\\delta(z-x)\\delta(z-y) \\,\\mathrm{d}x \\mathrm{d}y \\\\\n", " &= \\overrightarrow{\\mu}_{X}(z) \\int \\overrightarrow{\\mu}_{Y}(y) \\,\\delta(z-y) \\,\\mathrm{d}y \\\\\n", " &= \\overrightarrow{\\mu}_{X}(z) \\overrightarrow{\\mu}_{Y}(z) \n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- By symmetry, this also implies (for the same equality node) that\n", "\n", "$$\\begin{align*}\n", "\\overleftarrow{\\mu}_{X}(x) &= \\overrightarrow{\\mu}_{Y}(x) \\overleftarrow{\\mu}_{Z}(x) \\quad \\text{and} \\\\\n", "\\overleftarrow{\\mu}_{Y}(y) &= \\overrightarrow{\\mu}_{X}(y) \\overleftarrow{\\mu}_{Z}(y)\\,.\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- Let us now consider the case of Gaussian messages $\\overrightarrow{\\mu}_{X}(x) = \\mathcal{N}(\\overrightarrow{m}_X,\\overrightarrow{V}_X)$, $\\overrightarrow{\\mu}_{Y}(y) = \\mathcal{N}(\\overrightarrow{m}_Y,\\overrightarrow{V}_Y)$ and $\\overrightarrow{\\mu}_{Z}(z) = \\mathcal{N}(\\overrightarrow{m}_Z,\\overrightarrow{V}_Z)$. Let´s also define the precision matrices $\\overrightarrow{W}_X \\triangleq \\overrightarrow{V}_X^{-1}$ and similarly for $Y$ and $Z$. Then applying the SP update rule leads to multiplication of two Gaussian distributions, resulting in \n", "\n", "$$\\begin{align*}\n", "\\overrightarrow{W}_Z &= \\overrightarrow{W}_X + \\overrightarrow{W}_Y \\\\ \n", "\\overrightarrow{W}_Z \\overrightarrow{m}_z &= \\overrightarrow{W}_X \\overrightarrow{m}_X + \\overrightarrow{W}_Y \\overrightarrow{m}_Y\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- It follows that **message passing through an equality node is similar to applying Bayes rule**, i.e., fusion of two information sources. Does this make sense?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### (OPTIONAL SLIDE) Example: SP Messages for the Addition Nodes\n", "\n", "- Next, consider an **addition node** $f_+(x,y,z) = \\delta(z-x-y)$: \n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$\\begin{align*}\n", "\\overrightarrow{\\mu}_{Z}(z) &= \\int \\overrightarrow{\\mu}_{X}(x) \\overrightarrow{\\mu}_{Y}(y) \\,\\delta(z-x-y) \\,\\mathrm{d}x \\mathrm{d}y \\\\\n", " &= \\int \\overrightarrow{\\mu}_{X}(z) \\overrightarrow{\\mu}_{Y}(z-x) \\,\\mathrm{d}x \\,, \n", "\\end{align*}$$\n", "i.e., $\\overrightarrow{\\mu}_{Z}$ is the convolution of the messages $\\overrightarrow{\\mu}_{X}$ and $\\overrightarrow{\\mu}_{Y}$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Of course, for Gaussian messages, these update rules evaluate to\n", "\n", "$$\\begin{align*}\n", "\\overrightarrow{m}_Z = \\overrightarrow{m}_X + \\overrightarrow{m}_Y \\,,\\,\\text{and}\\,\\,\\overrightarrow{V}_z = \\overrightarrow{V}_X + \\overrightarrow{V}_Y \\,.\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "-
Exercise: For the same summation node, work out the SP update rule for the *backward* message $\\overleftarrow{\\mu}_{X}(x)$ as a function of $\\overrightarrow{\\mu}_{Y}(y)$ and $\\overleftarrow{\\mu}_{Z}(z)$? And further refine the answer for Gaussian messages.
" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### (OPTIONAL SLIDE) Example: SP Messages for Multiplication Nodes\n", "\n", "- Next, let us consider a **multiplication** by a fixed (invertable matrix) gain $f_A(x,y) = \\delta(y-Ax)$\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$\\begin{align*}\n", "\\overrightarrow{\\mu}_{Y}(y) = \\int \\overrightarrow{\\mu}_{X}(x) \\,\\delta(y-Ax) \\,\\mathrm{d}x = \\overrightarrow{\\mu}_{X}(A^{-1}y) \\,.\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- For a Gaussian message input message $\\overrightarrow{\\mu}_{X}(x) = \\mathcal{N}(\\overrightarrow{m}_{X},\\overrightarrow{V}_{X})$, the output message is also Gaussian with \n", "$$\\begin{align*}\n", "\\overrightarrow{m}_{Y} = A\\overrightarrow{m}_{X} \\,,\\,\\text{and}\\,\\,\n", "\\overrightarrow{V}_{Y} = A\\overrightarrow{V}_{X}A^T\n", "\\end{align*}$$\n", "since \n", "$$\\begin{align*}\n", "\\overrightarrow{\\mu}_{Y}(y) &= \\overrightarrow{\\mu}_{X}(A^{-1}y) \\\\\n", " &\\propto \\exp \\left( -\\frac{1}{2} \\left( A^{-1}y - \\overrightarrow{m}_{X}\\right)^T \\overrightarrow{V}_{X}^{-1} \\left( A^{-1}y - \\overrightarrow{m}_{X}\\right)\\right) \\\\\n", " &= \\exp \\left( -\\frac{1}{2} \\left( y - A\\overrightarrow{m}_{X}\\right)^T A^{-T}\\overrightarrow{V}_{X}^{-1} A \\left( y - A\\overrightarrow{m}_{X}\\right)\\right) \\\\\n", " &\\propto \\mathcal{N}(A\\overrightarrow{m}_{X},A\\overrightarrow{V}_{X}A^T) \\,.\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "-
Excercise: Proof that, for the same factor $\\delta(y-Ax)$ and Gaussian messages, the (backward) sum-product message $\\overleftarrow{\\mu}_{X}$ is given by \n", "$$\\begin{align*}\n", "\\overleftarrow{\\xi}_{X} &= A^T\\overleftarrow{\\xi}_{Y} \\\\\n", "\\overleftarrow{W}_{X} &= A^T\\overleftarrow{W}_{Y}A\n", "\\end{align*}$$\n", "where $\\overleftarrow{\\xi}_X \\triangleq \\overleftarrow{W}_X \\overleftarrow{m}_X$ and $\\overleftarrow{W}_{X} \\triangleq \\overleftarrow{V}_{X}^{-1}$ (and similarly for $Y$).
" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "#### CODE EXAMPLE\n", "\n", "Let's calculate the Gaussian forward and backward messages for the addition node in ForneyLab. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Forward message on Z: 𝒩(m=3.00, v=2.00)\n", "Backward message on X: 𝒩(m=1.00, v=2.00)\n" ] } ], "source": [ "# Forward message towards Z\n", "fg = FactorGraph()\n", "@RV x ~ GaussianMeanVariance(constant(1.0), constant(1.0), id=:x) \n", "@RV y ~ GaussianMeanVariance(constant(2.0), constant(1.0), id=:y)\n", "@RV z = x + y; z.id = :z\n", "\n", "eval(parse(sumProductAlgorithm(z, name=\"_z_fwd\")))\n", "msg_forward_Z = step_z_fwd!(Dict())[:z]\n", "print(\"Forward message on Z: $(msg_forward_Z)\")\n", "\n", "# Backward message towards X\n", "fg = FactorGraph()\n", "@RV x = Variable(id=:x)\n", "@RV y ~ GaussianMeanVariance(constant(2.0), constant(1.0), id=:y)\n", "@RV z = x + y\n", "GaussianMeanVariance(z, constant(3.0), constant(1.0), id=:z) \n", "\n", "eval(parse(sumProductAlgorithm(x, name=\"_x_bwd\")))\n", "msg_backward_X = step_x_bwd!(Dict())[:x]\n", "print(\"Backward message on X: $(msg_backward_X)\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "#### CODE EXAMPLE\n", "\n", "In the same way we can also investigate the forward and backward messages for the gain node " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Forward message on Y: 𝒩(m=4.00, v=16.00)\n" ] } ], "source": [ "# Forward message towards Y\n", "fg = FactorGraph()\n", "@RV x ~ GaussianMeanVariance(constant(1.0), constant(1.0), id=:x)\n", "@RV y = constant(4.0) * x; y.id = :y\n", "\n", "eval(parse(sumProductAlgorithm(y, name=\"_y_fwd\")))\n", "msg_forward_Y = step_y_fwd!(Dict())[:y]\n", "print(\"Forward message on Y: $(msg_forward_Y)\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Backward message on X: 𝒩(m=0.50, v=0.06)\n" ] } ], "source": [ "# Backward message towards X\n", "fg = FactorGraph()\n", "x = Variable(id=:x)\n", "@RV y = constant(4.0) * x\n", "GaussianMeanVariance(y, constant(2.0), constant(1.0))\n", "\n", "eval(parse(sumProductAlgorithm(x, name=\"_x_fwd2\")))\n", "msg_backward_X = step_x_fwd2!(Dict())[:x]\n", "print(\"Backward message on X: $(msg_backward_X)\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Example: Bayesian Linear Regression\n", "\n", "- Recall: the goal of regression is to estimate an unknown function from a set of (noisy) function values." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Assume we want to estimate some function $f: \\mathbb{R}^D \\rightarrow \\mathbb{R}$ from data set $D = \\{(x_1,y_1), \\ldots, (x_N,y_N)\\}$, where $y_i = f(x_i) + \\epsilon_i$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- We will assume a linear model with white Gaussian noise adn a Gaussian prior on the coefficients $w$:\n", "$$\\begin{align*}\n", " y_i &= w^T x_i + \\epsilon_i \\\\\n", " \\epsilon_i &\\sim \\mathcal{N}(0, \\sigma^2) \\\\ \n", " w &\\sim \\mathcal{N}(0,\\Sigma)\n", "\\end{align*}$$\n", "or equivalently\n", "$$\\begin{align*}\n", "p(D,w) &= \\overbrace{p(w)}^{\\text{weight prior}} \\prod_{i=1}^N \\overbrace{p(y_i\\,|\\,x_i,w,\\epsilon_i)}^{\\text{regression model}} \\overbrace{p(\\epsilon_i)}^{\\text{noise model}} \\\\\n", " &= \\mathcal{N}(w\\,|\\,0,\\Sigma) \\prod_{i=1}^N \\delta(y_i - w^T x_i - \\epsilon_i) \\mathcal{N}(\\epsilon_i\\,|\\,0,\\sigma^2) \n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- We are interested in inferring the posterior $p(w|D)$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- Here's the factor graph for this model\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "#### CODE EXAMPLE\n", "\n", "Let's build the factor graph in Julia (with the FFG toolbox ForneyLab)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "PyPlot.Figure(PyObject )" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "PyObject " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "using PyPlot\n", "include(\"scripts/innerproduct_node.jl\");\n", "\n", "# Parameters\n", "Σ = 1e5 * eye(3) # Covariance matrix of prior on w\n", "σ2 = 2.0 # Noise variance\n", "\n", "# Generate data set\n", "w = [1.0; 2.0; 0.25]\n", "N = 30\n", "z = 10.0*rand(N)\n", "x_train = [[1.0; z; z^2] for z in z] # Feature vector x = [1.0; z; z^2]\n", "f(x) = (w'*x)[1]\n", "y_train = map(f, x_train) + sqrt(σ2)*randn(N) # y[i] = w' * x[i] + ϵ\n", "scatter(z, y_train); xlabel(L\"z\"); ylabel(L\"f([1.0, z, z^2]) + \\epsilon\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "#### CODE EXAMPLE\n", "\n", "Perform sum-product message passing and plot result (mean of posterior)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "PyPlot.Figure(PyObject )" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Posterior distribution of w: 𝒩(m=[0.91, 2.26, 0.21], v=[[0.83, -0.32, 0.03][-0.32, 0.16, -0.02][0.03, -0.02, 1.59e-03]])\n", "\n" ] } ], "source": [ "# Build factorgraph\n", "fg = FactorGraph()\n", "@RV w ~ GaussianMeanVariance(constant(zeros(3)), constant(Σ, id=:Σ), id=:w) # p(w)\n", "for t=1:N\n", " x_t = Variable(id=:x_*t)\n", " d_t = Variable(id=:d_*t) # d=w'*x\n", " DotProduct(d_t, x_t, w) # p(f|w,x)\n", " @RV y_t ~ GaussianMeanVariance(d_t, constant(σ2, id=:σ2_*t), id=:y_*t) # p(y|d)\n", " placeholder(x_t, :x, index=t, dims=(3,))\n", " placeholder(y_t, :y, index=t);\n", "end\n", "\n", "# Build and run message passing algorithm\n", "eval(parse(sumProductAlgorithm(w)))\n", "data = Dict(:x => x_train, :y => y_train)\n", "w_posterior_dist = step!(data)[:w]\n", "\n", "# Plot result\n", "println(\"Posterior distribution of w: $(w_posterior_dist)\")\n", "scatter(z, y_train); xlabel(L\"z\"); ylabel(L\"f([1.0, z, z^2]) + \\epsilon\");\n", "z_test = collect(0:0.2:12)\n", "x_test = [[1.0; z; z^2] for z in z_test]\n", "for sample=1:10\n", " w = ForneyLab.sample(w_posterior_dist)\n", " f_est(x) = (w'*x)[1]\n", " plot(z_test, map(f_est, x_test), \"k-\", alpha=0.3);\n", "end" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Homework Exercises \n", "\n", "- (Ex.1) Reflect on the fact that we now have methods for both marginalization and processing observations in FFGs. In principle, we are sufficiently equipped to do inference in probabilistic models through message passing. Draw the graph for $$p(x_1,x_2,x_3)=f_a(x_1)\\cdot f_b(x_1,x_2)\\cdot f_c(x_2,x_3)$$ and show which boxes need to be closed for computing $p(x_1|x_2)$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- (Ex.2) Consider a variable $X$ with measurements $D=\\{x_1,x_2\\}$. We assume the following model for $X$:\n", "$$\\begin{align*}\n", "p(D,\\theta) &= p(\\theta)\\cdot \\prod_{n=1}^2 p(x_n|\\theta) \\\\\n", "p(\\theta) &= \\mathcal{N}(\\theta \\mid 0,1) \\\\\n", "p(x_n \\mid\\theta) &= \\mathcal{N}(x_n \\mid \\theta,1)\n", "\\end{align*}$$\n", " - Draw the factor graph and infer $\\theta$ through the Sum-Product Algorithm. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Inference in Linear Gaussian Models by Sum-Product Message Passing\n", "\n", "- The foregoing message update rules can be extended to all scenarios involving additions, fixed-gain multiplications and branching (equality nodes), thus creating a completely **automatable inference framework** for factorized linear Gaussian models." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- The update rules for elementary and important node types can be put in a Table (see **Tables 1 through 6** in [Loeliger, 2007](./files/Loeliger-2007-The-factor-graph-approach-to-model-based-signal-processing.pdf))." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- If the update rules for all node types in a graph have been tabulated, then inference by message passing comes down to a set of table-lookup operations. This also works for large graphs (where 'manual' inference becomes intractable)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- If the graph contains no cycles, the Sum-Product Algorithm computes **exact** marginals for all hidden variables." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- If the graph contains cycles, we have in principle an infinite tree without terminals. In this case, the SP Algorithm is not guaranteed to find exact marginals. In practice, if we apply the SP algorithm for just a few iterations we often find satisfying approximate marginals. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "skip" } }, "source": [ "The cell below loads the style file\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "open(\"../../styles/aipstyle.html\") do f display(\"text/html\", readstring(f)) end\n" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Julia 0.6.2", "language": "julia", "name": "julia-0.6" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "0.6.2" } }, "nbformat": 4, "nbformat_minor": 1 }