{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "pycharm": { "name": "#%%\n" }, "slideshow": { "slide_type": "skip" } }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "%%html\n", "\n", "
\n", "" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "pycharm": { "name": "#%%\n" }, "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "%%capture\n", "import sys\n", "sys.path.append(\"..\")\n", "import statnlpbook.util as util\n", "import matplotlib\n", "matplotlib.rcParams['figure.figsize'] = (10.0, 6.0)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "skip" } }, "source": [ "\n", "$$\n", "\\newcommand{\\Xs}{\\mathcal{X}}\n", "\\newcommand{\\Ys}{\\mathcal{Y}}\n", "\\newcommand{\\y}{\\mathbf{y}}\n", "\\newcommand{\\balpha}{\\boldsymbol{\\alpha}}\n", "\\newcommand{\\bbeta}{\\boldsymbol{\\beta}}\n", "\\newcommand{\\aligns}{\\mathbf{a}}\n", "\\newcommand{\\align}{a}\n", "\\newcommand{\\source}{\\mathbf{s}}\n", "\\newcommand{\\target}{\\mathbf{t}}\n", "\\newcommand{\\ssource}{s}\n", "\\newcommand{\\starget}{t}\n", "\\newcommand{\\repr}{\\mathbf{f}}\n", "\\newcommand{\\repry}{\\mathbf{g}}\n", "\\newcommand{\\x}{\\mathbf{x}}\n", "\\newcommand{\\prob}{p}\n", "\\newcommand{\\bar}{\\,|\\,}\n", "\\newcommand{\\vocab}{V}\n", "\\newcommand{\\params}{\\boldsymbol{\\theta}}\n", "\\newcommand{\\param}{\\theta}\n", "\\DeclareMathOperator{\\perplexity}{PP}\n", "\\DeclareMathOperator{\\argmax}{argmax}\n", "\\DeclareMathOperator{\\argmin}{argmin}\n", "\\newcommand{\\train}{\\mathcal{D}}\n", "\\newcommand{\\counts}[2]{\\#_{#1}(#2) }\n", "\\newcommand{\\length}[1]{\\text{length}(#1) }\n", "\\newcommand{\\indi}{\\mathbb{I}}\n", "$$" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "pycharm": { "name": "#%%\n" }, "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "%load_ext tikzmagic" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "skip" } }, "outputs": [], "source": [ "from IPython.display import Image\n", "import random" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Attention" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "+ Natural language inference\n", "+ Attention mechansim\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "
\n", "\n", "
\n", "\n", "**Given:**\n", "There are six bears. Three brown bears, a black bear and a pink bear run along the grass.\n", "\n", "Which of the following is correct?\n", "1. Some bears run\n", "2. All bears sit\n", "3. One bear sits" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Task: Natural Language Inference\n", "\n", "Determining the logical relationship between two sentences, a **premise** and a **hypothesis**.\n", "\n", "Also known as *Recognising Textual Entailment* ([Dagan et al., 2005](http://u.cs.biu.ac.il/~nlp/downloads/publications/RTEChallenge.pdf)).\n", "\n", "We define entailment as:\n", "P entails H if a human reading P would typically infer that H is most likely true." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "- (Pairwise) sequence classification task\n", "- Requires commonsense and world knowledge\n", "- Requires general natural language understanding\n", "- Requires fine-grained reasoning" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "> **P:** “Google files for its long awaited IPO.”\n", "> **H:** “Google goes public.”" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Positive ($\\Rightarrow$, entails)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Stanford Natural Language Inference (SNLI) dataset\n", "\n", "Crowdsourced annotations for 570K sentence pairs using image captions ([Bowman et al., 2015](https://www.aclweb.org/anthology/D15-1075.pdf)).\n", "\n", "**P**: A wedding party taking pictures\n", "- **H:** There is a funeral\t\t\t\t\t: **Contradiction** ($\\Rightarrow\\neg$)\n", "- **H:** They are outside\t\t\t\t\t : **Neutral** (?)\n", "- **H:** Someone got married\t\t\t\t : **Entailment** ($\\Rightarrow$)\n", "\n", " " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Representing sentences as vectors\n", "\n", "1. Encode premise and hypothesis\n", "2. Concatenate the representations\n", "3. Classify with MLP\n", "\n", "
\n", "\n", "
\n", "\n", "([Image source: Bowman et al., 2015](https://www.aclweb.org/anthology/D15-1075))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "How to represent a sentence with a vector?" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "The same LSTM encodes the premise and hypothesis.\n", "\n", " " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "Use the last hidden vectors of the LSTM as sentence representations.\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### SNLI results\n", "\n", "| Model | Accuracy |\n", "|---|---|\n", "| LSTM | 77.6 |" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Problem 1: \n", "\n", "Asymmetry of premise and hypothesis." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "
\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Conditional encoding\n", "\n", "\n", "\n", "([Image source: Rocktäschel et al., 2015](https://arxiv.org/abs/1509.06664))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "\n", "\n", "([Image source: Rocktäschel et al., 2015](https://arxiv.org/abs/1509.06664))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### SNLI results\n", "\n", "| Model | Accuracy |\n", "|---|---|\n", "| LSTM | 77.6 |\n", "| LSTMs with conditional encoding | 80.9 |" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Problem 2: global memory\n", "\n", "Some words are more important to focus on." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "\n", "\n", "\n", "([Image source: Rocktäschel et al., 2015](https://arxiv.org/abs/1509.06664))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "### Attention\n", "\n", "\n", "\n", "([Image source: Rocktäschel et al., 2015](https://arxiv.org/abs/1509.06664))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## Attention mechanism\n", "\n", "+ Original motivation: machine translation ([Bahdanau et al., 2014](https://arxiv.org/abs/1409.0473)); see [later lecture in the course](nmt_slides_active.ipynb)\n", "\n", "#### Idea\n", "\n", "+ A **weighted sum** of encoder hidden states is a differentiable function and has a fixed dimension\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "### What is happening here?\n", "\n", "For the final prediction,\n", "+ Attention takes all premise hidden vectors $(\\mathbf{h}_1, \\ldots, \\mathbf{h}_n)$ as well as the final hypothesis hidden vector ($\\mathbf{h}_N$) as input\n", "+ Calculates probability distribution $\\alpha$ over premise hidden vectors using a softmax\n", "+ Combines $\\mathbf{h}_N$ with an $\\alpha$-weighted average of all premise hidden vectors" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "More formally:\n", "\n", "
\n", "\\begin{align}\n", " \\mathbf{M} &= \\tanh(\\mathbf{W}^y\\mathbf{Y}+\\mathbf{W}^h\\mathbf{h}_N \\otimes \\mathbf{e}_L) & \\mathbf{M} &\\in\\mathbb{R}^{k\\times L}\\\\\n", " \\alpha &= \\text{softmax}(\\mathbf{w}^T\\mathbf{M})&\\alpha&\\in\\mathbb{R}^L\\\\\n", " \\mathbf{r} &= \\mathbf{Y}\\alpha^T_t &\\mathbf{r}&\\in\\mathbb{R}^k\\\\\n", " \\mathbf{h^*} &= \\tanh(\\mathbf{W}^p\\mathbf{r} + \\mathbf{W}^x\\mathbf{h}_N) & \\mathbf{h}^* &\\in\\mathbb{R}^{k}\n", "\\end{align}\n", "
\n", "\n", "where\n", "\n", "* $\\mathbf{Y}\\in\\mathbb{R}^{k\\times L}$ is the concatenation of all premise hidden vectors\n", "* $\\mathbf{W}^y$, $\\mathbf{W}^h$, $\\mathbf{W}^p$, $\\mathbf{W}^r$ $\\in\\mathbb{R}^{k\\times k}$ are trained projection matrices\n", "* $\\mathbf{w}\\in\\mathbb{R}^k$ is a trained parameter vector\n", "* $\\alpha_t\\in\\mathbb{R}^L$ is the attention probability distribution\n", "* $\\mathbf{r}\\in\\mathbb{R}^k$ is the weighted representation of the premise" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### SNLI results\n", "\n", "| Model | Accuracy |\n", "|---|---|\n", "| LSTM | 77.6 |\n", "| LSTMs with conditional encoding | 80.9 |\n", "| LSTMs with conditional encoding + attention | 82.3 |" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Problem 3: representation bottleneck\n", "\n", "> You can’t cram the meaning of a whole\n", "`%&!$#` sentence into a single `$&!#*` vector!\n", ">\n", "> -- Raymond J. Mooney" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## Alignment\n", "\n", "+ Non-neural models often use **alignment** between sequences\n", "\n", "\n", "\n", "([Image source: Rocktäschel et al., 2015](https://arxiv.org/abs/1509.06664))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Word-by-word Attention\n", "\n", "+ Computing attention for each hypothesis token can give us a **soft alignment**\n", "\n", "\n", "\n", "([Image source: Rocktäschel et al., 2015](https://arxiv.org/abs/1509.06664))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "### What is happening here?\n", "\n", "**For each hypothesis token $x_t$,**\n", "+ Attention takes all premise hidden vectors $(\\mathbf{h}_1, \\ldots, \\mathbf{h}_n)$ as well as the current hypothesis hidden vector ($\\mathbf{h}_t$) as input\n", "+ Generates probability distribution $\\alpha_t$ over all premise hidden vectors\n", "+ Uses a weighted average (by $\\alpha_t$) of all premise hidden vectors as input for the next layer" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "#### SNLI results\n", "\n", "| Model | Accuracy |\n", "|---|---|\n", "| LSTM | 77.6 |\n", "| LSTMs with conditional encoding | 80.9 |\n", "| LSTMs with conditional encoding + attention | 82.3 |\n", "| LSTMs with word-by-word attention | 83.5 |\n", "\n", "([Image source: Rocktäschel et al., 2015](https://arxiv.org/abs/1509.06664))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "More formally:\n", "\n", "
\n", "\\begin{align}\n", " \\mathbf{M}_t &= \\tanh(\\mathbf{W}^y\\mathbf{Y}+(\\mathbf{W}^h\\mathbf{h}_t+\\mathbf{W}^r\\mathbf{r}_{t-1})\\mathbf{1}^T_L) & \\mathbf{M}_t &\\in\\mathbb{R}^{k\\times L}\\\\\n", " \\alpha_t &= \\text{softmax}(\\mathbf{w}^T\\mathbf{M}_t)&\\alpha_t&\\in\\mathbb{R}^L\\\\\n", " \\mathbf{r}_t &= \\mathbf{Y}\\alpha^T_t + \\tanh(\\mathbf{W}^t\\mathbf{r}_{t-1})&\\mathbf{r}_t&\\in\\mathbb{R}^k\n", "\\end{align}\n", "
\n", "\n", "where\n", "\n", "* $\\mathbf{Y}\\in\\mathbb{R}^{k\\times L}$ is the concatenation of all premise hidden vectors\n", "* $\\mathbf{W}^y$, $\\mathbf{W}^h$, $\\mathbf{W}^r \\in\\mathbb{R}^{k\\times k}$ are trained projection matrices\n", "* $\\mathbf{w}\\in\\mathbb{R}^k$ is a trained parameter vector\n", "* $\\alpha_t\\in\\mathbb{R}^L$ is the attention probability distribution\n", "* $\\mathbf{r}_t\\in\\mathbb{R}^k$ is the weighted representation of the premise (dependent on $\\mathbf{r}_{t-1}$ to inform the model about what was attended over in the previous step)\n", "* Multiplying by $\\mathbf{1}^T_L$ is the same as repeating a matrix $L$ times" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "fragment" } }, "source": [ "Final pairwise sentence representation:\n", "\n", "
\n", "\\begin{align}\n", " \\mathbf{h}^{*} &= \\text{tanh} (\\mathbf{W}^p\\mathbf{r}_N + \\mathbf{W}^x\\mathbf{h}_N)\n", "\\end{align}\n", "
\n", "\n", "Non-linear combination of the attention-weighted representation $\\mathbf{r}_t$ and the last output vector $\\mathbf{h}_N$, where $\\mathbf{h}^{*} \\in\\mathbb{R}^{k}$ " ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Attention matrix\n", "$\\alpha_{ij}$ for all premise tokens $j$ and hypothesis tokens $i$:\n", "\n", "\n", "\n", "([Image source: Rocktäschel et al., 2015](https://arxiv.org/abs/1509.06664))" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### An important caveat\n", "\n", "+ The attention mechanism was motivated by the idea of aligning inputs & outputs\n", "+ Attention matrices often correspond to human intuitions about alignment\n", "+ But ***producing a sensible alignment is not a training objective!***\n", "\n", "In other words:\n", "\n", "+ Do not expect that attention weights will *necessarily* correspond to sensible alignments!" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "### Problem 4: attention only in one direction\n", "\n", "Hypothesis tokens attend to premise tokens.\n", "\n", "Why don't hypothesis tokens also attend to other **hypthesis** tokens?\n", "\n", "Why don't premise tokens also attend to **hypthesis** tokens?\n", "\n", "Why don't premise tokens attend to other **premise** tokens?" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" }, "slideshow": { "slide_type": "slide" } }, "source": [ "## Summary\n", "\n", "+ The **attention mechanism** alleviates the encoding bottleneck in encoder-decoder architectures\n" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" }, "slideshow": { "slide_type": "subslide" } }, "source": [ "## Further reading\n", "\n", "+ [Jurafsky & Martin Chapter 8, section 8.8](https://web.stanford.edu/~jurafsky/slp3/8.pdf)\n", "+ Lilian Weng's blog post [Attention? Attention!](https://lilianweng.github.io/lil-log/2018/06/24/attention-attention.html)\n", "+ Jay Alammar's blog post [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/)\n", "\n", "\n" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.17" } }, "nbformat": 4, "nbformat_minor": 4 }