{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Natural language inference: Models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "__author__ = \"Christopher Potts\"\n", "__version__ = \"CS224u, Stanford, Spring 2019\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Contents\n", "\n", "1. [Contents](#Contents)\n", "1. [Overview](#Overview)\n", "1. [Set-up](#Set-up)\n", "1. [Sparse feature representations](#Sparse-feature-representations)\n", " 1. [Feature representations](#Feature-representations)\n", " 1. [Model wrapper](#Model-wrapper)\n", " 1. [Assessment](#Assessment)\n", "1. [Sentence-encoding models](#Sentence-encoding-models)\n", " 1. [Dense representations with a linear classifier](#Dense-representations-with-a-linear-classifier)\n", " 1. [Dense representations with a shallow neural network](#Dense-representations-with-a-shallow-neural-network)\n", " 1. [Sentence-encoding RNNs](#Sentence-encoding-RNNs)\n", " 1. [Other sentence-encoding model ideas](#Other-sentence-encoding-model-ideas)\n", "1. [Chained models](#Chained-models)\n", " 1. [Simple RNN](#Simple-RNN)\n", " 1. [Separate premise and hypothesis RNNs](#Separate-premise-and-hypothesis-RNNs)\n", "1. [Attention mechanisms](#Attention-mechanisms)\n", "1. [Error analysis with the MultiNLI annotations](#Error-analysis-with-the-MultiNLI-annotations)\n", "1. [Other findings](#Other-findings)\n", "1. [Exploratory exercises](#Exploratory-exercises)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Overview\n", "\n", "This notebook defines and explores a number of models for NLI. The general plot is familiar from [our work with the Stanford Sentiment Treebank](sst_01_overview.ipynb):\n", "\n", "1. Models based on sparse feature representations\n", "1. Linear classifiers and feed-forward neural classifiers using dense feature representations\n", "1. Recurrent and tree-structured neural networks\n", "\n", "The twist here is that, while NLI is another classification problem, the inputs have important high-level structure: __a premise__ and __a hypothesis__. This invites exploration of a host of neural model designs:\n", "\n", "* In __sentence-encoding__ models, the premise and hypothesis are analyzed separately, combined only for the final classification step.\n", "\n", "* In __chained__ models, the premise is processed first, then the hypotheses, giving a unified representation of the pair.\n", "\n", "NLI resembles sequence-to-sequence problems like __machine translation__ and __language modeling__. The central modeling difference is that NLI doesn't produce an output sequence, but rather consumes two sequences to produce a label. Still, there are enough affinities that many ideas have been shared among these fields." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set-up\n", "\n", "See [the previous notebook](nli_01_task_and_data.ipynb#Set-up) for set-up instructions for this unit. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from collections import Counter\n", "from itertools import product\n", "import numpy as np\n", "import pandas as pd\n", "from sklearn.linear_model import LogisticRegression\n", "import torch\n", "import torch.nn as nn\n", "import torch.utils.data\n", "from torch_model_base import TorchModelBase\n", "from torch_rnn_classifier import TorchRNNClassifier, TorchRNNClassifierModel\n", "from torch_shallow_neural_classifier import TorchShallowNeuralClassifier\n", "from torch_rnn_classifier import TorchRNNClassifier\n", "import nli\n", "import os\n", "import utils" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "GLOVE_HOME = os.path.join('data', 'glove.6B')\n", "\n", "DATA_HOME = os.path.join(\"data\", \"nlidata\")\n", "\n", "SNLI_HOME = os.path.join(DATA_HOME, \"snli_1.0\")\n", "\n", "MULTINLI_HOME = os.path.join(DATA_HOME, \"multinli_1.0\")\n", "\n", "ANNOTATIONS_HOME = os.path.join(DATA_HOME, \"multinli_1.0_annotations\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sparse feature representations\n", "\n", "We begin by looking at models based in sparse, hand-built feature representations. As in earlier units of the course, we will see that __these models are competitive__: easy to design, fast to optimize, and highly effective." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Feature representations\n", "\n", "The guiding idea for NLI sparse features is that one wants to knit together the premise and hypothesis, so that the model can learn about their relationships rather than just about each part separately." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With `word_overlap_phi`, we just get the set of words that occur in both the premise and hypothesis." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def word_overlap_phi(t1, t2): \n", " \"\"\"Basis for features for the words in both the premise and hypothesis.\n", " This tends to produce very sparse representations.\n", " \n", " Parameters\n", " ----------\n", " t1, t2 : `nltk.tree.Tree`\n", " As given by `str2tree`.\n", " \n", " Returns\n", " -------\n", " defaultdict\n", " Maps each word in both `t1` and `t2` to 1.\n", " \n", " \"\"\"\n", " overlap = set([w1 for w1 in t1.leaves() if w1 in t2.leaves()])\n", " return Counter(overlap)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With `word_cross_product_phi`, we count all the pairs $(w_{1}, w_{1})$ where $w_{1}$ is a word from the premise and $w_{2}$ is a word from the hypothesis. This creates a very large feature space. These models are very strong right out of the box, and they can be supplemented with more fine-grained features." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def word_cross_product_phi(t1, t2):\n", " \"\"\"Basis for cross-product features. This tends to produce pretty \n", " dense representations.\n", " \n", " Parameters\n", " ----------\n", " t1, t2 : `nltk.tree.Tree`\n", " As given by `str2tree`.\n", " \n", " Returns\n", " -------\n", " defaultdict\n", " Maps each (w1, w2) in the cross-product of `t1.leaves()` and \n", " `t2.leaves()` to its count. This is a multi-set cross-product\n", " (repetitions matter).\n", " \n", " \"\"\"\n", " return Counter([(w1, w2) for w1, w2 in product(t1.leaves(), t2.leaves())])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model wrapper\n", "\n", "Our experiment framework is basically the same as the one we used for the Stanford Sentiment Treebank. Here, I actually use `sst.fit_classifier_with_crossvalidation` (from that unit) to create a wrapper around `LogisticRegression` for cross-validation of hyperparameters. At this point, I am not sure what parameters will be good for our NLI datasets, so this hyperparameter search is vital." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def fit_softmax_with_crossvalidation(X, y):\n", " \"\"\"A MaxEnt model of dataset with hyperparameter cross-validation.\n", " \n", " Parameters\n", " ----------\n", " X : 2d np.array\n", " The matrix of features, one example per row.\n", " \n", " y : list\n", " The list of labels for rows in `X`. \n", " \n", " Returns\n", " -------\n", " sklearn.linear_model.LogisticRegression\n", " A trained model instance, the best model found.\n", " \n", " \"\"\" \n", " basemod = LogisticRegression(\n", " fit_intercept=True, \n", " solver='liblinear', \n", " multi_class='auto')\n", " cv = 3\n", " param_grid = {'C': [0.4, 0.6, 0.8, 1.0],\n", " 'penalty': ['l1','l2']} \n", " best_mod = utils.fit_classifier_with_crossvalidation(\n", " X, y, basemod, cv, param_grid)\n", " return best_mod" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Assessment\n", "\n", "Because SNLI and MultiNLI are huge, we can't afford to do experiments on the full datasets all the time. Thus, we will mainly work within the training sets, using the train readers to sample smaller datasets that can then be divided for training and assessment.\n", "\n", "Here, we sample 10% of the training examples. I set the random seed (`random_state=42`) so that we get consistency across the samples; setting `random_state=None` will give new random samples each time." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "train_reader = nli.SNLITrainReader(\n", " SNLI_HOME, samp_percentage=0.10, random_state=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An experimental dataset can be built directly from the reader and a feature function:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "dataset = nli.build_dataset(train_reader, word_overlap_phi)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['X', 'y', 'vectorizer', 'raw_examples'])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset.keys()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, it's more efficient to use `nli.experiment` to bring all these pieces together. This wrapper will work for all the models we consider." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best params: {'C': 0.8, 'penalty': 'l2'}\n", "Best score: 0.415\n", "Accuracy: 0.437\n", " precision recall f1-score support\n", "\n", "contradiction 0.445 0.673 0.536 5509\n", " entailment 0.462 0.382 0.418 5545\n", " neutral 0.387 0.256 0.309 5482\n", "\n", " micro avg 0.437 0.437 0.437 16536\n", " macro avg 0.432 0.437 0.421 16536\n", " weighted avg 0.432 0.437 0.421 16536\n", "\n" ] } ], "source": [ "_ = nli.experiment(\n", " train_reader=nli.SNLITrainReader(SNLI_HOME, samp_percentage=0.10), \n", " phi=word_overlap_phi,\n", " train_func=fit_softmax_with_crossvalidation,\n", " assess_reader=None,\n", " random_state=42)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best params: {'C': 0.4, 'penalty': 'l1'}\n", "Best score: 0.614\n", "Accuracy: 0.626\n", " precision recall f1-score support\n", "\n", "contradiction 0.669 0.640 0.654 5473\n", " entailment 0.611 0.685 0.646 5356\n", " neutral 0.601 0.557 0.578 5558\n", "\n", " micro avg 0.626 0.626 0.626 16387\n", " macro avg 0.627 0.627 0.626 16387\n", " weighted avg 0.627 0.626 0.626 16387\n", "\n" ] } ], "source": [ "_ = nli.experiment(\n", " train_reader=nli.SNLITrainReader(SNLI_HOME, samp_percentage=0.10), \n", " phi=word_cross_product_phi,\n", " train_func=fit_softmax_with_crossvalidation,\n", " assess_reader=None,\n", " random_state=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As expected `word_cross_product_phi` is very strong. Let's take the hyperparameters chosen there and use them for an experiment in which we train on the entire training set and evaluate on the dev set; this seems like a good way to balance responsible search over hyperparameters with our resource limitations." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "def fit_softmax_classifier_with_preselected_params(X, y): \n", " mod = LogisticRegression(\n", " fit_intercept=True, \n", " penalty='l1', \n", " solver='saga', ## Required for penalty='ll'.\n", " multi_class='ovr',\n", " C=0.4)\n", " mod.fit(X, y)\n", " return mod" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", "contradiction 0.762 0.729 0.745 3278\n", " entailment 0.708 0.795 0.749 3329\n", " neutral 0.716 0.657 0.685 3235\n", "\n", " avg / total 0.729 0.728 0.727 9842\n", "\n", "CPU times: user 19min 17s, sys: 9.56 s, total: 19min 26s\n", "Wall time: 19min 26s\n" ] } ], "source": [ "%%time\n", "_ = nli.experiment(\n", " train_reader=nli.SNLITrainReader(SNLI_HOME, samp_percentage=1.0), \n", " assess_reader=nli.SNLIDevReader(SNLI_HOME, samp_percentage=1.0),\n", " phi=word_cross_product_phi,\n", " train_func=fit_softmax_classifier_with_preselected_params,\n", " random_state=None)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This baseline is very similar to the one established in [the original SNLI paper by Bowman et al.](https://aclanthology.info/papers/D15-1075/d15-1075) for models like this one." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sentence-encoding models\n", "\n", "We turn now to sentence-encoding models. The hallmark of these is that the premise and hypothesis get their own representation in some sense, and then those representations are combined to predict the label. [Bowman et al. 2015](http://aclweb.org/anthology/D/D15/D15-1075.pdf) explore models of this form as part of introducing SNLI.\n", "\n", "The feed-forward networks we used in [the word-level bake-off](nli_wordentail_bakeoff.ipynb) are members of this family of models: each word was represented separately, and the concatenation of those representations was used as the input to the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dense representations with a linear classifier\n", "\n", "Perhaps the simplest sentence-encoding model sums (or averages, etc.) the word representations for the premise, does the same for the hypothesis, and concatenates those two representations for use as the input to a linear classifier. \n", "\n", "Here's a diagram that is meant to suggest the full space of models of this form:\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's an implementation of this model where \n", "\n", "* The embedding is GloVe.\n", "* The word representations are summed.\n", "* The premise and hypothesis vectors are concatenated.\n", "* A softmax classifier is used at the top." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "glove_lookup = utils.glove2dict(\n", " os.path.join(GLOVE_HOME, 'glove.6B.50d.txt'))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def glove_leaves_phi(t1, t2, np_func=np.sum):\n", " \"\"\"Represent `tree` as a combination of the vector of its words.\n", " \n", " Parameters\n", " ----------\n", " t1 : nltk.Tree \n", " t2 : nltk.Tree \n", " np_func : function (default: np.sum)\n", " A numpy matrix operation that can be applied columnwise, \n", " like `np.mean`, `np.sum`, or `np.prod`. The requirement is that \n", " the function take `axis=0` as one of its arguments (to ensure\n", " columnwise combination) and that it return a vector of a \n", " fixed length, no matter what the size of the tree is.\n", " \n", " Returns\n", " -------\n", " np.array\n", " \n", " \"\"\" \n", " prem_vecs = _get_tree_vecs(t1, glove_lookup, np_func) \n", " hyp_vecs = _get_tree_vecs(t2, glove_lookup, np_func) \n", " return np.concatenate((prem_vecs, hyp_vecs))\n", " \n", " \n", "def _get_tree_vecs(tree, lookup, np_func):\n", " allvecs = np.array([lookup[w] for w in tree.leaves() if w in lookup]) \n", " if len(allvecs) == 0:\n", " dim = len(next(iter(lookup.values())))\n", " feats = np.zeros(dim) \n", " else: \n", " feats = np_func(allvecs, axis=0) \n", " return feats" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Best params: {'C': 0.4, 'penalty': 'l1'}\n", "Best score: 0.504\n", "Accuracy: 0.502\n", " precision recall f1-score support\n", "\n", "contradiction 0.495 0.464 0.479 5499\n", " entailment 0.484 0.563 0.521 5536\n", " neutral 0.533 0.478 0.504 5525\n", "\n", " micro avg 0.502 0.502 0.502 16560\n", " macro avg 0.504 0.502 0.501 16560\n", " weighted avg 0.504 0.502 0.501 16560\n", "\n" ] } ], "source": [ "_ = nli.experiment(\n", " train_reader=nli.SNLITrainReader(SNLI_HOME, samp_percentage=0.10), \n", " phi=glove_leaves_phi,\n", " train_func=fit_softmax_with_crossvalidation,\n", " assess_reader=None,\n", " random_state=42,\n", " vectorize=False) # Ask `experiment` not to featurize; we did it already." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Dense representations with a shallow neural network\n", "\n", "A small tweak to the above is to use a neural network instead of a softmax classifier at the top:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "def fit_shallow_neural_classifier_with_crossvalidation(X, y): \n", " basemod = TorchShallowNeuralClassifier(max_iter=1000)\n", " cv = 3\n", " param_grid = {'hidden_dim': [25, 50, 100]}\n", " best_mod = utils.fit_classifier_with_crossvalidation(\n", " X, y, basemod, cv, param_grid)\n", " return best_mod" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Finished epoch 1000 of 1000; error is 32.213042318820959" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Best params: {'hidden_dim': 25}\n", "Best score: 0.514\n", "Accuracy: 0.535\n", " precision recall f1-score support\n", "\n", "contradiction 0.541 0.476 0.506 5364\n", " entailment 0.522 0.646 0.577 5608\n", " neutral 0.547 0.478 0.510 5471\n", "\n", " micro avg 0.535 0.535 0.535 16443\n", " macro avg 0.537 0.533 0.531 16443\n", " weighted avg 0.537 0.535 0.532 16443\n", "\n" ] } ], "source": [ "_ = nli.experiment(\n", " train_reader=nli.SNLITrainReader(SNLI_HOME, samp_percentage=0.10), \n", " phi=glove_leaves_phi,\n", " train_func=fit_shallow_neural_classifier_with_crossvalidation,\n", " assess_reader=None,\n", " random_state=42,\n", " vectorize=False) # Ask `experiment` not to featurize; we did it already." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sentence-encoding RNNs\n", "\n", "A more sophisticated sentence-encoding model processes the premise and hypothesis with separate RNNs and uses the concatenation of their final states as the basis for the classification decision at the top:\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is relatively straightforward to extend `torch_rnn_classifier` so that it can handle this architecture:" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### A sentence-encoding dataset\n", "\n", "Whereas `torch_rnn_classifier.TorchRNNDataset` creates batches that consist of `(sequence, sequence_length, label)` triples, the sentence encoding model requires us to double the first two components. The most important features of this is `collate_fn`, which determines what the batches look like:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "class TorchRNNSentenceEncoderDataset(torch.utils.data.Dataset):\n", " def __init__(self, sequences, seq_lengths, y):\n", " self.prem_seqs, self.hyp_seqs = sequences\n", " self.prem_lengths, self.hyp_lengths = seq_lengths\n", " self.y = y\n", " assert len(self.prem_seqs) == len(self.y)\n", "\n", " @staticmethod\n", " def collate_fn(batch):\n", " X_prem, X_hyp, prem_lengths, hyp_lengths, y = zip(*batch)\n", " prem_lengths = torch.LongTensor(prem_lengths)\n", " hyp_lengths = torch.LongTensor(hyp_lengths)\n", " y = torch.LongTensor(y)\n", " return (X_prem, X_hyp), (prem_lengths, hyp_lengths), y\n", "\n", " def __len__(self):\n", " return len(self.prem_seqs)\n", "\n", " def __getitem__(self, idx):\n", " return (self.prem_seqs[idx], self.hyp_seqs[idx],\n", " self.prem_lengths[idx], self.hyp_lengths[idx],\n", " self.y[idx])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### A sentence-encoding model\n", "\n", "With `TorchRNNSentenceEncoderClassifierModel`, we subclass `torch_rnn_classifier.TorchRNNClassifierModel` and make use of many of its parameters. The changes:\n", "\n", "* We add an attribute `self.hypothesis_rnn` for encoding the hypothesis. (The super class has `self.rnn`, which we use for the premise.)\n", "* The `forward` method concatenates the final states from the premise and hypothesis, and they are the input to the classifier layer, which is unchanged from before but how accepts inputs that are double the size." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "class TorchRNNSentenceEncoderClassifierModel(TorchRNNClassifierModel):\n", " def __init__(self, vocab_size, embed_dim, embedding, use_embedding,\n", " hidden_dim, output_dim, bidirectional, device):\n", " super(TorchRNNSentenceEncoderClassifierModel, self).__init__(\n", " vocab_size, embed_dim, embedding, use_embedding,\n", " hidden_dim, output_dim, bidirectional, device)\n", " self.hypothesis_rnn = nn.LSTM(\n", " input_size=self.embed_dim,\n", " hidden_size=hidden_dim,\n", " batch_first=True,\n", " bidirectional=self.bidirectional)\n", " if bidirectional:\n", " classifier_dim = hidden_dim * 2 * 2\n", " else:\n", " classifier_dim = hidden_dim * 2\n", " self.classifier_layer = nn.Linear(\n", " classifier_dim, output_dim)\n", "\n", " def forward(self, X, seq_lengths):\n", " X_prem, X_hyp = X\n", " prem_lengths, hyp_lengths = seq_lengths\n", " prem_state = self.rnn_forward(X_prem, prem_lengths, self.rnn)\n", " hyp_state = self.rnn_forward(X_hyp, hyp_lengths, self.hypothesis_rnn)\n", " state = torch.cat((prem_state, hyp_state), dim=1)\n", " logits = self.classifier_layer(state)\n", " return logits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### A sentence-encoding model interface\n", "\n", "Finally, we subclass `TorchRNNClassifier`. Here, just need to redefine three methods: `build_dataset` and `build_graph` to make use of the new components above, and `predict_proba` so that it deals with the premise/hypothesis shape of new inputs." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "class TorchRNNSentenceEncoderClassifier(TorchRNNClassifier):\n", "\n", " def build_dataset(self, X, y):\n", " X_prem, X_hyp = zip(*X)\n", " X_prem, prem_lengths = self._prepare_dataset(X_prem)\n", " X_hyp, hyp_lengths = self._prepare_dataset(X_hyp)\n", " return TorchRNNSentenceEncoderDataset(\n", " (X_prem, X_hyp), (prem_lengths, hyp_lengths), y)\n", "\n", " def build_graph(self):\n", " return TorchRNNSentenceEncoderClassifierModel(\n", " len(self.vocab),\n", " embedding=self.embedding,\n", " embed_dim=self.embed_dim,\n", " use_embedding=self.use_embedding,\n", " hidden_dim=self.hidden_dim,\n", " output_dim=self.n_classes_,\n", " bidirectional=self.bidirectional,\n", " device=self.device)\n", "\n", " def predict_proba(self, X):\n", " with torch.no_grad():\n", " X_prem, X_hyp = zip(*X)\n", " X_prem, prem_lengths = self._prepare_dataset(X_prem)\n", " X_hyp, hyp_lengths = self._prepare_dataset(X_hyp)\n", " preds = self.model((X_prem, X_hyp), (prem_lengths, hyp_lengths))\n", " preds = torch.softmax(preds, dim=1).cpu().numpy()\n", " return preds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Simple example\n", "\n", "This toy problem illustrates how this works in detail:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "def simple_example():\n", " vocab = ['a', 'b', '$UNK']\n", "\n", " # Reversals are good, and other pairs are bad:\n", " train = [\n", " [(list('ab'), list('ba')), 'good'],\n", " [(list('aab'), list('baa')), 'good'],\n", " [(list('abb'), list('bba')), 'good'],\n", " [(list('aabb'), list('bbaa')), 'good'],\n", " [(list('ba'), list('ba')), 'bad'],\n", " [(list('baa'), list('baa')), 'bad'],\n", " [(list('bba'), list('bab')), 'bad'],\n", " [(list('bbaa'), list('bbab')), 'bad'],\n", " [(list('aba'), list('bab')), 'bad']]\n", "\n", " test = [\n", " [(list('baaa'), list('aabb')), 'bad'],\n", " [(list('abaa'), list('baaa')), 'bad'],\n", " [(list('bbaa'), list('bbaa')), 'bad'],\n", " [(list('aaab'), list('baaa')), 'good'],\n", " [(list('aaabb'), list('bbaaa')), 'good']]\n", "\n", " mod = TorchRNNSentenceEncoderClassifier(\n", " vocab,\n", " max_iter=100,\n", " embed_dim=50,\n", " hidden_dim=50)\n", "\n", " X, y = zip(*train)\n", " mod.fit(X, y)\n", "\n", " X_test, y_test = zip(*test)\n", " preds = mod.predict(X_test)\n", "\n", " print(\"\\nPredictions:\")\n", " for ex, pred, gold in zip(X_test, preds, y_test):\n", " score = \"correct\" if pred == gold else \"incorrect\"\n", " print(\"{0:>6} {1:>6} - predicted: {2:>4}; actual: {3:>4} - {4}\".format(\n", " \"\".join(ex[0]), \"\".join(ex[1]), pred, gold, score))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Finished epoch 100 of 100; error is 2.7179718017578125e-05" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "Predictions:\n", " baaa aabb - predicted: bad; actual: bad - correct\n", " abaa baaa - predicted: bad; actual: bad - correct\n", " bbaa bbaa - predicted: bad; actual: bad - correct\n", " aaab baaa - predicted: good; actual: good - correct\n", " aaabb bbaaa - predicted: good; actual: good - correct\n" ] } ], "source": [ "simple_example()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Example SNLI run" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "def sentence_encoding_rnn_phi(t1, t2):\n", " \"\"\"Map `t1` and `t2` to a pair of lits of leaf nodes.\"\"\"\n", " return (t1.leaves(), t2.leaves())" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "def get_sentence_encoding_vocab(X, n_words=None): \n", " wc = Counter([w for pair in X for ex in pair for w in ex])\n", " wc = wc.most_common(n_words) if n_words else wc.items()\n", " vocab = {w for w, c in wc}\n", " vocab.add(\"$UNK\")\n", " return sorted(vocab)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "def fit_sentence_encoding_rnn(X, y): \n", " vocab = get_sentence_encoding_vocab(X, n_words=10000)\n", " mod = TorchRNNSentenceEncoderClassifier(\n", " vocab, hidden_dim=50, max_iter=50)\n", " mod.fit(X, y)\n", " return mod" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Finished epoch 50 of 50; error is 0.009828846727032214" ] }, { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", "contradiction 0.542 0.540 0.541 5458\n", " entailment 0.567 0.569 0.568 5425\n", " neutral 0.544 0.545 0.545 5528\n", "\n", " micro avg 0.551 0.551 0.551 16411\n", " macro avg 0.551 0.551 0.551 16411\n", " weighted avg 0.551 0.551 0.551 16411\n", "\n" ] } ], "source": [ "_ = nli.experiment(\n", " train_reader=nli.SNLITrainReader(SNLI_HOME, samp_percentage=0.10), \n", " phi=sentence_encoding_rnn_phi,\n", " train_func=fit_sentence_encoding_rnn,\n", " assess_reader=None,\n", " random_state=42,\n", " vectorize=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Other sentence-encoding model ideas\n", "\n", "Given that [we already explored tree-structured neural networks (TreeNNs)](sst_03_neural_networks.ipynb#Tree-structured-neural-networks), it's natural to consider these as the basis for sentence-encoding NLI models:\n", "\n", "\n", "\n", "And this is just the begnning: any model used to represent sentences is presumably a candidate for use in sentence-encoding NLI!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Chained models\n", "\n", "The final major class of NLI designs we look at are those in which the premise and hypothesis are processed sequentially, as a pair. These don't deliver representations of the premise or hypothesis separately. They bear the strongest resemblance to classic sequence-to-sequence models." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Simple RNN\n", "\n", "In the simplest version of this model, we just concatenate the premise and hypothesis. The model itself is identical to the one we used for the Stanford Sentiment Treebank:\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To implement this, we can use `TorchRNNClassifier` out of the box. We just need to concatenate the leaves of the premise and hypothesis trees:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "def simple_chained_rep_rnn_phi(t1, t2):\n", " \"\"\"Map `t1` and `t2` to a single list of leaf nodes.\n", " \n", " A slight variant might insert a designated boundary symbol between \n", " the premise leaves and the hypothesis leaves. Be sure to add it to \n", " the vocab in that case, else it will be $UNK.\n", " \"\"\"\n", " return t1.leaves() + t2.leaves()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's a quick evaluation, just to get a feel for this model:" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "def fit_simple_chained_rnn(X, y): \n", " vocab = utils.get_vocab(X, n_words=10000)\n", " mod = TorchRNNClassifier(vocab, hidden_dim=50, max_iter=50)\n", " mod.fit(X, y)\n", " return mod" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Finished epoch 50 of 50; error is 1.9316425714641814" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.561\n", " precision recall f1-score support\n", "\n", "contradiction 0.564 0.544 0.554 5517\n", " entailment 0.581 0.566 0.573 5441\n", " neutral 0.541 0.573 0.557 5534\n", "\n", " micro avg 0.561 0.561 0.561 16492\n", " macro avg 0.562 0.561 0.561 16492\n", " weighted avg 0.562 0.561 0.561 16492\n", "\n" ] } ], "source": [ "_ = nli.experiment(\n", " train_reader=nli.SNLITrainReader(SNLI_HOME, samp_percentage=0.10), \n", " phi=simple_chained_rep_rnn_phi,\n", " train_func=fit_simple_chained_rnn,\n", " assess_reader=None,\n", " random_state=42,\n", " vectorize=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Separate premise and hypothesis RNNs\n", "\n", "A natural variation on the above is to give the premise and hypothesis each their own RNN:\n", "\n", "\n", "\n", "This greatly increases the number of parameters, but it gives the model more chances to learn that appearing in the premise is different from appearing in the hypothesis. One could even push this idea further by giving the premise and hypothesis their own embeddings as well. One could implement this easily by modifying [the sentence-encoder version defined above](#Sentence-encoding-RNNs)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Attention mechanisms\n", "\n", "Many of the best-performing systems in [the SNLI leaderboard](https://nlp.stanford.edu/projects/snli/) use __attention mechanisms__ to help the model learn important associations between words in the premise and words in the hypothesis. I believe [Rocktäschel et al. (2015)](https://arxiv.org/pdf/1509.06664v1.pdf) were the first to explore such models for NLI.\n", "\n", "For instance, if _puppy_ appears in the premise and _dog_ in the conclusion, then that might be a high-precision indicator that the correct relationship is entailment.\n", "\n", "This diagram is a high-level schematic for adding attention mechanisms to a chained RNN model for NLI:\n", "\n", "\n", "\n", "Since TensorFlow will handle the details of backpropagation, implementing these models is largely reduced to figuring out how to wrangle the states of the model in the desired way." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Error analysis with the MultiNLI annotations\n", "\n", "The annotations included with the MultiNLI corpus create some powerful yet easy opportunities for error analysis right out of the box. This section illustrates how to make use of them with models you've trained.\n", "\n", "First, we train a sentence-encoding model on a sample of the MultiNLI data, just for illustrative purposes:" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Finished epoch 50 of 50; error is 0.024059262155788027" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.426\n", " precision recall f1-score support\n", "\n", "contradiction 0.476 0.467 0.471 3888\n", " entailment 0.400 0.398 0.399 3991\n", " neutral 0.406 0.415 0.410 3929\n", "\n", " micro avg 0.426 0.426 0.426 11808\n", " macro avg 0.427 0.427 0.427 11808\n", " weighted avg 0.427 0.426 0.427 11808\n", "\n" ] } ], "source": [ "rnn_multinli_experiment = nli.experiment(\n", " train_reader=nli.MultiNLITrainReader(MULTINLI_HOME, samp_percentage=0.10), \n", " phi=sentence_encoding_rnn_phi,\n", " train_func=fit_sentence_encoding_rnn,\n", " assess_reader=None,\n", " random_state=42,\n", " vectorize=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The return value of `nli.experiment` contains the information we need to make predictions on new examples. \n", "\n", "Next, we load in the 'matched' condition annotations ('mismatched' would work as well):" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "matched_ann_filename = os.path.join(\n", " ANNOTATIONS_HOME,\n", " \"multinli_1.0_matched_annotations.txt\")" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "matched_ann = nli.read_annotated_subset(\n", " matched_ann_filename, MULTINLI_HOME)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following function uses `rnn_multinli_experiment` to make predictions on annotated examples, and harvests some other information that is useful for error analysis:" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "def predict_annotated_example(ann, experiment_results):\n", " model = experiment_results['model']\n", " phi = experiment_results['phi']\n", " ex = ann['example']\n", " prem = ex.sentence1_parse\n", " hyp = ex.sentence2_parse\n", " feats = phi(prem, hyp)\n", " pred = model.predict([feats])[0]\n", " gold = ex.gold_label\n", " data = {cat: True for cat in ann['annotations']}\n", " data.update({'gold': gold, 'prediction': pred, 'correct': gold == pred})\n", " return data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, this function applies `predict_annotated_example` to a collection of annotated examples and puts the results in a `pd.DataFrame` for flexible analysis:" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "def get_predictions_for_annotated_data(anns, experiment_results):\n", " data = []\n", " for ex_id, ann in anns.items():\n", " results = predict_annotated_example(ann, experiment_results)\n", " data.append(results)\n", " return pd.DataFrame(data)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "ann_analysis_df = get_predictions_for_annotated_data(\n", " matched_ann, rnn_multinli_experiment)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With `ann_analysis_df`, we can see how the model does on individual annotation categories:" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
#MODALTrue
correct
False75
True69
\n", "
" ], "text/plain": [ "#MODAL True\n", "correct \n", "False 75\n", "True 69" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.crosstab(ann_analysis_df['correct'], ann_analysis_df['#MODAL'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Other findings\n", "\n", "1. A high-level lesson of [the SNLI leaderboard](https://nlp.stanford.edu/projects/snli/) is that one can do __extremely well__ with simple neural models whose hyperparameters are selected via extensive cross-validation. This is mathematically interesting but might be dispiriting to those of us without vast resources to devote to these computations! (On the flip side, cleverly designed linear models or ensembles with sparse feature representations might beat all of these entrants with a fraction of the computational budget.)\n", "\n", "1. In an outstanding project for this course in 2016, [Leonid Keselman](https://leonidk.com) observed that [one can do much better than chance on SNLI by processing only the hypothesis](https://leonidk.com/stanford/cs224u.html). This relates to [observations we made in the word-level homework/bake-off](hw4_wordentail.ipynb) about how certain terms will tend to appear more on the right in entailment pairs than on the left. Last year, a number of groups independently (re-)discovered this fact and published analyses: [Poliak et al. 2018](https://aclanthology.info/papers/S18-2023/s18-2023), [Tsuchiya 2018](https://aclanthology.info/papers/L18-1239/l18-1239), [Gururangan et al. 2018](https://aclanthology.info/papers/N18-2017/n18-2017).\n", "\n", "1. As we pointed out at the start of this unit, [Dagan et al. (2006) pitched NLI as a general-purpose NLU task](nli_01_task_and_data.ipynb#Overview). We might then hope that the representations we learn on this task will transfer to others. So far, the evidence for this is decidedly mixed. I suspect the core scientific idea is sound, but that __we still lack the needed methods for doing transfer learning__.\n", "\n", "1. For SNLI, we seem to have entered the inevitable phase in machine learning problems where __ensembles do best__." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exploratory exercises\n", "\n", "These are largely meant to give you a feel for the material, but some of them could lead to projects and help you with future work for the course. These are not for credit.\n", "\n", "1. When we [feed dense representations to a simple classifier](#Dense-representations-with-a-linear-classifier), what is the effect of changing the combination functions (e.g., changing `sum` to `mean`; changing `concatenate` to `difference`)? What happens if we swap out `LogisticRegression` for, say, an [sklearn.ensemble.RandomForestClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html) instance?\n", "\n", "1. Implement the [Separate premise and hypothesis RNN](#Separate-premise-and-hypothesis-RNNs) and evaluate it, comparing in particular against [the version that simply concatenates the premise and hypothesis](#Simple-RNN). Does having all these additional parameters pay off? Do you need more training examples to start to see the value of this idea?\n", "\n", "1. The illustrations above all use SNLI. It is worth experimenting with MultiNLI as well. It has both __matched__ and __mismatched__ dev sets. It's also interesting to think about combining SNLI and MultiNLI, to get additional training instances, to push the models to generalize more, and to assess transfer learning hypotheses." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 2 }