{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Factor Graphs" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "\n", "### Preliminaries\n", "\n", "- Goal \n", " - Introduction to Forney-style factor graphs and message passing-based inference \n", "- Materials \n", " - Mandatory\n", " - These lecture notes \n", " - Loeliger (2007), [The factor graph approach to model based signal processing](https://github.com/bertdv/BMLIP/blob/master/lessons/notebooks/files/Loeliger-2007-The-factor-graph-approach-to-model-based-signal-processing.pdf), pp. 1295-1302 (until section V)\n", " - Optional\n", " - Frederico Wadehn (2015), [Probabilistic graphical models: Factor graphs and more](https://www.youtube.com/watch?v=Fv2YbVg9Frc&t=31) video lecture (**recommended**)\n", " - References\n", " - Forney (2001), [Codes on graphs: normal realizations](https://github.com/bertdv/BMLIP/blob/master/lessons/notebooks/files/Forney-2001-Codes-on-graphs-normal-realizations.pdf)\n", " " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Why Factor Graphs?\n", "\n", "- A probabilistic inference task gets its computational load mainly through the need for marginalization (i.e., computing integrals). E.g., for a model $p(x_1,x_2,x_3,x_4,x_5)$, the inference task $p(x_2|x_3)$ is given by \n", "$$\n", "p(x_2|x_3) = \\frac{p(x_2,x_3)}{p(x_3)} = \\frac{\\int \\cdots \\int p(x_1,x_2,x_3,x_4,x_5) \\, \\mathrm{d}x_1 \\mathrm{d}x_4 \\mathrm{d}x_5}{\\int \\cdots \\int p(x_1,x_2,x_3,x_4,x_5) \\, \\mathrm{d}x_1 \\mathrm{d}x_2 \\mathrm{d}x_4 \\mathrm{d}x_5}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Since these computations (integrals or sums) suffer from the \"curse of dimensionality\", we often need to solve a simpler problem in order to get an answer. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Factor graphs provide a computationally efficient approach to solving inference problems **if the probabilistic model can be factorized**. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Factorization helps. For instance, if $p(x_1,x_2,x_3,x_4,x_5) = p(x_1)p(x_2,x_3)p(x_4)p(x_5|x_4)$, then\n", "\n", "$$\n", "p(x_2|x_3) = \\frac{\\int \\cdots \\int p(x_1)p(x_2,x_3)p(x_4)p(x_5|x_4) \\, \\mathrm{d}x_1 \\mathrm{d}x_4 \\mathrm{d}x_5}{\\int \\cdots \\int p(x_1)p(x_2,x_3)p(x_4)p(x_5|x_4) \\, \\mathrm{d}x_1 \\mathrm{d}x_2 \\mathrm{d}x_4 \\mathrm{d}x_5} \n", " = \\frac{p(x_2,x_3)}{\\int p(x_2,x_3) \\mathrm{d}x_2}\n", "$$\n", "\n", "which is computationally much cheaper than the general case above." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- In this lesson, we discuss how computationally efficient inference in *factorized* probability distributions can be automated by message passing-based inference in factor graphs." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Factor Graph Construction Rules\n", "\n", "- Consider a function \n", "$$\n", "f(x_1,x_2,x_3,x_4,x_5) = f_a(x_1,x_2,x_3) \\cdot f_b(x_3,x_4,x_5) \\cdot f_c(x_4)\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- The factorization of this function can be graphically represented by a **Forney-style Factor Graph** (FFG):\n", "\n", "

" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- An FFG is an **undirected** graph subject to the following construction rules ([Forney, 2001](https://github.com/bertdv/BMLIP/blob/master/lessons/notebooks/files/Forney-2001-Codes-on-graphs-normal-realizations.pdf))\n", "\n", " 1. A **node** for every factor;\n", " 1. An **edge** (or **half-edge**) for every variable;\n", " 1. Node $f_\\bullet$ is connected to edge $x$ **iff** variable $x$ appears in factor $f_\\bullet$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- A **configuration** is an assigment of values to all variables.\n", "\n", "- A configuration $\\omega=(x_1,x_2,x_3,x_4,x_5)$ is said to be **valid** iff $f(\\omega) \\neq 0$\n", " " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Equality Nodes for Branching Points\n", "\n", "\n", "- Note that a variable can appear in maximally two factors in an FFG (since an edge has only two end points)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Consider the factorization (where $x_2$ appears in three factors) \n", "\n", "$$\n", " f(x_1,x_2,x_3,x_4) = f_a(x_1,x_2)\\cdot f_b(x_2,x_3) \\cdot f_c(x_2,x_4)\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- For the factor graph representation, we will instead consider the function $g$, defined as\n", "$$\\begin{align*}\n", " g(x_1,x_2&,x_2^\\prime,x_2^{\\prime\\prime},x_3,x_4) \n", " = f_a(x_1,x_2)\\cdot f_b(x_2^\\prime,x_3) \\cdot f_c(x_2^{\\prime\\prime},x_4) \\cdot f_=(x_2,x_2^\\prime,x_2^{\\prime\\prime})\n", "\\end{align*}$$\n", " where \n", "$$\n", "f_=(x_2,x_2^\\prime,x_2^{\\prime\\prime}) \\triangleq \\delta(x_2-x_2^\\prime)\\, \\delta(x_2-x_2^{\\prime\\prime})\n", "$$\n", "\n", "

\n", " " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- Note that through introduction of auxiliary variables $X_2^{\\prime}$ and $X_2^{\\prime\\prime}$ and a factor $f_=(x_2,x_2^\\prime,x_2^{\\prime\\prime})$, each variable in $g$ appears in maximally two factors." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- The constraint $f_=(x,x^\\prime,x^{\\prime\\prime})$ enforces that $X=X^\\prime=X^{\\prime\\prime}$ **for every valid configuration**." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Since $f$ is a marginal of $g$, i.e., \n", "$$\n", "f(x_1,x_2,x_3,x_4) = \\iint g(x_1,x_2,x_2^\\prime,x_2^{\\prime\\prime},x_3,x_4)\\, \\mathrm{d}x_2^\\prime \\mathrm{d}x_2^{\\prime\\prime}\n", "$$\n", "it follows that any inference problem on $f$ can be executed by a corresponding inference problem on $g$, e.g.,\n", "$$\\begin{align*}\n", "f(x_1 \\mid x_2) &\\triangleq \\frac{\\iint f(x_1,x_2,x_3,x_4) \\,\\mathrm{d}x_3 \\mathrm{d}x_4 }{ \\int\\cdots\\int f(x_1,x_2,x_3,x_4) \\,\\mathrm{d}x_1 \\mathrm{d}x_3 \\mathrm{d}x_4} \\\\\n", " &= \\frac{\\int\\cdots\\int g(x_1,x_2,x_2^\\prime,x_2^{\\prime\\prime},x_3,x_4) \\,\\mathrm{d}x_2^\\prime \\mathrm{d}x_2^{\\prime\\prime} \\mathrm{d}x_3 \\mathrm{d}x_4 }{ \\int\\cdots\\int g(x_1,x_2,x_2^\\prime,x_2^{\\prime\\prime},x_3,x_4) \\,\\mathrm{d}x_1 \\mathrm{d}x_2^\\prime \\mathrm{d}x_2^{\\prime\\prime} \\mathrm{d}x_3 \\mathrm{d}x_4} \\\\\n", " &= g(x_1 \\mid x_2)\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- $\\Rightarrow$ **Any factorization of a global function $f$ can be represented by a Forney-style Factor Graph**." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Probabilistic Models as Factor Graphs\n", "\n", "- FFGs can be used to express conditional independence (factorization) in probabilistic models. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- For example, the (previously shown) graph for \n", "$f_a(x_1,x_2,x_3) \\cdot f_b(x_3,x_4,x_5) \\cdot f_c(x_4)$ \n", "could represent the probabilistic model\n", "$$\n", "p(x_1,x_2,x_3,x_4,x_5) = p(x_1,x_2|x_3) \\cdot p(x_3,x_5|x_4) \\cdot p(x_4)\n", "$$\n", "where we identify \n", "$$\\begin{align*}\n", "f_a(x_1,x_2,x_3) &= p(x_1,x_2|x_3) \\\\\n", "f_b(x_3,x_4,x_5) &= p(x_3,x_5|x_4) \\\\\n", "f_c(x_4) &= p(x_4)\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- This is the graph\n", "\n", "

" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Inference by Closing Boxes\n", "\n", "- Factorizations provide opportunities to cut on the amount of needed computations when doing inference. In what follows, we will use FFGs to process these opportunities in an automatic way by message passing between the nodes of the graph. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Assume we wish to compute the marginal\n", "$$\n", "\\bar{f}(x_3) \\triangleq \\sum\\limits_{x_1,x_2,x_4,x_5,x_6,x_7}f(x_1,x_2,\\ldots,x_7) \n", "$$\n", "for a model $f$ with given factorization \n", "$$\n", "f(x_1,x_2,\\ldots,x_7) = f_a(x_1) f_b(x_2) f_c(x_1,x_2,x_3) f_d(x_4) f_e(x_3,x_4,x_5) f_f(x_5,x_6,x_7) f_g(x_7)\n", "$$ " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Note that, if each variable $x_i$ can take on $10$ values, then the computing the marginal $\\bar{f}(x_3)$ takes about $10^6$ (1 million) additions. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- Due to the factorization and the [Generalized Distributive Law](https://en.wikipedia.org/wiki/Generalized_distributive_law), we can decompose this sum-of-products to the following product-of-sums:\n", "$$\\begin{align*}\\bar{f}&(x_3) = \\\\\n", " &\\underbrace{ \\Bigg( \\sum_{x_1,x_2} \\underbrace{f_a(x_1)}_{\\overrightarrow{\\mu}_{X_1}(x_1)}\\, \\underbrace{f_b(x_2)}_{\\overrightarrow{\\mu}_{X_2}(x_2)}\\,f_c(x_1,x_2,x_3)\\Bigg) }_{\\overrightarrow{\\mu}_{X_3}(x_3)} \n", " \\underbrace{ \\cdot\\Bigg( \\sum_{x_4,x_5} \\underbrace{f_d(x_4)}_{\\overrightarrow{\\mu}_{X_4}(x_4)}\\,f_e(x_3,x_4,x_5) \\cdot \\underbrace{ \\big( \\sum_{x_6,x_7} f_f(x_5,x_6,x_7)\\,\\underbrace{f_g(x_7)}_{\\overleftarrow{\\mu}_{X_7}(x_7)}\\big) }_{\\overleftarrow{\\mu}_{X_5}(x_5)} \\Bigg) }_{\\overleftarrow{\\mu}_{X_3}(x_3)}\n", "\\end{align*}$$\n", "which, in case $x_i$ has $10$ values, requires a few hundred additions and is therefore computationally (much!) lighter than executing the full sum $\\sum_{x_1,\\ldots,x_7}f(x_1,x_2,\\ldots,x_7)$\n", "\n", "

" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Note that the auxiliary factor $\\overrightarrow{\\mu}_{X_3}(x_3)$ is obtained by multiplying all enclosed factors ($f_a$, $f_b, f_c$) by the red dashed box, followed by marginalization (summing) over all enclosed variables ($x_1$, $x_2$).\n", " " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- This is the **Closing the Box**-rule, which is a general recipe for marginalization of latent variables (inside the box) and leads to a new factor that has the variables (edges) that cross the box as arguments. For instance, the argument of the remaining factor $\\overrightarrow{\\mu}_{X_3}(x_3)$ is the variable on the edge that crosses the red box ($x_3$).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Hence, $\\overrightarrow{\\mu}_{X_3}(x_3)$ can be interpreted as a **message from the red box toward variable** $x_3$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- We drew _directed edges_ in the FFG in order to distinguish forward messages $\\overrightarrow{\\mu}_\\bullet(\\cdot)$ (in the same direction as the arrow of the edge) from backward messages $\\overleftarrow{\\mu}_\\bullet(\\cdot)$ (in opposite direction). This is just a notational convenience since an FFG is computationally an undirected graph. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Sum-Product Algorithm\n", "\n", "- Closing-the-box can also be interpreted as a **message update rule** for an outgoing message from a node. For a node $f(y,x_1,\\ldots,x_n)$ with incoming messages $\\overrightarrow{\\mu}_{X_1}(x_1), \\overrightarrow{\\mu}_{X_1}(x_1), \\ldots,\\overrightarrow{\\mu}_{X_n}(x_n)$, the outgoing message is given by ([Loeliger (2007), pg.1299](https://github.com/bertdv/BMLIP/blob/master/lessons/notebooks/files/Loeliger-2007-The-factor-graph-approach-to-model-based-signal-processing.pdf)): \n", "\n", "$$ \\boxed{\n", "\\underbrace{\\overrightarrow{\\mu}_{Y}(y)}_{\\substack{ \\text{outgoing}\\\\ \\text{message}}} = \\sum_{x_1,\\ldots,x_n} \\underbrace{\\overrightarrow{\\mu}_{X_1}(x_1)\\cdots \\overrightarrow{\\mu}_{X_n}(x_n)}_{\\substack{\\text{incoming} \\\\ \\text{messages}}} \\cdot \\underbrace{f(y,x_1,\\ldots,x_n)}_{\\substack{\\text{node}\\\\ \\text{function}}} }\n", "$$\n", "\n", "

\n", "\n", "- This is called the **Sum-Product Message** (SPM) update rule. (Look at the formula to understand why it's called the SPM update rule)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Note that all SPM update rules can be computed from information that is **locally available** at each node." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- If the factor graph for a function $f$ has **no cycles** (i.e., the graph is a tree), then the marginal $\\bar{f}(x_3) = \\sum_{x_1,x_2,x_4,x_5,x_6,x_7}f(x_1,x_2,\\ldots,x_7)$ is given by multiplying the forward and backward messages on that edge:\n", "\n", "$$ \\boxed{\n", "\\bar{f}(x_3) = \\overrightarrow{\\mu}_{X_3}(x_3)\\cdot \\overleftarrow{\\mu}_{X_3}(x_3)}\n", "$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- It follows that the marginal $\\bar{f}(x_3) = \\sum_{x_1,x_2,x_4,x_5,x_6,x_7}f(x_1,x_2,\\ldots,x_7)$ can be efficiently computed through sum-product messages. Executing inference through SP message passing is called the **Sum-Product Algorithm** (or alternatively, the **belief propagation** algorithm)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Just as a final note, inference by sum-product message passing is much like replacing the sum-of-products\n", "$$\n", "ac + ad + bc + bd$$\n", "by the following product-of-sums:\n", "$$\n", "(a + b)(c + d) \\,.$$\n", "- Which of these two computations is cheaper to execute?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sum-Product Messages for the Equality Node\n", "\n", "- As an example, let´s evaluate the SP messages for the **equality node** $f_=(x,y,z) = \\delta(z-x)\\delta(z-y)$: \n", "\n", "

\n", "\n", "$$\\begin{align*}\n", "\\overrightarrow{\\mu}_{Z}(z) &= \\iint \\overrightarrow{\\mu}_{X}(x) \\overrightarrow{\\mu}_{Y}(y) \\,\\delta(z-x)\\delta(z-y) \\,\\mathrm{d}x \\mathrm{d}y \\\\\n", " &= \\overrightarrow{\\mu}_{X}(z) \\int \\overrightarrow{\\mu}_{Y}(y) \\,\\delta(z-y) \\,\\mathrm{d}y \\\\\n", " &= \\overrightarrow{\\mu}_{X}(z) \\overrightarrow{\\mu}_{Y}(z) \n", "\\end{align*}$$\n", "\n", "- By symmetry, this also implies (for the same equality node) that\n", "$$\\begin{align*}\n", "\\overleftarrow{\\mu}_{X}(x) &= \\overrightarrow{\\mu}_{Y}(x) \\overleftarrow{\\mu}_{Z}(x) \\quad \\text{and} \\\\\n", "\\overleftarrow{\\mu}_{Y}(y) &= \\overrightarrow{\\mu}_{X}(y) \\overleftarrow{\\mu}_{Z}(y)\\,.\n", "\\end{align*}$$\n", "\n", "- Let us now consider the case of Gaussian messages $\\overrightarrow{\\mu}_{X}(x) = \\mathcal{N}(x|\\overrightarrow{m}_X,\\overrightarrow{V}_X)$, $\\overrightarrow{\\mu}_{Y}(y) = \\mathcal{N}(y| \\overrightarrow{m}_Y,\\overrightarrow{V}_Y)$ and $\\overrightarrow{\\mu}_{Z}(z) = \\mathcal{N}(z|\\overrightarrow{m}_Z,\\overrightarrow{V}_Z)$. Let´s also define the precision matrices $\\overrightarrow{W}_X \\triangleq \\overrightarrow{V}_X^{-1}$ and similarly for $Y$ and $Z$. Then applying the SP update rule leads to multiplication of two Gaussian distributions (see [Roweis notes](https://github.com/bertdv/BMLIP/blob/master/lessons/notebooks/files/Roweis-1999-gaussian-identities.pdf)), resulting in \n", "\n", "$$\\begin{align*}\n", "\\overrightarrow{W}_Z &= \\overrightarrow{W}_X + \\overrightarrow{W}_Y \\qquad &\\text{(precisions add)}\\\\ \n", "\\overrightarrow{W}_Z \\overrightarrow{m}_z &= \\overrightarrow{W}_X \\overrightarrow{m}_X + \\overrightarrow{W}_Y \\overrightarrow{m}_Y \\qquad &\\text{(natural means add)}\n", "\\end{align*}$$\n", "\n", "- It follows that **message passing through an equality node is similar to applying Bayes rule**, i.e., fusion of two information sources. Does this make sense?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Message Passing Schedules\n", "\n", "- In a non-cyclic (ie, tree) graph, start with messages from the terminals and keep passing messages through the internal nodes towards the \"target\" variable ($x_3$ in above problem) until you have both the forward and backward message for the target variable. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- In a tree graph, if you continue to pass messages throughout the graph, the Sum-Product Algorithm computes **exact** marginals for all hidden variables." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- If the graph contains cycles, we have in principle an infinite tree by \"unrolling\" the graph. In this case, the SP Algorithm is not guaranteed to find exact marginals. In practice, if we apply the SP algorithm for just a few iterations (\"unrolls\"), then we often find satisfying approximate marginals. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Terminal Nodes and Processing Observations \n", "\n", "- We can use terminal nodes to represent observations, e.g., add a factor $f(y)=\\delta(y−3)$ to terminate the half-edge for variable $Y$ if $y=3$ is observed.\n", "\n", "

\n", "\n", "- Terminal nodes that carry observations are denoted by small black boxes." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "\n", " \n", "- The message out of a **terminal node** (attached to only 1 edge) is the factor itself. For instance, closing a box around terminal node $f_a(x_1)$ would lead to $$\\overrightarrow{\\mu}_{X_1}(x_1) \\triangleq \\sum_{ \\stackrel{ \\textrm{enclosed} }{ \\textrm{variables} } } \\;\\prod_{\\stackrel{ \\textrm{enclosed} }{ \\textrm{factors} }} f_a(x_1) = f_a(x_1)\\,$$\n", "since there are no enclosed variables. \n", "\n", "

\n", " " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- The message from a half-edge is $1$ (one). You can verify this by imagining that a half-edge $x$ can be terminated by a node function $f(x)=1$ without affecting any inference issue." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Automating Bayesian Inference by Message Passing\n", "\n", "- The foregoing message update rules can be worked out in closed-form and put into tables (e.g., see Tables 1 through 6 in [Loeliger (2007)](./files/Loeliger-2007-The-factor-graph-approach-to-model-based-signal-processing.pdf) for many standard factors such as essential probability distributions and operations such as additions, fixed-gain multiplications and branching (equality nodes).\n", "\n", "- In the optional slides below, we have worked out a few more update rules for the [addition node](#sp-for-addition-node) and the [multiplication node](#sp-for-multiplication-node).\n", "\n", "- If the update rules for all node types in a graph have been tabulated, then inference by message passing comes down to executing a set of table-lookup operations, thus creating a completely **automatable Bayesian inference framework**. \n", "\n", "- In our research lab [BIASlab](http://biaslab.org) (FLUX 7.060), we are developing [RxInfer](http://rxinfer.ml), which is a (Julia) toolbox for automating Bayesian inference by message passing in a factor graph.\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Example: Bayesian Linear Regression by Message Passing\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- Assume we want to estimate some function $f: \\mathbb{R}^D \\rightarrow \\mathbb{R}$ from a given data set $D = \\{(x_1,y_1), \\ldots, (x_N,y_N)\\}$, with model assumption $y_i = f(x_i) + \\epsilon_i$." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "#### model specification \n", "\n", "- We will assume a linear model with white Gaussian noise and a Gaussian prior on the coefficients $w$:\n", "$$\\begin{align*}\n", " y_i &= w^T x_i + \\epsilon_i \\\\\n", " \\epsilon_i &\\sim \\mathcal{N}(0, \\sigma^2) \\\\ \n", " w &\\sim \\mathcal{N}(0,\\Sigma)\n", "\\end{align*}$$\n", "or equivalently\n", "$$\\begin{align*}\n", "p(w,\\epsilon,D) &= \\overbrace{p(w)}^{\\text{weight prior}} \\prod_{i=1}^N \\overbrace{p(y_i\\,|\\,x_i,w,\\epsilon_i)}^{\\text{regression model}} \\overbrace{p(\\epsilon_i)}^{\\text{noise model}} \\\\\n", " &= \\mathcal{N}(w\\,|\\,0,\\Sigma) \\prod_{i=1}^N \\delta(y_i - w^T x_i - \\epsilon_i) \\mathcal{N}(\\epsilon_i\\,|\\,0,\\sigma^2) \n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "#### Inference (parameter estimation)\n", "\n", "- We are interested in inferring the posterior $p(w|D)$. We will execute inference by message passing on the FFG for the model." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- The left figure shows the factor graph for this model. \n", "- The right figure shows the message passing scheme. \n", " \n", "

\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "#### CODE EXAMPLE\n", "\n", "Let's solve this problem by message passing-based inference with Julia's FFG toolbox [RxInfer](https://biaslab.github.io/rxinfer-website/)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "using Pkg; Pkg.activate(\"../.\"); Pkg.instantiate();\n", "using IJulia; try IJulia.clear_output(); catch _ end" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "image/png": "", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ], "text/html": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "using Plots, LinearAlgebra, LaTeXStrings\n", "\n", "# Parameters\n", "Σ = 1e5 * Diagonal(I,3) # Covariance matrix of prior on w\n", "σ2 = 2.0 # Noise variance\n", "\n", "# Generate data set\n", "w = [1.0; 2.0; 0.25]\n", "N = 30\n", "z = 10.0*rand(N)\n", "x_train = [[1.0; z; z^2] for z in z] # Feature vector x = [1.0; z; z^2]\n", "f(x) = (w'*x)[1]\n", "y_train = map(f, x_train) + sqrt(σ2)*randn(N) # y[i] = w' * x[i] + ϵ\n", "scatter(z, y_train, label=\"data\", xlabel=L\"z\", ylabel=L\"f([1.0, z, z^2]) + \\epsilon\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "Now build the factor graph in RxInfer, perform sum-product message passing and plot results (mean of posterior)." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "image/png": "", "image/svg+xml": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ], "text/html": [ "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", " \n", " \n", " \n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Posterior distribution of w: MvNormalWeightedMeanPrecision(\n", "xi: [307.7818564544745, 2208.395106150794, 17693.181354856355]\n", "Λ: [15.00001 77.73321536791191 538.9138576517117; 77.73321536791191 538.9138676517117 4221.088404002255; 538.9138576517117 4221.088404002255 35248.048308516314]\n", ")\n", "\n" ] } ], "source": [ "using RxInfer, Random\n", "# Build model\n", "@model function linear_regression(y,x, N, Σ, σ2)\n", "\n", " w ~ MvNormalMeanCovariance(zeros(3),Σ)\n", " \n", " for i in 1:N\n", " y[i] ~ NormalMeanVariance(dot(w , x[i]), σ2)\n", " end\n", "end\n", "# Run message passing algorithm \n", "results = infer(\n", " model = linear_regression(N=length(x_train), Σ=Σ, σ2=σ2),\n", " data = (y = y_train, x = x_train),\n", " returnvars = (w = KeepLast(),),\n", " iterations = 20,\n", ");\n", "# Plot result\n", "w = results.posteriors[:w]\n", "println(\"Posterior distribution of w: $(w)\")\n", "plt = scatter(z, y_train, label=\"data\", xlabel=L\"z\", ylabel=L\"f([1.0, z, z^2]) + \\epsilon\")\n", "z_test = collect(0:0.2:12)\n", "x_test = [[1.0; z; z^2] for z in z_test]\n", "for i=1:10\n", " w_sample = rand(results.posteriors[:w])\n", " f_est(x) = (w_sample'*x)[1]\n", " plt = plot!(z_test, map(f_est, x_test), alpha=0.3, label=\"\");\n", "end\n", "display(plt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Final thoughts: Modularity and Abstraction \n", "\n", "- The great Michael Jordan (no, not [this one](https://youtu.be/cuLprHh_BRg), but [this one](https://people.eecs.berkeley.edu/~jordan/)), wrote: \n", "\n", " > \"I basically know of two principles for treating complicated systems in simple ways: the first is the principle of **modularity** and the second is the principle of **abstraction**. I am an apologist for computational probability in machine learning because I believe that probability theory implements these two principles in deep and intriguing ways — namely through factorization and through averaging. Exploiting these two mechanisms as fully as possible seems to me to be the way forward in machine learning.\" — Michael Jordan, 1997 (quoted in [Fre98](https://mitpress.mit.edu/9780262062022/)).\n", "\n", "- Factor graphs realize these ideas nicely, both visually and computationally.\n", "\n", "- Visually, the modularity of conditional independencies in the model are displayed by the graph structure. Each node hides internal complexity and by closing-the-box, we can hierarchically move on to higher levels of abstraction. \n", "\n", "- Computationally, message passing-based inference uses the Distributive Law to avoid any unnecessary computations. \n", "\n", "- What is the relevance of this lesson? RxInfer is not yet a finished project. Still, my prediction is that in 5-10 years, this lesson on Factor Graphs will be the final lecture of part-A of this class, aimed at engineers who need to develop machine learning applications. In principle you have all the tools now to work out the 4-step machine learning recipe (1. model specification, 2. parameter learning, 3. model evaluation, 4. application) that was proposed in the [Bayesian machine learning lesson](https://nbviewer.org/github/bertdv/BMLIP/blob/master/lessons/notebooks/Bayesian-Machine-Learning.ipynb#Bayesian-design). You can propose any model and execute the (learning, evaluation, and application) stages by executing the corresponding inference task automatically in RxInfer. \n", "\n", "- Part-B of this class would be about on advanced methods on how to improve automated inference by RxInfer or a similar probabilistic programming package. The Bayesian approach fully supports separating model specification from the inference task. \n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "##
OPTIONAL SLIDES
" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Sum-Product Messages for Multiplication Nodes\n", "- Next, let us consider a **multiplication** by a fixed (invertible matrix) gain $f_A(x,y) = \\delta(y-Ax)$\n", "\n", "

" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "$$\\begin{align*}\n", "\\overrightarrow{\\mu}_{Y}(y) &= \\int \\overrightarrow{\\mu}_{X}(x) \\,\\delta(y-Ax) \\,\\mathrm{d}x \\\\\n", "&= \\int \\overrightarrow{\\mu}_{X}(x) \\,|A|^{-1}\\delta(x-A^{-1}y) \\,\\mathrm{d}x \\\\\n", "&= |A|^{-1}\\overrightarrow{\\mu}_{X}(A^{-1}y) \\,.\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- For a Gaussian message input message $\\overrightarrow{\\mu}_{X}(x) = \\mathcal{N}(x|\\overrightarrow{m}_{X},\\overrightarrow{V}_{X})$, the output message is also Gaussian with \n", "$$\\begin{align*}\n", "\\overrightarrow{m}_{Y} = A\\overrightarrow{m}_{X} \\,,\\,\\text{and}\\,\\,\n", "\\overrightarrow{V}_{Y} = A\\overrightarrow{V}_{X}A^T\n", "\\end{align*}$$\n", "since \n", "$$\\begin{align*}\n", "\\overrightarrow{\\mu}_{Y}(y) &= |A|^{-1}\\overrightarrow{\\mu}_{X}(A^{-1}y) \\\\\n", " &\\propto \\exp \\left( -\\frac{1}{2} \\left( A^{-1}y - \\overrightarrow{m}_{X}\\right)^T \\overrightarrow{V}_{X}^{-1} \\left( A^{-1}y - \\overrightarrow{m}_{X}\\right)\\right) \\\\\n", " &= \\exp \\big( -\\frac{1}{2} \\left( y - A\\overrightarrow{m}_{X}\\right)^T \\underbrace{A^{-T}\\overrightarrow{V}_{X}^{-1} A^{-1}}_{(A \\overrightarrow{V}_{X} A^T)^{-1}} \\left( y - A\\overrightarrow{m}_{X}\\right)\\big) \\\\\n", " &\\propto \\mathcal{N}(y| A\\overrightarrow{m}_{X},A\\overrightarrow{V}_{X}A^T) \\,.\n", "\\end{align*}$$" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "- Exercise: Proof that, for the same factor $\\delta(y-Ax)$ and Gaussian messages, the (backward) sum-product message $\\overleftarrow{\\mu}_{X}$ is given by \n", "$$\\begin{align*}\n", "\\overleftarrow{\\xi}_{X} &= A^T\\overleftarrow{\\xi}_{Y} \\\\\n", "\\overleftarrow{W}_{X} &= A^T\\overleftarrow{W}_{Y}A\n", "\\end{align*}$$\n", "where $\\overleftarrow{\\xi}_X \\triangleq \\overleftarrow{W}_X \\overleftarrow{m}_X$ and $\\overleftarrow{W}_{X} \\triangleq \\overleftarrow{V}_{X}^{-1}$ (and similarly for $Y$)." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "\n", "### Code example: Gaussian forward and backward messages for the Addition node\n", "\n", "Let's calculate the Gaussian forward and backward messages for the addition node in RxInfer. \n", "

" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Forward message on Z:\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "WARNING: both ExponentialFamily and ReactiveMP export \"MvNormalMeanScalePrecision\"; uses of it in module RxInfer must be qualified\n" ] }, { "data": { "text/plain": [ "NormalMeanVariance{Float64}(μ=3.0, v=2.0)" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "println(\"Forward message on Z:\")\n", "@call_rule typeof(+)(:out, Marginalisation) (m_in1 = NormalMeanVariance(1.0, 1.0), m_in2 = NormalMeanVariance(2.0, 1.0))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Backward message on X:\n" ] }, { "data": { "text/plain": [ "NormalMeanVariance{Float64}(μ=1.0, v=2.0)" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "println(\"Backward message on X:\")\n", "@call_rule typeof(+)(:in1, Marginalisation) (m_out = NormalMeanVariance(3.0, 1.0), m_in2 = NormalMeanVariance(2.0, 1.0))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Code Example: forward and backward messages for the Matrix Multiplication node\n", "\n", "In the same way we can also investigate the forward and backward messages for the matrix multiplication (\"gain\") node \n", "

" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Forward message on Y:\n" ] }, { "data": { "text/plain": [ "NormalMeanVariance{Float64}(μ=4.0, v=16.0)" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "println(\"Forward message on Y:\")\n", "@call_rule typeof(*)(:out, Marginalisation) (m_A = PointMass(4.0), m_in = NormalMeanVariance(1.0, 1.0))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Backward message on X:\n" ] }, { "data": { "text/plain": [ "NormalWeightedMeanPrecision{Float64}(xi=8.0, w=16.0)" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "println(\"Backward message on X:\")\n", "@call_rule typeof(*)(:in, Marginalisation) (m_out = NormalMeanVariance(2.0, 1.0), m_A = PointMass(4.0))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ " \n", " ### Example: Sum-Product Algorithm to infer a posterior\n", " \n", " - Consider a generative model \n", "$$p(x,y_1,y_2) = p(x)\\,p(y_1|x)\\,p(y_2|x) .$$ \n", " - This model expresses the assumption that $Y_1$ and $Y_2$ are independent measurements of $X$.\n", "\n", "

" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ " \n", "- Assume that we are interested in the posterior for $X$ after observing $Y_1= \\hat y_1$ and $Y_2= \\hat y_2$. The posterior for $X$ can be inferred by applying the sum-product algorithm to the following graph:\n", "\n", "

" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ " - (Note that) we usually draw terminal nodes for observed variables in the graph by smaller solid-black squares. This is just to help the visualization of the graph, since the computational rules are no different than for other nodes. " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Code for Sum-Product Algorithm to infer a posterior\n", "\n", "We'll use RxInfer to build the above graph, and perform sum-product message passing to infer the posterior $p(x|y_1,y_2)$. We assume $p(y_1|x)$ and $p(y_2|x)$ to be Gaussian likelihoods with known variances:\n", "$$\\begin{align*}\n", " p(y_1\\,|\\,x) &= \\mathcal{N}(y_1\\,|\\,x, v_{y1}) \\\\\n", " p(y_2\\,|\\,x) &= \\mathcal{N}(y_2\\,|\\,x, v_{y2})\n", "\\end{align*}$$\n", "Under this model, the posterior is given by:\n", "$$\\begin{align*}\n", " p(x\\,|\\,y_1,y_2) &\\propto \\overbrace{p(y_1\\,|\\,x)\\,p(y_2\\,|\\,x)}^{\\text{likelihood}}\\,\\overbrace{p(x)}^{\\text{prior}} \\\\\n", " &=\\mathcal{N}(x\\,|\\,\\hat{y}_1, v_{y1})\\, \\mathcal{N}(x\\,|\\,\\hat{y}_2, v_{y2}) \\, \\mathcal{N}(x\\,|\\,m_x, v_x) \n", "\\end{align*}$$\n", "so we can validate the answer by solving the Gaussian multiplication manually." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# Data\n", "y1_hat = 1.0\n", "y2_hat = 2.0\n", "\n", "# Construct the factor graph\n", "@model function my_model(y1,y2)\n", "\n", " # `x` is the hidden states\n", " x ~ NormalMeanVariance(0.0, 4.0)\n", "\n", " # `y1` and `y2` are \"clamped\" observations\n", " y1 ~ NormalMeanVariance(x, 1.0)\n", " y2 ~ NormalMeanVariance(x, 2.0)\n", " \n", " return x\n", "end\n", "\n", "result = infer(model=my_model(), data=(y1=y1_hat, y2 = y2_hat,))\n", "println(\"Sum-product message passing result: p(x|y1,y2) = 𝒩($(mean(result.posteriors[:x])),$(var(result.posteriors[:x])))\")\n", "\n", "# Calculate mean and variance of p(x|y1,y2) manually by multiplying 3 Gaussians (see lesson 4 for details)\n", "v = 1 / (1/4 + 1/1 + 1/2)\n", "m = v * (0/4 + y1_hat/1.0 + y2_hat/2.0)\n", "println(\"Manual result: p(x|y1,y2) = 𝒩($(m), $(v))\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "open(\"../../styles/aipstyle.html\") do f display(\"text/html\", read(f, String)) end\n" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Julia 1.10.5", "language": "julia", "name": "julia-1.10" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "1.10.5" } }, "nbformat": 4, "nbformat_minor": 4 }