{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Humans and animals integrate multisensory cues near-optimally\n", "## An intuition for how populations of neurons can perform Bayesian inference" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [], "source": [ "from __future__ import division\n", "import numpy as np\n", "from scipy.special import factorial\n", "import scipy.stats as stats\n", "import pylab\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "import seaborn as sns\n", "sns.set_style(\"darkgrid\")\n", "import ipywidgets\n", "from IPython.display import display\n", "from matplotlib.font_manager import FontProperties\n", "fontP = FontProperties()\n", "fontP.set_size('medium')\n", "%config InlineBackend.figure_format = 'svg'\n", "\n", "\n", "def mean_firing_rate(gain, stimulus, preferred_stimulus, std_tc, baseline):\n", " # Gaussian tuning curve that determines the mean firing rate (Poisson rate parameter) for a given stimulus\n", " return baseline + gain*stats.norm.pdf(preferred_stimulus, loc = stimulus, scale = std_tc)\n", "\n", "def get_spikes(gain, stimulus, preferred_stimuli, std_tc, baseline):\n", " # produce a vector of spikes for some population given some stimulus\n", " lambdas = mean_firing_rate(gain, stimulus, preferred_stimuli, std_tc, baseline)\n", " return np.random.poisson(lambdas)\n", " \n", "def likelihood(stimulus, r, gain, preferred_stimuli, std_tc, baseline):\n", " # returns p(r|s)\n", " lambdas = mean_firing_rate(gain, stimulus, preferred_stimuli, std_tc, baseline)\n", " return np.prod(lambdas**r)\n", "\n", "def spikes_and_inference(r_V = True,\n", " r_A = True,\n", " show_tuning_curves = False,\n", " show_spike_count = False,\n", " show_likelihoods = True,\n", " true_stimulus = 10,\n", " number_of_neurons = 40,\n", " r_V_gain = 15,\n", " r_A_gain = 75,\n", " r_V_tuning_curve_sigma = 10,\n", " r_A_tuning_curve_sigma = 10,\n", " tuning_curve_baseline = 0,\n", " joint_likelihood = True,\n", " r_V_plus_r_A = True,\n", " cue = False):\n", " np.random.seed(7)\n", " max_s = 40\n", " preferred_stimuli = np.linspace(-max_s*2, max_s*2, number_of_neurons)\n", " n_hypothesized_s = 250\n", " hypothesized_s = np.linspace(-max_s, max_s, n_hypothesized_s)\n", " gains = {'r1': r_V_gain,\n", " 'r2': r_A_gain,\n", " 'r1+r2': r_V_gain + r_A_gain}\n", " sigma_TCs = {'r1': r_V_tuning_curve_sigma,\n", " 'r2': r_A_tuning_curve_sigma,\n", " 'r1+r2': (r_V_tuning_curve_sigma + r_A_tuning_curve_sigma)/2}\n", " spikes = {'r1': get_spikes(gains['r1'], true_stimulus, preferred_stimuli, sigma_TCs['r1'], tuning_curve_baseline),\n", " 'r2': get_spikes(gains['r2'], true_stimulus, preferred_stimuli, sigma_TCs['r2'], tuning_curve_baseline)}\n", " spikes['r1+r2'] = spikes['r1'] + spikes['r2']\n", " active_pops = []\n", " if r_V: active_pops.append('r1')\n", " if r_A: active_pops.append('r2')\n", " if r_V_plus_r_A: active_pops.append('r1+r2')\n", "\n", " colors = {'r1': sns.xkcd_rgb['light purple'],\n", " 'r2': sns.xkcd_rgb['dark pink'],\n", " 'r1+r2': sns.xkcd_rgb['royal blue'],\n", " 'joint': sns.xkcd_rgb['gold']}\n", " nSubplots = show_spike_count + show_tuning_curves + show_likelihoods\n", " fig, axes = plt.subplots(nSubplots, figsize = (7, 1.5*nSubplots)) # number of subplots according to what's been requested\n", " if not isinstance(axes, np.ndarray): axes = [axes] # makes axes into a list even if it's just one subplot\n", " subplot_idx = 0\n", " \n", " def plot_true_stimulus_and_legend(subplot_idx):\n", " axes[subplot_idx].plot(true_stimulus, 0, 'k^', markersize = 12, clip_on = False, label = 'true rattlesnake location')\n", " axes[subplot_idx].legend(loc = 'center left', bbox_to_anchor = (1, 0.5), prop = fontP)\n", " \n", " if show_tuning_curves:\n", " for neuron in range(number_of_neurons):\n", " if r_V:\n", " axes[subplot_idx].plot(hypothesized_s,\n", " mean_firing_rate(gains['r1'],\n", " hypothesized_s,\n", " preferred_stimuli[neuron],\n", " sigma_TCs['r1'],\n", " tuning_curve_baseline),\n", " color = colors['r1'])\n", " if r_A:\n", " axes[subplot_idx].plot(hypothesized_s,\n", " mean_firing_rate(gains['r2'],\n", " hypothesized_s,\n", " preferred_stimuli[neuron],\n", " sigma_TCs['r2'],\n", " tuning_curve_baseline),\n", " color = colors['r2'])\n", " axes[subplot_idx].set_xlabel('location $s$')\n", " axes[subplot_idx].set_ylabel('mean firing rate\\n(spikes/s)')\n", " axes[subplot_idx].set_ylim((0, 4))\n", " axes[subplot_idx].set_xlim((-40, 40))\n", " axes[subplot_idx].set_yticks(np.linspace(0, 4, 5))\n", " subplot_idx += 1\n", "\n", " if show_spike_count:\n", " idx = abs(preferred_stimuli) < max_s\n", " if r_V:\n", " axes[subplot_idx].plot(preferred_stimuli[idx], spikes['r1'][idx], 'o', color = colors['r1'],\n", " clip_on = False, label = '$\\mathbf{r}_\\mathrm{V}$',\n", " markersize=4)\n", " if r_A:\n", " axes[subplot_idx].plot(preferred_stimuli[idx], spikes['r2'][idx], 'o', color = colors['r2'],\n", " clip_on = False, label = '$\\mathbf{r}_\\mathrm{A}$',\n", " markersize=4)\n", " if r_V_plus_r_A:\n", " axes[subplot_idx].plot(preferred_stimuli[idx], spikes['r1+r2'][idx], 'o', color = colors['r1+r2'],\n", " clip_on = False, label = '$\\mathbf{r}_\\mathrm{V}+\\mathbf{r}_\\mathrm{A}$',\n", " markersize=8, zorder=1)\n", " axes[subplot_idx].set_xlabel('preferred location')\n", " axes[subplot_idx].set_ylabel('spike count')\n", " axes[subplot_idx].set_ylim((0, 10))\n", " axes[subplot_idx].set_xlim((-40, 40))\n", " plot_true_stimulus_and_legend(subplot_idx)\n", " subplot_idx += 1\n", "\n", " if show_likelihoods:\n", " if cue:\n", " var = 'c'\n", " else:\n", " var = '\\mathbf{r}'\n", " likelihoods = {}\n", " \n", " for population in active_pops:\n", " likelihoods[population] = np.zeros_like(hypothesized_s)\n", " for idx, ort in enumerate(hypothesized_s):\n", " likelihoods[population][idx] = likelihood(ort, spikes[population], gains[population],\n", " preferred_stimuli, sigma_TCs[population], tuning_curve_baseline)\n", " likelihoods[population] /= np.sum(likelihoods[population]) # normalize\n", "\n", " if r_V:\n", " axes[subplot_idx].plot(hypothesized_s, likelihoods['r1'], color = colors['r1'],\n", " linewidth = 2, label = '$p({}_\\mathrm{{V}}|s)$'.format(var))\n", " if r_A:\n", " axes[subplot_idx].plot(hypothesized_s, likelihoods['r2'], color = colors['r2'],\n", " linewidth = 2, label = '$p({}_\\mathrm{{A}}|s)$'.format(var))\n", " if r_V_plus_r_A:\n", " axes[subplot_idx].plot(hypothesized_s, likelihoods['r1+r2'], color = colors['r1+r2'],\n", " linewidth = 2, label = '$p({}_\\mathrm{{V}}+{}_\\mathrm{{A}}|s)$'.format(var, var))\n", " if joint_likelihood:\n", " product = likelihoods['r1']*likelihoods['r2']\n", " product /= np.sum(product)\n", " axes[subplot_idx].plot(hypothesized_s, product, color = colors['joint'],linewidth = 7,\n", " label = '$p({}_\\mathrm{{V}}|s)\\ p({}_\\mathrm{{A}}|s)$'.format(var, var), zorder = 1)\n", "\n", " axes[subplot_idx].set_xlabel('location $s$')\n", " axes[subplot_idx].set_ylabel('probability')\n", " axes[subplot_idx].set_xlim((-40, 40))\n", " axes[subplot_idx].legend()\n", " axes[subplot_idx].set_yticks([])\n", " \n", " plot_true_stimulus_and_legend(subplot_idx)\n", " subplot_idx += 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "
We live in a complex environment and must constantly integrate sensory information to interact with the world around us. Inputs from different modalities might not always be congruent with each other, but dissociating the true nature of the stimulus may be a matter of life or death for an organism.
\n", "\n", "You hear and see evidence of a rattlesnake in tall grass near you. You get an auditory and a visual cue of the snake's location $s$. Both cues are associated with a likelihood function indicating the probability of that cue for all possible locations of the snake. The likelihood function associated with the visual cue, $p(c_\\mathrm{V}|s)$, has high uncertainty, because of the tall grass. The auditory cue is easier to localize, so its associated likelihood function, $p(c_\\mathrm{A}|s)$, is sharper. In accordance with Bayes' Rule, and assuming a flat prior over the snake's location, an optimal estimate of the location of the snake can be computed by multiplying the two likelihoods. This joint likelihood will be between the two cues but closer to the less uncertain cue, and will have less uncertainty than both unimodal likelihood functions.
" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n" ], "text/plain": [ "