{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# S1 - Sample-based implementation of Blahut-Arimoto iteration\n", "This notebook is part of the supplementary material for: \n", "Genewein T., Leibfried F., Grau-Moya J., Braun D.A. (2015) *Bounded rationality, abstraction and hierarchical decision-making: an information-theoretic optimality principle*, Frontiers in Robotics and AI. \n", "\n", "More information on how to run the notebook on the accompanying [github repsitory](https://github.com/tgenewein/BoundedRationalityAbstractionAndHierarchicalDecisionMaking) where you can also find updated versions of the code and notebooks.\n", "\n", "This notebook in mentioned in Section 2.3. Due to time- and space-limitations the results of this notebook are not in the paper." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Disclaimer\n", "This notebook provides a proof-of-concept implementation of a naive sample-based Blahut-Arimoto iteration scheme. Neither the code nor the notebook have been particularly polished or tested. There is a short theory-bit in the beginning of the notebook but most of the explanations are brief and mixed into the code as comments." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "## Free energy rejection sampling\n", "The solution to a free energy variational problem (Section 2.1 in the paper) has the form of a Boltzmann distribution\n", "$$p(y) = \\frac{1}{Z}p_0(y)e^{\\beta U(y)},$$\n", "where $Z=\\sum_y p_0(y)e^{\\beta U(y)}$ denotes the partition sum, $p_0(y)$ is a prior distribution and $U(y)$ is the utility function. The inverse temperature $\\beta$ can be interpreted as a resource parameter and it governs how far the posterior $p(y)$ can deviate from the prior (measured as a KL-divergence) - see the paper Section 2.1 for details.\n", "\n", "For a decision-maker, it suffices to obtain a sample from $p(y)$ and act according to that sample, rather than computing the full distribution $p(y)$. A simple scheme to sample from $p(y)$ is given by rejection sampling.\n", "\n", "**Rejection sampling** \n", "Goal: get a sample $y$ from the distribution $f(y)$. Draw from a uniform distribution $u\\sim \\mathcal{U}(0,1)$ and from a proposal distribution $y\\sim g(y)$. If $u < \\frac{f(y)}{M g(y)}$, accept the sample as a sample from $f(y)$, otherwise reject the sample and repeat. $M$ is a constant that ensures that $M g(y) \\geq f(y)~\\forall y$. Note that rejection sampling also works for sampling from an unnormalized distribution as long as $M$ is chosen accordingly.\n", "\n", "For the free-energy problem, we want a sample from $p(y)\\propto f(y) = p_0(y)e^{\\beta U(y)}$. We choose $g(y)=p_0(y)$ and set $M=e^{\\beta U_{max}}$, where $U_{max}=\\underset{y}{max}~U(y)$\n", "\n", "**Finally we get the following rejection sampling scheme:** \n", "* draw from a uniform distribution $u\\sim \\mathcal{U}(0,1)$ \n", "* draw from the proposal distribution $x\\sim p_0(y)$ (*the prior*)\n", " * if $u < \\frac{\\exp(\\beta U(y))}{\\exp(\\beta U_\\mathrm{max})}$ accept the sample as a sample from the posterior $p(y)$\n", " * otherwise reject the sample (and re-sample). \n", "\n", "\n", "## Rate distortion rejection sampling\n", "The solution to the rate distortion problem looks very similar to the Boltzmann distribution in the free-energy case. However, there is one crucial difference: in the free-energy case, the prior is an arbitrary distribution - in the rate distortion case, the prior is replaced by the marginal distribution, which leads to a set of self-consistent equations\n", "$$\\begin{align}\n", "p^*(a|w)&=\\frac{1}{Z(w)}p(a)e^{\\beta U(a,w)} \\\\\n", "p(a)&=\\sum_w p(w)p(a|w)\n", "\\end{align}$$\n", "*After* convergence of the Blahut-Arimoto iterations, the marginal $p(a)$ can just be treated like a prior and the rejection sampling scheme described above can straightforwardly be used. However, when initializing with an arbitary marginal distribution $\\hat{p}(a)$ the iterations must be performed in a sample-based manner until convergence.\n", "\n", "Here, we do this in a naive and very straightforward way: we represent $\\hat{p}(a)$ simply through counters (a categorical distribution). Then we do the following:\n", "1. Draw a number of samples (a batch) from $\\hat{p}^*(a|w)=\\frac{1}{Z(w)}\\hat{p}(a)e^{\\beta U(a,w)}$ using the rejection sampling scheme.\n", "2. Update $\\hat{p}(a)$ with the accepted samples obtained in step 1. There are different possibilities for the update step\n", " 1. Simply increase the counters for each accepted $a$ and re-normalize (no-forgetting)\n", " 2. Reset the counters for $a$ and use only the last batch of accepted samples to empirically estimate $p(a)$ (full-forgetting)\n", " 3. Use an exponentially decaying window over the last samples to update the empirical estimate of $p(a)$ (not implemented in this notebook).\n", " 4. Use a parametric model for $p_theta(a)$ and then perform some moment-matching or use a gradient-based update rule to adjust the parameters $\\theta$ (not implemented in this notebook).\n", "3. Repeat until convergence (or here for simplicity: for a fixed number of steps)\n", "\n", "Additionally, this notebook allows for some burn-in time, where after a certain number of iterations of 1. and 2. (i.e. after the \"burn-in\") the counters for $\\hat{p}(a)$ are reset. This naive scheme seems to work but it is unclear how to choose the batch-size (number of samples from $\\hat{p}^*(a|w)$ to take before performing an update step on $\\hat{p}(a)$), how to set the burn-in phase, etc.\n", "\n", "\n", "In the notebook below, you can try different batch-sizes and different burn-in times and you can compare full-forgetting against no forgetting (i.e. no resetting of the counters). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load the taxonomy example as a testbed" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", " \n", " \n", "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#load taxonomy example\n", "using RateDistortionDecisionMaking, DataFrames, Gadfly, Distributions\n", "\n", "#set up taxonomy example\n", "include(\"TaxonomyExample.jl\")\n", "w_vec, w_strings, a_vec, a_strings, p_w, U = setuptaxonomy()\n", "\n", "#pre-compute utilities, find maxima\n", "U_pre, Umax = setuputilityarrays(a_vec,w_vec,U)\n", "\n", "#initialize p(a) uniformly\n", "num_acts = length(a_vec)\n", "pa_init = ones(num_acts)/num_acts;\n" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## Set up functions for sampling and run on the example from above\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "rej_samp_const (generic function with 1 method)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#Performs rejection sampling with a constant (scaled uniform) envelope\n", "#using a softmax acceptance-rejection criterion.\n", "#prop_dist .. proposal distribution (must be an instance of type ::Distribution)\n", "#nsamps ..... desired number of samples (scalar)\n", "#maxsteps ... maximum number of acceptance-rejection steps (scalar, must be ≧ nsamps)\n", "#β .......... softmax parameter\n", "#lh ......... likelihood value (vector of length N)\n", "#maxlh ...... maximum value that the likelihood can take (scalar)\n", "function rej_samp_const(prop_dist::Distribution, nsamps::Integer, maxsteps::Integer, β::Number, lh::Vector, maxlh::Number) \n", " #initialize\n", " samps = zeros(nsamps)\n", " acc_cnt = 0 #acceptance-counter\n", " if(maxsteps < nsamps)\n", " maxsteps = nsamps\n", " end\n", " \n", " k=0 #use this to make sure that k is still available after the loop\n", " for k in 1:maxsteps\n", " u=rand(1) #sample from uniform between (0,1)\n", " index = rand(prop_dist) #sample from proposal\n", " \n", " ratio = exp(β*lh[index])/exp(β*maxlh)\n", " if u[1]= can not handle arrays\n", " #if we enter here, accept the sample \n", " acc_cnt = acc_cnt + 1 \n", " samps[acc_cnt] = index\n", " \n", " if(acc_cnt == nsamps)\n", " #we have enough samples, exit loop\n", " break\n", " end\n", " end\n", " end\n", " \n", " if(k==maxsteps)\n", " warn(\"[RejSampConst] Maximum number of steps reached - number of samples is potentially lower than nsamps!\\n\")\n", " end\n", " \n", " #store all accepted samples (this can be less than nsamps if maxsteps is too low or acceptance-rate is low)\n", " samples = samps[1:acc_cnt]\n", " \n", " #compute acceptance ratio\n", " acc_ratio = acc_cnt/k\n", " \n", " return samples, acc_ratio\n", "end" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/plain": [ "BAsampling (generic function with 1 method)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#marginal is simply represented by counters (i.e. by frequencies)\n", "function init_marginal_representation_ctrs(pa_init::Vector)\n", " return pa_init\n", "end\n", "\n", "#this updates the marginal over actions p(a) using a counter-representation\n", "#this function counts the number of times each action-index occurs in sampled_indices\n", "#these counts are then added to the current marginal_ctrs. Optionally, the counters are reset\n", "#before adding the new samples (=hard forgetting).\n", "function update_marginal_ctrs(sampled_indices::Vector, marginal_ctrs::Vector; reset_ctrs::Bool=false) \n", " #TODO: perhaps replace hard-resetting with an exponential decay?\n", " \n", " p_ctrs = marginal_ctrs\n", " card_p = length(p_ctrs)\n", " \n", " #reset counters for marginal? (make sure every entry is non-zero!)\n", " if reset_ctrs\n", " p_ctrs = ones(card_p)/card_p\n", " end\n", "\n", " #update marginal counters using a histogram to do the counting (bin-borders have to be set manually!)\n", " e,p_counts = hist(sampled_indices,0.5:1:(card_p+0.5)) \n", " p_ctrs = p_ctrs + p_counts\n", " \n", " #normalize to get the updated marginal\n", " p_sampled = p_ctrs / sum(p_ctrs)\n", " \n", " return p_sampled, p_ctrs #return the probability-vector, but also the representation of the marginal (as counts)\n", "end\n", "\n", "\n", "\n", "#function for BA sampling\n", "#burnin_ratio specifies the ratio of outer iterations that will not count\n", "#towards computation of the final marginal distribution (counters will be blocked)\n", "#reset_marginal_ctrs specifies whether the marginal is computed with the samples of the last\n", "#iteration only (=hard forgetting by resetting counters) or whether the marginal\n", "#is computed with all samples of all iterations (=no forgetting)\n", "function BAsampling(pa_init::Vector, β::Number, U_pre::Matrix, Umax::Vector, pw::Vector, \n", " nsteps_marginalupdate::Integer, nsteps_conditionalupdate::Integer;\n", " burnin_ratio::Real=0.7, max_rejsamp_steps::Integer=200,\n", " compute_performance::Bool=false, performance_as_dataframe::Bool=false,\n", " performance_per_iteration::Bool=false,\n", " init_marg_func::Function=init_marginal_representation_ctrs,\n", " update_marg_func::Function=update_marginal_ctrs, update_func_args...)\n", " \n", " #compute cardinality, check size of U_pre\n", " card_a = length(pa_init)\n", " card_w = length(pw)\n", " if size(U_pre) != (card_a, card_w)\n", " error(\"Size mismatch of U_pre and pa_init or pw!\")\n", " end\n", " \n", " #check that burnin_ratio is really a ratio\n", " if (burnin_ratio < 0) || (burnin_ratio > 1)\n", " error(\"burnin_ratio must be a number between 0 and 1.\")\n", " end\n", " \n", " #if performance measures don't need to be returned, don't compute them per iteration\n", " if compute_performance==false\n", " performance_per_iteration = false\n", " end \n", " #preallocate if necessary\n", " if performance_per_iteration \n", " I_i = zeros(maxiter)\n", " Ha_i = zeros(maxiter)\n", " Hagw_i = zeros(maxiter)\n", " EU_i = zeros(maxiter)\n", " RDobj_i = zeros(maxiter)\n", " end\n", " \n", " #initialize sampling distributions\n", " pw_dist = Categorical(pw) #proposal distribution \n", " pagw_ctrs = ones(card_a, card_w) #counters for conditional distribution \n", " pa_sampled = pa_init #marginal distribution\n", " \n", " #initialize the marginal representation\n", " pa_ctrs = init_marg_func(pa_init)\n", "\n", " burnin_triggered=false\n", " #outer loop - in each iteration the marginal is updated\n", " iter=0\n", " for iter in 1:nsteps_marginalupdate \n", " a_samples = zeros(nsteps_conditionalupdate) #this will hold the samples from p(a|w) during inner loop\n", " \n", " #inner loop - in each step a sample is drawn from the conditional and stored for\n", " #for the batch-update of the marginal\n", " for j in 1:nsteps_conditionalupdate\n", " #draw an w sample\n", " w_samp = rand(pw_dist)\n", "\n", " #draw a sample from p(a|w) using the current estimate of p(a) as proposal distribution using rejection sampling\n", " agw_samp, acc_ratio = rej_samp_const(Categorical(pa_sampled), 1, max_rejsamp_steps, β, U_pre[:,w_samp], Umax[w_samp])\n", " a_samples[j] = agw_samp[1]\n", "\n", " #update conditional counters\n", " pagw_ctrs[agw_samp, w_samp] += 1\n", " end\n", " \n", " #very simple burn-in: simply reset counters\n", " if (iter >(nsteps_marginalupdate)*burnin_ratio) && (!burnin_triggered)\n", " burnin_triggered = true\n", " pagw_ctrs = ones(card_a, card_w) \n", " end\n", "\n", " #update marginal with samples drawn in inner loop\n", " pa_sampled, pa_ctrs = update_marg_func(a_samples, pa_ctrs; update_func_args...) \n", " \n", " \n", " #compute entropic quantities (if requested with additional parameter)\n", " if performance_per_iteration\n", " #compute sample-based conditional p(a|w)\n", " pagw_sampled = zeros(card_a, card_w)\n", " for i in 1:card_w\n", " pagw_sampled[:,i] = pagw_ctrs[:,i] / sum(pagw_ctrs[:,i])\n", " end\n", " I_i[iter], Ha_i[iter], Hagw_i[iter], EU_i[iter], RDobj_i[iter] = analyzeBAsolution(pw, pa_sampled, pagw_sampled, U_pre, β)\n", " end\n", " end\n", "\n", " #compute conditionals using the sample-counts of the previous inner loops\n", " #the burn-in parameter specifies how many of the inner loops are discarded\n", " pagw_sampled = zeros(card_a, card_w)\n", " for i in 1:card_w\n", " pagw_sampled[:,i] = pagw_ctrs[:,i] / sum(pagw_ctrs[:,i])\n", " end\n", "\n", " \n", " #return results\n", " if compute_performance == false\n", " return pagw_sampled, pa_sampled\n", " else \n", " if performance_per_iteration == false\n", " #compute performance measures for final solution\n", " I, Ha, Hagw, EU, RDobj = analyzeBAsolution(pw, pa_sampled, pagw_sampled, U_pre, β)\n", " else\n", " #\"cut\" valid results from preallocated vector\n", " I = I_i[1:iter]\n", " Ha = Ha_i[1:iter]\n", " Hagw = Hagw_i[1:iter]\n", " EU = EU_i[1:iter]\n", " RDobj = RDobj_i[1:iter]\n", " end\n", "\n", " #if needed, transform to data frame\n", " if performance_as_dataframe == false\n", " return pagw_sampled, pa_sampled, I, Ha, Hagw, EU, RDobj\n", " else\n", " performance_df = performancemeasures2DataFrame(I, Ha, Hagw, EU, RDobj)\n", " return pagw_sampled, pa_sampled, performance_df \n", " end\n", " end\n", " \n", "end" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", " \n", " β\n", " \n", " \n", " \n", " \n", " RU_obj\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " E[U]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " -600\n", " -500\n", " -400\n", " -300\n", " -200\n", " -100\n", " 0\n", " 100\n", " 200\n", " 300\n", " 400\n", " 500\n", " 600\n", " 700\n", " 800\n", " 900\n", " 1000\n", " 1100\n", " -500\n", " -480\n", " -460\n", " -440\n", " -420\n", " -400\n", " -380\n", " -360\n", " -340\n", " -320\n", " -300\n", " -280\n", " -260\n", " -240\n", " -220\n", " -200\n", " -180\n", " -160\n", " -140\n", " -120\n", " -100\n", " -80\n", " -60\n", " -40\n", " -20\n", " 0\n", " 20\n", " 40\n", " 60\n", " 80\n", " 100\n", " 120\n", " 140\n", " 160\n", " 180\n", " 200\n", " 220\n", " 240\n", " 260\n", " 280\n", " 300\n", " 320\n", " 340\n", " 360\n", " 380\n", " 400\n", " 420\n", " 440\n", " 460\n", " 480\n", " 500\n", " 520\n", " 540\n", " 560\n", " 580\n", " 600\n", " 620\n", " 640\n", " 660\n", " 680\n", " 700\n", " 720\n", " 740\n", " 760\n", " 780\n", " 800\n", " 820\n", " 840\n", " 860\n", " 880\n", " 900\n", " 920\n", " 940\n", " 960\n", " 980\n", " 1000\n", " -500\n", " 0\n", " 500\n", " 1000\n", " -500\n", " -450\n", " -400\n", " -350\n", " -300\n", " -250\n", " -200\n", " -150\n", " -100\n", " -50\n", " 0\n", " 50\n", " 100\n", " 150\n", " 200\n", " 250\n", " 300\n", " 350\n", " 400\n", " 450\n", " 500\n", " 550\n", " 600\n", " 650\n", " 700\n", " 750\n", " 800\n", " 850\n", " 900\n", " 950\n", " 1000\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " -3.0\n", " -2.5\n", " -2.0\n", " -1.5\n", " -1.0\n", " -0.5\n", " 0.0\n", " 0.5\n", " 1.0\n", " 1.5\n", " 2.0\n", " 2.5\n", " 3.0\n", " 3.5\n", " 4.0\n", " 4.5\n", " 5.0\n", " 5.5\n", " -2.5\n", " -2.4\n", " -2.3\n", " -2.2\n", " -2.1\n", " -2.0\n", " -1.9\n", " -1.8\n", " -1.7\n", " -1.6\n", " -1.5\n", " -1.4\n", " -1.3\n", " -1.2\n", " -1.1\n", " -1.0\n", " -0.9\n", " -0.8\n", " -0.7\n", " -0.6\n", " -0.5\n", " -0.4\n", " -0.3\n", " -0.2\n", " -0.1\n", " 0.0\n", " 0.1\n", " 0.2\n", " 0.3\n", " 0.4\n", " 0.5\n", " 0.6\n", " 0.7\n", " 0.8\n", " 0.9\n", " 1.0\n", " 1.1\n", " 1.2\n", " 1.3\n", " 1.4\n", " 1.5\n", " 1.6\n", " 1.7\n", " 1.8\n", " 1.9\n", " 2.0\n", " 2.1\n", " 2.2\n", " 2.3\n", " 2.4\n", " 2.5\n", " 2.6\n", " 2.7\n", " 2.8\n", " 2.9\n", " 3.0\n", " 3.1\n", " 3.2\n", " 3.3\n", " 3.4\n", " 3.5\n", " 3.6\n", " 3.7\n", " 3.8\n", " 3.9\n", " 4.0\n", " 4.1\n", " 4.2\n", " 4.3\n", " 4.4\n", " 4.5\n", " 4.6\n", " 4.7\n", " 4.8\n", " 4.9\n", " 5.0\n", " -2.5\n", " 0.0\n", " 2.5\n", " 5.0\n", " -2.6\n", " -2.4\n", " -2.2\n", " -2.0\n", " -1.8\n", " -1.6\n", " -1.4\n", " -1.2\n", " -1.0\n", " -0.8\n", " -0.6\n", " -0.4\n", " -0.2\n", " 0.0\n", " 0.2\n", " 0.4\n", " 0.6\n", " 0.8\n", " 1.0\n", " 1.2\n", " 1.4\n", " 1.6\n", " 1.8\n", " 2.0\n", " 2.2\n", " 2.4\n", " 2.6\n", " 2.8\n", " 3.0\n", " 3.2\n", " 3.4\n", " 3.6\n", " 3.8\n", " 4.0\n", " 4.2\n", " 4.4\n", " 4.6\n", " 4.8\n", " 5.0\n", " \n", " \n", " [utils]\n", " \n", "\n", "\n", " \n", " β\n", " \n", " \n", " \n", " \n", " H(A|W)\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " H(A)\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " I(A;W)\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " -600\n", " -500\n", " -400\n", " -300\n", " -200\n", " -100\n", " 0\n", " 100\n", " 200\n", " 300\n", " 400\n", " 500\n", " 600\n", " 700\n", " 800\n", " 900\n", " 1000\n", " 1100\n", " -500\n", " -480\n", " -460\n", " -440\n", " -420\n", " -400\n", " -380\n", " -360\n", " -340\n", " -320\n", " -300\n", " -280\n", " -260\n", " -240\n", " -220\n", " -200\n", " -180\n", " -160\n", " -140\n", " -120\n", " -100\n", " -80\n", " -60\n", " -40\n", " -20\n", " 0\n", " 20\n", " 40\n", " 60\n", " 80\n", " 100\n", " 120\n", " 140\n", " 160\n", " 180\n", " 200\n", " 220\n", " 240\n", " 260\n", " 280\n", " 300\n", " 320\n", " 340\n", " 360\n", " 380\n", " 400\n", " 420\n", " 440\n", " 460\n", " 480\n", " 500\n", " 520\n", " 540\n", " 560\n", " 580\n", " 600\n", " 620\n", " 640\n", " 660\n", " 680\n", " 700\n", " 720\n", " 740\n", " 760\n", " 780\n", " 800\n", " 820\n", " 840\n", " 860\n", " 880\n", " 900\n", " 920\n", " 940\n", " 960\n", " 980\n", " 1000\n", " -500\n", " 0\n", " 500\n", " 1000\n", " -500\n", " -450\n", " -400\n", " -350\n", " -300\n", " -250\n", " -200\n", " -150\n", " -100\n", " -50\n", " 0\n", " 50\n", " 100\n", " 150\n", " 200\n", " 250\n", " 300\n", " 350\n", " 400\n", " 450\n", " 500\n", " 550\n", " 600\n", " 650\n", " 700\n", " 750\n", " 800\n", " 850\n", " 900\n", " 950\n", " 1000\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " -6\n", " -5\n", " -4\n", " -3\n", " -2\n", " -1\n", " 0\n", " 1\n", " 2\n", " 3\n", " 4\n", " 5\n", " 6\n", " 7\n", " 8\n", " 9\n", " 10\n", " 11\n", " -5.0\n", " -4.8\n", " -4.6\n", " -4.4\n", " -4.2\n", " -4.0\n", " -3.8\n", " -3.6\n", " -3.4\n", " -3.2\n", " -3.0\n", " -2.8\n", " -2.6\n", " -2.4\n", " -2.2\n", " -2.0\n", " -1.8\n", " -1.6\n", " -1.4\n", " -1.2\n", " -1.0\n", " -0.8\n", " -0.6\n", " -0.4\n", " -0.2\n", " 0.0\n", " 0.2\n", " 0.4\n", " 0.6\n", " 0.8\n", " 1.0\n", " 1.2\n", " 1.4\n", " 1.6\n", " 1.8\n", " 2.0\n", " 2.2\n", " 2.4\n", " 2.6\n", " 2.8\n", " 3.0\n", " 3.2\n", " 3.4\n", " 3.6\n", " 3.8\n", " 4.0\n", " 4.2\n", " 4.4\n", " 4.6\n", " 4.8\n", " 5.0\n", " 5.2\n", " 5.4\n", " 5.6\n", " 5.8\n", " 6.0\n", " 6.2\n", " 6.4\n", " 6.6\n", " 6.8\n", " 7.0\n", " 7.2\n", " 7.4\n", " 7.6\n", " 7.8\n", " 8.0\n", " 8.2\n", " 8.4\n", " 8.6\n", " 8.8\n", " 9.0\n", " 9.2\n", " 9.4\n", " 9.6\n", " 9.8\n", " 10.0\n", " -5\n", " 0\n", " 5\n", " 10\n", " -5.0\n", " -4.5\n", " -4.0\n", " -3.5\n", " -3.0\n", " -2.5\n", " -2.0\n", " -1.5\n", " -1.0\n", " -0.5\n", " 0.0\n", " 0.5\n", " 1.0\n", " 1.5\n", " 2.0\n", " 2.5\n", " 3.0\n", " 3.5\n", " 4.0\n", " 4.5\n", " 5.0\n", " 5.5\n", " 6.0\n", " 6.5\n", " 7.0\n", " 7.5\n", " 8.0\n", " 8.5\n", " 9.0\n", " 9.5\n", " 10.0\n", " \n", " \n", " [bits]\n", " \n", "\n", "\n", "\n", " \n", "\n", " \n", "\n", "\n", "\n", "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "\n", "\n", " \n", " World state w\n", " \n", " \n", " Laptop\n", " Monitor\n", " Gamepad\n", " Coffee machine\n", " Vacuum cleaner\n", " Electric toothbrush\n", " Grapes\n", " Strawberries\n", " Limes\n", " Pancake mix\n", " Baking soda\n", " Baker's yeast\n", " Muffin cups\n", " \n", " \n", " \n", " 0.0\n", " 0.5\n", " 1.0\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " p*(a|w)\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " Laptop sleeve\n", " Monitor cable\n", " Video game\n", " Coffee\n", " Vacuum cleaner bags\n", " Brush heads\n", " Cheese\n", " Cream\n", " Cane sugar\n", " Maple syrup\n", " Vinegar\n", " Flour\n", " Chocolate chips\n", " COMPUTERS\n", " APPLIANCES\n", " FRUIT\n", " BAKING\n", " Electronics\n", " Food\n", " \n", " \n", " Action a\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#example call and also plot evolution of performance measueres\n", "maxiter = 10000\n", "β = 1.2\n", "nsteps_marg = 500\n", "nsteps_cond = 750\n", "pagw_s,pa_s,perf = BAsampling(pa_init, β, U_pre, Umax, p_w, nsteps_marg, nsteps_cond,\n", " burnin_ratio=0.7, max_rejsamp_steps=500, reset_ctrs=false,\n", " compute_performance=true, performance_as_dataframe=true, performance_per_iteration=true)\n", "\n", "plt_cond = visualizeBAconditional(pagw_s,a_vec,w_vec,a_strings,w_strings)\n", "\n", "#instead of using a range of β-values (as for the standard-performance plot), \n", "#use a vector indicating the iteration\n", "niter = size(perf,1)\n", "plt_perf_entropy, plt_perf_utility, plt_rateutility = plotperformancemeasures(perf,[1:niter],\n", " suppress_vis=true, xlabel_perf=\"Iteration\")\n", "\n", "#TODO: somehow the \"Iteration\" label above doesn't seem to work!\n", "\n", "display(vstack(plt_perf_entropy, plt_perf_utility))\n", "display(plt_cond)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##[Interact] Change the parameters in the code-cell above to explore the sampling scheme and its solutions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compare sampling-solutions against analytical results\n", "\n", "Below, we will average over several sampling-runs at different temperatures $\\beta$ to see a difference between the analytical solutions and the sample-based solutions (with and without forgetting)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": false }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", " \n", " I(A;W) [bits]\n", " \n", " \n", " -5\n", " -4\n", " -3\n", " -2\n", " -1\n", " 0\n", " 1\n", " 2\n", " 3\n", " 4\n", " 5\n", " 6\n", " 7\n", " 8\n", " 9\n", " -4.0\n", " -3.8\n", " -3.6\n", " -3.4\n", " -3.2\n", " -3.0\n", " -2.8\n", " -2.6\n", " -2.4\n", " -2.2\n", " -2.0\n", " -1.8\n", " -1.6\n", " -1.4\n", " -1.2\n", " -1.0\n", " -0.8\n", " -0.6\n", " -0.4\n", " -0.2\n", " 0.0\n", " 0.2\n", " 0.4\n", " 0.6\n", " 0.8\n", " 1.0\n", " 1.2\n", " 1.4\n", " 1.6\n", " 1.8\n", " 2.0\n", " 2.2\n", " 2.4\n", " 2.6\n", " 2.8\n", " 3.0\n", " 3.2\n", " 3.4\n", " 3.6\n", " 3.8\n", " 4.0\n", " 4.2\n", " 4.4\n", " 4.6\n", " 4.8\n", " 5.0\n", " 5.2\n", " 5.4\n", " 5.6\n", " 5.8\n", " 6.0\n", " 6.2\n", " 6.4\n", " 6.6\n", " 6.8\n", " 7.0\n", " 7.2\n", " 7.4\n", " 7.6\n", " 7.8\n", " 8.0\n", " -5\n", " 0\n", " 5\n", " 10\n", " -4.0\n", " -3.5\n", " -3.0\n", " -2.5\n", " -2.0\n", " -1.5\n", " -1.0\n", " -0.5\n", " 0.0\n", " 0.5\n", " 1.0\n", " 1.5\n", " 2.0\n", " 2.5\n", " 3.0\n", " 3.5\n", " 4.0\n", " 4.5\n", " 5.0\n", " 5.5\n", " 6.0\n", " 6.5\n", " 7.0\n", " 7.5\n", " 8.0\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " -4\n", " -3\n", " -2\n", " -1\n", " 0\n", " 1\n", " 2\n", " 3\n", " 4\n", " 5\n", " 6\n", " 7\n", " -3.0\n", " -2.9\n", " -2.8\n", " -2.7\n", " -2.6\n", " -2.5\n", " -2.4\n", " -2.3\n", " -2.2\n", " -2.1\n", " -2.0\n", " -1.9\n", " -1.8\n", " -1.7\n", " -1.6\n", " -1.5\n", " -1.4\n", " -1.3\n", " -1.2\n", " -1.1\n", " -1.0\n", " -0.9\n", " -0.8\n", " -0.7\n", " -0.6\n", " -0.5\n", " -0.4\n", " -0.3\n", " -0.2\n", " -0.1\n", " 0.0\n", " 0.1\n", " 0.2\n", " 0.3\n", " 0.4\n", " 0.5\n", " 0.6\n", " 0.7\n", " 0.8\n", " 0.9\n", " 1.0\n", " 1.1\n", " 1.2\n", " 1.3\n", " 1.4\n", " 1.5\n", " 1.6\n", " 1.7\n", " 1.8\n", " 1.9\n", " 2.0\n", " 2.1\n", " 2.2\n", " 2.3\n", " 2.4\n", " 2.5\n", " 2.6\n", " 2.7\n", " 2.8\n", " 2.9\n", " 3.0\n", " 3.1\n", " 3.2\n", " 3.3\n", " 3.4\n", " 3.5\n", " 3.6\n", " 3.7\n", " 3.8\n", " 3.9\n", " 4.0\n", " 4.1\n", " 4.2\n", " 4.3\n", " 4.4\n", " 4.5\n", " 4.6\n", " 4.7\n", " 4.8\n", " 4.9\n", " 5.0\n", " 5.1\n", " 5.2\n", " 5.3\n", " 5.4\n", " 5.5\n", " 5.6\n", " 5.7\n", " 5.8\n", " 5.9\n", " 6.0\n", " -3\n", " 0\n", " 3\n", " 6\n", " -3.0\n", " -2.8\n", " -2.6\n", " -2.4\n", " -2.2\n", " -2.0\n", " -1.8\n", " -1.6\n", " -1.4\n", " -1.2\n", " -1.0\n", " -0.8\n", " -0.6\n", " -0.4\n", " -0.2\n", " 0.0\n", " 0.2\n", " 0.4\n", " 0.6\n", " 0.8\n", " 1.0\n", " 1.2\n", " 1.4\n", " 1.6\n", " 1.8\n", " 2.0\n", " 2.2\n", " 2.4\n", " 2.6\n", " 2.8\n", " 3.0\n", " 3.2\n", " 3.4\n", " 3.6\n", " 3.8\n", " 4.0\n", " 4.2\n", " 4.4\n", " 4.6\n", " 4.8\n", " 5.0\n", " 5.2\n", " 5.4\n", " 5.6\n", " 5.8\n", " 6.0\n", " \n", " \n", " E[U]\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#compute theoretical result for rate-dutility curve\n", "ε = 0.0001 #convergence critetion for BAiterations\n", "maxiter = 10000\n", "β_sweep = [0.01:0.05:3]\n", "nβ = length(β_sweep)\n", "\n", "#preallocate\n", "I = zeros(nβ)\n", "Ha = zeros(nβ)\n", "Hagw = zeros(nβ)\n", "EU = zeros(nβ)\n", "RDobj = zeros(nβ)\n", "\n", "#sweep through β values and perfomr Blahut-Arimoto iterations for each value\n", "for i=1:nβ \n", " pagw, pa, I[i], Ha[i], Hagw[i], EU[i], RDobj[i] = BAiterations(pa_init, β_sweep[i], U_pre, p_w, ε, maxiter,compute_performance=true) \n", "end\n", "\n", "#show rate-utility curve (shaded region is theoretically infeasible)\n", "perf_res_analytical = performancemeasures2DataFrame(I, Ha, Hagw, EU, RDobj); \n", "plot_perf_entropy, plot_perf_util, plot_rateutility = plotperformancemeasures(perf_res_analytical, β_sweep, suppress_vis=true);\n", "display(plot_rateutility)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "BA Sampling, run 1 of 160\n", "BA Sampling, run 2 of 160\n", "BA Sampling, run 3 of 160\n", "BA Sampling, run 4 of 160\n", "BA Sampling, run 5 of 160\n", "BA Sampling, run 6 of 160\n", "BA Sampling, run 7 of 160\n", "BA Sampling, run 8 of 160\n", "BA Sampling, run 9 of 160\n", "BA Sampling, run 10 of 160\n", "BA Sampling, run 11 of 160\n", "BA Sampling, run 12 of 160\n", "BA Sampling, run 13 of 160\n", "BA Sampling, run 14 of 160\n", "BA Sampling, run 15 of 160\n", "BA Sampling, run 16 of 160\n", "BA Sampling, run 17 of 160\n", "BA Sampling, run 18 of 160\n", "BA Sampling, run 19 of 160\n", "BA Sampling, run 20 of 160\n", "BA Sampling, run 21 of 160\n", "BA Sampling, run 22 of 160\n", "BA Sampling, run 23 of 160\n", "BA Sampling, run 24 of 160\n", "BA Sampling, run 25 of 160\n", "BA Sampling, run 26 of 160\n", "BA Sampling, run 27 of 160\n", "BA Sampling, run 28 of 160\n", "BA Sampling, run 29 of 160\n", "BA Sampling, run 30 of 160\n", "BA Sampling, run 31 of 160\n", "BA Sampling, run 32 of 160\n", "BA Sampling, run 33 of 160\n", "BA Sampling, run 34 of 160\n", "BA Sampling, run 35 of 160\n", "BA Sampling, run 36 of 160\n", "BA Sampling, run 37 of 160\n", "BA Sampling, run 38 of 160\n", "BA Sampling, run 39 of 160\n", "BA Sampling, run 40 of 160\n", "BA Sampling, run 41 of 160\n", "BA Sampling, run 42 of 160\n", "BA Sampling, run 43 of 160\n", "BA Sampling, run 44 of 160\n", "BA Sampling, run 45 of 160\n", "BA Sampling, run 46 of 160\n", "BA Sampling, run 47 of 160\n", "BA Sampling, run 48 of 160\n", "BA Sampling, run 49 of 160\n", "BA Sampling, run 50 of 160\n", "BA Sampling, run 51 of 160\n", "BA Sampling, run 52 of 160\n", "BA Sampling, run 53 of 160\n", "BA Sampling, run 54 of 160\n", "BA Sampling, run 55 of 160\n", "BA Sampling, run 56 of 160\n", "BA Sampling, run 57 of 160\n", "BA Sampling, run 58 of 160\n", "BA Sampling, run 59 of 160\n", "BA Sampling, run 60 of 160\n", "BA Sampling, run 61 of 160\n", "BA Sampling, run 62 of 160\n", "BA Sampling, run 63 of 160\n", "BA Sampling, run 64 of 160\n", "BA Sampling, run 65 of 160\n", "BA Sampling, run 66 of 160\n", "BA Sampling, run 67 of 160\n", "BA Sampling, run 68 of 160\n", "BA Sampling, run 69 of 160\n", "BA Sampling, run 70 of 160\n", "BA Sampling, run 71 of 160\n", "BA Sampling, run 72 of 160\n", "BA Sampling, run 73 of 160\n", "BA Sampling, run 74 of 160\n", "BA Sampling, run 75 of 160\n", "BA Sampling, run 76 of 160\n", "BA Sampling, run 77 of 160\n", "BA Sampling, run 78 of 160\n", "BA Sampling, run 79 of 160\n", "BA Sampling, run 80 of 160\n", "BA Sampling, run 81 of 160\n", "BA Sampling, run 82 of 160\n", "BA Sampling, run 83 of 160\n", "BA Sampling, run 84 of 160\n", "BA Sampling, run 85 of 160\n", "BA Sampling, run 86 of 160\n", "BA Sampling, run 87 of 160\n", "BA Sampling, run 88 of 160\n", "BA Sampling, run 89 of 160\n", "BA Sampling, run 90 of 160\n", "BA Sampling, run 91 of 160\n", "BA Sampling, run 92 of 160\n", "BA Sampling, run 93 of 160\n", "BA Sampling, run 94 of 160\n", "BA Sampling, run 95 of 160\n", "BA Sampling, run 96 of 160\n", "BA Sampling, run 97 of 160\n", "BA Sampling, run 98 of 160\n", "BA Sampling, run 99 of 160\n", "BA Sampling, run 100 of 160\n", "BA Sampling, run 101 of 160\n", "BA Sampling, run 102 of 160\n", "BA Sampling, run 103 of 160\n", "BA Sampling, run 104 of 160\n", "BA Sampling, run 105 of 160\n", "BA Sampling, run 106 of 160\n", "BA Sampling, run 107 of 160\n", "BA Sampling, run 108 of 160\n", "BA Sampling, run 109 of 160\n", "BA Sampling, run 110 of 160\n", "BA Sampling, run 111 of 160\n", "BA Sampling, run 112 of 160\n", "BA Sampling, run 113 of 160\n", "BA Sampling, run 114 of 160\n", "BA Sampling, run 115 of 160\n", "BA Sampling, run 116 of 160\n", "BA Sampling, run 117 of 160\n", "BA Sampling, run 118 of 160\n", "BA Sampling, run 119 of 160\n", "BA Sampling, run 120 of 160\n", "BA Sampling, run 121 of 160\n", "BA Sampling, run 122 of 160\n", "BA Sampling, run 123 of 160\n", "BA Sampling, run 124 of 160\n", "BA Sampling, run 125 of 160\n", "BA Sampling, run 126 of 160\n", "BA Sampling, run 127 of 160\n", "BA Sampling, run 128 of 160\n", "BA Sampling, run 129 of 160\n", "BA Sampling, run 130 of 160\n", "BA Sampling, run 131 of 160\n", "BA Sampling, run 132 of 160\n", "BA Sampling, run 133 of 160\n", "BA Sampling, run 134 of 160\n", "BA Sampling, run 135 of 160\n", "BA Sampling, run 136 of 160\n", "BA Sampling, run 137 of 160\n", "BA Sampling, run 138 of 160\n", "BA Sampling, run 139 of 160\n", "BA Sampling, run 140 of 160\n", "BA Sampling, run 141 of 160\n", "BA Sampling, run 142 of 160\n", "BA Sampling, run 143 of 160\n", "BA Sampling, run 144 of 160\n", "BA Sampling, run 145 of 160\n", "BA Sampling, run 146 of 160\n", "BA Sampling, run 147 of 160\n", "BA Sampling, run 148 of 160\n", "BA Sampling, run 149 of 160\n", "BA Sampling, run 150 of 160\n", "BA Sampling, run 151 of 160\n", "BA Sampling, run 152 of 160\n", "BA Sampling, run 153 of 160\n", "BA Sampling, run 154 of 160\n", "BA Sampling, run 155 of 160\n", "BA Sampling, run 156 of 160\n", "BA Sampling, run 157 of 160\n", "BA Sampling, run 158 of 160\n", "BA Sampling, run 159 of 160\n", "BA Sampling, run 160 of 160\n" ] }, { "data": { "text/html": [ "
βI_awH_aE_UForgetting
10.10.0076233951921786334.2236547059021650.4867369465550719false
20.10.008263473254464434.197278412367970.4796636647188409false
30.10.009108058996702924.2188185497919430.4898830914482514false
40.10.0068252908838357824.196000494748640.4961444698317179false
50.10.0079456667480988684.1660146299290050.5204773431112447false
60.10.0089001348969712044.1956204830261670.5086251383894396false
70.10.0088465563416564884.2218688832508490.4909369684302811false
80.10.0082198311412237624.2034056554996410.5056413173662162false
90.10.0075341387878385234.2150624974124840.47660347809918846false
100.10.0081985123423541434.2149925740212130.49480395197059424false
110.250.047763494947520964.1399860584509040.6930868103587762false
120.250.047465975208361234.0539053221865620.7182313413119743false
130.250.047423062236118584.0765948725699980.7172030851353952false
140.250.0484597315151902644.1268958348549170.7034042728013368false
150.250.046400812202852554.0817854910855040.7146539689227286false
160.250.047044075171233594.136927411497470.6986909719867275false
170.250.044639774300140364.0817494019770540.710043807218996false
180.250.047837154040233314.1046867596305670.7084194959050816false
190.250.047931653619681614.1372953134414650.6946415175575932false
200.250.0482519014149356764.1339402756521880.7020453434164003false
210.50.198547157612811143.849495957442761.0577196597741998false
220.50.196684998134606853.8541749095090681.053826008588814false
230.50.204186396436697673.93000666076177071.0505271632756288false
240.50.196665471341891273.79722432109419471.0619336641350583false
250.50.201936075038803543.84814100876025031.0625370243284582false
260.50.200633024853904753.80704666997319841.069440788536095false
270.50.201758640606612423.84372022089549641.0647381638218636false
280.50.193434784488293523.7970515413556441.0607814895776928false
290.50.20937363730427113.91528588888118941.060438448687588false
300.50.20737095423720553.9381843236614621.0491560829614597false
" ], "text/plain": [ "160x5 DataFrame\n", "| Row | β | I_aw | H_a | E_U | Forgetting |\n", "|-----|------|------------|---------|----------|------------|\n", "| 1 | 0.1 | 0.0076234 | 4.22365 | 0.486737 | false |\n", "| 2 | 0.1 | 0.00826347 | 4.19728 | 0.479664 | false |\n", "| 3 | 0.1 | 0.00910806 | 4.21882 | 0.489883 | false |\n", "| 4 | 0.1 | 0.00682529 | 4.196 | 0.496144 | false |\n", "| 5 | 0.1 | 0.00794567 | 4.16601 | 0.520477 | false |\n", "| 6 | 0.1 | 0.00890013 | 4.19562 | 0.508625 | false |\n", "| 7 | 0.1 | 0.00884656 | 4.22187 | 0.490937 | false |\n", "| 8 | 0.1 | 0.00821983 | 4.20341 | 0.505641 | false |\n", "| 9 | 0.1 | 0.00753414 | 4.21506 | 0.476603 | false |\n", "| 10 | 0.1 | 0.00819851 | 4.21499 | 0.494804 | false |\n", "| 11 | 0.25 | 0.0477635 | 4.13999 | 0.693087 | false |\n", "⋮\n", "| 149 | 1.6 | 2.3355 | 2.99266 | 2.53438 | true |\n", "| 150 | 1.6 | 2.33106 | 2.92917 | 2.53786 | true |\n", "| 151 | 2.0 | 3.26488 | 3.56937 | 2.90997 | true |\n", "| 152 | 2.0 | 3.27499 | 3.5577 | 2.90857 | true |\n", "| 153 | 2.0 | 3.25576 | 3.5279 | 2.9076 | true |\n", "| 154 | 2.0 | 3.26577 | 3.53308 | 2.90904 | true |\n", "| 155 | 2.0 | 3.25836 | 3.53558 | 2.90868 | true |\n", "| 156 | 2.0 | 3.26414 | 3.53648 | 2.91028 | true |\n", "| 157 | 2.0 | 3.26418 | 3.55415 | 2.90976 | true |\n", "| 158 | 2.0 | 3.27046 | 3.54517 | 2.91326 | true |\n", "| 159 | 2.0 | 3.26718 | 3.50027 | 2.90933 | true |\n", "| 160 | 2.0 | 3.26761 | 3.51412 | 2.91024 | true |" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#run the smapling for different temperatures and repeat each run n-times\n", "#then plot these results against the rate-utility curve (based on the closed-form solutions)\n", "βrange_samp = [0.1, 0.25, 0.5, 0.8, 1.2, 1.4, 1.6, 2]\n", "#βrange_samp = [1.2, 2]\n", "\n", "nruns = 10; #number of runs per β point\n", "\n", "nsteps_marg = 500\n", "nsteps_cond = 750\n", "burnin_ratio = 0.8\n", "max_rejsamp_steps=500 #maximum number of steps for sampling from the conditional\n", "\n", "\n", "nconditions = size(βrange_samp,1)*nruns\n", "I_sampled = zeros(2*nconditions)\n", "Ha_sampled = zeros(2*nconditions)\n", "EU_sampled = zeros(2*nconditions)\n", "βval = zeros(2*nconditions)\n", "ResetCtrs = falses(2*nconditions)\n", "\n", "#first run with reset_ctrs to false\n", "reset_ctrs = false #if true, the ctrs for the marginal are reset in each iteration (=\"hard\" forgetting)\n", "for b in 1:nconditions\n", " println(\"BA Sampling, run $b of $(2*nconditions)\")\n", " β = βrange_samp[ceil(b/nruns)]\n", " \n", " pagw_s, pa_s, I, Ha, Hagw, EU, RDobj = BAsampling(pa_init, β, U_pre, Umax, p_w, nsteps_marg, nsteps_cond,\n", " reset_ctrs=reset_ctrs, burnin_ratio=burnin_ratio,\n", " max_rejsamp_steps=max_rejsamp_steps, compute_performance=true)\n", " \n", " I_sampled[b] = I\n", " Ha_sampled[b] = Ha\n", " EU_sampled[b] = EU\n", " βval[b] = β\n", " ResetCtrs[b] = reset_ctrs\n", "end\n", "\n", "#second run with reset_ctrs to true\n", "reset_ctrs = true #if true, the ctrs for the marginal are reset in each iteration (=\"hard\" forgetting)\n", "for b in 1:nconditions\n", " println(\"BA Sampling, run $(nconditions+b) of $(2*nconditions)\")\n", " β = βrange_samp[ceil(b/nruns)]\n", " \n", " pagw_s, pa_s, I, Ha, Hagw, EU, RDobj = BAsampling(pa_init, β, U_pre, Umax, p_w, nsteps_marg, nsteps_cond,\n", " reset_ctrs=reset_ctrs, burnin_ratio=burnin_ratio,\n", " max_rejsamp_steps=max_rejsamp_steps, compute_performance=true)\n", " \n", " I_sampled[nconditions+b] = I\n", " Ha_sampled[nconditions+b] = Ha\n", " EU_sampled[nconditions+b] = EU\n", " βval[nconditions+b] = β\n", " ResetCtrs[nconditions+b] = reset_ctrs\n", "end\n", "\n", "#wrap data in DataFrame for convenient plotting\n", "res_sampled = DataFrame(β=βval, I_aw=I_sampled, H_a=Ha_sampled, E_U=EU_sampled, Forgetting=ResetCtrs)\n", "\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#compute theoretical result for same set of temperatures\n", "ε = 0.0001 #convergence critetion for BAiterations\n", "maxiter = 10000\n", "nβ = length(βrange_samp)\n", "\n", "#preallocate\n", "I = zeros(nβ)\n", "Ha = zeros(nβ)\n", "Hagw = zeros(nβ)\n", "EU = zeros(nβ)\n", "RDobj = zeros(nβ)\n", "\n", "#sweep through β values and perfomr Blahut-Arimoto iterations for each value\n", "for i=1:nβ \n", " pagw, pa, I[i], Ha[i], Hagw[i], EU[i], RDobj[i] = BAiterations(pa_init, βrange_samp[i], U_pre, p_w, ε, maxiter,compute_performance=true) \n", "end\n", "\n", "#show rate-utility curve (shaded region is theoretically infeasible)\n", "perf_res_analytical_samp = performancemeasures2DataFrame(I, Ha, Hagw, EU, RDobj)\n", "perf_res_analytical_samp[:β] = βrange_samp;" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " I(A;W) [bits]\n", " \n", " \n", " 0\n", " 1\n", " 2\n", " 3\n", " 4\n", " \n", " \n", " \n", " false\n", " true\n", " \n", " \n", " \n", " \n", " \n", " \n", " Forgetting\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " 0\n", " 1\n", " 2\n", " 3\n", " \n", " \n", " E[U]\n", " \n", " \n", " Rate-Utility curve (dots show sampling solutions)\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n" ], "text/html": [ "\n", "\n", "\n", " \n", " I(A;W) [bits]\n", " \n", " \n", " -5\n", " -4\n", " -3\n", " -2\n", " -1\n", " 0\n", " 1\n", " 2\n", " 3\n", " 4\n", " 5\n", " 6\n", " 7\n", " 8\n", " 9\n", " -4.0\n", " -3.8\n", " -3.6\n", " -3.4\n", " -3.2\n", " -3.0\n", " -2.8\n", " -2.6\n", " -2.4\n", " -2.2\n", " -2.0\n", " -1.8\n", " -1.6\n", " -1.4\n", " -1.2\n", " -1.0\n", " -0.8\n", " -0.6\n", " -0.4\n", " -0.2\n", " 0.0\n", " 0.2\n", " 0.4\n", " 0.6\n", " 0.8\n", " 1.0\n", " 1.2\n", " 1.4\n", " 1.6\n", " 1.8\n", " 2.0\n", " 2.2\n", " 2.4\n", " 2.6\n", " 2.8\n", " 3.0\n", " 3.2\n", " 3.4\n", " 3.6\n", " 3.8\n", " 4.0\n", " 4.2\n", " 4.4\n", " 4.6\n", " 4.8\n", " 5.0\n", " 5.2\n", " 5.4\n", " 5.6\n", " 5.8\n", " 6.0\n", " 6.2\n", " 6.4\n", " 6.6\n", " 6.8\n", " 7.0\n", " 7.2\n", " 7.4\n", " 7.6\n", " 7.8\n", " 8.0\n", " -5\n", " 0\n", " 5\n", " 10\n", " -4.0\n", " -3.5\n", " -3.0\n", " -2.5\n", " -2.0\n", " -1.5\n", " -1.0\n", " -0.5\n", " 0.0\n", " 0.5\n", " 1.0\n", " 1.5\n", " 2.0\n", " 2.5\n", " 3.0\n", " 3.5\n", " 4.0\n", " 4.5\n", " 5.0\n", " 5.5\n", " 6.0\n", " 6.5\n", " 7.0\n", " 7.5\n", " 8.0\n", " \n", " \n", " \n", " false\n", " true\n", " \n", " \n", " \n", " \n", " \n", " \n", " Forgetting\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " -4\n", " -3\n", " -2\n", " -1\n", " 0\n", " 1\n", " 2\n", " 3\n", " 4\n", " 5\n", " 6\n", " 7\n", " -3.0\n", " -2.9\n", " -2.8\n", " -2.7\n", " -2.6\n", " -2.5\n", " -2.4\n", " -2.3\n", " -2.2\n", " -2.1\n", " -2.0\n", " -1.9\n", " -1.8\n", " -1.7\n", " -1.6\n", " -1.5\n", " -1.4\n", " -1.3\n", " -1.2\n", " -1.1\n", " -1.0\n", " -0.9\n", " -0.8\n", " -0.7\n", " -0.6\n", " -0.5\n", " -0.4\n", " -0.3\n", " -0.2\n", " -0.1\n", " 0.0\n", " 0.1\n", " 0.2\n", " 0.3\n", " 0.4\n", " 0.5\n", " 0.6\n", " 0.7\n", " 0.8\n", " 0.9\n", " 1.0\n", " 1.1\n", " 1.2\n", " 1.3\n", " 1.4\n", " 1.5\n", " 1.6\n", " 1.7\n", " 1.8\n", " 1.9\n", " 2.0\n", " 2.1\n", " 2.2\n", " 2.3\n", " 2.4\n", " 2.5\n", " 2.6\n", " 2.7\n", " 2.8\n", " 2.9\n", " 3.0\n", " 3.1\n", " 3.2\n", " 3.3\n", " 3.4\n", " 3.5\n", " 3.6\n", " 3.7\n", " 3.8\n", " 3.9\n", " 4.0\n", " 4.1\n", " 4.2\n", " 4.3\n", " 4.4\n", " 4.5\n", " 4.6\n", " 4.7\n", " 4.8\n", " 4.9\n", " 5.0\n", " 5.1\n", " 5.2\n", " 5.3\n", " 5.4\n", " 5.5\n", " 5.6\n", " 5.7\n", " 5.8\n", " 5.9\n", " 6.0\n", " -3\n", " 0\n", " 3\n", " 6\n", " -3.0\n", " -2.8\n", " -2.6\n", " -2.4\n", " -2.2\n", " -2.0\n", " -1.8\n", " -1.6\n", " -1.4\n", " -1.2\n", " -1.0\n", " -0.8\n", " -0.6\n", " -0.4\n", " -0.2\n", " 0.0\n", " 0.2\n", " 0.4\n", " 0.6\n", " 0.8\n", " 1.0\n", " 1.2\n", " 1.4\n", " 1.6\n", " 1.8\n", " 2.0\n", " 2.2\n", " 2.4\n", " 2.6\n", " 2.8\n", " 3.0\n", " 3.2\n", " 3.4\n", " 3.6\n", " 3.8\n", " 4.0\n", " 4.2\n", " 4.4\n", " 4.6\n", " 4.8\n", " 5.0\n", " 5.2\n", " 5.4\n", " 5.6\n", " 5.8\n", " 6.0\n", " \n", " \n", " E[U]\n", " \n", " \n", " Rate-Utility curve (dots show sampling solutions)\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "\n" ], "text/plain": [ "Plot(...)" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#plot the solutions from the sampling runs into the (analytical) rate-utility plot\n", "plot(Guide.ylabel(\"E[U]\"), Guide.xlabel(\"I(A;W) [bits]\"),\n", " Guide.title(\"Rate-Utility curve (dots show sampling solutions)\"), BAtheme(), BAdiscretecolorscale(2),\n", " layer(res_sampled,y=\"E_U\",x=\"I_aw\",Geom.point,color=\"Forgetting\"),\n", " layer(perf_res_analytical_samp,y=\"E_U\",x=\"I_aw\",Geom.point),\n", " layer(perf_res_analytical,y=\"E_U\",x=\"I_aw\",Geom.line),\n", " layer(perf_res_analytical,y=\"E_U\",x=\"I_aw\",ymin=\"E_U\",ymax=ones(size(perf_res_analytical,1))*maximum(perf_res_analytical[:E_U]),\n", " Geom.ribbon,BAtheme(default_color=colorant\"green\"))\n", " )" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": false }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", " \n", " I(A;W) [bits]\n", " \n", " \n", " 0\n", " 1\n", " 2\n", " 3\n", " 4\n", " \n", " \n", " \n", " false\n", " true\n", " \n", " \n", " \n", " \n", " \n", " \n", " Forgetting\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " 0\n", " 1\n", " 2\n", " 3\n", " \n", " \n", " E[U]\n", " \n", " \n", " Rate-Utility curve (dots show mean sampling solutions)\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n" ], "text/html": [ "\n", "\n", "\n", " \n", " I(A;W) [bits]\n", " \n", " \n", " -5\n", " -4\n", " -3\n", " -2\n", " -1\n", " 0\n", " 1\n", " 2\n", " 3\n", " 4\n", " 5\n", " 6\n", " 7\n", " 8\n", " 9\n", " -4.0\n", " -3.8\n", " -3.6\n", " -3.4\n", " -3.2\n", " -3.0\n", " -2.8\n", " -2.6\n", " -2.4\n", " -2.2\n", " -2.0\n", " -1.8\n", " -1.6\n", " -1.4\n", " -1.2\n", " -1.0\n", " -0.8\n", " -0.6\n", " -0.4\n", " -0.2\n", " 0.0\n", " 0.2\n", " 0.4\n", " 0.6\n", " 0.8\n", " 1.0\n", " 1.2\n", " 1.4\n", " 1.6\n", " 1.8\n", " 2.0\n", " 2.2\n", " 2.4\n", " 2.6\n", " 2.8\n", " 3.0\n", " 3.2\n", " 3.4\n", " 3.6\n", " 3.8\n", " 4.0\n", " 4.2\n", " 4.4\n", " 4.6\n", " 4.8\n", " 5.0\n", " 5.2\n", " 5.4\n", " 5.6\n", " 5.8\n", " 6.0\n", " 6.2\n", " 6.4\n", " 6.6\n", " 6.8\n", " 7.0\n", " 7.2\n", " 7.4\n", " 7.6\n", " 7.8\n", " 8.0\n", " -5\n", " 0\n", " 5\n", " 10\n", " -4.0\n", " -3.5\n", " -3.0\n", " -2.5\n", " -2.0\n", " -1.5\n", " -1.0\n", " -0.5\n", " 0.0\n", " 0.5\n", " 1.0\n", " 1.5\n", " 2.0\n", " 2.5\n", " 3.0\n", " 3.5\n", " 4.0\n", " 4.5\n", " 5.0\n", " 5.5\n", " 6.0\n", " 6.5\n", " 7.0\n", " 7.5\n", " 8.0\n", " \n", " \n", " \n", " false\n", " true\n", " \n", " \n", " \n", " \n", " \n", " \n", " Forgetting\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " -4\n", " -3\n", " -2\n", " -1\n", " 0\n", " 1\n", " 2\n", " 3\n", " 4\n", " 5\n", " 6\n", " 7\n", " -3.0\n", " -2.9\n", " -2.8\n", " -2.7\n", " -2.6\n", " -2.5\n", " -2.4\n", " -2.3\n", " -2.2\n", " -2.1\n", " -2.0\n", " -1.9\n", " -1.8\n", " -1.7\n", " -1.6\n", " -1.5\n", " -1.4\n", " -1.3\n", " -1.2\n", " -1.1\n", " -1.0\n", " -0.9\n", " -0.8\n", " -0.7\n", " -0.6\n", " -0.5\n", " -0.4\n", " -0.3\n", " -0.2\n", " -0.1\n", " 0.0\n", " 0.1\n", " 0.2\n", " 0.3\n", " 0.4\n", " 0.5\n", " 0.6\n", " 0.7\n", " 0.8\n", " 0.9\n", " 1.0\n", " 1.1\n", " 1.2\n", " 1.3\n", " 1.4\n", " 1.5\n", " 1.6\n", " 1.7\n", " 1.8\n", " 1.9\n", " 2.0\n", " 2.1\n", " 2.2\n", " 2.3\n", " 2.4\n", " 2.5\n", " 2.6\n", " 2.7\n", " 2.8\n", " 2.9\n", " 3.0\n", " 3.1\n", " 3.2\n", " 3.3\n", " 3.4\n", " 3.5\n", " 3.6\n", " 3.7\n", " 3.8\n", " 3.9\n", " 4.0\n", " 4.1\n", " 4.2\n", " 4.3\n", " 4.4\n", " 4.5\n", " 4.6\n", " 4.7\n", " 4.8\n", " 4.9\n", " 5.0\n", " 5.1\n", " 5.2\n", " 5.3\n", " 5.4\n", " 5.5\n", " 5.6\n", " 5.7\n", " 5.8\n", " 5.9\n", " 6.0\n", " -3\n", " 0\n", " 3\n", " 6\n", " -3.0\n", " -2.8\n", " -2.6\n", " -2.4\n", " -2.2\n", " -2.0\n", " -1.8\n", " -1.6\n", " -1.4\n", " -1.2\n", " -1.0\n", " -0.8\n", " -0.6\n", " -0.4\n", " -0.2\n", " 0.0\n", " 0.2\n", " 0.4\n", " 0.6\n", " 0.8\n", " 1.0\n", " 1.2\n", " 1.4\n", " 1.6\n", " 1.8\n", " 2.0\n", " 2.2\n", " 2.4\n", " 2.6\n", " 2.8\n", " 3.0\n", " 3.2\n", " 3.4\n", " 3.6\n", " 3.8\n", " 4.0\n", " 4.2\n", " 4.4\n", " 4.6\n", " 4.8\n", " 5.0\n", " 5.2\n", " 5.4\n", " 5.6\n", " 5.8\n", " 6.0\n", " \n", " \n", " E[U]\n", " \n", " \n", " Rate-Utility curve (dots show mean sampling solutions)\n", " \n", "\n", "\n", "\n", " \n", "\n", "\n", "\n", "\n" ], "text/plain": [ "Plot(...)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#compute the mean for each group of points (that were produced with the same beta and same forgetting setting)\n", "res_samp_aggregated = aggregate(res_sampled, [:β,:Forgetting], mean)\n", "\n", "#plot the mean-solutions from the sampling runs into the (analytical) rate-utility plot\n", "plt_samp = plot(Guide.ylabel(\"E[U]\"), Guide.xlabel(\"I(A;W) [bits]\"),\n", " Guide.title(\"Rate-Utility curve (dots show mean sampling solutions)\"), BAtheme(), BAdiscretecolorscale(2),\n", " layer(res_samp_aggregated,y=\"E_U_mean\",x=\"I_aw_mean\",Geom.point,Geom.line(preserve_order=true),color=\"Forgetting\"),\n", " layer(perf_res_analytical_samp,y=\"E_U\",x=\"I_aw\",Geom.point),\n", " layer(perf_res_analytical,y=\"E_U\",x=\"I_aw\",Geom.line),\n", " layer(perf_res_analytical,y=\"E_U\",x=\"I_aw\",ymin=\"E_U\",ymax=ones(size(perf_res_analytical,1))*maximum(perf_res_analytical[:E_U]),\n", " Geom.ribbon,BAtheme(default_color=colorant\"green\"))\n", " )" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#store the plots\n", "#draw(SVG(\"Figures/RateUtilityCurve.svg\", 8.5cm, 7cm), plot_rateutility)\n", "\n", "#draw(SVG(\"Figures/SamplingCond.svg\", 8.5cm, 9cm), plt_cond)\n", "#draw(SVG(\"Figures/BASampling.svg\", 13cm, 11cm), plt_samp)\n", "\n", "#plot_samp = vstack(plt_samp, plt_cond)\n", "#draw(SVG(\"Figures/SamplingResults.svg\", 18cm,16cm),plot_samp)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "#try changing the number of inner-/outer-loop iterations\n", "#try changing the burn-in ratio\n", "#try soft-forgetting (with exponential-decay window)\n", "\n", "#forgetting seems to do better than no forgetting (in terms of being closer to the rate-utility curve),\n", "#but it also seems that the points with forgetting tend to have a lower I(A;O) (and also a lower E[U]) - even though\n", "#the temperatures are the same." ] } ], "metadata": { "kernelspec": { "display_name": "Julia 0.3.10", "language": "julia", "name": "julia-0.3" }, "language_info": { "file_extension": ".jl", "mimetype": "application/julia", "name": "julia", "version": "0.3.10" } }, "nbformat": 4, "nbformat_minor": 0 }