{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This example script requires a number of extra packages to be installed:\n",
    "\n",
    "<list>\n",
    "    <li><a href=\"https://github.com/SeanNaren/deepspeech.pytorch\">deepspeech_pytorch</a></li>\n",
    "    <li><a href=\"https://github.com/omry/omegaconf\">omegaconf</a>\n",
    "    <li><a href=\"https://github.com/PyTorchLightning/pytorch-lightning\">pytorch_lightning</a></li>\n",
    "</list>\n",
    "\n",
    "The $\\mathrm{DeepSpeech}$ class defined below is a slightly modified version of the similarly named class in the <a href=\"https://github.com/SeanNaren/deepspeech.pytorch/blob/master/deepspeech_pytorch/model.py\">models.py</a> module of the $\\mathrm{deepspeech\\_pytorch}$ package.\n",
    "\n",
    "To replicate the results in this notebook, the following additional resources are provided:\n",
    "\n",
    "<list>\n",
    "    <li>Pretrained DeepSpeech model: $\\text{deepspeech2-pretrained.ckpt}$</li>\n",
    "    <li>Sample 4-second 16KHz audio clips extracted from the <em>test-clean</em> and <em>dev-clean</em> subsets of the LibriSpeech corpus: $\\text{segments-librispeech-*}$</li>\n",
    "</list>\n",
    "\n",
    "To replicate the results here, extract $\\text{Examples/resources.tar}$ and place the contents in a folder named $\\text{resources/}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchaudio\n",
    "from torch.cuda.amp import autocast\n",
    "torchaudio.set_audio_backend(\"sox_io\")\n",
    "\n",
    "from omegaconf import OmegaConf\n",
    "from deepspeech_pytorch.configs.train_config import SpectConfig, BiDirectionalConfig, AdamConfig, SGDConfig\n",
    "from deepspeech_pytorch.decoder import GreedyDecoder\n",
    "from deepspeech_pytorch.utils import load_model\n",
    "from deepspeech_pytorch.validation import CharErrorRate, WordErrorRate\n",
    "from deepspeech_pytorch.model import SequenceWise, MaskConv, InferenceBatchSoftmax, Lookahead\n",
    "from deepspeech_pytorch.enums import RNNType, SpectrogramWindow\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "import PyTCI.tci as tci\n",
    "import PyTCI.audio as fx\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "\n",
    "LABELS = list(\"_'ABCDEFGHIJKLMNOPQRSTUVWXYZ \")\n",
    "MODEL_CFG = BiDirectionalConfig(rnn_type=RNNType.lstm, hidden_size=1024, hidden_layers=7)\n",
    "OPTIM_CFG = AdamConfig(learning_rate=0.00015, learning_anneal=0.99, weight_decay=1e-05, eps=1e-08, betas=[0.9, 0.999])\n",
    "SPECT_CFG = SpectConfig(sample_rate=16000, window_size=0.02, window_stride=0.01, window=SpectrogramWindow.hamming)\n",
    "PRECISION = 16\n",
    "\n",
    "\n",
    "class SpectrogramParser(nn.Module):\n",
    "    def __init__(self, audio_conf: SpectConfig, normalize: bool = False):\n",
    "        \"\"\"\n",
    "       \tParses audio file into spectrogram with optional normalization\n",
    "       \t:param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds\n",
    "       \t:param normalize(default False):  Apply standard mean and deviation normalization to audio tensor\n",
    "       \t\"\"\"\n",
    "        super().__init__()\n",
    "        self.window_stride = audio_conf.window_stride\n",
    "        self.window_size = audio_conf.window_size\n",
    "        self.sample_rate = audio_conf.sample_rate\n",
    "        self.window = audio_conf.window.value\n",
    "        self.normalize = normalize\n",
    "        \n",
    "        n_fft = int(self.sample_rate * self.window_size)\n",
    "        win_length = n_fft\n",
    "        hop_length = int(self.sample_rate * self.window_stride)\n",
    "        if self.window == 'hamming':\n",
    "            window = torch.hamming_window\n",
    "        else:\n",
    "            raise NotImplementedError()\n",
    "        \n",
    "        self.transform = torchaudio.transforms.Spectrogram(\n",
    "            n_fft, win_length, hop_length, window_fn=window, power=1, normalized=False)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def forward(self, audio):\n",
    "        if audio.shape[-1] == 1:\n",
    "            audio = audio.squeeze(dim=-1) # mono\n",
    "        else:\n",
    "            audio = audio.mean(dim=-1) # multiple channels, average\n",
    "        \n",
    "        # trim final samples if extra samples left out from downsampling doing conversion\n",
    "        audio = audio[:-round(len(audio) % self.transform.hop_length)-1]\n",
    "        \n",
    "        spect = self.transform(audio)\n",
    "        spect = torch.log1p(spect)\n",
    "        \n",
    "        if self.normalize:\n",
    "            mean = spect.mean()\n",
    "            std = spect.std()\n",
    "            spect.add_(-mean)\n",
    "            spect.div_(std)\n",
    "        \n",
    "        # reshape to [time x frequency]\n",
    "        spect = spect.T.contiguous()\n",
    "        \n",
    "        return spect\n",
    "\n",
    "\n",
    "class BatchRNN(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True):\n",
    "        super(BatchRNN, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size = hidden_size\n",
    "        self.bidirectional = bidirectional\n",
    "        self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None\n",
    "        self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size,\n",
    "                            bidirectional=bidirectional, bias=True)\n",
    "        self.num_directions = 2 if bidirectional else 1\n",
    "\n",
    "    def flatten_parameters(self):\n",
    "        self.rnn.flatten_parameters()\n",
    "\n",
    "    def forward(self, x, output_lengths):\n",
    "        if self.batch_norm is not None:\n",
    "            x = self.batch_norm(x)\n",
    "        x = nn.utils.rnn.pack_padded_sequence(x, output_lengths, enforce_sorted=False)\n",
    "        x, h = self.rnn(x)\n",
    "        x, _ = nn.utils.rnn.pad_packed_sequence(x)\n",
    "        if self.bidirectional:\n",
    "            x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1)  # (TxNxH*2) -> (TxNxH) by sum\n",
    "        return x\n",
    "\n",
    "\n",
    "class DeepSpeech(pl.LightningModule):\n",
    "    def __init__(self, labels=LABELS, model_cfg=MODEL_CFG, precision=PRECISION, optim_cfg=OPTIM_CFG, spect_cfg=SPECT_CFG):\n",
    "        super().__init__()\n",
    "        self.save_hyperparameters()\n",
    "        self.model_cfg = model_cfg\n",
    "        self.precision = precision\n",
    "        self.optim_cfg = optim_cfg\n",
    "        self.spect_cfg = spect_cfg\n",
    "        self.bidirectional = True if OmegaConf.get_type(model_cfg) is BiDirectionalConfig else False\n",
    "\n",
    "        self.labels = labels\n",
    "        num_classes = len(self.labels)\n",
    "\n",
    "        self.conv = MaskConv(nn.Sequential(\n",
    "            nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)),\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.Hardtanh(0, 20, inplace=True),\n",
    "            nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)),\n",
    "            nn.BatchNorm2d(32),\n",
    "            nn.Hardtanh(0, 20, inplace=True)\n",
    "        ))\n",
    "        # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1\n",
    "        rnn_input_size = int(math.floor((self.spect_cfg.sample_rate * self.spect_cfg.window_size) / 2) + 1)\n",
    "        rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1)\n",
    "        rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1)\n",
    "        rnn_input_size *= 32\n",
    "\n",
    "        self.rnns = nn.Sequential(\n",
    "            BatchRNN(\n",
    "                input_size=rnn_input_size,\n",
    "                hidden_size=self.model_cfg.hidden_size,\n",
    "                rnn_type=self.model_cfg.rnn_type.value,\n",
    "                bidirectional=self.bidirectional,\n",
    "                batch_norm=False\n",
    "            ),\n",
    "            *(\n",
    "                BatchRNN(\n",
    "                    input_size=self.model_cfg.hidden_size,\n",
    "                    hidden_size=self.model_cfg.hidden_size,\n",
    "                    rnn_type=self.model_cfg.rnn_type.value,\n",
    "                    bidirectional=self.bidirectional\n",
    "                ) for x in range(self.model_cfg.hidden_layers - 3)\n",
    "            )\n",
    "        )\n",
    "\n",
    "        self.lookahead = nn.Sequential(\n",
    "            # consider adding batch norm?\n",
    "            Lookahead(self.model_cfg.hidden_size, context=self.model_cfg.lookahead_context),\n",
    "            nn.Hardtanh(0, 20, inplace=True)\n",
    "        ) if not self.bidirectional else None\n",
    "\n",
    "        fully_connected = nn.Sequential(\n",
    "            nn.BatchNorm1d(self.model_cfg.hidden_size),\n",
    "            nn.Linear(self.model_cfg.hidden_size, num_classes, bias=False)\n",
    "        )\n",
    "        self.fc = nn.Sequential(\n",
    "            SequenceWise(fully_connected),\n",
    "        )\n",
    "        self.inference_softmax = InferenceBatchSoftmax()\n",
    "        self.criterion = nn.CTCLoss(blank=self.labels.index('_'), reduction='sum', zero_infinity=True)\n",
    "        self.evaluation_decoder = GreedyDecoder(self.labels)  # Decoder used for validation\n",
    "        self.wer = WordErrorRate(\n",
    "            decoder=self.evaluation_decoder,\n",
    "            target_decoder=self.evaluation_decoder\n",
    "        )\n",
    "        self.cer = CharErrorRate(\n",
    "            decoder=self.evaluation_decoder,\n",
    "            target_decoder=self.evaluation_decoder\n",
    "        )\n",
    "\n",
    "    def forward(self, x, lengths):\n",
    "        lengths = lengths.cpu().int()\n",
    "        output_lengths = self.get_seq_lens(lengths)\n",
    "        x, _ = self.conv(x.transpose(1,2).unsqueeze(1).contiguous(), output_lengths)\n",
    "\n",
    "        sizes = x.size()\n",
    "        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # Collapse feature dimension\n",
    "        x = x.transpose(1, 2).transpose(0, 1).contiguous()  # TxNxH\n",
    "\n",
    "        for rnn in self.rnns:\n",
    "            x = rnn(x, output_lengths)\n",
    "\n",
    "        if not self.bidirectional:  # no need for lookahead layer in bidirectional\n",
    "            x = self.lookahead(x)\n",
    "\n",
    "        x = self.fc(x)\n",
    "        x = x.transpose(0, 1)\n",
    "        # identity in training mode, softmax in eval mode\n",
    "        x = self.inference_softmax(x)\n",
    "        return x, output_lengths\n",
    "    \n",
    "    def unpack_batch(self, batch):\n",
    "        inputs = batch.get('inputs', None)\n",
    "        input_lengths = batch.get('input_lengths', None)\n",
    "        labels = batch.get('labels', None)\n",
    "        label_lengths = batch.get('label_lengths', None)\n",
    "        \n",
    "        return inputs, labels, input_lengths, label_lengths\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        inputs, targets, input_sizes, target_sizes = self.unpack_batch(batch)\n",
    "        if inputs is None: # skip step\n",
    "            return None\n",
    "        \n",
    "        out, output_sizes = self(inputs, input_sizes)\n",
    "        out = out.transpose(0, 1)  # TxNxH\n",
    "        out = out.log_softmax(-1)\n",
    "\n",
    "        loss = self.criterion(out, targets, output_sizes, target_sizes)\n",
    "        self.log('train_loss', loss)\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        inputs, targets, input_sizes, target_sizes = self.unpack_batch(batch)\n",
    "        if inputs is None: # skip step\n",
    "            return\n",
    "        \n",
    "        inputs = inputs.to(self.device)\n",
    "        with autocast(enabled=self.precision == 16):\n",
    "            out, output_sizes = self(inputs, input_sizes)\n",
    "        decoded_output, _ = self.evaluation_decoder.decode(out, output_sizes)\n",
    "        \n",
    "        self.wer(preds=out, preds_sizes=output_sizes, targets=targets, target_sizes=target_sizes)\n",
    "        self.cer(preds=out, preds_sizes=output_sizes, targets=targets, target_sizes=target_sizes)\n",
    "        self.log('wer', self.wer.compute(), prog_bar=True, on_epoch=True)\n",
    "        self.log('cer', self.cer.compute(), prog_bar=True, on_epoch=True)\n",
    "    \n",
    "    def test_step(self, *args):\n",
    "        return self.validation_step(*args)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        if OmegaConf.get_type(self.optim_cfg) is SGDConfig:\n",
    "            optimizer = torch.optim.SGD(\n",
    "                params=self.parameters(),\n",
    "                lr=self.optim_cfg.learning_rate,\n",
    "                momentum=self.optim_cfg.momentum,\n",
    "                nesterov=True,\n",
    "                weight_decay=self.optim_cfg.weight_decay\n",
    "            )\n",
    "        elif OmegaConf.get_type(self.optim_cfg) is AdamConfig:\n",
    "            optimizer = torch.optim.AdamW(\n",
    "                params=self.parameters(),\n",
    "                lr=self.optim_cfg.learning_rate,\n",
    "                betas=self.optim_cfg.betas,\n",
    "                eps=self.optim_cfg.eps,\n",
    "                weight_decay=self.optim_cfg.weight_decay\n",
    "            )\n",
    "        else:\n",
    "            raise ValueError(\"Optimizer has not been specified correctly.\")\n",
    "\n",
    "        scheduler = torch.optim.lr_scheduler.ExponentialLR(\n",
    "            optimizer=optimizer,\n",
    "            gamma=self.optim_cfg.learning_anneal\n",
    "        )\n",
    "        return [optimizer], [scheduler]\n",
    "\n",
    "    def get_seq_lens(self, input_length):\n",
    "        \"\"\"\n",
    "        Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable\n",
    "        containing the size sequences that will be output by the network.\n",
    "        :param input_length: 1D Tensor\n",
    "        :return: 1D Tensor scaled by model\n",
    "        \"\"\"\n",
    "        seq_len = input_length\n",
    "        for m in self.conv.modules():\n",
    "            if type(m) == nn.modules.conv.Conv2d:\n",
    "                seq_len = torch.div(\n",
    "                    seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1,\n",
    "                    m.stride[1], rounding_mode='floor'\n",
    "                ) + 1\n",
    "        return seq_len.int()\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def activation_fx(self, layer, log=True):\n",
    "        # waveform 2 spectrogram parser\n",
    "        spect_parser = SpectrogramParser(audio_conf=SPECT_CFG, normalize=True).to(self.device)\n",
    "        \n",
    "        def activation(x, /, layer=layer):\n",
    "            # convert to spectrogram\n",
    "            x = spect_parser(x)\n",
    "            lengths = torch.tensor([x.shape[0]], dtype=int)\n",
    "            output_lengths = self.get_seq_lens(lengths)\n",
    "            \n",
    "            # make into 4D tensor of [batch x channel x frequency x time]\n",
    "            # and move to same device as the model\n",
    "            x = x.T[np.newaxis, np.newaxis, ...].contiguous().to(device)\n",
    "            \n",
    "            for module in self.conv.seq_module:\n",
    "                x = module(x)\n",
    "                mask = torch.BoolTensor(x.size()).fill_(0)\n",
    "                if x.is_cuda:\n",
    "                    mask = mask.cuda()\n",
    "                for i, length in enumerate(output_lengths):\n",
    "                    length = length.item()\n",
    "                    if (mask[i].size(2) - length) > 0:\n",
    "                        mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1)\n",
    "                x = x.masked_fill(mask, 0)\n",
    "                \n",
    "                if isinstance(module, torch.nn.Hardtanh):\n",
    "                    layer -= 1\n",
    "                    if layer < 0:\n",
    "                        break\n",
    "            \n",
    "            sizes = x.size()\n",
    "            x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # Collapse feature dimension\n",
    "            x = x.transpose(1, 2).transpose(0, 1).contiguous()  # TxNxH\n",
    "            if layer < 0:\n",
    "                #import pdb; pdb.set_trace()\n",
    "                return x.squeeze(dim=1).cpu()\n",
    "            \n",
    "            for rnn in self.rnns:\n",
    "                x = rnn(x, output_lengths)\n",
    "                layer -= 1\n",
    "                if layer < 0:\n",
    "                    return x.squeeze(dim=1).cpu()\n",
    "            \n",
    "            if not self.bidirectional:  # no need for lookahead layer in bidirectional\n",
    "                x = self.lookahead(x)\n",
    "            \n",
    "            x = self.fc(x)\n",
    "            \n",
    "            # identity in training mode, softmax in eval mode\n",
    "            if log:\n",
    "                x = torch.nn.functional.log_softmax(x, dim=-1)\n",
    "            else:\n",
    "                x = torch.nn.functional.softmax(x, dim=-1)\n",
    "            layer -= 1\n",
    "            if layer < 0:\n",
    "                return x.squeeze(dim=1).cpu()\n",
    "            \n",
    "            return None\n",
    "        \n",
    "        return activation\n",
    "\n",
    "\n",
    "# Pretrained DeepSpeech2 model\n",
    "model = DeepSpeech().to(device).eval()\n",
    "model.load_state_dict(torch.load('resources/deepspeech2-pretrained.ckpt')['state_dict'])\n",
    "\n",
    "# Model output sampling rate\n",
    "out_sr = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Measuring integration windows"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This section demonstrates how to measure integration windows for a more complex deep neural network with multiple layers. To do this, we need to define a function that takes in input $\\mathrm{x}$ and returns output $\\mathrm{y}$, which in this case is the activations at a specific layer $\\mathrm{layer}$. The shape of both $\\mathrm{x}$ and $\\mathrm{y}$ should be of form $[ time \\times channels ]$.\n",
    "\n",
    "We first define a wrapper function to estimate integration windows of all nodes in a layer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wrapper function to estimate integration windows of a layer\n",
    "def estimate_layer_intwin(layer, stimuli, segment_durs, in_sr, out_sr):\n",
    "    sequence_pairs = tci.generate_sequence_pair(\n",
    "        stimuli, in_sr, segment_durs\n",
    "    )\n",
    "    \n",
    "    response_pairs = tci.infer_sequence_pair(\n",
    "        model.activation_fx(layer), sequence_pairs, segment_durs,\n",
    "        in_sr=in_sr, out_sr=out_sr, block_size=48.0, context_size=8.0, device=device\n",
    "    )\n",
    "    \n",
    "    SAR_pairs = tci.rearrange_sequence_pair(\n",
    "        response_pairs, out_sr, segment_durs\n",
    "    )\n",
    "    \n",
    "    cross_context_corrs = tci.cross_context_corrs(\n",
    "        SAR_pairs, batch_size=100\n",
    "    )\n",
    "    \n",
    "    integration_windows = tci.estimate_integration_window(\n",
    "        cross_context_corrs, segment_durs, threshold=0.75\n",
    "    )\n",
    "    \n",
    "    return integration_windows"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now analyze the first, fourth and sixth layers of the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Conv1... Done.\n",
      "> LSTM2... Done.\n",
      "> LSTM4... Done.\n"
     ]
    }
   ],
   "source": [
    "stimuli, in_sr = tci.load_stimuli('resources/segments-librispeech-1k/')\n",
    "\n",
    "segment_durs = tci.SEGMENT_DURS\n",
    "\n",
    "print('> Conv1... ', flush=True, end='')\n",
    "intwin_conv1_norm = estimate_layer_intwin(0, stimuli, segment_durs, in_sr, out_sr)\n",
    "print('Done.')\n",
    "\n",
    "print('> LSTM2... ', flush=True, end='')\n",
    "intwin_lstm2_norm = estimate_layer_intwin(3, stimuli, segment_durs, in_sr, out_sr)\n",
    "print('Done.')\n",
    "\n",
    "print('> LSTM4... ', flush=True, end='')\n",
    "intwin_lstm4_norm = estimate_layer_intwin(5, stimuli, segment_durs, in_sr, out_sr)\n",
    "print('Done.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The results show increasing integration windows as we look deeper in the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(6, 4))\n",
    "plt.hist(np.log10(intwin_conv1_norm), bins=50, alpha=0.8, label='Conv1')\n",
    "plt.hist(np.log10(intwin_lstm2_norm), bins=50, alpha=0.8, label='LSTM2')\n",
    "plt.hist(np.log10(intwin_lstm4_norm), bins=50, alpha=0.8, label='LSTM4')\n",
    "plt.xticks(np.log10([0.02, 0.1, 0.5, 2.5]), [20, 100, 500, 2500], fontsize=16)\n",
    "plt.yticks([0, 200], fontsize=16)\n",
    "plt.xlabel('Integration window (ms)', fontsize=16)\n",
    "plt.ylabel('Density', fontsize=16)\n",
    "plt.legend(loc='upper left', fontsize=14)\n",
    "plt.xlim(np.log10([0.02, 2.5]))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Measuring adaptation indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we study whether the model adapts its integration window depending on the input speech rate, to match the length of abstract structures in speech, rather than fixed integration windows, i.e., a fixed duration of speech.\n",
    "\n",
    "To do this, we can either preprocess the stimuli independently or use the `process` parameter of `load_stimuli` and `process_stimuli`. There are a number of basic audio processing functions based on the SoX backend of the torchaudio library in the `PyTCI.audio` module, including audio stretching while preserving the audio pitch, named `tempo_fx`, which we use in this example.\n",
    "\n",
    "We then compute the adaptation index, as defined in the paper:\n",
    "\n",
    "\\begin{equation*}\n",
    "\\mathrm{adaptation\\;index} = \\frac{I_{mod}/I_{ref} - 1}{D_{mod}/D_{ref} - 1}\n",
    "\\end{equation*}\n",
    "\n",
    "where $I_{mod}$ and $I_{ref}$ are integration windows for modified (i.e., slowed-down) and reference (i.e., natural) stimuli, and $D_{mod}/D_{ref}$ is the ratio of the duration of the stimuli between modified and reference conditions. Given the definition, a non-adaptive model with a fixed window would have an adaptation index of ~0, and a fully-adaptive model with flexible integration would have an adaptation index of ~1.\n",
    "\n",
    "<strong>NOTE</strong>: the scaling factor input to the `temp_fx` and `speed_fx` functions are the scaling factor for the rate, not duration, so to increase stimuli duration by 60%, we have to set the rate scaling to $\\frac{1}{1.6}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Conv1 (slowed-down)... Done.\n",
      "> LSTM2 (slowed-down)... Done.\n",
      "> LSTM4 (slowed-down)... Done.\n"
     ]
    }
   ],
   "source": [
    "# Estimate integration periods for 60% longer stimuli\n",
    "time_stretch_factor = 1.6\n",
    "stimuli, in_sr = tci.load_stimuli(\n",
    "    'resources/segments-librispeech-1k/',\n",
    "    process=fx.tempo_fx(scale_factor=1/time_stretch_factor)\n",
    ")\n",
    "\n",
    "print('> Conv1 (slowed-down)... ', flush=True, end='')\n",
    "intwin_conv1_slow = estimate_layer_intwin(0, stimuli, segment_durs, in_sr, out_sr)\n",
    "print('Done.')\n",
    "\n",
    "print('> LSTM2 (slowed-down)... ', flush=True, end='')\n",
    "intwin_lstm2_slow = estimate_layer_intwin(3, stimuli, segment_durs, in_sr, out_sr)\n",
    "print('Done.')\n",
    "\n",
    "print('> LSTM4 (slowed-down)... ', flush=True, end='')\n",
    "intwin_lstm4_slow = estimate_layer_intwin(5, stimuli, segment_durs, in_sr, out_sr)\n",
    "print('Done.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We calculate the adaptation index as described in the paper, based on the two conditions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "adaptation_conv1 = (intwin_conv1_slow/intwin_conv1_norm - 1) / (time_stretch_factor - 1)\n",
    "adaptation_lstm2 = (intwin_lstm2_slow/intwin_lstm2_norm - 1) / (time_stretch_factor - 1)\n",
    "adaptation_lstm4 = (intwin_lstm4_slow/intwin_lstm4_norm - 1) / (time_stretch_factor - 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Comparing the integration windows between the two conditions, demonstrates how the network layers transition from short and fixed integration windows (close to black dashed line) to long and flexible integrations (close to red dashed line):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(6, 6))\n",
    "plt.plot(np.log10([0.018, 1.28]), np.log10([0.018, 1.28]), 'k--', linewidth=1)\n",
    "plt.plot(np.log10([0.018, 1.28]), np.log10([0.018, 1.28])+np.log10(1.6), 'r--', linewidth=1)\n",
    "plt.scatter(\n",
    "    np.log10(intwin_conv1_norm),\n",
    "    np.log10(intwin_conv1_slow),\n",
    "    2, alpha=0.3, label='Conv1'\n",
    ")\n",
    "plt.scatter(\n",
    "    np.log10(intwin_lstm2_norm),\n",
    "    np.log10(intwin_lstm2_slow),\n",
    "    2, alpha=0.3, label='LSTM2'\n",
    ")\n",
    "plt.scatter(\n",
    "    np.log10(intwin_lstm4_norm),\n",
    "    np.log10(intwin_lstm4_slow),\n",
    "    2, alpha=0.3, label='LSTM4'\n",
    ")\n",
    "plt.xticks(np.log10([0.02, 0.08, 0.32, 1.28]), [20, 80, 320, 1280], fontsize=16)\n",
    "plt.yticks(np.log10([0.02, 0.08, 0.32, 1.28]), [20, 80, 320, 1280], fontsize=16)\n",
    "plt.xlabel('Integration window for\\nnatural speech (ms)', fontsize=16)\n",
    "plt.ylabel('Integration window for\\nslowed-down speech (ms)', fontsize=16)\n",
    "plt.xlim(np.log10([0.02, 1.28]))\n",
    "plt.ylim(np.log10([0.02, 1.28]))\n",
    "legend = plt.legend(loc='lower right', fontsize=14)\n",
    "for h in legend.legendHandles:\n",
    "    h.set_sizes([25])\n",
    "    h.set_alpha(1)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The adaptation indices quantify this transition in a clearer way:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(6, 4))\n",
    "plt.hist(adaptation_conv1, bins=100, alpha=0.6, range=(-0.24, 1.24), label='Conv1')\n",
    "plt.hist(adaptation_lstm2, bins=50, alpha=0.6, range=(-0.24, 1.24), label='LSTM2')\n",
    "plt.hist(adaptation_lstm4, bins=50, alpha=0.6, range=(-0.24, 1.24), label='LSTM4')\n",
    "plt.legend(fontsize=14)\n",
    "plt.xticks([0, 0.25, 0.5, 0.75, 1.0], fontsize=16)\n",
    "plt.yticks([0, 100], fontsize=16)\n",
    "plt.plot([0, 0], [0, 150], 'k--')\n",
    "plt.plot([1, 1], [0, 150], 'r--')\n",
    "plt.xlabel('Adaptation index', fontsize=16)\n",
    "plt.ylabel('Density', fontsize=16)\n",
    "plt.xlim([-0.24, 1.24])\n",
    "plt.ylim([0, 150])\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-deepspeech]",
   "language": "python",
   "name": "conda-env-.conda-deepspeech-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}