{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# S1 - Sample-based implementation of Blahut-Arimoto iteration\n", "This notebook is part of the supplementary material for: \n", "Genewein T., Leibfried F., Grau-Moya J., Braun D.A. (2015) *Bounded rationality, abstraction and hierarchical decision-making: an information-theoretic optimality principle*, Frontiers in Robotics and AI. \n", "\n", "More information on how to run the notebook on the accompanying [github repsitory](https://github.com/tgenewein/BoundedRationalityAbstractionAndHierarchicalDecisionMaking) where you can also find updated versions of the code and notebooks.\n", "\n", "This notebook in mentioned in Section 2.3. Due to time- and space-limitations the results of this notebook are not in the paper." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Disclaimer\n", "This notebook provides a proof-of-concept implementation of a naive sample-based Blahut-Arimoto iteration scheme. Neither the code nor the notebook have been particularly polished or tested. There is a short theory-bit in the beginning of the notebook but most of the explanations are brief and mixed into the code as comments." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Free energy rejection sampling\n", "The solution to a free energy variational problem (Section 2.1 in the paper) has the form of a Boltzmann distribution\n", "$$p(y) = \\frac{1}{Z}p_0(y)e^{\\beta U(y)},$$\n", "where $Z=\\sum_y p_0(y)e^{\\beta U(y)}$ denotes the partition sum, $p_0(y)$ is a prior distribution and $U(y)$ is the utility function. The inverse temperature $\\beta$ can be interpreted as a resource parameter and it governs how far the posterior $p(y)$ can deviate from the prior (measured as a KL-divergence) - see the paper Section 2.1 for details.\n", "\n", "For a decision-maker, it suffices to obtain a sample from $p(y)$ and act according to that sample, rather than computing the full distribution $p(y)$. A simple scheme to sample from $p(y)$ is given by rejection sampling.\n", "\n", "**Rejection sampling** \n", "Goal: get a sample $y$ from the distribution $f(y)$. Draw from a uniform distribution $u\\sim \\mathcal{U}(0,1)$ and from a proposal distribution $y\\sim g(y)$. If $u < \\frac{f(y)}{M g(y)}$, accept the sample as a sample from $f(y)$, otherwise reject the sample and repeat. $M$ is a constant that ensures that $M g(y) \\geq f(y)~\\forall y$. Note that rejection sampling also works for sampling from an unnormalized distribution as long as $M$ is chosen accordingly.\n", "\n", "For the free-energy problem, we want a sample from $p(y)\\propto f(y) = p_0(y)e^{\\beta U(y)}$. We choose $g(y)=p_0(y)$ and set $M=e^{\\beta U_{max}}$, where $U_{max}=\\underset{y}{max}~U(y)$\n", "\n", "**Finally we get the following rejection sampling scheme:** \n", "* draw from a uniform distribution $u\\sim \\mathcal{U}(0,1)$ \n", "* draw from the proposal distribution $x\\sim p_0(y)$ (*the prior*)\n", " * if $u < \\frac{\\exp(\\beta U(y))}{\\exp(\\beta U_\\mathrm{max})}$ accept the sample as a sample from the posterior $p(y)$\n", " * otherwise reject the sample (and re-sample). \n", "\n", "\n", "## Rate distortion rejection sampling\n", "The solution to the rate distortion problem looks very similar to the Boltzmann distribution in the free-energy case. However, there is one crucial difference: in the free-energy case, the prior is an arbitrary distribution - in the rate distortion case, the prior is replaced by the marginal distribution, which leads to a set of self-consistent equations\n", "$$\\begin{align}\n", "p^*(a|w)&=\\frac{1}{Z(w)}p(a)e^{\\beta U(a,w)} \\\\\n", "p(a)&=\\sum_w p(w)p(a|w)\n", "\\end{align}$$\n", "*After* convergence of the Blahut-Arimoto iterations, the marginal $p(a)$ can just be treated like a prior and the rejection sampling scheme described above can straightforwardly be used. However, when initializing with an arbitary marginal distribution $\\hat{p}(a)$ the iterations must be performed in a sample-based manner until convergence.\n", "\n", "Here, we do this in a naive and very straightforward way: we represent $\\hat{p}(a)$ simply through counters (a categorical distribution). Then we do the following:\n", "1. Draw a number of samples (a batch) from $\\hat{p}^*(a|w)=\\frac{1}{Z(w)}\\hat{p}(a)e^{\\beta U(a,w)}$ using the rejection sampling scheme.\n", "2. Update $\\hat{p}(a)$ with the accepted samples obtained in step 1. There are different possibilities for the update step\n", " 1. Simply increase the counters for each accepted $a$ and re-normalize (no-forgetting)\n", " 2. Reset the counters for $a$ and use only the last batch of accepted samples to empirically estimate $p(a)$ (full-forgetting)\n", " 3. Use an exponentially decaying window over the last samples to update the empirical estimate of $p(a)$ (not implemented in this notebook).\n", " 4. Use a parametric model for $p_theta(a)$ and then perform some moment-matching or use a gradient-based update rule to adjust the parameters $\\theta$ (not implemented in this notebook).\n", "3. Repeat until convergence (or here for simplicity: for a fixed number of steps)\n", "\n", "Additionally, this notebook allows for some burn-in time, where after a certain number of iterations of 1. and 2. (i.e. after the \"burn-in\") the counters for $\\hat{p}(a)$ are reset. This naive scheme seems to work but it is unclear how to choose the batch-size (number of samples from $\\hat{p}^*(a|w)$ to take before performing an update step on $\\hat{p}(a)$), how to set the burn-in phase, etc.\n", "\n", "\n", "In the notebook below, you can try different batch-sizes and different burn-in times and you can compare full-forgetting against no forgetting (i.e. no resetting of the counters). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the taxonomy example as a testbed" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
β | I_aw | H_a | E_U | Forgetting | |
---|---|---|---|---|---|
1 | 0.1 | 0.007623395192178633 | 4.223654705902165 | 0.4867369465550719 | false |
2 | 0.1 | 0.00826347325446443 | 4.19727841236797 | 0.4796636647188409 | false |
3 | 0.1 | 0.00910805899670292 | 4.218818549791943 | 0.4898830914482514 | false |
4 | 0.1 | 0.006825290883835782 | 4.19600049474864 | 0.4961444698317179 | false |
5 | 0.1 | 0.007945666748098868 | 4.166014629929005 | 0.5204773431112447 | false |
6 | 0.1 | 0.008900134896971204 | 4.195620483026167 | 0.5086251383894396 | false |
7 | 0.1 | 0.008846556341656488 | 4.221868883250849 | 0.4909369684302811 | false |
8 | 0.1 | 0.008219831141223762 | 4.203405655499641 | 0.5056413173662162 | false |
9 | 0.1 | 0.007534138787838523 | 4.215062497412484 | 0.47660347809918846 | false |
10 | 0.1 | 0.008198512342354143 | 4.214992574021213 | 0.49480395197059424 | false |
11 | 0.25 | 0.04776349494752096 | 4.139986058450904 | 0.6930868103587762 | false |
12 | 0.25 | 0.04746597520836123 | 4.053905322186562 | 0.7182313413119743 | false |
13 | 0.25 | 0.04742306223611858 | 4.076594872569998 | 0.7172030851353952 | false |
14 | 0.25 | 0.048459731515190264 | 4.126895834854917 | 0.7034042728013368 | false |
15 | 0.25 | 0.04640081220285255 | 4.081785491085504 | 0.7146539689227286 | false |
16 | 0.25 | 0.04704407517123359 | 4.13692741149747 | 0.6986909719867275 | false |
17 | 0.25 | 0.04463977430014036 | 4.081749401977054 | 0.710043807218996 | false |
18 | 0.25 | 0.04783715404023331 | 4.104686759630567 | 0.7084194959050816 | false |
19 | 0.25 | 0.04793165361968161 | 4.137295313441465 | 0.6946415175575932 | false |
20 | 0.25 | 0.048251901414935676 | 4.133940275652188 | 0.7020453434164003 | false |
21 | 0.5 | 0.19854715761281114 | 3.84949595744276 | 1.0577196597741998 | false |
22 | 0.5 | 0.19668499813460685 | 3.854174909509068 | 1.053826008588814 | false |
23 | 0.5 | 0.20418639643669767 | 3.9300066607617707 | 1.0505271632756288 | false |
24 | 0.5 | 0.19666547134189127 | 3.7972243210941947 | 1.0619336641350583 | false |
25 | 0.5 | 0.20193607503880354 | 3.8481410087602503 | 1.0625370243284582 | false |
26 | 0.5 | 0.20063302485390475 | 3.8070466699731984 | 1.069440788536095 | false |
27 | 0.5 | 0.20175864060661242 | 3.8437202208954964 | 1.0647381638218636 | false |
28 | 0.5 | 0.19343478448829352 | 3.797051541355644 | 1.0607814895776928 | false |
29 | 0.5 | 0.2093736373042711 | 3.9152858888811894 | 1.060438448687588 | false |
30 | 0.5 | 0.2073709542372055 | 3.938184323661462 | 1.0491560829614597 | false |
⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ |