{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "9OrmVqQMBWOb" }, "source": [ "# Multilingual Question Answering w/ Transformers\n", "\n", "[Link to the lab](https://colab.research.google.com/drive/10b26Jxho7EsWevWnItaumPoiSMKAXEXa?usp=sharing)\n", "\n", "This lab will focus on how to train and evaluate a model for multilingual question answering using the HuggingFace transformers library \n", "\n", "For this lab, we will use the multilingual [XLM RoBERTa model](https://huggingface.co/xlm-roberta-base). \n", "\n", "The task is extractive question answering. In this, the data consists of a question, and answer, and the span in the context which contains the correct answer. To model this, we will train our model to simply predict the start and end tokens of the answer.\n", "\n", "![](https://miro.medium.com/max/680/1*gwu3JjZ3hM08dIUziSJ3yg.png)\n", "\n", "Much of the code for this lab is cribbed from [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/question_answering.ipynb#scrollTo=brBgQe9uAM3F)\n", "\n", "\n", "\n", "**It must be noted** that the raw output of the model we are going to train is not just 2 single numbers like the misleading diagram above, but rather 2 numbers for each token in the input. That is, we end up with a distribution of logits for a `start of the answer` and `end of the answer` tokens akin to the following:\n", "\n", "\n", "![logits_qa_example](https://user-images.githubusercontent.com/8036160/195606786-019c88d1-5e06-4434-b2f7-621815362f58.png)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-CdMFPLVV-Q0" }, "outputs": [], "source": [ "!pip install update transformers\n", "!pip install datasets" ] }, { "cell_type": "markdown", "metadata": { "id": "ccDnA6JJCOBB" }, "source": [ "The usual housekeeping to ensure reproducible results" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "Nm4J_0FOWE9g" }, "outputs": [], "source": [ "from datasets import load_dataset\n", "from datasets import load_metric\n", "from torch.utils.data import Dataset, DataLoader\n", "from transformers import AutoTokenizer\n", "from transformers import AutoModelForQuestionAnswering\n", "from transformers import AutoConfig\n", "from functools import partial\n", "import torch\n", "import random\n", "import numpy as np\n", "from tqdm import tqdm\n", "from transformers import AdamW\n", "from transformers import get_linear_schedule_with_warmup\n", "from torch.optim.lr_scheduler import LambdaLR\n", "from torch import nn\n", "from collections import defaultdict, OrderedDict\n", "MODEL_NAME = 'xlm-roberta-base'\n", "#MODEL_NAME = 'bert-base-uncased'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "56EivecrpXmv" }, "outputs": [], "source": [ "def enforce_reproducibility(seed=42):\n", " # Sets seed manually for both CPU and CUDA\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", " # For atomic operations there is currently \n", " # no simple way to enforce determinism, as\n", " # the order of parallel operations is not known.\n", " # CUDNN\n", " torch.backends.cudnn.deterministic = True\n", " torch.backends.cudnn.benchmark = False\n", " # System based\n", " random.seed(seed)\n", " np.random.seed(seed)\n", "\n", "device = torch.device(\"cpu\")\n", "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "\n", "enforce_reproducibility()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "n7_yg1MicosB" }, "outputs": [], "source": [ "\"\"\" Official evaluation script for v1.1 of the SQuAD dataset. \"\"\"\n", "from __future__ import print_function\n", "from collections import Counter\n", "import string\n", "import re\n", "import argparse\n", "import json\n", "import sys\n", "\n", "\n", "def normalize_answer(s):\n", " \"\"\"Lower text and remove punctuation, articles and extra whitespace.\"\"\"\n", " def remove_articles(text):\n", " return re.sub(r'\\b(a|an|the)\\b', ' ', text)\n", "\n", " def white_space_fix(text):\n", " return ' '.join(text.split())\n", "\n", " def remove_punc(text):\n", " exclude = set(string.punctuation)\n", " return ''.join(ch for ch in text if ch not in exclude)\n", "\n", " def lower(text):\n", " return text.lower()\n", "\n", " return white_space_fix(remove_articles(remove_punc(lower(s))))\n", "\n", "\n", "def f1_score(prediction, ground_truth):\n", " prediction_tokens = normalize_answer(prediction).split()\n", " ground_truth_tokens = normalize_answer(ground_truth).split()\n", " common = Counter(prediction_tokens) & Counter(ground_truth_tokens)\n", " num_same = sum(common.values())\n", " if num_same == 0:\n", " return 0\n", " precision = 1.0 * num_same / len(prediction_tokens)\n", " recall = 1.0 * num_same / len(ground_truth_tokens)\n", " f1 = (2 * precision * recall) / (precision + recall)\n", " return f1\n", "\n", "\n", "def exact_match_score(prediction, ground_truth):\n", " return (normalize_answer(prediction) == normalize_answer(ground_truth))\n", "\n", "\n", "def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):\n", " scores_for_ground_truths = []\n", " for ground_truth in ground_truths:\n", " score = metric_fn(prediction, ground_truth)\n", " scores_for_ground_truths.append(score)\n", " return max(scores_for_ground_truths)\n", "\n", "\n", "def evaluate_squad(dataset, predictions):\n", " f1 = exact_match = total = 0\n", " for article in dataset:\n", " for paragraph in article['paragraphs']:\n", " for qa in paragraph['qas']:\n", " total += 1\n", " if qa['id'] not in predictions:\n", " message = 'Unanswered question ' + qa['id'] + \\\n", " ' will receive score 0.'\n", " print(message, file=sys.stderr)\n", " continue\n", " ground_truths = list(map(lambda x: x['text'], qa['answers']))\n", " prediction = predictions[qa['id']]\n", " exact_match += metric_max_over_ground_truths(\n", " exact_match_score, prediction, ground_truths)\n", " f1 += metric_max_over_ground_truths(\n", " f1_score, prediction, ground_truths)\n", "\n", " exact_match = 100.0 * exact_match / total\n", " f1 = 100.0 * f1 / total\n", "\n", " return {'exact_match': exact_match, 'f1': f1}\n", "\n", "def compute_squad(predictions, references):\n", " pred_dict = {prediction[\"id\"]: prediction[\"prediction_text\"] for prediction in predictions}\n", " dataset = [\n", " {\n", " \"paragraphs\": [\n", " {\n", " \"qas\": [\n", " {\n", " \"answers\": [{\"text\": answer_text} for answer_text in ref[\"answers\"][\"text\"]],\n", " \"id\": ref[\"id\"],\n", " }\n", " for ref in references\n", " ]\n", " }\n", " ]\n", " }\n", " ]\n", " score = evaluate_squad(dataset=dataset, predictions=pred_dict)\n", " return score" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# this is also equivalent to those 2 lines. I recommend going with that, unless you want more control over your code\n", "from datasets import load_metric\n", "compute_squad = load_metric(\"squad\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**For your project, use load_metric(\"squad_v2\")**. SQuAD v2, like the TyDiQA dataset, contains unanswerable questions, and it's evaluation script supports that" ] }, { "cell_type": "markdown", "metadata": { "id": "RRJ2QQN7CSpl" }, "source": [ "Here we are using the huggingface datasets library to load the [MLQA dataset](https://github.com/facebookresearch/MLQA). MLQA contains QA data in SQuAD format for 7 different languages. To start, we will load the English only data to train and test our model." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 199, "referenced_widgets": [ "c589d765fbf64405bce8c3f11122f099", "0912e7bcf15a44e08cf1dc9c2480acfe", "4724eb53d6854f9a86655e981e43fc37", "a56b7c73173d4b25a2e3ae5454654a23", "cf4d87117ad340b8bd04e3b0b8f4a711", "ce393212e5714924861f1ab6f66cf1b6", "e180093f83624b5f8209443ad6c0e07f", "6a3dfe16074b4992aa3addd6e8c7ad1e", "66d6fe75140548649311fe1c884bd906", "8d35bc695602407c831164ea35efe8d4", "386a8b0447874004b163e991b8187af5", "102a6644fea04e9cbf1852bfb4deb84a", "023384a3a50845c99df05548a288d909", "5a6e29e6e6494ae391f1dbdcf003be12", "d01856d2c7ad4723a6d3479bae5f2fd4", "daecb5245f024ae49c662e788450990e", "3c026d5e261c4987a6c726a70457b293", "88ee479777994309a7dafdc26bbd72ff", "15844b50756a462fae645f5abff9cf96", "a3e6966515034c39b6a7850ac05935bf", "9acc1fc2a44a4c7c872fcb18cd2566a7", "42d134d6b42b40bc9e709497c61d4d87", "7890a520709f4d80a55e382a7d12164d", "4a4feb78772d4ffca5c8fb3962567eae", "321335b60e0140efbd24f1072c4a08ab", "f0a30ec6b3ae481f8a862f3a6009cf16", "7ca377156fcc4dee9ed574a7995dfb47", "f2e1d3c79cec49488c990be277b06b8f", "54d3b111229e4aa393ee5f6f8b12809c", "7d989cf6bcfb41aebb4079e19bcf1305", "7e6b6057dd4749e4aea1fabc425cb0bb", "29849752a1444eeaaa399b2d1a0ca5e1", "2c5abb1711c64be09bfabbea81fac79c", "e567ba0ab547401e9b539b26bf6bd88f", "db852905bad7488ab4641c8572bb3e4c", "54ce53df5de34860b9791ff5706da33e", "0a4f0a55ec9e4af580d2fc393d8f118f", "62f5b79d825c4721b99cc30790431a68", "90000e6eb5e14927aa706b8e8ad353c3", "66b1bc80cfae4c1595488cf47ccb681e", "7f7e1bec66ab4dbd9b39d313c39e8230", "53888dda205f4e13a27b7d8ef55423a1", "eb71bee9ceb343aba3551764339c396d", "f19db5e375fc423d8245c3ac8f8fa357", "3b740d84200e4b44ae6127052dd0ffa1", "96e79fa4c0ef46488cce820781c688ca", "6289a291db904c71872d68b8f71a0f67", "26ff027b4fd94185814cf456a8e944c9", "aa7c156597aa4d06a0b43728f010987d", "7c33ee1129fb4039ac247df1a843b337", "4ade7ac1af2c4a1ba66c07af59ab58a5", "f337b6e31e194fe08c8e29b9160b7bd2", "46630b3f4ef84ccfb43db675719306b8", "f89d1babaedf4292a9ceefd6a92ac157", "f374a849198045f48b8d3da160f3275f", "9f69fd43eedd49daaff5cd162df4af31", "8dcd24b557814e349bf37eda8e909078", "f454af7c5b6e4832868ae3ef92ec57c6", "2b4baf6ddfe74a258295820800326a7d", "6aa2ba383ecc4671beae266a4880657c", "a8108c710f524246aa5e1157a1d083b8", "c37fb4d733b44cd384bfc1c64584bbef", "2adf527e8f55478f9d3813439e817af1", "6d4bf5462a8246b69a5dcbb052f0912c", "0017c79766274cd5b2db2d3b60cbc5f7", "cd9bafc620ad465eaa54f42c137be663" ] }, "id": "IXT1MELFXSov", "outputId": "e950d0f4-46d2-4ce2-fec4-6a5f87890672" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c589d765fbf64405bce8c3f11122f099", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/2.29k [00:00 0 and (end_char >= len(samples['context'][sample_idx]) or samples['context'][sample_idx][end_char] == ' '):\n", " # end_char -= 1\n", "\n", " # Start from the first token in the context, which can be found by going to the \n", " # first token where sequence_ids is 1\n", " start_token = 0\n", " while sequence_ids[start_token] != 1:\n", " start_token += 1\n", "\n", " end_token = len(offsets) - 1\n", " while sequence_ids[end_token] != 1:\n", " end_token -= 1\n", "\n", " # By default set it to the CLS token if the answer isn't in this input\n", " if start_char < offsets[start_token][0] or end_char > offsets[end_token][1]:\n", " start_token = 0\n", " end_token = 0\n", " # Otherwise find the correct token indices\n", " else:\n", " # Advance the start token index until we have passed the start character index \n", " while start_token < len(offsets) and offsets[start_token][0] <= start_char:\n", " start_token += 1\n", " start_token -= 1\n", " \n", " # Decrease the end token index until we have passed the end character index\n", " while end_token >= 0 and offsets[end_token][1] >= end_char:\n", " end_token -= 1\n", " end_token += 1\n", "\n", " batch['start_tokens'].append(start_token)\n", " batch['end_tokens'].append(end_token)\n", "\n", " #batch['start_tokens'] = np.array(batch['start_tokens'])\n", " #batch['end_tokens'] = np.array(batch['end_tokens'])\n", "\n", " return batch\n", "\n", "def collate_fn(inputs):\n", " '''\n", " Defines how to combine different samples in a batch\n", " '''\n", " input_ids = torch.tensor([i['input_ids'] for i in inputs])\n", " attention_mask = torch.tensor([i['attention_mask'] for i in inputs])\n", " start_tokens = torch.tensor([i['start_tokens'] for i in inputs])\n", " end_tokens = torch.tensor([i['end_tokens'] for i in inputs])\n", "\n", " # Truncate to max length\n", " max_len = max(attention_mask.sum(-1))\n", " input_ids = input_ids[:,:max_len]\n", " attention_mask = attention_mask[:,:max_len]\n", " \n", " return {'input_ids': input_ids, 'attention_mask': attention_mask, 'start_tokens': start_tokens, 'end_tokens': end_tokens}" ] }, { "cell_type": "markdown", "metadata": { "id": "Tx2KcgYuOWjM" }, "source": [ "We can easily tokenize the whole dataset by calling the \"map\" function on the dataset." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 49, "referenced_widgets": [ "91e962c8df9c43fca98ded3c08ddeea1", "80180151c6c44e98adde5ebab977d682", "ba1d04afa1b044d3b9c0779e6672972e", "2c3801f345ba45488c4bb18e12009b79", "730afce7ca22420685dbce7632150346", "659aa2b1dc90412ab1506d8f25edce52", "8170c51a861547ce863786b210504402", "b5667c0d697544d0ae353752817e0b35", "8cacc1c885824cafb48f249ce19fe9c1", "407a069c40e343e3a65bd2e7b0c73944", "1df0135be0d84c78a85d428403b977da" ] }, "id": "TBWSklJXf_q-", "outputId": "c3455ed8-c7ed-4d21-f114-6df3ff8115a3" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "91e962c8df9c43fca98ded3c08ddeea1", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/12 [00:00= len(offset_mapping) or end_index >= len(offset_mapping) or offset_mapping[start_index] is None or offset_mapping[end_index] is None:\n", " continue\n", "\n", " # Also ignore if the start index is greater than the end index of the number of tokens\n", " # is greater than some specified threshold\n", " if start_index > end_index or end_index - start_index + 1 > max_answer_length:\n", " continue\n", "\n", " ans_text = context[offset_mapping[start_index][0]:offset_mapping[end_index][1]]\n", " preds.append({\n", " 'score': start_logits[start_index] + end_logits[end_index],\n", " 'text': ans_text\n", " })\n", "\n", " if len(preds) > 0:\n", " # Sort by score to get the top answer\n", " answer = sorted(preds, key=lambda x: x['score'], reverse=True)[0]\n", " else:\n", " answer = {'score': 0.0, 'text': \"\"}\n", " \n", " predictions[sample['id']] = answer['text']\n", " return predictions" ] }, { "cell_type": "markdown", "metadata": { "id": "0upXUFTGkaxb" }, "source": [ "Create the DataLoader and run prediction!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qsGK-W8hz2Yt" }, "outputs": [], "source": [ "val_dl = DataLoader(validation_dataset, collate_fn=val_collate_fn, batch_size=32)\n", "logits = predict(model, val_dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hY_Rf0BAz479" }, "outputs": [], "source": [ "predictions = post_process_predictions(mlqa['validation'], validation_dataset, logits)\n", "formatted_predictions = [{'id': k, 'prediction_text': v} for k,v in predictions.items()]\n", "gold = [{'id': example['id'], 'answers': example['answers']} for example in mlqa['validation']]" ] }, { "cell_type": "markdown", "metadata": { "id": "y9sZFzHrkek-" }, "source": [ "We're using the official SQuAD evaluation metric which measure exact span match as well as token-level F1 score" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Fa7xO83d9Du1", "outputId": "aac00fca-9f8b-4a61-84df-8990906a847e" }, "outputs": [ { "data": { "text/plain": [ "{'exact_match': 57.055749128919864, 'f1': 72.38330713792288}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" }, { "data": { "text/plain": [ "{'exact_match': 57.055749128919864, 'f1': 72.38330713792288}" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "compute_squad(references=gold, predictions=formatted_predictions)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multilingual transformers\n", " \n", "While usually Transformer models are trained on a dataset made of a single language (e.g. the now classic BERT model), it is just as simple and easy to train them on a dataset that contains texts in more than one. For example, XLM-RoBERTa was trained on texts in more than 100 languages!\n", "\n", "Are those models any good?\n", "\n", "**Question:** What are the possible advantages of using a multilimgual model?\n", "\n", "(generalization, shared representation, cross-lingual training and inference, stronger performance on low resource languages.)\n", ".\n", "\n", "\n", "Are those models any good, though?\n", "Let's test one of those models - XLM-RoBERTa that was later fine-tuned on an **English only** QA dataset.\n", "\n", "It can correctly answer questions in English, nothing remarkable here.\n", "\n", "![example1](../../img/eng-eng.jpeg)\n", "\n", "\n", "It can also answer questions in Danish! Not bad.\n", "\n", "![example2](../../img/eng-eng.jpeg)\n", "\n", "What about a question asked in English with a Danish context?\n", "\n", "![example3](../../img/en-dk.jpeg)\n", "\n", "Easy peasy. Let's take it to the extreme, what about this chimeric monstrosity?\n", "\n", "![example4](../../img/chimera.jpeg)\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "Quite remarkable!\n", "\n", "![example4](../../img/chimera-ans.jpeg)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "yvLqvMH5koyw" }, "source": [ "## Training on other languages\n", "\n", "Lets test those capabilities ourselves using the MLQA datast. Let's see how a model trained on German performs on English" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 135, "referenced_widgets": [ "592da8b21ea14f879a6695dfbed4efcb", "d1db65cfc4d14d91bbc70c591c3364c5", "f06101de25d147f9b90e1fae82394dfa", "978502ae027d4117ab881fc2021e5932", "73d4779299c74c9099d846484b7c3092", "331a530f12fb460584d6d6c306477746", "52ce261d398942dab498e5104dba17f2", "a61696199916409ba88193563f8d147c", "0dd43edbeb2e490aaebfef60ea47bb33", "6a6dcd35c91542b4b648b3825c3e6c13", "1cf7be208e5b4cb39f3de644e93d2689", "8053fd98666d4a7780e5235db42a744b", "869b5e5be56d4e1f911015b4c3b5af2b", "22d113293a7d424f8ab6a38574133255", "8a3b38282b48478f8ef966e02beab067", "bda87e59e6534da8993ee9029855371f", "8c53b82ba87e468982b37941732335b4", "d7a04960453f4d3782d6b3bdfd92b60d", "c994f7a93be449d0bc15c18c2310351e", "8e1e96e6c0464ee1952cd6c60a338370", "89bfca467956458fb064dad32a76faeb", "8105b5f8289e4ebd9b4a3d5047031e9a", "7ab0dbb45a574dfa81141ea3e537ab44", "5df5e97993db46b4a95e05009a58053b", "4a35b5616425419b99e66c33728adfab", "7b8c90f834914aff8a44cdfc5b86ef35", "38fafa615e454f149bfc72be06a2eef5", "6f04e6faf52d49c38e734a504daa25e7", "e4ff308c3c5f41a4b8d04f8f80e05116", "5261413965cf4fe4bfff55abbfb3a3b7", "6a052707017a4f6994451d7452317f9c", "15e01816f48f4ef995b5051a4827b91b", "f325925ae5b84addbe8c86010f495fde", "68d09241839149d2ae4da1a32ae412b5", "59f3d0c35ad34a6095191a4cd7d15f8c", "158d4955bef6431d87bfab22532a33fe", "8bc0f298d3ed4acfa96d67bf15fb28c9", "1cefa88067f14428af4c860560da2f39", "86add730f9764b0e818917c3dade9623", "d7326946c12a4ea483e751838dd4d76d", "678c11afba3c4cd680a07a7f32860a5e", "d300fce89b874bafb055964f7e43fc4d", "3690575d8b244e668e79b4bb10200ee4", "e3753fb2f5a64df3b323a80815bf2b29" ] }, "id": "7RP9ZOK5qnSQ", "outputId": "1d5ea027-d19c-4572-c8ff-51bb2a6f551a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Downloading and preparing dataset mlqa/mlqa-translate-train.de (download: 60.43 MiB, generated: 84.23 MiB, post-processed: Unknown size, total: 144.66 MiB) to /root/.cache/huggingface/datasets/mlqa/mlqa-translate-train.de/1.0.0/1a1ae267d8d9e8e9ff25bd8811a27c5f8752ee58c5d75cf6c6451cbaba777c87...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "592da8b21ea14f879a6695dfbed4efcb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/63.4M [00:00