{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/EugOT/compneuro/blob/main/notebooks/2023-11-26-BiologicalNeuronModels-2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a> &nbsp; <a href=\"https://kaggle.com/kernels/welcome?src=https://raw.githubusercontent.com/EugOT/compneuro/main/notebooks/2023-11-26-BiologicalNeuronModels-2.ipynb\" target=\"_parent\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" alt=\"Open in Kaggle\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "_KpEXqXDlvMG"
      },
      "source": [
        "# Tutorial 2: Effects of Input Correlation\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "GPTkzqA7lvMG"
      },
      "source": [
        "---\n",
        "# Tutorial Objectives\n",
        "\n",
        "In this tutorial, we will use the leaky integrate-and-fire (LIF) neuron model (see Tutorial 1) to study how they transform input correlations to output properties (transfer of correlations). In particular, we are going to write a few lines of code to:\n",
        "\n",
        "- inject correlated GWN in a pair of neurons\n",
        "\n",
        "- measure correlations between the spiking activity of the two neurons\n",
        "\n",
        "- study how the transfer of correlation depends on the statistics of the input, i.e. mean and standard deviation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "AJwIy4h4lvMG"
      },
      "source": [
        "---\n",
        "# Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "execution": {},
        "id": "bnorJ1O3lvMG"
      },
      "outputs": [],
      "source": [
        "# Imports\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import time"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "execution": {},
        "id": "85gQjgd_lvMH"
      },
      "outputs": [],
      "source": [
        "# @title Figure Settings\n",
        "import logging\n",
        "logging.getLogger('matplotlib.font_manager').disabled = True\n",
        "\n",
        "import ipywidgets as widgets  # interactive display\n",
        "%config InlineBackend.figure_format = 'retina'\n",
        "# use NMA plot style\n",
        "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle\")\n",
        "my_layout = widgets.Layout()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "execution": {},
        "id": "OsIBHS7jlvMH"
      },
      "outputs": [],
      "source": [
        "# @title Plotting Functions\n",
        "\n",
        "def example_plot_myCC():\n",
        "  pars = default_pars(T=50000, dt=.1)\n",
        "\n",
        "  c = np.arange(10) * 0.1\n",
        "  r12 = np.zeros(10)\n",
        "  for i in range(10):\n",
        "    I1gL, I2gL = correlate_input(pars, mu=20.0, sig=7.5, c=c[i])\n",
        "    r12[i] = my_CC(I1gL, I2gL)\n",
        "\n",
        "  plt.figure()\n",
        "  plt.plot(c, r12, 'bo', alpha=0.7, label='Simulation', zorder=2)\n",
        "  plt.plot([-0.05, 0.95], [-0.05, 0.95], 'k--', label='y=x',\n",
        "           dashes=(2, 2), zorder=1)\n",
        "  plt.xlabel('True CC')\n",
        "  plt.ylabel('Sample CC')\n",
        "  plt.legend(loc='best')\n",
        "  plt.show()\n",
        "\n",
        "\n",
        "def my_raster_Poisson(range_t, spike_train, n):\n",
        "  \"\"\"\n",
        "  Ffunction generates and plots the raster of the Poisson spike train\n",
        "\n",
        "  Args:\n",
        "    range_t     : time sequence\n",
        "    spike_train : binary spike trains, with shape (N, Lt)\n",
        "    n           : number of Poisson trains plot\n",
        "\n",
        "  Returns:\n",
        "    Raster plot of the spike train\n",
        "  \"\"\"\n",
        "\n",
        "  # find the number of all the spike trains\n",
        "  N = spike_train.shape[0]\n",
        "\n",
        "  # n should smaller than N:\n",
        "  if n > N:\n",
        "    print('The number n exceeds the size of spike trains')\n",
        "    print('The number n is set to be the size of spike trains')\n",
        "    n = N\n",
        "\n",
        "  # plot rater\n",
        "  plt.figure()\n",
        "  i = 0\n",
        "  while i < n:\n",
        "    if spike_train[i, :].sum() > 0.:\n",
        "      t_sp = range_t[spike_train[i, :] > 0.5]  # spike times\n",
        "      plt.plot(t_sp, i * np.ones(len(t_sp)), 'k|', ms=10, markeredgewidth=2)\n",
        "    i += 1\n",
        "  plt.xlim([range_t[0], range_t[-1]])\n",
        "  plt.ylim([-0.5, n + 0.5])\n",
        "  plt.xlabel('Time (ms)', fontsize=12)\n",
        "  plt.ylabel('Neuron ID', fontsize=12)\n",
        "  plt.show()\n",
        "\n",
        "def plot_c_r_LIF(c, r, mycolor, mylabel):\n",
        "  z = np.polyfit(c, r, deg=1)\n",
        "  c_range = np.array([c.min() - 0.05, c.max() + 0.05])\n",
        "  plt.plot(c, r, 'o', color=mycolor, alpha=0.7, label=mylabel, zorder=2)\n",
        "  plt.plot(c_range, z[0] * c_range + z[1], color=mycolor, zorder=1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "execution": {},
        "id": "tAhn69xilvMH"
      },
      "outputs": [],
      "source": [
        "# @title Helper Functions\n",
        "def default_pars(**kwargs):\n",
        "  pars = {}\n",
        "\n",
        "  ### typical neuron parameters###\n",
        "  pars['V_th'] = -55.     # spike threshold [mV]\n",
        "  pars['V_reset'] = -75.  # reset potential [mV]\n",
        "  pars['tau_m'] = 10.     # membrane time constant [ms]\n",
        "  pars['g_L'] = 10.       # leak conductance [nS]\n",
        "  pars['V_init'] = -75.   # initial potential [mV]\n",
        "  pars['V_L'] = -75.      # leak reversal potential [mV]\n",
        "  pars['tref'] = 2.       # refractory time (ms)\n",
        "\n",
        "  ### simulation parameters ###\n",
        "  pars['T'] = 400. # Total duration of simulation [ms]\n",
        "  pars['dt'] = .1  # Simulation time step [ms]\n",
        "\n",
        "  ### external parameters if any ###\n",
        "  for k in kwargs:\n",
        "    pars[k] = kwargs[k]\n",
        "\n",
        "  pars['range_t'] = np.arange(0, pars['T'], pars['dt'])  # Vector of discretized\n",
        "                                                         # time points [ms]\n",
        "  return pars\n",
        "\n",
        "\n",
        "def run_LIF(pars, Iinj):\n",
        "  \"\"\"\n",
        "  Simulate the LIF dynamics with external input current\n",
        "\n",
        "  Args:\n",
        "    pars       : parameter dictionary\n",
        "    Iinj       : input current [pA]. The injected current here can be a value or an array\n",
        "\n",
        "  Returns:\n",
        "    rec_spikes : spike times\n",
        "    rec_v      : mebrane potential\n",
        "  \"\"\"\n",
        "\n",
        "  # Set parameters\n",
        "  V_th, V_reset = pars['V_th'], pars['V_reset']\n",
        "  tau_m, g_L = pars['tau_m'], pars['g_L']\n",
        "  V_init, V_L = pars['V_init'], pars['V_L']\n",
        "  dt, range_t = pars['dt'], pars['range_t']\n",
        "  Lt = range_t.size\n",
        "  tref = pars['tref']\n",
        "\n",
        "  # Initialize voltage and current\n",
        "  v = np.zeros(Lt)\n",
        "  v[0] = V_init\n",
        "  Iinj = Iinj * np.ones(Lt)\n",
        "  tr = 0.\n",
        "\n",
        "  # simulate the LIF dynamics\n",
        "  rec_spikes = []   # record spike times\n",
        "  for it in range(Lt - 1):\n",
        "    if tr > 0:\n",
        "      v[it] = V_reset\n",
        "      tr = tr - 1\n",
        "    elif v[it] >= V_th:  # reset voltage and record spike event\n",
        "      rec_spikes.append(it)\n",
        "      v[it] = V_reset\n",
        "      tr = tref / dt\n",
        "\n",
        "    # calculate the increment of the membrane potential\n",
        "    dv = (-(v[it] - V_L) + Iinj[it] / g_L) * (dt / tau_m)\n",
        "\n",
        "    # update the membrane potential\n",
        "    v[it + 1] = v[it] + dv\n",
        "\n",
        "  rec_spikes = np.array(rec_spikes) * dt\n",
        "\n",
        "  return v, rec_spikes\n",
        "\n",
        "\n",
        "def my_GWN(pars, sig, myseed=False):\n",
        "  \"\"\"\n",
        "  Function that calculates Gaussian white noise inputs\n",
        "\n",
        "  Args:\n",
        "    pars       : parameter dictionary\n",
        "    mu         : noise baseline (mean)\n",
        "    sig        : noise amplitute (standard deviation)\n",
        "    myseed     : random seed. int or boolean\n",
        "                 the same seed will give the same random number sequence\n",
        "\n",
        "  Returns:\n",
        "    I          : Gaussian white noise input\n",
        "  \"\"\"\n",
        "\n",
        "  # Retrieve simulation parameters\n",
        "  dt, range_t = pars['dt'], pars['range_t']\n",
        "  Lt = range_t.size\n",
        "\n",
        "  # Set random seed. You can fix the seed of the random number generator so\n",
        "  # that the results are reliable however, when you want to generate multiple\n",
        "  # realization make sure that you change the seed for each new realization\n",
        "  if myseed:\n",
        "      np.random.seed(seed=myseed)\n",
        "  else:\n",
        "      np.random.seed()\n",
        "\n",
        "  # generate GWN\n",
        "  # we divide here by 1000 to convert units to sec.\n",
        "  I_GWN = sig * np.random.randn(Lt) * np.sqrt(pars['tau_m'] / dt)\n",
        "\n",
        "  return I_GWN\n",
        "\n",
        "\n",
        "def LIF_output_cc(pars, mu, sig, c, bin_size, n_trials=20):\n",
        "  \"\"\" Simulates two LIF neurons with correlated input and computes output correlation\n",
        "\n",
        "  Args:\n",
        "  pars       : parameter dictionary\n",
        "  mu         : noise baseline (mean)\n",
        "  sig        : noise amplitute (standard deviation)\n",
        "  c          : correlation coefficient ~[0, 1]\n",
        "  bin_size   : bin size used for time series\n",
        "  n_trials   : total simulation trials\n",
        "\n",
        "  Returns:\n",
        "  r          : output corr. coe.\n",
        "  sp_rate    : spike rate\n",
        "  sp1        : spike times of neuron 1 in the last trial\n",
        "  sp2        : spike times of neuron 2 in the last trial\n",
        "  \"\"\"\n",
        "\n",
        "  r12 = np.zeros(n_trials)\n",
        "  sp_rate = np.zeros(n_trials)\n",
        "  for i_trial in range(n_trials):\n",
        "    I1gL, I2gL = correlate_input(pars, mu, sig, c)\n",
        "    _, sp1 = run_LIF(pars, pars['g_L'] * I1gL)\n",
        "    _, sp2 = run_LIF(pars, pars['g_L'] * I2gL)\n",
        "\n",
        "    my_bin = np.arange(0, pars['T'], bin_size)\n",
        "\n",
        "    sp1_count, _ = np.histogram(sp1, bins=my_bin)\n",
        "    sp2_count, _ = np.histogram(sp2, bins=my_bin)\n",
        "\n",
        "    r12[i_trial] = my_CC(sp1_count[::20], sp2_count[::20])\n",
        "    sp_rate[i_trial] = len(sp1) / pars['T'] * 1000.\n",
        "\n",
        "  return r12.mean(), sp_rate.mean(), sp1, sp2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "2hMtX64jlvMI"
      },
      "source": [
        "The helper function contains the:\n",
        "\n",
        "- Parameter dictionary: `default_pars( **kwargs)` from Tutorial 1\n",
        "- LIF simulator: `run_LIF` from Tutorial 1\n",
        "- Gaussian white noise generator: `my_GWN(pars, sig, myseed=False)` from Tutorial 1\n",
        "- Poisson type spike train generator: `Poisson_generator(pars, rate, n, myseed=False)`\n",
        "- Two LIF neurons with correlated inputs simulator: `LIF_output_cc(pars, mu, sig, c, bin_size, n_trials=20)`\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "3xB0MTXhlvMI"
      },
      "source": [
        "---\n",
        "# Section 1: Correlations (Synchrony)\n",
        "Correlation or synchrony in neuronal activity can be described for any readout of brain activity. Here, we are concerned with the spiking activity of neurons.\n",
        "\n",
        "In the simplest way, correlation/synchrony refers to coincident spiking of neurons, i.e., when two neurons spike together, they are firing in **synchrony** or are **correlated**. Neurons can be synchronous in their instantaneous activity, i.e., they spike together with some probability. However, it is also possible that spiking of a neuron at time $t$ is correlated with the spikes of another neuron with a delay (time-delayed synchrony).\n",
        "\n",
        "## Origin of synchronous neuronal activity:\n",
        "- Common inputs, i.e., two neurons are receiving input from the same sources. The degree of correlation of the shared inputs is proportional to their output correlation.\n",
        "- Pooling from the same sources. Neurons do not share the same input neurons but are receiving inputs from neurons which themselves are correlated.\n",
        "- Neurons are connected to each other (uni- or bi-directionally): This will only give rise to time-delayed synchrony. Neurons could also be connected via gap-junctions.\n",
        "- Neurons have similar parameters and initial conditions.\n",
        "\n",
        "## Implications of synchrony\n",
        "When neurons spike together, they can have a stronger impact on downstream neurons. Synapses in the brain are sensitive to the temporal correlations (i.e., delay) between pre- and postsynaptic activity, and this, in turn, can lead to the formation of functional neuronal networks - the basis of unsupervised learning (we will study some of these concepts in a forthcoming tutorial).\n",
        "\n",
        "Synchrony implies a reduction in the dimensionality of the system. In addition, correlations, in many cases, can impair the decoding of neuronal activity."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "execution": {},
        "id": "wfZKLtKZlvMI"
      },
      "outputs": [],
      "source": [
        "# @title Video 1: Input & output correlations\n",
        "from ipywidgets import widgets\n",
        "from IPython.display import YouTubeVideo\n",
        "from IPython.display import IFrame\n",
        "from IPython.display import display\n",
        "\n",
        "\n",
        "class PlayVideo(IFrame):\n",
        "  def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
        "    self.id = id\n",
        "    if source == 'Bilibili':\n",
        "      src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
        "    elif source == 'Osf':\n",
        "      src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
        "    super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
        "\n",
        "\n",
        "def display_videos(video_ids, W=400, H=300, fs=1):\n",
        "  tab_contents = []\n",
        "  for i, video_id in enumerate(video_ids):\n",
        "    out = widgets.Output()\n",
        "    with out:\n",
        "      if video_ids[i][0] == 'Youtube':\n",
        "        video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
        "                             height=H, fs=fs, rel=0)\n",
        "        print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
        "      else:\n",
        "        video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
        "                          height=H, fs=fs, autoplay=False)\n",
        "        if video_ids[i][0] == 'Bilibili':\n",
        "          print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
        "        elif video_ids[i][0] == 'Osf':\n",
        "          print(f'Video available at https://osf.io/{video.id}')\n",
        "      display(video)\n",
        "    tab_contents.append(out)\n",
        "  return tab_contents\n",
        "\n",
        "\n",
        "video_ids = [('Youtube', 'nsAYFBcAkes'), ('Bilibili', 'BV1Bh411o7eV')]\n",
        "tab_contents = display_videos(video_ids, W=854, H=480)\n",
        "tabs = widgets.Tab()\n",
        "tabs.children = tab_contents\n",
        "for i in range(len(tab_contents)):\n",
        "  tabs.set_title(i, video_ids[i][0])\n",
        "display(tabs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "RYltMoyDlvMI"
      },
      "source": [
        "A simple model to study the emergence of correlations is to inject common inputs to a pair of neurons and measure the output correlation as a function of the fraction of common inputs.\n",
        "\n",
        "Here, we are going to investigate the transfer of correlations by computing the correlation coefficient of spike trains recorded from two unconnected LIF neurons, which received correlated inputs.\n",
        "\n",
        "\n",
        "The input current to LIF neuron $i$ $(i=1,2)$ is:\n",
        "\n",
        "\\begin{equation}\n",
        "\\frac{I_i}{g_L} = \\mu_i + \\sigma_i (\\sqrt{1-c}\\xi_i + \\sqrt{c}\\xi_c) \\quad (1)\n",
        "\\end{equation}\n",
        "\n",
        "where $\\mu_i$ is the temporal average of the current. The Gaussian white noise $\\xi_i$ is independent for each neuron, while $\\xi_c$ is common to all neurons. The variable $c$ ($0\\le c\\le1$) controls the fraction of common and independent inputs. $\\sigma_i$ shows the variance of the total input.\n",
        "\n",
        "So, first, we will generate correlated inputs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "execution": {},
        "id": "gI6qRA3QlvMI"
      },
      "outputs": [],
      "source": [
        "# @markdown Execute this cell to get a function `correlate_input` for generating correlated GWN inputs\n",
        "def correlate_input(pars, mu=20., sig=7.5, c=0.3):\n",
        "  \"\"\"\n",
        "  Args:\n",
        "    pars       : parameter dictionary\n",
        "    mu         : noise baseline (mean)\n",
        "    sig        : noise amplitute (standard deviation)\n",
        "    c.         : correlation coefficient ~[0, 1]\n",
        "\n",
        "  Returns:\n",
        "    I1gL, I2gL : two correlated inputs with corr. coe. c\n",
        "  \"\"\"\n",
        "\n",
        "  # generate Gaussian whute noise xi_1, xi_2, xi_c\n",
        "  xi_1 = my_GWN(pars, sig)\n",
        "  xi_2 = my_GWN(pars, sig)\n",
        "  xi_c = my_GWN(pars, sig)\n",
        "\n",
        "  # Generate two correlated inputs by Equation. (1)\n",
        "  I1gL = mu + np.sqrt(1. - c) * xi_1 + np.sqrt(c) * xi_c\n",
        "  I2gL = mu + np.sqrt(1. - c) * xi_2 + np.sqrt(c) * xi_c\n",
        "\n",
        "  return I1gL, I2gL\n",
        "\n",
        "help(correlate_input)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "tmASEvkllvMI"
      },
      "source": [
        "## Coding Exercise 1A: Compute the correlation\n",
        "\n",
        "The _sample correlation coefficient_ between two input currents $I_i$ and $I_j$ is defined as the sample covariance of $I_i$ and $I_j$ divided by the square root of the sample variance of $I_i$ multiplied with the square root of the sample variance of $I_j$. In equation form:\n",
        "\n",
        "\\begin{align}\n",
        "r_{ij} &= \\frac{cov(I_i, I_j)}{\\sqrt{var(I_i)} \\sqrt{var(I_j)}}\\\\\n",
        "cov(I_i, I_j) &= \\sum_{k=1}^L (I_i^k -\\bar{I_i})(I_j^k -\\bar{I_j}) \\\\\n",
        "var(I_i) &= \\sum_{k=1}^L (I_i^k -\\bar{I}_i)^2\n",
        "\\end{align}\n",
        "\n",
        "where $\\bar{I_i}$ is the sample mean, $k$ is the time bin, and $L$ is the length of $I$. This means that $I_i^k$ is current $i$ at time $k\\cdot dt$.\n",
        "\n",
        "<br>\n",
        "\n",
        "**Important note:** The equations above are not accurate for sample covariances and variances as they should be additionally divided by $L-1$. We have dropped this term because it cancels out in the sample correlation coefficient formula.\n",
        "\n",
        "<br>\n",
        "\n",
        "The _sample correlation coefficient_ may also be referred to as the _sample Pearson correlation coefficient_. Here, is a beautiful paper that explains multiple ways to calculate and understand correlations [Rodgers and Nicewander 1988](https://www.stat.berkeley.edu/~rabbee/correlation.pdf).\n",
        "\n",
        "In this exercise, we will create a function, `my_CC` to compute the sample correlation coefficient between two time series. Note that while we introduced this computation here in the context of input currents, the sample correlation coefficient is used to compute the correlation between any two time series - we will use it later on binned spike trains.\n",
        "\n",
        "We then check our method is accurate by generating currents with a certain correlation (using `correlate_input`), computing the correlation coefficient using `my_CC`, and plotting the true vs sample correlation coefficients."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {},
        "id": "2aTOgyL2lvMJ"
      },
      "outputs": [],
      "source": [
        "def my_CC(i, j):\n",
        "  \"\"\"\n",
        "  Args:\n",
        "    i, j  : two time series with the same length\n",
        "\n",
        "  Returns:\n",
        "    rij   : correlation coefficient\n",
        "  \"\"\"\n",
        "  ########################################################################\n",
        "  ## TODO for students: compute rxy, then remove the NotImplementedError #\n",
        "  # Tip1: array([a1, a2, a3])*array([b1, b2, b3]) = array([a1*b1, a2*b2, a3*b3])\n",
        "  # Tip2: np.sum(array([a1, a2, a3])) = a1+a2+a3\n",
        "  # Tip3: square root, np.sqrt()\n",
        "  # Fill out function and remove\n",
        "  raise NotImplementedError(\"Student exercise: compute the sample correlation coefficient\")\n",
        "  ########################################################################\n",
        "\n",
        "  # Calculate the covariance of i and j\n",
        "  cov = ...\n",
        "\n",
        "  # Calculate the variance of i\n",
        "  var_i = ...\n",
        "\n",
        "  # Calculate the variance of j\n",
        "  var_j = ...\n",
        "\n",
        "  # Calculate the correlation coefficient\n",
        "  rij = ...\n",
        "\n",
        "  return rij\n",
        "\n",
        "\n",
        "example_plot_myCC()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "v29xwxRtlvMJ"
      },
      "source": [
        "The sample correlation coefficients (computed using `my_CC`) match the ground truth correlation coefficient!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "WWFWZusXlvMJ"
      },
      "source": [
        "In the next exercise, we will use the Poisson distribution to model spike trains. Remember that you have seen the Poisson distribution used in this way in the [pre-reqs math day on Statistics](https://compneuro.neuromatch.io/tutorials/W0D5_Statistics/student/W0D5_Tutorial1.html#section-2-2-poisson-distribution). Remember that a Poisson spike train has the following properties:\n",
        "- The ratio of the mean and variance of spike count is 1\n",
        "- Inter-spike-intervals are exponentially distributed\n",
        "- Spike times are irregular i.e. š¶š‘‰ISI=1\n",
        "- Adjacent spike intervals are independent of each other.\n",
        "\n",
        "In the following cell, we provide a helper function `Poisson_generator` and then use it to produce a Poisson spike train."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "execution": {},
        "id": "lTY683XplvMJ"
      },
      "outputs": [],
      "source": [
        "# @markdown Execute this cell to get helper function `Poisson_generator`\n",
        "def Poisson_generator(pars, rate, n, myseed=False):\n",
        "  \"\"\"\n",
        "  Generates poisson trains\n",
        "\n",
        "  Args:\n",
        "    pars       : parameter dictionary\n",
        "    rate       : noise amplitute [Hz]\n",
        "    n          : number of Poisson trains\n",
        "    myseed     : random seed. int or boolean\n",
        "\n",
        "  Returns:\n",
        "    pre_spike_train : spike train matrix, ith row represents whether\n",
        "                      there is a spike in ith spike train over time\n",
        "                      (1 if spike, 0 otherwise)\n",
        "  \"\"\"\n",
        "\n",
        "  # Retrieve simulation parameters\n",
        "  dt, range_t = pars['dt'], pars['range_t']\n",
        "  Lt = range_t.size\n",
        "\n",
        "  # set random seed\n",
        "  if myseed:\n",
        "      np.random.seed(seed=myseed)\n",
        "  else:\n",
        "      np.random.seed()\n",
        "\n",
        "  # generate uniformly distributed random variables\n",
        "  u_rand = np.random.rand(n, Lt)\n",
        "\n",
        "  # generate Poisson train\n",
        "  poisson_train = 1. * (u_rand < rate * (dt / 1000.))\n",
        "\n",
        "  return poisson_train\n",
        "\n",
        "help(Poisson_generator)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "execution": {},
        "id": "7pSDW03elvMJ"
      },
      "outputs": [],
      "source": [
        "# @markdown Execute this cell to visualize Poisson spike train\n",
        "\n",
        "pars = default_pars()\n",
        "pre_spike_train = Poisson_generator(pars, rate=10, n=100, myseed=2020)\n",
        "my_raster_Poisson(pars['range_t'], pre_spike_train, 100)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "85rvvIfulvMJ"
      },
      "source": [
        "## Coding Exercise 1B: Measure the correlation between spike trains\n",
        "\n",
        "After recording the spike times of the two neurons, how can we estimate their correlation coefficient?\n",
        "\n",
        "In order to find this, we need to bin the spike times and obtain two time series. Each data point in the time series is the number of spikes in the corresponding time bin. You can use `np.histogram()` to bin the spike times.\n",
        "\n",
        "Complete the code below to bin the spike times and calculate the correlation coefficient for two Poisson spike trains. Note that `c` here is the ground-truth correlation coefficient that we define.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "execution": {},
        "id": "KU5u1MnnlvMJ"
      },
      "outputs": [],
      "source": [
        "# @markdown Execute this cell to get a function for generating correlated Poisson inputs (`generate_corr_Poisson`)\n",
        "\n",
        "\n",
        "def generate_corr_Poisson(pars, poi_rate, c, myseed=False):\n",
        "  \"\"\"\n",
        "  function to generate correlated Poisson type spike trains\n",
        "  Args:\n",
        "    pars       : parameter dictionary\n",
        "    poi_rate   : rate of the Poisson train\n",
        "    c.         : correlation coefficient ~[0, 1]\n",
        "\n",
        "  Returns:\n",
        "    sp1, sp2   : two correlated spike time trains with corr. coe. c\n",
        "  \"\"\"\n",
        "\n",
        "  range_t = pars['range_t']\n",
        "\n",
        "  mother_rate = poi_rate / c\n",
        "  mother_spike_train = Poisson_generator(pars, rate=mother_rate,\n",
        "                                         n=1, myseed=myseed)[0]\n",
        "  sp_mother = range_t[mother_spike_train > 0]\n",
        "\n",
        "  L_sp_mother = len(sp_mother)\n",
        "  sp_mother_id = np.arange(L_sp_mother)\n",
        "  L_sp_corr = int(L_sp_mother * c)\n",
        "\n",
        "  np.random.shuffle(sp_mother_id)\n",
        "  sp1 = np.sort(sp_mother[sp_mother_id[:L_sp_corr]])\n",
        "\n",
        "  np.random.shuffle(sp_mother_id)\n",
        "  sp2 = np.sort(sp_mother[sp_mother_id[:L_sp_corr]])\n",
        "\n",
        "  return sp1, sp2\n",
        "\n",
        "print(help(generate_corr_Poisson))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {},
        "id": "Nyr_WdDGlvMJ"
      },
      "outputs": [],
      "source": [
        "def corr_coeff_pairs(pars, rate, c, trials, bins):\n",
        "  \"\"\"\n",
        "  Calculate the correlation coefficient of two spike trains, for different\n",
        "  realizations\n",
        "\n",
        "  Args:\n",
        "      pars   : parameter dictionary\n",
        "      rate   : rate of poisson inputs\n",
        "      c      : correlation coefficient ~ [0, 1]\n",
        "      trials : number of realizations\n",
        "      bins   : vector with bins for time discretization\n",
        "\n",
        "  Returns:\n",
        "    r12      : correlation coefficient of a pair of inputs\n",
        "  \"\"\"\n",
        "\n",
        "  r12 = np.zeros(trials)\n",
        "\n",
        "  for i in range(trials):\n",
        "    ##############################################################\n",
        "    ## TODO for students\n",
        "    # Note that you can run multiple realizations and compute their r_12(diff_trials)\n",
        "    # with the defined function above. The average r_12 over trials can get close to c.\n",
        "    # Note: change seed to generate different input per trial\n",
        "    # Fill out function and remove\n",
        "    raise NotImplementedError(\"Student exercise: compute the correlation coefficient\")\n",
        "    ##############################################################\n",
        "\n",
        "    # Generate correlated Poisson inputs\n",
        "    sp1, sp2 = generate_corr_Poisson(pars, ..., ..., myseed=2020+i)\n",
        "\n",
        "    # Bin the spike times of the first input\n",
        "    sp1_count, _ = np.histogram(..., bins=...)\n",
        "\n",
        "    # Bin the spike times of the second input\n",
        "    sp2_count, _ = np.histogram(..., bins=...)\n",
        "\n",
        "    # Calculate the correlation coefficient\n",
        "    r12[i] = my_CC(..., ...)\n",
        "\n",
        "  return r12\n",
        "\n",
        "\n",
        "poi_rate = 20.\n",
        "c = 0.2  # set true correlation\n",
        "pars = default_pars(T=10000)\n",
        "\n",
        "# bin the spike time\n",
        "bin_size = 20  # [ms]\n",
        "my_bin = np.arange(0, pars['T'], bin_size)\n",
        "n_trials = 100  # 100 realizations\n",
        "\n",
        "r12 = corr_coeff_pairs(pars, rate=poi_rate, c=c, trials=n_trials, bins=my_bin)\n",
        "print(f'True corr coe = {c:.3f}')\n",
        "print(f'Simu corr coe = {r12.mean():.3f}')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "lDoQsLOYlvMJ"
      },
      "source": [
        "Sample output\n",
        "\n",
        "```\n",
        "True corr coe = 0.200\n",
        "Simu corr coe = 0.197\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "PGoi_ApPlvMJ"
      },
      "source": [
        "---\n",
        "# Section 2: Investigate the effect of input correlation on the output correlation\n",
        "\n",
        "\n",
        "Now let's combine the aforementioned two procedures. We first generate the correlated inputs. Then we inject the correlated inputs $I_1, I_2$ into a pair of neurons and record their output spike times. We continue measuring the correlation between the output and\n",
        "investigate the relationship between the input correlation and the output correlation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "QRd26JgalvMJ"
      },
      "source": [
        "\n",
        "In the following, you will inject correlated GWN in two neurons. You need to define the mean (`gwn_mean`), standard deviation (`gwn_std`), and input correlations (`c_in`).\n",
        "\n",
        "We will simulate $10$ trials to get a better estimate of the output correlation. Change the values in the following cell for the above variables (and then run the next cell) to explore how they impact the output correlation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "execution": {},
        "id": "5Lg-QjwQlvMJ"
      },
      "outputs": [],
      "source": [
        "# Play around with these parameters\n",
        "\n",
        "pars = default_pars(T=80000, dt=1.)  # get the parameters\n",
        "c_in = 0.3  # set input correlation value\n",
        "gwn_mean = 10.\n",
        "gwn_std = 10."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "execution": {},
        "id": "sxaB9hCrlvMJ"
      },
      "outputs": [],
      "source": [
        "# @markdown Do not forget to execute this cell to simulate the LIF\n",
        "\n",
        "bin_size = 10.  # ms\n",
        "starttime = time.perf_counter()  # time clock\n",
        "r12_ss, sp_ss, sp1, sp2 = LIF_output_cc(pars, mu=gwn_mean, sig=gwn_std, c=c_in,\n",
        "                                        bin_size=bin_size, n_trials=10)\n",
        "\n",
        "# just the time counter\n",
        "endtime = time.perf_counter()\n",
        "timecost = (endtime - starttime) / 60.\n",
        "print(f\"Simulation time = {timecost:.2f} min\")\n",
        "\n",
        "print(f\"Input correlation = {c_in}\")\n",
        "print(f\"Output correlation = {r12_ss}\")\n",
        "\n",
        "plt.figure(figsize=(12, 6))\n",
        "plt.plot(sp1, np.ones(len(sp1)) * 1, '|', ms=20, label='neuron 1')\n",
        "plt.plot(sp2, np.ones(len(sp2)) * 1.1, '|', ms=20, label='neuron 2')\n",
        "plt.xlabel('time (ms)')\n",
        "plt.ylabel('neuron id.')\n",
        "plt.xlim(1000, 8000)\n",
        "plt.ylim(0.9, 1.2)\n",
        "plt.legend()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "5tJnEZxdlvMK"
      },
      "source": [
        "## Think! 2: Input and Output Correlations\n",
        "- Is the output correlation always smaller than the input correlation? If yes, why?\n",
        "- Should there be a systematic relationship between input and output correlations?\n",
        "\n",
        "You will explore these questions in the next figure but try to develop your own intuitions first!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "fz2sZ9BrlvMK"
      },
      "source": [
        "Let's vary `c_in` and plot the relationship between the `c_in` and output correlation. This might take some time depending on the number of trials."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "execution": {},
        "id": "a0Kd98w1lvMK"
      },
      "outputs": [],
      "source": [
        "# @markdown Don't forget to execute this cell!\n",
        "\n",
        "pars = default_pars(T=80000, dt=1.)  # get the parameters\n",
        "bin_size = 10.\n",
        "c_in = np.arange(0, 1.0, 0.1)  # set the range for input CC\n",
        "r12_ss = np.zeros(len(c_in))  # small mu, small sigma\n",
        "\n",
        "starttime = time.perf_counter() # time clock\n",
        "for ic in range(len(c_in)):\n",
        "  r12_ss[ic], sp_ss, sp1, sp2 = LIF_output_cc(pars, mu=10.0, sig=10.,\n",
        "                                              c=c_in[ic], bin_size=bin_size,\n",
        "                                              n_trials=10)\n",
        "\n",
        "endtime = time.perf_counter()\n",
        "timecost = (endtime - starttime) / 60.\n",
        "print(f\"Simulation time = {timecost:.2f} min\")\n",
        "\n",
        "plot_c_r_LIF(c_in, r12_ss, mycolor='b', mylabel='Output CC')\n",
        "plt.plot([c_in.min() - 0.05, c_in.max() + 0.05],\n",
        "         [c_in.min() - 0.05, c_in.max() + 0.05],\n",
        "         'k--', dashes=(2, 2), label='y=x')\n",
        "\n",
        "plt.xlabel('Input CC')\n",
        "plt.ylabel('Output CC')\n",
        "plt.legend(loc='best', fontsize=16)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "eHAB_q6IlvMK"
      },
      "source": [
        "---\n",
        "# Section 3: Correlation transfer function\n",
        "\n",
        "The above plot of input correlation vs. output correlation is called the __correlation transfer function__ of the neurons."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "WYEt0b0ElvMK"
      },
      "source": [
        "## Section 3.1: How do the mean and standard deviation of the Gaussian white noise (GWN) affect the correlation transfer function?\n",
        "\n",
        "The correlations transfer function appears to be linear. The above can be taken as the input/output transfer function of LIF neurons for correlations, instead of the transfer function for input/output firing rates as we had discussed in the previous tutorial (i.e., F-I curve).\n",
        "\n",
        "What would you expect to happen to the slope of the correlation transfer function if you vary the mean and/or the standard deviation of the GWN of the inputs ?"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "execution": {},
        "id": "y10WlmiylvMK"
      },
      "outputs": [],
      "source": [
        "# @markdown Execute this cell to visualize correlation transfer functions\n",
        "\n",
        "pars = default_pars(T=80000, dt=1.) # get the parameters\n",
        "n_trials = 10\n",
        "bin_size = 10.\n",
        "c_in = np.arange(0., 1., 0.2)  # set the range for input CC\n",
        "r12_ss = np.zeros(len(c_in))   # small mu, small sigma\n",
        "r12_ls = np.zeros(len(c_in))   # large mu, small sigma\n",
        "r12_sl = np.zeros(len(c_in))   # small mu, large sigma\n",
        "\n",
        "starttime = time.perf_counter()  # time clock\n",
        "for ic in range(len(c_in)):\n",
        "  r12_ss[ic], sp_ss, sp1, sp2 = LIF_output_cc(pars, mu=10.0, sig=10.,\n",
        "                                              c=c_in[ic], bin_size=bin_size,\n",
        "                                              n_trials=n_trials)\n",
        "  r12_ls[ic], sp_ls, sp1, sp2 = LIF_output_cc(pars, mu=18.0, sig=10.,\n",
        "                                              c=c_in[ic], bin_size=bin_size,\n",
        "                                              n_trials=n_trials)\n",
        "  r12_sl[ic], sp_sl, sp1, sp2 = LIF_output_cc(pars, mu=10.0, sig=20.,\n",
        "                                              c=c_in[ic], bin_size=bin_size,\n",
        "                                              n_trials=n_trials)\n",
        "endtime = time.perf_counter()\n",
        "timecost = (endtime - starttime) / 60.\n",
        "print(f\"Simulation time = {timecost:.2f} min\")\n",
        "\n",
        "\n",
        "plot_c_r_LIF(c_in, r12_ss, mycolor='b', mylabel=r'Small $\\mu$, small $\\sigma$')\n",
        "plot_c_r_LIF(c_in, r12_ls, mycolor='y', mylabel=r'Large $\\mu$, small $\\sigma$')\n",
        "plot_c_r_LIF(c_in, r12_sl, mycolor='r', mylabel=r'Small $\\mu$, large $\\sigma$')\n",
        "plt.plot([c_in.min() - 0.05, c_in.max() + 0.05],\n",
        "         [c_in.min() - 0.05, c_in.max() + 0.05],\n",
        "         'k--', dashes=(2, 2), label='y=x')\n",
        "plt.xlabel('Input CC')\n",
        "plt.ylabel('Output CC')\n",
        "plt.legend(loc='best', fontsize=14)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "ja1Lc6c_lvMK"
      },
      "source": [
        "### Think! 3.1: GWN and the Correlation Transfer Function\n",
        "Why do both the mean and the standard deviation of the GWN affect the slope of the correlation transfer function?"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "8v4hV_nblvMK"
      },
      "source": [
        "## Section 3.2: What is the rationale behind varying $\\mu$ and $\\sigma$?\n",
        "\n",
        "The mean and the variance of the synaptic current depends on the spike rate of a Poisson process. We can use something called [Campbell's theorem](https://en.wikipedia.org/wiki/Campbell%27s_theorem_(probability)) to estimate the mean and the variance of the synaptic current:\n",
        "\n",
        "\\begin{align}\n",
        "\\mu_{\\rm syn} = \\lambda J \\int P(t) dt \\\\\n",
        "\\sigma_{\\rm syn} = \\lambda J \\int P(t)^2 dt\n",
        "\\end{align}\n",
        "\n",
        "where $\\lambda$ is the firing rate of the Poisson input, $J$ the amplitude of the postsynaptic current and $P(t)$ is the shape of the postsynaptic current as a function of time.\n",
        "\n",
        "Therefore, when we varied $\\mu$ and/or $\\sigma$ of the GWN, we mimicked a change in the input firing rate. Note that, if we change the firing rate, both $\\mu$ and $\\sigma$ will change simultaneously, not independently.\n",
        "\n",
        "Here, since we observe an effect of $\\mu$ and $\\sigma$ on correlation transfer, this implies that the input rate has an impact on the correlation transfer function."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "lkJdvGe0lvMK"
      },
      "source": [
        "# Think!: Correlations and Network Activity\n",
        "\n",
        "- What are the factors that would make output correlations smaller than input correlations? (Notice that the colored lines are below the black dashed line)\n",
        "- What does the fact that output correlations are smaller mean for the correlations throughout a network?\n",
        "- Here we have studied the transfer of correlations by injecting GWN. But in the previous tutorial, we mentioned that GWN is unphysiological. Indeed, neurons receive colored noise (i.e., Shot noise or OU process). How do these results obtained from injection of GWN apply to the case where correlated spiking inputs are injected in the two LIFs? Will the results be the same or different?\n",
        "\n",
        "<br>\n",
        "\n",
        "References:\n",
        "\n",
        "- de la Rocha J, Doiron B, Shea-Brown E, Josić K, Reyes A (2007). Correlation between neural spike trains increases with firing rate. Nature 448:802-806. doi: [10.1038/nature06028](https://doi.org/10.1038/nature06028)\n",
        "\n",
        "- Bujan AF, Aertsen A, Kumar A (2015). Role of input correlations in shaping the variability and noise correlations of evoked activity in the neocortex, Journal of Neuroscience 35(22):8611-25. doi: [10.1523/JNEUROSCI.4536-14.2015](https://doi.org/10.1523/JNEUROSCI.4536-14.2015)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "execution": {},
        "id": "KCPQ03ftlvMK"
      },
      "source": [
        "---\n",
        "# Summary\n",
        "\n",
        "\n",
        "In this tutorial, we studied how the input correlation of two LIF neurons is mapped to their output correlation. Specifically, we:\n",
        "\n",
        "- injected correlated GWN in a pair of neurons,\n",
        "\n",
        "- measured correlations between the spiking activity of the two neurons, and\n",
        "\n",
        "- studied how the transfer of correlation depends on the statistics of the input, i.e., mean and standard deviation.\n",
        "\n",
        "Here, we were concerned with zero time lag correlation. For this reason, we restricted estimation of correlation to instantaneous correlations. If you are interested in time-lagged correlation, then we should estimate the cross-correlogram of the spike trains and find out the dominant peak and area under the peak to get an estimate of output correlations.\n",
        "\n",
        "We leave this as a future to-do for you if you are interested.\n",
        "\n",
        "If you have time, check out the bonus video to think about responses of ensembles of neurons to time-varying input."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernel": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.17"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}