{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "pl0KyQomBQ7l" }, "source": [ "# Evaluating QA: Metrics, Predictions, and the Null Response\n", "> A deep dive into computing QA predictions and when to tell BERT to zip it! \n", "\n", "- title: \"Evaluating QA: Metrics, Predictions, and the Null Response\"\n", "- toc: true \n", "- badges: true\n", "- comments: true\n", "- categories: [no answer, null threshold, bert, distilbert, exact match, F1, robust predictions]" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8sU415V_BQ7n" }, "source": [ "![](my_icons/tomas-sobek-nVqNmnAWz3A-unsplash.jpg \"Sometimes BERT needs to zip it.\")\n", "\n", "\n", "In our last post, [Building a QA System with BERT on Wikipedia](https://qa.fastforwardlabs.com/pytorch/hugging%20face/wikipedia/bert/transformers/2020/05/19/Getting_Started_with_QA.html), we used the HuggingFace framework to train BERT on the SQuAD2.0 dataset and built a simple QA system on top of the Wikipedia search engine. This time, we'll look at how to assess the quality of a BERT-like model for Question Answering. We'll cover what metrics are used to quantify quality, how to evaluate a model using the Hugging Face framework, and the importance of the \"null response\" (questions that don't have answers) for both improved performance and more realistic QA output. By the end of this post, we'll have implemented a more robust answering method for our QA system. \n", "\n", "> Note: Throughout this post we'll be using a distilBERT model fine-tuned on SQuAD2.0 by a member of the NLP community; this model can be found [here](https://huggingface.co/twmkn9/distilbert-base-uncased-squad2) in the HF repository. Additionally, much of the code in this post is inspired by the HF `squad_metrics.py` [script](https://github.com/huggingface/transformers/blob/5856999a9f2926923f037ecd8d27b8058bcf9dae/src/transformers/data/metrics/squad_metrics.py). " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "np9SZdgqBQ7m" }, "source": [ "### Prerequisites\n", "* a basic understanding of Transformers and PyTorch\n", "* a basic understanding of Transformer outputs (logits) and softmax\n", "* a Transformer fine-tuned on SQuAD2.0\n", "* the SQuAD2.0 dev set" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "JT27HZ5NzAaE" }, "source": [ "# Answering questions is complicated\n", "Quantifying the success of question answering is a tricky task. When you or I ask a question, the correct answer could take multiple forms. For example, in our previous post, BERT answered the question, \"Why is the sky blue?\" with \"Rayleigh scattering,\" but another answer would be: \n", "\n", "> The Earth's atmosphere scatters short-wavelength light more efficiently than that of longer wavelengths. Because its wavelengths are shorter, blue light is more strongly scattered than the longer-wavelength lights, red or green. Hence the result that when looking at the sky away from the direct incident sunlight, the human eye perceives the sky to be blue. \n", "\n", "Both of these answers can be found in the Wikipedia article [Diffuse Sky Radiation](https://en.wikipedia.org/wiki/Diffuse_sky_radiation) and both are correct. However, we've also had a model answer the same question with \"because its wavelengths are shorter,\" which is close - but not really a correct answer; the sky itself doesn't have a wavelength. This answer is missing too much context to be useful.\n", "\n", "What if we'd asked a question that couldn't be answered by the Diffuse Sky Radiation page? For example: \"Could the sky ever be green?\" If you read that Wiki article you'll see there probably isn't a sure-fire answer to this question. What should the model do in this case? \n", "\n", "How should we judge a model’s success when there are multiple correct answers, even more incorrect answers, and potentially no answer available to it at all? To properly assess quality, we need a labeled set of questions and answers. Let's turn back to the SQuAD dataset." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "QcFk6SzNBQ7o" }, "source": [ "# The SQuAD2.0 dev set\n", "The [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/) comes in two flavors: SQuAD1.1 and SQuAD2.0. The latter contains the same questions and answers as the former, but also includes additional questions that cannot be answered by the accompanying passage. This is intended to create a more realistic question answering task. The ability to identify unanswerable questions is much more challenging for Transformer models, which is why we focused on the SQuAD2.0 dataset rather than SQuAD1.1. \n", "\n", "SQuAD2.0 consists of over 150k questions, of which more than 35% are unanswerable in relation to their associated passage. [For our last post](https://qa.fastforwardlabs.com/pytorch/hugging%20face/wikipedia/bert/transformers/2020/05/19/Getting_Started_with_QA.html), we fine-tuned on the train set (130k examples); now we'll focus on the dev set, which contains nearly 12k examples. Only about half of these examples are answerable questions. In the following section, we'll look at a couple of these examples to get a feel for them.\n", "\n", "(Use the hidden cells below to get set up, if needed.)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 717 }, "colab_type": "code", "id": "5Kd4H1SnBQ7o", "outputId": "cae5362f-e8dc-4bbc-9610-a30c7fb5aca1" }, "outputs": [], "source": [ "# collapse-hide\n", "\n", "# use this cell to install packages if needed\n", "!pip install torch torchvision -f https://download.pytorch.org/whl/torch_stable.html\n", "!pip install transformers" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": {}, "colab_type": "code", "id": "w0j4I6gjBQ7s" }, "outputs": [], "source": [ "# collapse-hide\n", "import json\n", "import collections\n", "from pprint import pprint\n", "import numpy as np\n", "import torch\n", "from transformers import AutoTokenizer, AutoModelForQuestionAnswering\n", "\n", "# This is the directory in which we'll store all evaluation output\n", "model_dir = \"models/distilbert/twmkn9_distilbert-base-uncased-squad2/\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 204 }, "colab_type": "code", "id": "TedCgru5BQ7u", "outputId": "e5a976df-4c13-4fe1-d362-ccd7458d0fc8" }, "outputs": [], "source": [ "# collapse-hide\n", "\n", "# Download the SQuAD2.0 dev set\n", "!wget -P data/squad/ https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DWMxBldUBQ7x" }, "source": [ "### Load the dev set using HF data processors\n", "\n", "Hugging Face provides the [Processors](https://huggingface.co/transformers/main_classes/processors.html) library for facilitating basic processing tasks with some canonical NLP datasets. The processors can be used for loading datasets and converting their examples to features for direct use in the model. We'll be using the [SQuAD processors](https://huggingface.co/transformers/main_classes/processors.html#squad). " ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "pf5RyO5pBQ7x", "outputId": "314f069b-d106-4ab6-8a0b-29300a0e32cd" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 35/35 [00:05<00:00, 6.71it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "11873\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "from transformers.data.processors.squad import SquadV2Processor\n", "\n", "# this processor loads the SQuAD2.0 dev set examples\n", "processor = SquadV2Processor()\n", "examples = processor.get_dev_examples(\"./data/squad/\", filename=\"dev-v2.0.json\")\n", "print(len(examples))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "nVpWimMtBQ70" }, "source": [ "While `examples` is a list, most other tasks we'll work with use a unique identifier - one for each question in the dev set. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "6Z0OYHoeBQ71" }, "outputs": [], "source": [ "# generate some maps to help us identify examples of interest\n", "qid_to_example_index = {example.qas_id: i for i, example in enumerate(examples)}\n", "qid_to_has_answer = {example.qas_id: bool(example.answers) for example in examples}\n", "answer_qids = [qas_id for qas_id, has_answer in qid_to_has_answer.items() if has_answer]\n", "no_answer_qids = [qas_id for qas_id, has_answer in qid_to_has_answer.items() if not has_answer]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "KsKb1ExvBQ73" }, "outputs": [], "source": [ "def display_example(qid): \n", " from pprint import pprint\n", "\n", " idx = qid_to_example_index[qid]\n", " q = examples[idx].question_text\n", " c = examples[idx].context_text\n", " a = [answer['text'] for answer in examples[idx].answers]\n", " \n", " print(f'Example {idx} of {len(examples)}\\n---------------------')\n", " print(f\"Q: {q}\\n\")\n", " print(\"Context:\")\n", " pprint(c)\n", " print(f\"\\nTrue Answers:\\n{a}\")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "uqw_vVECBQ76" }, "source": [ "#### A positive example \n", "\n", "Approximately 50% of the examples in the dev set are questions that have answers contained within their corresponding passage. In these cases, up to five possible correct answers are provided (questions and answers were generated and identified by crowd-sourced workers). Answers must be direct excerpts from the passage, but we can see there are several ways to arrive at a \"correct\" answer. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 340 }, "colab_type": "code", "id": "tS6FLSNZBQ76", "outputId": "453b966f-5aed-49e6-a2ed-df2fba36f7b8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Example 2548 of 11873\n", "---------------------\n", "Q: Where on Earth is free oxygen found?\n", "\n", "Context:\n", "(\"Free oxygen also occurs in solution in the world's water bodies. The \"\n", " 'increased solubility of O\\n'\n", " '2 at lower temperatures (see Physical properties) has important implications '\n", " 'for ocean life, as polar oceans support a much higher density of life due to '\n", " 'their higher oxygen content. Water polluted with plant nutrients such as '\n", " 'nitrates or phosphates may stimulate growth of algae by a process called '\n", " 'eutrophication and the decay of these organisms and other biomaterials may '\n", " 'reduce amounts of O\\n'\n", " '2 in eutrophic water bodies. Scientists assess this aspect of water quality '\n", " \"by measuring the water's biochemical oxygen demand, or the amount of O\\n\"\n", " '2 needed to restore it to a normal concentration.')\n", "\n", "True Answers:\n", "['water', \"in solution in the world's water bodies\", \"the world's water bodies\"]\n" ] } ], "source": [ "display_example(answer_qids[1300])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "q28-IXiGBQ79" }, "source": [ "#### A negative example\n", "\n", "The other half of the questions in the dev set do not have an answer in the corresponding passage. These questions were generated by crowd-sourced workers to be related and relevant to the passage, but unanswerable by that passage. There are thus no True Answers associated with these questions, as we see in the example below. \n", "\n", "Note: In this case, the question is a trick -- the numbers are reoriented in a way that no longer holds true. Will the model pick up on that?" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 306 }, "colab_type": "code", "id": "As9Yw52PBQ79", "outputId": "2e455b9a-586f-4867-8c2a-8498557ddfc5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Example 2564 of 11873\n", "---------------------\n", "Q: What happened 3.7-2 billion years ago?\n", "\n", "Context:\n", "(\"Free oxygen gas was almost nonexistent in Earth's atmosphere before \"\n", " 'photosynthetic archaea and bacteria evolved, probably about 3.5 billion '\n", " 'years ago. Free oxygen first appeared in significant quantities during the '\n", " 'Paleoproterozoic eon (between 3.0 and 2.3 billion years ago). For the first '\n", " 'billion years, any free oxygen produced by these organisms combined with '\n", " 'dissolved iron in the oceans to form banded iron formations. When such '\n", " 'oxygen sinks became saturated, free oxygen began to outgas from the oceans '\n", " '3–2.7 billion years ago, reaching 10% of its present level around 1.7 '\n", " 'billion years ago.')\n", "\n", "True Answers:\n", "[]\n" ] } ], "source": [ "display_example(no_answer_qids[1254])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "A0nzI48yBQ7_" }, "source": [ "# Metrics for QA\n", "\n", "There are two dominant metrics used by many question answering datasets, including SQuAD: exact match (EM) and F1 score. These scores are computed on individual question+answer pairs. When multiple correct answers are possible for a given question, the maximum score over all possible correct answers is computed. Overall EM and F1 scores are computed for a model by averaging over the individual example scores. \n", "\n", "\n", "### Exact Match\n", "This metric is as simple as it sounds. For each question+answer pair, if the _characters_ of the model's prediction exactly match the characters of (one of) the True Answer(s), EM = 1, otherwise EM = 0. This is a strict all-or-nothing metric; being off by a single character results in a score of 0. When assessing against a negative example, if the model predicts any text at all, it automatically receives a 0 for that example. \n", "\n", "### F1 \n", "F1 score is a common metric for classification problems, and widely used in QA. It is appropriate when we care equally about precision and recall. In this case, it's computed over the individual _words_ in the prediction against those in the True Answer. The number of shared words between the prediction and the truth is the basis of the F1 score: precision is the ratio of the number of shared words to the total number of words in the _prediction_, and recall is the ratio of the number of shared words to the total number of words in the _ground truth_.\n", "\n", "\n", "![](my_icons/f1score.png \"Thanks Wikipedia\")\n", "\n", "Let's see how these metrics work in practice. We'll load up a fine-tuned model ([this one](https://huggingface.co/twmkn9/distilbert-base-uncased-squad2), to be precise) and its tokenizer, and compare our predictions against the True Answers." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "In0YUapuBQ8A" }, "source": [ "### Load a Transformer model fine-tuned on SQuAD 2.0" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 71 }, "colab_type": "code", "id": "ryQvHxDCBQ8A", "outputId": "9f7fd9e3-f3e7-40b4-e190-11526ebe09c9" }, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(\"twmkn9/distilbert-base-uncased-squad2\")\n", "model = AutoModelForQuestionAnswering.from_pretrained(\"twmkn9/distilbert-base-uncased-squad2\")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "x6pUc1pNBQ8C" }, "source": [ "The following `get_prediction` method is essentially identical to what we used last time in our simple QA system." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": {}, "colab_type": "code", "id": "MzPlHgWEBQ8D" }, "outputs": [], "source": [ "def get_prediction(qid):\n", " # given a question id (qas_id or qid), load the example, get the model outputs and generate an answer\n", " question = examples[qid_to_example_index[qid]].question_text\n", " context = examples[qid_to_example_index[qid]].context_text\n", "\n", " inputs = tokenizer.encode_plus(question, context, return_tensors='pt')\n", "\n", " outputs = model(**inputs)\n", " answer_start = torch.argmax(outputs[0]) # get the most likely beginning of answer with the argmax of the score\n", " answer_end = torch.argmax(outputs[1]) + 1 \n", "\n", " answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))\n", "\n", " return answer" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "EamwnbLzBQ8F" }, "source": [ "Below are some functions we'll need to compute our quality metrics. " ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": {}, "colab_type": "code", "id": "5fCmCYttBQ8G" }, "outputs": [], "source": [ "# these functions are heavily influenced by the HF squad_metrics.py script\n", "def normalize_text(s):\n", " \"\"\"Removing articles and punctuation, and standardizing whitespace are all typical text processing steps.\"\"\"\n", " import string, re\n", "\n", " def remove_articles(text):\n", " regex = re.compile(r\"\\b(a|an|the)\\b\", re.UNICODE)\n", " return re.sub(regex, \" \", text)\n", "\n", " def white_space_fix(text):\n", " return \" \".join(text.split())\n", "\n", " def remove_punc(text):\n", " exclude = set(string.punctuation)\n", " return \"\".join(ch for ch in text if ch not in exclude)\n", "\n", " def lower(text):\n", " return text.lower()\n", "\n", " return white_space_fix(remove_articles(remove_punc(lower(s))))\n", "\n", "def compute_exact_match(prediction, truth):\n", " return int(normalize_text(prediction) == normalize_text(truth))\n", "\n", "def compute_f1(prediction, truth):\n", " pred_tokens = normalize_text(prediction).split()\n", " truth_tokens = normalize_text(truth).split()\n", " \n", " # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise\n", " if len(pred_tokens) == 0 or len(truth_tokens) == 0:\n", " return int(pred_tokens == truth_tokens)\n", " \n", " common_tokens = set(pred_tokens) & set(truth_tokens)\n", " \n", " # if there are no common tokens then f1 = 0\n", " if len(common_tokens) == 0:\n", " return 0\n", " \n", " prec = len(common_tokens) / len(pred_tokens)\n", " rec = len(common_tokens) / len(truth_tokens)\n", " \n", " return 2 * (prec * rec) / (prec + rec)\n", "\n", "def get_gold_answers(example):\n", " \"\"\"helper function that retrieves all possible true answers from a squad2.0 example\"\"\"\n", " \n", " gold_answers = [answer[\"text\"] for answer in example.answers if answer[\"text\"]]\n", "\n", " # if gold_answers doesn't exist it's because this is a negative example - \n", " # the only correct answer is an empty string\n", " if not gold_answers:\n", " gold_answers = [\"\"]\n", " \n", " return gold_answers" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "eo6lXxP5BQ8I" }, "source": [ "In the following cell, we start by computing EM and F1 for our first example - the one that has several True Answers associated with it." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "id": "NBpm1l47BQ8I", "outputId": "0c5ca9ad-72c8-4ba4-8b1a-4fa0fbacc28f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Question: Where on Earth is free oxygen found?\n", "Prediction: water bodies\n", "True Answers: ['water', \"in solution in the world's water bodies\", \"the world's water bodies\"]\n", "EM: 0 \t F1: 0.8\n" ] } ], "source": [ "prediction = get_prediction(answer_qids[1300])\n", "example = examples[qid_to_example_index[answer_qids[1300]]]\n", "\n", "gold_answers = get_gold_answers(example)\n", "\n", "em_score = max((compute_exact_match(prediction, answer)) for answer in gold_answers)\n", "f1_score = max((compute_f1(prediction, answer)) for answer in gold_answers)\n", "\n", "print(f\"Question: {example.question_text}\")\n", "print(f\"Prediction: {prediction}\")\n", "print(f\"True Answers: {gold_answers}\")\n", "print(f\"EM: {em_score} \\t F1: {f1_score}\")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_xQUCuM6BQ8K" }, "source": [ "We see that our prediction is actually quite close to some of the True Answers, resulting in a respectable F1 score. However, it does not exactly match any of them, so our EM score is 0. \n", "\n", "Let's try with our negative example now. " ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 105 }, "colab_type": "code", "id": "pbtb9Nk0BQ8L", "outputId": "5c65ca9a-e3f7-4c51-8a49-888f1d638d72" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Question: What happened 3.7-2 billion years ago?\n", "Prediction: [CLS] what happened 3 . 7 - 2 billion years ago ? [SEP] free oxygen gas was almost nonexistent in earth ' s atmosphere before photosynthetic archaea and bacteria evolved , probably about 3 . 5 billion years ago . free oxygen first appeared in significant quantities during the paleoproterozoic eon ( between 3 . 0 and 2 . 3 billion years ago ) . for the first billion years , any free oxygen produced by these organisms combined with dissolved iron in the oceans to form banded iron formations . when such oxygen sinks became saturated , free oxygen began to outgas from the oceans\n", "True Answers: ['']\n", "EM: 0 \t F1: 0\n" ] } ], "source": [ "prediction = get_prediction(no_answer_qids[1254])\n", "example = examples[qid_to_example_index[no_answer_qids[1254]]]\n", "\n", "gold_answers = get_gold_answers(example)\n", "\n", "em_score = max((compute_exact_match(prediction, answer)) for answer in gold_answers)\n", "f1_score = max((compute_f1(prediction, answer)) for answer in gold_answers)\n", "\n", "print(f\"Question: {example.question_text}\")\n", "print(f\"Prediction: {prediction}\")\n", "print(f\"True Answers: {gold_answers}\")\n", "print(f\"EM: {em_score} \\t F1: {f1_score}\")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "XKTiDRRuBQ8O" }, "source": [ "Wow. Both our metrics are zero, because this model does not correctly assess that this question is unanswerable! Even worse, it seems to have catastrophically failed, including the entire question as part of the answer. In a later section, we'll explicitly dig into why this happens, but for now, it's important to note that we got this answer because we simply extracted start and end tokens associated with the maximum score (we took an `argmax` of the model output in `get_prediction`) and this lead to some unintended consequences. \n", "\n", "Now that we’ve seen the basics of computing QA metrics on a couple of examples, we need to assess the model on the entire dev set. Luckily, there's a script for that." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Y2dTdvr1BQ8P" }, "source": [ "# Evaluating a model on the SQuAD2.0 dev set with HF\n", "\n", "The same `run_squad.py` script we used to fine-tune a Transformer for question answering can also be used to evaluate the model. (You can grab the script [here](https://github.com/huggingface/transformers/blob/master/examples/question-answering/run_squad.py) or run the hidden cell below.)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 68 }, "colab_type": "code", "id": "FiQAPvqukMi6", "outputId": "8463c633-2000-40c6-d556-29c3b35f1d8f" }, "outputs": [], "source": [ "# collapse-hide\n", "\n", "# Grab the run_squad.py script\n", "!curl -L -O https://raw.githubusercontent.com/huggingface/transformers/master/examples/question-answering/run_squad.py" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "H0s7n_0Pk-Qu" }, "source": [ "Below are the arguments needed to properly evaluate a fine-tuned model for question answering on the SQuAD dev set. Because we're using SQuAD2.0, it is **crucial** to include the `--version_2_with_negative` flag!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "CZmLUi7QBQ8P" }, "outputs": [], "source": [ "!python run_squad.py \\\n", " --model_type distilbert \\\n", " --model_name_or_path twmkn9/distilbert-base-uncased-squad2 \\\n", " --output_dir models/distilbert/twmkn9_distilbert-base-uncased-squad2 \\\n", " --data_dir data/squad \\\n", " --predict_file dev-v2.0.json \\\n", " --do_eval \\\n", " --version_2_with_negative \\\n", " --do_lower_case \\\n", " --per_gpu_eval_batch_size 12 \\\n", " --max_seq_length 384 \\\n", " --doc_stride 128" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "8elqhokWBQ8R" }, "source": [ "Refer to [our last post](https://qa.fastforwardlabs.com/pytorch/hugging%20face/wikipedia/bert/transformers/2020/05/19/Getting_Started_with_QA.html) for more details on what these arguments mean and what this script does. For our immediate purposes, running the cell above will produce the following output in the `--output_dir` directory:\n", "\n", "* `predictions_.json`\n", "* `nbest_predictions_.json`\n", "* `null_odds_.json`\n", "\n", "(We'll go over what these are later on.) Additionally, an overall `Results` dict will be displayed to the screen. If you run the above cell, the last line of output should display something like the following: " ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": {}, "colab_type": "code", "id": "1Ee4ckURBQ8S" }, "outputs": [], "source": [ "Results = {\n", " # a) scores averaged over all examples in the dev set\n", " 'exact': 66.25958056093658, \n", " 'f1': 69.66994428499025, \n", " 'total': 11873, # number of examples in the dev set\n", " \n", " # b) scores averaged over only positive examples (have answers)\n", " 'HasAns_exact': 68.91025641025641, \n", " 'HasAns_f1': 75.74076391627662, \n", " 'HasAns_total': 5928, # number of positive examples\n", " \n", " # c) scores averaged over only negative examples (no answers)\n", " 'NoAns_exact': 63.61648444070648, \n", " 'NoAns_f1': 63.61648444070648, \n", " 'NoAns_total': 5945, # number of negative examples\n", " \n", " # d) given probabilities of no-answer for each example, what would the best scores and thresholds be?\n", " 'best_exact': 66.25958056093658, \n", " 'best_exact_thresh': 0.0, \n", " 'best_f1': 69.66994428499046, \n", " 'best_f1_thresh': 0.0\n", "}" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "pKVJVU5hBQ8U" }, "source": [ "The first three blocks of the `Results` output are pretty straightforward. EM and F1 scores are reported over a) the full dev set, b) the set of positive examples, and c) the set of negative examples. This can provide some insight into whether a model is performing adequately on both answer and no-answer questions. (This particular model is pretty bad at no-answer questions). \n", "\n", "However, what's going on with the last block? This portion of the output is not useful unless we supply the evaluation method with additional information. For that, we'll need to dig deeper into the evaluation process - because it turns out that we need to compute more than just a prediction for an answer; we must also compute a prediction for NO answer and we must score both predictions!\n", "\n", "The following section will dive into the technical details of computing robust predictions on SQuAD2.0 examples, including how to score an answer and the null answer, as well as how to determine which one should be the \"correct\" prediction for a given example. Feel free to skip to the [next section](#Using-the-null-threshold) for the punchline. (For those of you considering building your own QA system, we found this information to be invaluable for understanding the inner workings of prediction and assessment.)\n", "\n", "\n", "### [Optional] Computing predictions\n", "\n", "> Note: The code in the following section is an under-the-hood dive into the HF `compute_predictions_logits` [method](https://github.com/huggingface/transformers/blob/5856999a9f2926923f037ecd8d27b8058bcf9dae/src/transformers/data/metrics/squad_metrics.py#L371-L573) in their `squad_metrics.py` script. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When the tokenized question+context is passed to the model, the output consists of two sets of logits: one for the start of the answer span, the other for the end of the answer span. These logits represent the likelihood of any given token being the start or end of the answer. Every token passed to the model is assigned a logit, including special tokens (e.g., [CLS], [SEP]), and tokens corresponding to the question itself. \n", "\n", "Let's walk through the process using our last example (Q: What happened 3.7-2 billion years ago?). " ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": {}, "colab_type": "code", "id": "vWCs3sDJBQ8U" }, "outputs": [], "source": [ "inputs = tokenizer.encode_plus(example.question_text, example.context_text, return_tensors='pt')\n", "start_logits, end_logits = model(**inputs)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 391 }, "colab_type": "code", "id": "Nqo4G2G5BQ8W", "outputId": "31b6a901-5dbf-4c03-dff0-881867b1fec8" }, "outputs": [ { "data": { "text/plain": [ "tensor([[ 6.4914, -9.1416, -8.4068, -7.5684, -9.9081, -9.4256, -10.1625,\n", " -9.2579, -10.0554, -9.9653, -9.2002, -8.8657, -9.1162, 0.6481,\n", " -2.5947, -4.5072, -8.1189, -6.5871, -5.8973, -10.8619, -11.0953,\n", " -10.2294, -9.3660, -7.6017, -10.8009, -10.8197, -6.1258, -8.3507,\n", " -4.2463, -10.0987, -10.2659, -8.8490, -6.7346, -8.6513, -9.7573,\n", " -5.7496, -5.5851, -8.9483, -7.0652, -6.1369, -5.7810, -9.4366,\n", " -8.7670, -9.6743, -9.7446, -7.7905, -7.4541, -1.5963, -3.8540,\n", " -7.3450, -8.1854, -9.5566, -8.3416, -8.9553, -8.3144, -6.4132,\n", " -4.2285, -9.4427, -9.5111, -9.2931, -8.9154, -9.3930, -8.2111,\n", " -8.9774, -9.0274, -7.2652, -7.4511, -9.8597, -9.5869, -9.9735,\n", " -7.0526, -9.7560, -8.7788, -9.5117, -9.6391, -8.6487, -9.5994,\n", " -7.8213, -5.1754, -4.3561, -4.3913, -7.8499, -7.7522, -8.9651,\n", " -3.5229, -0.8312, -2.7668, -7.9180, -10.0320, -8.7797, -4.5965,\n", " -5.9465, -9.9442, -3.2135, -5.0734, -8.3462, -7.5366, -3.7073,\n", " -7.0968, -4.3325, -1.3691, -4.1477, -5.3794, -7.6138, 1.3183,\n", " -3.4190, 3.1457, -3.0152, -0.4102, -2.4606, -3.5971, 6.4519,\n", " -0.5654, 0.9829, -1.6682, 3.3549, -4.7847, -2.8024, -3.3160,\n", " -0.5868, -0.9617, -8.1925, -4.3299, -7.3923, -5.0875, -5.3880,\n", " -5.3676, -3.0878, -4.3427, 4.3975, 1.8860, -5.4661, -9.1565,\n", " -3.6369, -3.5462, -4.1448, -2.0250, -2.4492, -8.7015, -7.3292,\n", " -7.7616, -7.0786, -4.6668, -4.4089, -9.1182]],\n", " grad_fn=)" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# look at how large the logit is in the [CLS] position (index 0)! \n", "# strong possibility that this question has no answer... but our prediction returned an answer anyway!\n", "start_logits" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "F9kYQXAYBQ8Y" }, "source": [ "In our simple QA system, we predicted the best answer by selecting the start and end tokens with the largest logits, but that's not very robust. In fact, the original [BERT paper](https://arxiv.org/abs/1810.04805) suggested considering any sensible start+end combination as a possible answer to the question. These combinations would then be scored, and the one with the highest score would be considered the best answer. A possible (candidate) answer is scored as the sum of its start and end logits. \n", "\n", "> Note: This reflects how a basic span extraction classifier works. The raw hidden layer from the model is passed through a `Linear` layer and then fed to a `CrossEntropyLoss` for each class. In span extraction, there are two classes: the beginning of the span and the end of the span. The span loss is computed as the sum of the `CrossEntropyLoss` for the start and end positions. The probability of an answer span is the probability of a given start token S and an end token E: P(S and E) = P(S)P(E), because the start and end tokens are treated as being independent. Thus summing the start and end logits is equivalent to a product of their softmax probabilities. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To mimic this behavior, we'll start by taking the _n_ largest `start_logits` and the _n_ largest `end_logits` as candidates. Any sensible combination of these start + end tokens is considered a candidate answer; however, several consistency checks must first be performed. For example, an answer wherein the end token falls before the start token should be excluded, because that just doesn't make sense. Candidate answers wherein the start or end tokens are associated with question tokens are also excluded, because the answer to the question should obviously not be in the question itself! It is important to note that the [CLS] token and its corresponding logits are not removed, because this token indicates the null answer. " ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "OnllFnCeBQ8Y", "outputId": "bf71470a-de03-419e-9181-425db0739f76", "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[(0, 6.491387367248535), (111, 6.451895713806152), (129, 4.397505760192871), (115, 3.354909658432007), (106, 3.1457457542419434)]\n", "[(119, 6.33292293548584), (0, 6.084450721740723), (135, 4.417276382446289), (116, 4.3764214515686035), (112, 4.125303268432617)]\n" ] } ], "source": [ "def to_list(tensor):\n", " return tensor.detach().cpu().tolist()\n", "\n", "# convert our start and end logit tensors to lists\n", "start_logits = to_list(start_logits)[0]\n", "end_logits = to_list(end_logits)[0]\n", "\n", "# sort our start and end logits from largest to smallest, keeping track of the index\n", "start_idx_and_logit = sorted(enumerate(start_logits), key=lambda x: x[1], reverse=True)\n", "end_idx_and_logit = sorted(enumerate(end_logits), key=lambda x: x[1], reverse=True)\n", "\n", "# select the top n (in this case, 5)\n", "print(start_idx_and_logit[:5])\n", "print(end_idx_and_logit[:5]) " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "bUYkA7MhBQ8a" }, "source": [ "The null answer token (index 0) is in the top five of both the start and end logit lists.\n", "\n", "In order to eventually predict a text answer (or empty string), we need to keep track of the indexes which will be used to pull the corresponding token ids later on. We'll also need to identify which indexes correspond to the question tokens, so we can ensure we don't allow a nonsensical prediction." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "7A3CPhLtBQ8a", "outputId": "882ffa34-18fc-4713-8834-2d088cd5ceaf" }, "outputs": [ { "data": { "text/plain": [ "[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "start_indexes = [idx for idx, logit in start_idx_and_logit[:5]]\n", "end_indexes = [idx for idx, logit in end_idx_and_logit[:5]]\n", "\n", "# convert the token ids from a tensor to a list\n", "tokens = to_list(inputs['input_ids'])[0]\n", "\n", "# question tokens are defined as those between the CLS token (101, at position 0) and first SEP (102) token \n", "question_indexes = [i+1 for i, token in enumerate(tokens[1:tokens.index(102)])]\n", "question_indexes" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Gl6OpPGCBQ8c" }, "source": [ "Next, we'll generate a list of candidate predictions by looping through all combinations of the start and end token indexes, excluding nonsensical combinations. We'll save these to a list for the next step." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "colab": {}, "colab_type": "code", "id": "6dxSmamFBQ8c" }, "outputs": [], "source": [ "import collections\n", "\n", "# keep track of all preliminary predictions\n", "PrelimPrediction = collections.namedtuple( \n", " \"PrelimPrediction\", [\"start_index\", \"end_index\", \"start_logit\", \"end_logit\"]\n", ")\n", "\n", "prelim_preds = []\n", "for start_index in start_indexes:\n", " for end_index in end_indexes:\n", " # throw out invalid predictions\n", " if start_index in question_indexes:\n", " continue\n", " if end_index in question_indexes:\n", " continue\n", " if end_index < start_index:\n", " continue\n", " prelim_preds.append(\n", " PrelimPrediction(\n", " start_index = start_index,\n", " end_index = end_index,\n", " start_logit = start_logits[start_index],\n", " end_logit = end_logits[end_index]\n", " )\n", " )" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "d6idRrMkBQ8e" }, "source": [ "With a list of sensible candidate predictions, it's time to score them.\n", "\n", "For a candidate answer, score = `start_logit` + `end_logit`. Below, we sort our candidate predictions by their score." ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 306 }, "colab_type": "code", "id": "s7qlh7rsBQ8e", "outputId": "54fd57d1-035a-4a71-dd45-a7eee18f761d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[PrelimPrediction(start_index=0, end_index=119, start_logit=6.491387367248535, end_logit=6.33292293548584),\n", " PrelimPrediction(start_index=111, end_index=119, start_logit=6.451895713806152, end_logit=6.33292293548584),\n", " PrelimPrediction(start_index=0, end_index=0, start_logit=6.491387367248535, end_logit=6.084450721740723),\n", " PrelimPrediction(start_index=0, end_index=135, start_logit=6.491387367248535, end_logit=4.417276382446289),\n", " PrelimPrediction(start_index=111, end_index=135, start_logit=6.451895713806152, end_logit=4.417276382446289)]\n" ] } ], "source": [ "# sort preliminary predictions by their score\n", "prelim_preds = sorted(prelim_preds, key=lambda x: (x.start_logit + x.end_logit), reverse=True)\n", "pprint(prelim_preds[:5])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "CtHhWp42BQ8h" }, "source": [ "Next we need to convert our preliminary predictions into actual text (or the empty string, if null). We'll keep track of text predictions we've seen, because different token combinations can result in the same text prediction and we only want to keep the one with the highest score (we're looping in descending score order). Finally, we'll trim this list down to the best 5 predictions. " ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": {}, "colab_type": "code", "id": "2USLqoLhBQ8h" }, "outputs": [], "source": [ "# keep track of all best predictions\n", "BestPrediction = collections.namedtuple( # pylint: disable=invalid-name\n", " \"BestPrediction\", [\"text\", \"start_logit\", \"end_logit\"]\n", ")\n", "\n", "nbest = []\n", "seen_predictions = []\n", "for pred in prelim_preds:\n", " \n", " # for now we only care about the top 5 best predictions\n", " if len(nbest) >= 5: \n", " break\n", " \n", " # loop through predictions according to their start index\n", " if pred.start_index > 0: # non-null answers have start_index > 0\n", "\n", " text = tokenizer.convert_tokens_to_string(\n", " tokenizer.convert_ids_to_tokens(\n", " tokens[pred.start_index:pred.end_index+1]\n", " )\n", " )\n", " # clean whitespace\n", " text = text.strip()\n", " text = \" \".join(text.split())\n", "\n", " if text in seen_predictions:\n", " continue\n", "\n", " # flag this text as being seen -- if we see it again, don't add it to the nbest list\n", " seen_predictions.append(text) \n", "\n", " # add this text prediction to a pruned list of the top 5 best predictions\n", " nbest.append(BestPrediction(text=text, start_logit=pred.start_logit, end_logit=pred.end_logit))\n", "\n", "# and don't forget -- include the null answer!\n", "nbest.append(BestPrediction(text=\"\", start_logit=start_logits[0], end_logit=end_logits[0]))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qDIFEXy1BQ8i" }, "source": [ "The null answer is scored as the sum of the start_logit and end_logit associated with the [CLS] token.\n", " \n", "At this point, we have a neat list of the top 5 best predictions for this question. The number of best predictions for each example is adjustable with the `--n_best_size` argument of the `run_squad.py` script. The `nbest` predictions for _every question_ in the dev set are saved to disk under `nbest_predictions_.json` in `--output_dir`. (This is a great resource for digging into how a model is behaving.) Let's take a look at our `nbest` predictions." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 139 }, "colab_type": "code", "id": "O-jHOrUABQ8j", "outputId": "1892fe60-0c2c-4cfc-975f-bc3bdf9d7bed" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[BestPrediction(text='free oxygen began to outgas from the oceans', start_logit=6.451895713806152, end_logit=6.33292293548584),\n", " BestPrediction(text='free oxygen began to outgas from the oceans 3 – 2 . 7 billion years ago , reaching 10 % of its present level', start_logit=6.451895713806152, end_logit=4.417276382446289),\n", " BestPrediction(text='free oxygen began to outgas', start_logit=6.451895713806152, end_logit=4.3764214515686035),\n", " BestPrediction(text='free oxygen', start_logit=6.451895713806152, end_logit=4.125303268432617),\n", " BestPrediction(text='outgas from the oceans', start_logit=3.354909658432007, end_logit=6.33292293548584),\n", " BestPrediction(text='', start_logit=6.491387367248535, end_logit=6.084450721740723)]\n" ] } ], "source": [ "pprint(nbest)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "9uXJxXPwBQ8k" }, "source": [ "Our top prediction so far is \"free oxygen began to outgas from the oceans,\" which is already a far cry better than what we originally predicted. This is because we have successfully excluded nonsensical predictions that would incorporate question tokens as part of the answer. However, we know it's still incorrect. Let's keep going. \n", "\n", "The last step is to compute the null score -- more specifically, the difference between the null score and the best non-null score as shown below." ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "vjkpoZ1dBQ8k", "outputId": "4f837605-403d-46e9-94d5-495e3efc6941" }, "outputs": [ { "data": { "text/plain": [ "-0.20898056030273438" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# compute the null score as the sum of the [CLS] token logits\n", "score_null = start_logits[0] + end_logits[0]\n", "\n", "# compute the difference between the null score and the best non-null score\n", "score_diff = score_null - nbest[0].start_logit - nbest[0].end_logit\n", "\n", "score_diff" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "pf0lbxj5BQ8m" }, "source": [ "This `score_diff` is computed for every example in the dev set and these scores are saved to disk in the `null_odds_.json`. Let's pull up the score stored for the example we're using and see how we did!" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "TS74Yow4BQ8m", "outputId": "b3107661-d98e-4d9b-eea7-e0c6cede878a" }, "outputs": [ { "data": { "text/plain": [ "-0.2090005874633789" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "filename = model_dir + 'null_odds_.json'\n", "null_odds = json.load(open(filename, 'rb'))\n", "\n", "null_odds[example.qas_id]" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "xRNzEiE0BQ8o" }, "source": [ "We basically nailed it! (The full HF version contains a few more checks and some additional subtleties that could account for the slight differences in our `score_diff`.)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "t4s4Bx3LBQ8p" }, "source": [ "### Using the null threshold\n", "\n", "In the previous section we covered:\n", "* how to generate more robust predictions (e.g., by excluding predictions that include question tokens in the answer),\n", "* how to score a prediction as the sum of its start and end logits,\n", "* and how to compute the score difference between the null prediction and the best text prediction.\n", "\n", "\n", "The `run_squad.py` script performs all of these tasks for us and saves the score differences for every example in the `null_odds_.json`. With that, we can now start to make sense of the fourth block of the results output!\n", "\n", "According to the original [BERT paper](https://arxiv.org/abs/1810.04805), \n", "\n", "> We predict a non-null answer when sˆi,j > s_null + τ , where the threshold τ is selected on the dev set to maximize F1. \n", "\n", "In other words, the authors are saying that one should predict a null answer for a given example if that example's score difference is above a certain threshold. What should that threshold be? How should we compute it? They give us a recipe: select the threshold that maximizes F1. Rather than rerunning `run_squad.py`, we can import the aptly-named method that computes SQuAD evaluation: `squad_evaluate`. (You can take a look at the code for yourself [here](https://github.com/huggingface/transformers/blob/5856999a9f2926923f037ecd8d27b8058bcf9dae/src/transformers/data/metrics/squad_metrics.py#L211-L239).)\n", "\n", "To use `squad_evaluate` we'll need:\n", "\n", "* the original examples (because that's where the True Answers are stored),\n", "* `predictions_.json`,\n", "* `null_odds_.json`,\n", "* and a null threshold." ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "colab": {}, "colab_type": "code", "id": "yQMK-gjSBQ8q" }, "outputs": [], "source": [ "# load the predictions we generated earlier\n", "filename = model_dir + 'predictions_.json'\n", "preds = json.load(open(filename, 'rb'))\n", "\n", "# load the null score differences we generated earlier\n", "filename = model_dir + 'null_odds_.json'\n", "null_odds = json.load(open(filename, 'rb'))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "EDboU14MBQ8s" }, "source": [ "Let's re-evaluate our model on SQuAD2.0 using the `squad_evaluate` method. This method uses the score differences for each example in the dev set to determine thresholds that maximize either the EM score or the F1 score. It then recomputes the best possible EM score and F1 score associated with that null threshold. " ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": {}, "colab_type": "code", "id": "bZ36KXH3BQ8s" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OrderedDict([('exact', 66.25958056093658),\n", " ('f1', 69.66994428499025),\n", " ('total', 11873),\n", " ('HasAns_exact', 68.91025641025641),\n", " ('HasAns_f1', 75.74076391627662),\n", " ('HasAns_total', 5928),\n", " ('NoAns_exact', 63.61648444070648),\n", " ('NoAns_f1', 63.61648444070648),\n", " ('NoAns_total', 5945),\n", " ('best_exact', 68.36519834919565),\n", " ('best_exact_thresh', -4.189256191253662),\n", " ('best_f1', 71.1144383018176),\n", " ('best_f1_thresh', -3.767639636993408)])\n" ] } ], "source": [ "from transformers.data.metrics.squad_metrics import squad_evaluate\n", "\n", "# the default threshold is set to 1.0 -- we'll leave it there for now\n", "results_default_thresh = squad_evaluate(examples, \n", " preds, \n", " no_answer_probs=null_odds, \n", " no_answer_probability_threshold=1.0)\n", "\n", "pprint(results_default_thresh)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "2v0EOHM4BQ8v" }, "source": [ "The first three blocks have identical values as in our initial evaluation because they are based on the default threshold (which is currently 1.0). However, the values in the fourth block have been updated by taking into account the `null_odds` information. When a given example's `score_diff` is greater than the threshold, the prediction is flipped to a null answer which affects the overall EM and F1 scores. \n", "\n", "Let's use the `best_f1_thresh` and run the evaluation once more to see a breakdown of our model's performance on `HasAns` and `NoAns` examples:" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": {}, "colab_type": "code", "id": "BXoyssXkBQ8v" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "OrderedDict([('exact', 68.31466352227744),\n", " ('f1', 71.11106931335648),\n", " ('total', 11873),\n", " ('HasAns_exact', 61.53846153846154),\n", " ('HasAns_f1', 67.13929250294865),\n", " ('HasAns_total', 5928),\n", " ('NoAns_exact', 75.07148864592094),\n", " ('NoAns_f1', 75.07148864592094),\n", " ('NoAns_total', 5945),\n", " ('best_exact', 68.36519834919565),\n", " ('best_exact_thresh', -4.189256191253662),\n", " ('best_f1', 71.1144383018176),\n", " ('best_f1_thresh', -3.767639636993408)])\n" ] } ], "source": [ "best_f1_thresh = -3.7676548957824707\n", "results_f1_thresh = squad_evaluate(examples, \n", " preds, \n", " no_answer_probs=null_odds, \n", " no_answer_probability_threshold=best_f1_thresh)\n", "\n", "pprint(results_f1_thresh)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "tsLQYSE_BQ8y" }, "source": [ "When we used the default threshold of 1.0, we saw that our `NoAns_f1` score was a mere 63.6, but when we use the `best_f1_thresh`, we now get a `NoAns_f1` score of 75 - nearly a 12 point jump! The downside is that we lose some ground in how well our model correctly predicts `HasAns` examples. Overall, however, we see a net increase of a couple points in both EM and F1 scores. This demonstrates that computing null scores and properly using a null threshold significantly increases QA performance on the SQuAD2.0 dev set with almost no additional work." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qBGB1LPFBQ8z" }, "source": [ "# Putting it all together\n", "\n", "Below we present a new method that will select more robust predictions, compute scores for the best text predictions (as well as for the null prediction), and use these scores along with a null threshold to determine whether the question should be answered. As a bonus, this method also computes and returns the probability of the answer, which is often easier to interpret than a logit score. Prediction probabilities depend on `nbest`, since they are computed with a softmax over the number of most likely predictions. " ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "colab": {}, "colab_type": "code", "id": "mQheg5xqBQ8z" }, "outputs": [], "source": [ "def get_robust_prediction(example, tokenizer, nbest=10, null_threshold=1.0):\n", " \n", " inputs = get_qa_inputs(example, tokenizer)\n", " start_logits, end_logits = model(**inputs)\n", "\n", " # get sensible preliminary predictions, sorted by score\n", " prelim_preds = preliminary_predictions(start_logits, \n", " end_logits, \n", " inputs['input_ids'],\n", " nbest)\n", " \n", " # narrow that down to the top nbest predictions\n", " nbest_preds = best_predictions(prelim_preds, nbest, tokenizer)\n", "\n", " # compute the probability of each prediction - nice but not necessary\n", " probabilities = prediction_probabilities(nbest_preds)\n", " \n", " # compute score difference\n", " score_difference = compute_score_difference(nbest_preds)\n", "\n", " # if score difference > threshold, return the null answer\n", " if score_difference > null_threshold:\n", " return \"\", probabilities[-1]\n", " else:\n", " return nbest_preds[0].text, probabilities[0]" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "colab": {}, "colab_type": "code", "id": "CSVOOr3l7R2k" }, "outputs": [], "source": [ "# collapse-hide\n", "\n", "# ----------------- Helper functions for get_robust_prediction ----------------- #\n", "def get_qa_inputs(example, tokenizer):\n", " # load the example, convert to inputs, get model outputs\n", " question = example.question_text\n", " context = example.context_text\n", " return tokenizer.encode_plus(question, context, return_tensors='pt')\n", "\n", "def get_clean_text(tokens, tokenizer):\n", " text = tokenizer.convert_tokens_to_string(\n", " tokenizer.convert_ids_to_tokens(tokens)\n", " )\n", " # Clean whitespace\n", " text = text.strip()\n", " text = \" \".join(text.split())\n", " return text\n", "\n", "def prediction_probabilities(predictions):\n", "\n", " def softmax(x):\n", " \"\"\"Compute softmax values for each sets of scores in x.\"\"\"\n", " e_x = np.exp(x - np.max(x))\n", " return e_x / e_x.sum()\n", "\n", " all_scores = [pred.start_logit+pred.end_logit for pred in predictions] \n", " return softmax(np.array(all_scores))\n", "\n", "def preliminary_predictions(start_logits, end_logits, input_ids, nbest):\n", " # convert tensors to lists\n", " start_logits = to_list(start_logits)[0]\n", " end_logits = to_list(end_logits)[0]\n", " tokens = to_list(input_ids)[0]\n", "\n", " # sort our start and end logits from largest to smallest, keeping track of the index\n", " start_idx_and_logit = sorted(enumerate(start_logits), key=lambda x: x[1], reverse=True)\n", " end_idx_and_logit = sorted(enumerate(end_logits), key=lambda x: x[1], reverse=True)\n", " \n", " start_indexes = [idx for idx, logit in start_idx_and_logit[:nbest]]\n", " end_indexes = [idx for idx, logit in end_idx_and_logit[:nbest]]\n", "\n", " # question tokens are between the CLS token (101, at position 0) and first SEP (102) token \n", " question_indexes = [i+1 for i, token in enumerate(tokens[1:tokens.index(102)])]\n", "\n", " # keep track of all preliminary predictions\n", " PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name\n", " \"PrelimPrediction\", [\"start_index\", \"end_index\", \"start_logit\", \"end_logit\"]\n", " )\n", " prelim_preds = []\n", " for start_index in start_indexes:\n", " for end_index in end_indexes:\n", " # throw out invalid predictions\n", " if start_index in question_indexes:\n", " continue\n", " if end_index in question_indexes:\n", " continue\n", " if end_index < start_index:\n", " continue\n", " prelim_preds.append(\n", " PrelimPrediction(\n", " start_index = start_index,\n", " end_index = end_index,\n", " start_logit = start_logits[start_index],\n", " end_logit = end_logits[end_index]\n", " )\n", " )\n", " # sort prelim_preds in descending score order\n", " prelim_preds = sorted(prelim_preds, key=lambda x: (x.start_logit + x.end_logit), reverse=True)\n", " return prelim_preds\n", "\n", "def best_predictions(prelim_preds, nbest, tokenizer):\n", " # keep track of all best predictions\n", "\n", " # This will be the pool from which answer probabilities are computed \n", " BestPrediction = collections.namedtuple(\n", " \"BestPrediction\", [\"text\", \"start_logit\", \"end_logit\"]\n", " )\n", " nbest_predictions = []\n", " seen_predictions = []\n", " for pred in prelim_preds:\n", " if len(nbest_predictions) >= nbest: \n", " break\n", " if pred.start_index > 0: # non-null answers have start_index > 0\n", "\n", " toks = tokens[pred.start_index : pred.end_index+1]\n", " text = get_clean_text(toks, tokenizer)\n", "\n", " # if this text has been seen already - skip it\n", " if text in seen_predictions:\n", " continue\n", "\n", " # flag text as being seen\n", " seen_predictions.append(text) \n", "\n", " # add this text to a pruned list of the top nbest predictions\n", " nbest_predictions.append(\n", " BestPrediction(\n", " text=text, \n", " start_logit=pred.start_logit,\n", " end_logit=pred.end_logit\n", " )\n", " )\n", " \n", " # Add the null prediction\n", " nbest_predictions.append(\n", " BestPrediction(\n", " text=\"\", \n", " start_logit=start_logits[0], \n", " end_logit=end_logits[0]\n", " )\n", " )\n", " return nbest_predictions\n", "\n", "def compute_score_difference(predictions):\n", " \"\"\" Assumes that the null answer is always the last prediction \"\"\"\n", " score_null = predictions[-1].start_logit + predictions[-1].end_logit\n", " score_non_null = predictions[0].start_logit + predictions[0].end_logit\n", " return score_null - score_non_null" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "KzdgHx_nBQ81" }, "source": [ "Will we now get the right answer (an empty string) for that tricky no-answer example we were working with?" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "lp5ZFiYvBQ81", "outputId": "09ba5c31-28f9-4e20-dc0e-1e95954234c1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "What happened 3.7-2 billion years ago?\n" ] }, { "data": { "text/plain": [ "('', 0.34412444013709165)" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(example.question_text)\n", "get_robust_prediction(example, tokenizer, nbest=10, null_threshold=best_f1_thresh)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "EMVkwrW7BQ83" }, "source": [ "Woohoo!! We got the right answer this time!! \n", "\n", "Even if we didn't have the best threshold in place, our additional checks still allow us to output more sensible looking answers, rejecting predictions that include the question tokens." ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "SA2HTci7BQ83", "outputId": "bb4a0374-7397-43ce-d51d-6fd7c8e55046" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "What happened 3.7-2 billion years ago?\n" ] }, { "data": { "text/plain": [ "('free oxygen began to outgas from the oceans', 0.42410620054269993)" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(example.question_text)\n", "get_robust_prediction(example, tokenizer, nbest=10, null_threshold=1.0)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "GXQbjkQbBQ85" }, "source": [ "And if it hadn't been a trick question, this would be the correct answer! (Seems like distilBERT could use some improvement in number understanding.) \n", "\n", "# Final Thoughts \n", "Using a robust prediction method like the above will do more than allow a model to perform better on a curated dev set, though this is an important first step. It will also provide the model with a slightly better ability to refrain from answering questions that simply don't have an answer in the associated passage. This is a crucial feature for QA models, because it's not enough to get an answer if that answer doesn't make sense. We want our models to tell us something useful -- and sometimes that means telling us nothing at all. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [ "In0YUapuBQ8A", "pKVJVU5hBQ8U", "t4s4Bx3LBQ8p" ], "name": "2020-06-09-Evaluating_BERT_on_SQuAD.ipynb", "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 1 }