{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "slideshow": {
     "slide_type": "skip"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<script>\n",
       "  function code_toggle() {\n",
       "    if (code_shown){\n",
       "      $('div.input').hide('500');\n",
       "      $('#toggleButton').val('Show Code')\n",
       "    } else {\n",
       "      $('div.input').show('500');\n",
       "      $('#toggleButton').val('Hide Code')\n",
       "    }\n",
       "    code_shown = !code_shown\n",
       "  }\n",
       "\n",
       "  $( document ).ready(function(){\n",
       "    code_shown=false;\n",
       "    $('div.input').hide()\n",
       "  });\n",
       "</script>\n",
       "<form action=\"javascript:code_toggle()\"><input type=\"submit\" id=\"toggleButton\" value=\"Show Code\"></form>\n",
       "<style>\n",
       ".rendered_html td {\n",
       "    font-size: xx-large;\n",
       "    text-align: left; !important\n",
       "}\n",
       ".rendered_html th {\n",
       "    font-size: xx-large;\n",
       "    text-align: left; !important\n",
       "}\n",
       "</style>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%html\n",
    "<script>\n",
    "  function code_toggle() {\n",
    "    if (code_shown){\n",
    "      $('div.input').hide('500');\n",
    "      $('#toggleButton').val('Show Code')\n",
    "    } else {\n",
    "      $('div.input').show('500');\n",
    "      $('#toggleButton').val('Hide Code')\n",
    "    }\n",
    "    code_shown = !code_shown\n",
    "  }\n",
    "\n",
    "  $( document ).ready(function(){\n",
    "    code_shown=false;\n",
    "    $('div.input').hide()\n",
    "  });\n",
    "</script>\n",
    "<form action=\"javascript:code_toggle()\"><input type=\"submit\" id=\"toggleButton\" value=\"Show Code\"></form>\n",
    "<style>\n",
    ".rendered_html td {\n",
    "    font-size: xx-large;\n",
    "    text-align: left; !important\n",
    "}\n",
    ".rendered_html th {\n",
    "    font-size: xx-large;\n",
    "    text-align: left; !important\n",
    "}\n",
    "</style>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "slideshow": {
     "slide_type": "skip"
    }
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import statnlpbook.util as util\n",
    "import matplotlib\n",
    "matplotlib.rcParams['figure.figsize'] = (10.0, 6.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "skip"
    }
   },
   "source": [
    "<!---\n",
    "Latex Macros\n",
    "-->\n",
    "$$\n",
    "\\newcommand{\\Xs}{\\mathcal{X}}\n",
    "\\newcommand{\\Ys}{\\mathcal{Y}}\n",
    "\\newcommand{\\y}{\\mathbf{y}}\n",
    "\\newcommand{\\balpha}{\\boldsymbol{\\alpha}}\n",
    "\\newcommand{\\bbeta}{\\boldsymbol{\\beta}}\n",
    "\\newcommand{\\aligns}{\\mathbf{a}}\n",
    "\\newcommand{\\align}{a}\n",
    "\\newcommand{\\source}{\\mathbf{s}}\n",
    "\\newcommand{\\target}{\\mathbf{t}}\n",
    "\\newcommand{\\ssource}{s}\n",
    "\\newcommand{\\starget}{t}\n",
    "\\newcommand{\\repr}{\\mathbf{f}}\n",
    "\\newcommand{\\repry}{\\mathbf{g}}\n",
    "\\newcommand{\\x}{\\mathbf{x}}\n",
    "\\newcommand{\\prob}{p}\n",
    "\\newcommand{\\bar}{\\,|\\,}\n",
    "\\newcommand{\\vocab}{V}\n",
    "\\newcommand{\\params}{\\boldsymbol{\\theta}}\n",
    "\\newcommand{\\param}{\\theta}\n",
    "\\DeclareMathOperator{\\perplexity}{PP}\n",
    "\\DeclareMathOperator{\\argmax}{argmax}\n",
    "\\DeclareMathOperator{\\argmin}{argmin}\n",
    "\\newcommand{\\train}{\\mathcal{D}}\n",
    "\\newcommand{\\counts}[2]{\\#_{#1}(#2) }\n",
    "\\newcommand{\\length}[1]{\\text{length}(#1) }\n",
    "\\newcommand{\\indi}{\\mathbb{I}}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    },
    "slideshow": {
     "slide_type": "skip"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext tikzmagic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Attention"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Schedule\n",
    "\n",
    "+ Background: neural MT (5 min.)\n",
    "\n",
    "+ Math: attention (10 min.)\n",
    "\n",
    "+ Math: self-attention (10 min.)\n",
    "\n",
    "+ Background: BERT (15 min.)\n",
    "\n",
    "+ Background: mBERT (5 min.)\n",
    "\n",
    "+ Quiz: mBERT (5 min.)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## Encoder-decoder seq2seq\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "<center>\n",
    "    <img src=\"mt_figures/encdec_rnn2.svg\" width=\"70%\" />\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "## More problems with our approach\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "<center>\n",
    "    <img src=\"mt_figures/encdec_rnn3.svg\" width=\"70%\" />\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "    You can't cram the meaning of a whole %&!$ing sentence into a single $&!*ing vector!\n",
    "\n",
    "— Ray Mooney\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "## Idea\n",
    "\n",
    "+ Traditional (non-neural) MT models often perform **alignment** between sequences\n",
    "\n",
    "<center style=\"padding: 1em 0;\">\n",
    "    <img src=\"mt_figures/align.svg\" width=\"20%\" />\n",
    "</center>\n",
    "\n",
    "+ Can we learn something similar with our encoder–decoder model?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "## Attention mechanism\n",
    "\n",
    "+ Original motivation: Bahdanau et al. 2014, [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473)\n",
    "\n",
    "#### Idea\n",
    "\n",
    "+ Each encoder timestep gives us a **contextual representation** of the corresponding input token\n",
    "+ A **weighted combination** of those is a differentiable function\n",
    "+ Computing such a combination for each decoder timestep can give us a **soft alignment**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "<center>\n",
    "    <img src=\"mt_figures/encdec_att.svg\" width=\"40%\" />\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### What is happening here?\n",
    "\n",
    "**Attention model** takes as input:\n",
    "\n",
    "+ Hidden state of the decoder $\\mathbf{s}_t^{\\textrm{dec}}$\n",
    "+ All encoder hidden states $(\\mathbf{h}_1^{\\textrm{enc}}, \\ldots, \\mathbf{h}_n^{\\textrm{enc}})$\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "**Attention model** produces:\n",
    "\n",
    "+ An attention vector $\\mathbf{\\alpha}_t \\in \\mathbb{R}^n$ (where $n$ is the length of the source sequence)\n",
    "+ $\\mathbf{\\alpha}_t$ is computed as a softmax distribution:\n",
    "\n",
    "$$\n",
    "\\mathbf{\\alpha}_{t,j} = \\text{softmax}\\left(f_{\\mathrm{att}}(\\mathbf{s}_{t-1}^{\\textrm{dec}}, \\mathbf{h}_j^{\\textrm{enc}})\\right)\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "### How do we compute $f_\\mathrm{att}$?\n",
    "\n",
    "Usually with a very simple feedforward neural network.\n",
    "\n",
    "For example:\n",
    "\n",
    "$$\n",
    "f_{\\mathrm{att}}(\\mathbf{s}_{t-1}^{\\textrm{dec}}, \\mathbf{h}_j^{\\textrm{enc}}) =\n",
    "\\tanh\n",
    "\\left(\n",
    "\\mathbf{W}^s \\mathbf{s}_{t-1}^{\\textrm{dec}} +\n",
    "\\mathbf{W}^h \\mathbf{h}_j^{\\textrm{enc}}\n",
    "\\right)\n",
    "$$\n",
    "\n",
    "This is called **additive** attention."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "Another alternative:\n",
    "\n",
    "$$\n",
    "f_{\\mathrm{att}}(\\mathbf{s}_{t-1}^{\\textrm{dec}}, \\mathbf{h}_j^{\\textrm{enc}}) =\n",
    "\\frac{\\left(\\mathbf{s}_{t-1}^{\\textrm{dec}}\\right)^\\intercal\n",
    "\\mathbf{W} \\mathbf{h}_j^{\\textrm{enc}}}\n",
    "{\\sqrt{d_{\\mathbf{h}^{\\textrm{enc}}}}}\n",
    "$$\n",
    "\n",
    "This is called **scaled dot-product** attention.\n",
    "\n",
    "(But many alternatives have been proposed!)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "### What do we do with $\\mathbf{\\alpha}_t$?\n",
    "\n",
    "Computing a **context vector:**\n",
    "\n",
    "$$\n",
    "\\mathbf{c}_t = \\sum_{i=1}^n \\mathbf{\\alpha}_{t,i} \\mathbf{h}_i^\\mathrm{enc}\n",
    "$$\n",
    "\n",
    "This is the weighted combination of the input representations!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "\n",
    "Include this context vector in the calculation of decoder's hidden state:\n",
    "\n",
    "$$\n",
    "\\mathbf{s}_t^{\\textrm{dec}} = f\\left(\\mathbf{s}_{t-1}^{\\textrm{dec}}, \\mathbf{y}_{t-1}^\\textrm{dec}, \\mathbf{c}_t\\right)\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "<center>\n",
    "    <img src=\"mt_figures/encdec_att.svg\" width=\"22%\" />\n",
    "</center>\n",
    "\n",
    "> Intuitively, this implements a mechanism of attention in the decoder.  The decoder **decides parts of the source sentence to pay attention to.**  By letting the decoder have an attention mechanism, we relieve the encoder from the burden of having to encode all information in the source sentence into a fixed-length vector.\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://arxiv.org/pdf/1409.0473.pdf\">Bahdanau et al., 2014</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "We can visualize what this model learns in an\n",
    "\n",
    "## Attention matrix\n",
    "\n",
    "&nbsp;\n",
    "\n",
    "$\\rightarrow$ Simply concatenate all $\\alpha_t$ for $1 \\leq t \\leq m$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "<center>\n",
    "    <img src=\"mt_figures/att_matrix.png\" width=\"40%\" />\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://arxiv.org/pdf/1409.0473.pdf\">Bahdanau et al., 2014</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "### An important caveat\n",
    "\n",
    "+ The attention mechanism was motivated by the idea of aligning inputs & outputs\n",
    "+ Attention matrices often correspond to human intuitions about alignment\n",
    "+ But ***producing a sensible alignment is not a training objective!***\n",
    "\n",
    "In other words:\n",
    "\n",
    "+ Do not expect that attention weights will *necessarily* correspond to sensible alignments!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## Another way to think about attention\n",
    "\n",
    "1. $\\color{purple}{\\mathbf{s}_{t-1}^{\\textrm{dec}}}$ is the **query**\n",
    "2. Retrieve the best $\\mathbf{h}_j^{\\textrm{enc}}$ by taking\n",
    "$\\color{orange}{\\mathbf{W} \\mathbf{h}_j^{\\textrm{enc}}}$ as the **key**\n",
    "3. Softly select a $\\color{blue}{\\mathbf{h}_j^{\\textrm{enc}}}$ as **value**\n",
    "\n",
    "$$\n",
    "\\mathbf{\\alpha}_{t,j} = \\text{softmax}\\left(\n",
    "\\frac{\\left(\\color{purple}{\\mathbf{s}_{t-1}^{\\textrm{dec}}}\\right)^\\intercal\n",
    "\\color{orange}{\\mathbf{W} \\mathbf{h}_j^{\\textrm{enc}}}}\n",
    "{\\sqrt{d_{\\mathbf{h}^{\\textrm{enc}}}}}\n",
    "\\right) \\\\\n",
    "\\mathbf{c}_t = \\sum_{i=1}^n \\mathbf{\\alpha}_{t,i} \\color{blue}{\\mathbf{h}_i^\\mathrm{enc}} \\\\\n",
    "\\mathbf{s}_t^{\\textrm{dec}} = f\\left(\\mathbf{s}_{t-1}^{\\textrm{dec}}, \\mathbf{y}_{t-1}^\\textrm{dec}, \\mathbf{c}_t\\right)\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "Used during **decoding** to attend to $\\mathbf{h}^{\\textrm{enc}}$, encoded by Bi-LSTM.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "More recent development:\n",
    "\n",
    "### Transformer models\n",
    "\n",
    "+ Described in Vaswani et al. (2017) paper famously titled [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf)\n",
    "+ Gets rid of RNNs, uses attention calculations everywhere (also called **self-attention**)\n",
    "+ Used in most current state-of-the-art NMT models\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "## Self-attention\n",
    "\n",
    "Forget about Bi-LSTMs, because Attention is All You Need even for **encoding**!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "All encoder tokens attend to each other:\n",
    "\n",
    "<center>\n",
    "    <img src=\"http://jalammar.github.io/images/t/transformer_self-attention_visualization.png\" width=30%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"http://jalammar.github.io/illustrated-transformer/\">The Illustrated Transformer</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Scaled Dot-Product Attention\n",
    "\n",
    "Use $\\mathbf{h}_i$ to create three vectors:\n",
    "$\\color{purple}{\\mathbf{q}_i}=W^q\\mathbf{h}_i,\n",
    "\\color{orange}{\\mathbf{k}_i}=W^k\\mathbf{h}_i,\n",
    "\\color{blue}{\\mathbf{v}_i}=W^v\\mathbf{h}_i$.\n",
    "\n",
    "$$\n",
    "\\mathbf{\\alpha}_{i,j} = \\text{softmax}\\left(\n",
    "\\frac{\\color{purple}{\\mathbf{q}_i}^\\intercal\n",
    "\\color{orange}{\\mathbf{k}_j}}\n",
    "{\\sqrt{d_{\\mathbf{h}}}}\n",
    "\\right) \\\\\n",
    "\\mathbf{h}_i^\\prime = \\sum_{j=1}^n \\mathbf{\\alpha}_{i,j} \\color{blue}{\\mathbf{v}_j}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "In matrix form:\n",
    "\n",
    "$$\n",
    "\\text{softmax}\\left(\n",
    "\\frac{\\color{purple}{Q}\n",
    "\\color{orange}{K}^\\intercal}\n",
    "{\\sqrt{d_{\\mathbf{h}}}}\n",
    "\\right) \\color{blue}{V}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "<center>\n",
    "    <img src=\"mt_figures/self_att.png\" width=25%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://arxiv.org/pdf/1706.03762.pdf\">Vaswani et al., 2017</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Multi-head self-attention\n",
    "\n",
    "<center>\n",
    "    <img src=\"mt_figures/multi_head_self_att.png\" width=30%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://arxiv.org/pdf/1706.03762.pdf\">Vaswani et al., 2017</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Transformer layer\n",
    "\n",
    "<center>\n",
    "    <img src=\"mt_figures/transformer_layer.png\" width=30%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://arxiv.org/pdf/1706.03762.pdf\">Vaswani et al., 2017</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Transformer\n",
    "\n",
    "<center>\n",
    "    <img src=\"mt_figures/transformer.png\" width=30%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://arxiv.org/pdf/1706.03762.pdf\">Vaswani et al., 2017</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Long-distance dependencies\n",
    "\n",
    "<center>\n",
    "    <img src=\"mt_figures/ldd.png\" width=80%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://arxiv.org/pdf/1706.03762.pdf\">Vaswani et al., 2017</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "Unlike RNNs, no inherent locality bias!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Transformers for decoding\n",
    "\n",
    "Attends to encoded input *and* to partial output.\n",
    "\n",
    "<center>\n",
    "    <img src=\"http://jalammar.github.io/images/xlnet/transformer-encoder-decoder.png\" width=80%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"http://jalammar.github.io/illustrated-gpt2/\">The Illustrated GPT-2</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "Can only attend to already-generated tokens.\n",
    "\n",
    "<center>\n",
    "    <img src=\"http://jalammar.github.io/images/gpt2/self-attention-and-masked-self-attention.png\" width=80%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"http://jalammar.github.io/illustrated-gpt2/\">The Illustrated GPT-2</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "The encoder transformer is sometimes called \"bidirectional transformer\"."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## BERT\n",
    "\n",
    "[Devlin et al., 2019](https://www.aclweb.org/anthology/N19-1423.pdf):\n",
    "**B**idirectional **E**ncoder **R**epresentations from **T**ransformers.\n",
    "\n",
    "<center>\n",
    "    <img src=\"https://miro.medium.com/max/300/0*2XpE-VjhhLGkFDYg.jpg\" width=40%/>\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### BERT architecture\n",
    "\n",
    "Transformer with $L$ layers of dimension $H$, and $A$ self-attention heads.\n",
    "\n",
    "* BERT$_\\mathrm{BASE}$: $L=12, H=768, A=12$\n",
    "* BERT$_\\mathrm{LARGE}$: $L=24, H=1024, A=16$\n",
    "\n",
    "Other pre-trained checkpoints: https://github.com/google-research/bert"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "Trained on 16GB of text from Wikipedia + BookCorpus.\n",
    "\n",
    "* BERT$_\\mathrm{BASE}$: 4 TPUs for 4 days\n",
    "* BERT$_\\mathrm{LARGE}$: 16 TPUs for 4 days"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Training objective (1): masked language model\n",
    "\n",
    "Predict masked words given context on both sides:\n",
    "\n",
    "<center>\n",
    "    <img src=\"http://jalammar.github.io/images/BERT-language-modeling-masked-lm.png\" width=50%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"http://jalammar.github.io/illustrated-bert/\">The Illustrated BERT</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "<center>\n",
    "<a href=\"slides/mlm.pdf\"><img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/0/08/Sesame_Street_logo.svg/500px-Sesame_Street_logo.svg.png\"></a>\n",
    "</center>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Training objective (2): next sentence prediction\n",
    "\n",
    "**Conditional encoding** of both sentences:\n",
    "\n",
    "<center>\n",
    "    <img src=\"http://jalammar.github.io/images/bert-next-sentence-prediction.png\" width=60%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"http://jalammar.github.io/illustrated-bert/\">The Illustrated BERT</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### How is that different from ELMo and GPT-$n$?\n",
    "\n",
    "<center>\n",
    "    <img src=\"mt_figures/bert_gpt_elmo.png\" width=100%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://www.aclweb.org/anthology/N19-1423.pdf\">Devlin et al., 2019</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Not words, but WordPieces\n",
    "\n",
    "<center>\n",
    "    <img src=\"https://vamvas.ch/assets/bert-for-ner/tokenizer.png\" width=80%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://vamvas.ch/bert-for-ner\">BERT for NER</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "* 30,000 WordPiece vocabulary\n",
    "* No unknown words!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Using BERT\n",
    "\n",
    "<center>\n",
    "    <img src=\"http://jalammar.github.io/images/bert-tasks.png\" width=60%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://www.aclweb.org/anthology/N19-1423.pdf\">Devlin et al., 2019</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "Feature extraction (❄️) vs. fine-tuning (🔥)\n",
    "\n",
    "<center>\n",
    "    <img src=\"https://d3i71xaburhd42.cloudfront.net/8659bf379ca8756755125a487c43cfe8611ce842/1-Table1-1.png\" width=80%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://www.aclweb.org/anthology/W19-4302.pdf\">Peters et al. 2019</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "Don't stop pretraining!\n",
    "\n",
    "<center>\n",
    "    <img src=\"https://d3i71xaburhd42.cloudfront.net/e816f788767eec6a8ef0ea9eddd0e902435d4271/1-Figure1-1.png\" width=80%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://www.aclweb.org/anthology/2020.acl-main.740.pdf\">Gururangan et al. 2020</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Which layer to use?\n",
    "\n",
    "<center>\n",
    "    <img src=\"http://jalammar.github.io/images/bert-feature-extraction-contextualized-embeddings.png\" width=80%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"http://jalammar.github.io/illustrated-bert/\">The Illustrated BERT</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### RoBERTa\n",
    "\n",
    "[Liu et al., 2019](https://arxiv.org/pdf/1907.11692.pdf): bigger is better.\n",
    "\n",
    "BERT with additionally\n",
    "\n",
    "- CC-News (76GB)\n",
    "- OpenWebText (38GB)\n",
    "- Stories (31GB)\n",
    "\n",
    "and **no** next-sentence-prediction task (only masked LM).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "Training: 1024 GPUs for one day."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## Multilingual BERT\n",
    "\n",
    "* One model pre-trained on 104 languages with the largest Wikipedias\n",
    "* 110k *shared* WordPiece vocabulary\n",
    "* Same architecture as BERT$_\\mathrm{BASE}$: $L=12, H=768, A=12$\n",
    "* Same training objectives, **no cross-lingual signal**\n",
    "\n",
    "https://github.com/google-research/bert/blob/master/multilingual.md"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "<center>\n",
    "    <img src=\"https://d3i71xaburhd42.cloudfront.net/5d8beeca1a2e3263b2796e74e2f57ffb579737ee/3-Figure1-1.png\" width=80%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://arxiv.org/pdf/1911.03310.pdf\">Libovický et al., 2019</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Other multilingual transformers\n",
    "\n",
    "+ XLM ([Lample and Conneau, 2019](https://arxiv.org/pdf/1901.07291.pdf)) additionally uses an MT objective\n",
    "+ DistilmBERT ([Sanh et al., 2020](https://arxiv.org/pdf/1910.01108.pdf)) is a lighter version of mBERT\n",
    "+ Many monolingual BERTs for languages other than English\n",
    "([CamemBERT](https://arxiv.org/pdf/1911.03894.pdf),\n",
    "[BERTje](https://arxiv.org/pdf/1912.09582),\n",
    "[Nordic BERT](https://github.com/botxo/nordic_bert)...)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "### Zero-shot cross-lingual transfer\n",
    "\n",
    "1. Pre-train (or download) mBERT\n",
    "2. Fine-tune on a task in one language (e.g., English)\n",
    "3. Test on the same task in another language\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "mBERT is unreasonably effective at cross-lingual transfer!\n",
    "\n",
    "NER F1:\n",
    "<center>\n",
    "    <img src=\"https://d3i71xaburhd42.cloudfront.net/809cc93921e4698bde891475254ad6dfba33d03b/2-Table1-1.png\" width=80%/>\n",
    "</center>\n",
    "\n",
    "POS accuracy:\n",
    "<center>\n",
    "    <img src=\"https://d3i71xaburhd42.cloudfront.net/809cc93921e4698bde891475254ad6dfba33d03b/2-Table2-1.png\" width=80%/>\n",
    "</center>\n",
    "\n",
    "<div style=\"text-align: right;\">\n",
    "    (from <a href=\"https://www.aclweb.org/anthology/P19-1493.pdf\">Pires et al., 2019</a>)\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "Why? (poll)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "See also [K et al., 2020](https://arxiv.org/pdf/1912.07840.pdf);\n",
    "[Wu and Dredze., 2019](https://www.aclweb.org/anthology/D19-1077.pdf).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    },
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## Summary\n",
    "\n",
    "+ The **attention mechanism** alleviates the encoding bottleneck in encoder-decoder architectures\n",
    "\n",
    "+ Attention can even replace (bi)-LSTMs, giving **self-attention**\n",
    "\n",
    "+ **Transformers** rely on self-attention for encoding and decoding\n",
    "\n",
    "+ **BERT**, GPT-$n$ and other transformers are powerful pre-trained contextualized representations\n",
    "\n",
    "+ **Multilingual** pre-trained transformers enable zero-shot cross-lingual transfer\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    },
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "## Further reading\n",
    "\n",
    "* Attention:\n",
    "  + Lilian Weng's blog post [Attention? Attention!](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html)\n",
    "\n",
    "\n",
    "* Transformers\n",
    "  + Jay Alammar's blog posts:\n",
    "    + [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/)\n",
    "    + [The Illustrated GPT-2](http://jalammar.github.io/illustrated-gpt2/)\n",
    "    + [The Illustrated BERT](http://jalammar.github.io/illustrated-bert/)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}