{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Réponse aux questions (PyTorch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Install the Transformers, Datasets, and Evaluate libraries to run this notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install datasets evaluate transformers[sentencepiece]\n", "!pip install accelerate\n", "# To run the training on TPU, you will need to uncomment the followin line:\n", "# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n", "!apt install git-lfs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You will need to setup git, adapt your email and name in the following cell." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!git config --global user.email \"you@example.com\"\n", "!git config --global user.name \"Your Name\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "raw_datasets = load_dataset(\"squad\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['id', 'title', 'context', 'question', 'answers'],\n", " num_rows: 87599\n", " })\n", " validation: Dataset({\n", " features: ['id', 'title', 'context', 'question', 'answers'],\n", " num_rows: 10570\n", " })\n", "})" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Context: 'Architecturally, the school has a Catholic character. Atop the Main Building\\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.'\n", "# Sur le plan architectural, l'école a un caractère catholique. Au sommet du dôme doré du bâtiment principal se trouve une statue dorée de la Vierge Marie. Immédiatement devant le bâtiment principal et face à lui, se trouve une statue en cuivre du Christ, les bras levés, avec la légende \"Venite Ad Me Omnes\". À côté du bâtiment principal se trouve la basilique du Sacré-Cœur. Immédiatement derrière la basilique se trouve la Grotte, un lieu marial de prière et de réflexion. Il s'agit d'une réplique de la grotte de Lourdes, en France, où la Vierge Marie serait apparue à Sainte Bernadette Soubirous en 1858. Au bout de l'allée principale (et dans une ligne directe qui passe par 3 statues et le Dôme d'or), se trouve une statue de pierre simple et moderne de Marie'.\n", "Question: 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?' \n", "# A qui la Vierge Marie serait-elle apparue en 1858 à Lourdes, en France ?\n", "Answer: {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(\"Context: \", raw_datasets[\"train\"][0][\"context\"])\n", "print(\"Question: \", raw_datasets[\"train\"][0][\"question\"])\n", "print(\"Answer: \", raw_datasets[\"train\"][0][\"answers\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Dataset({\n", " features: ['id', 'title', 'context', 'question', 'answers'],\n", " num_rows: 0\n", "})" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_datasets[\"train\"].filter(lambda x: len(x[\"answers\"][\"text\"]) != 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answer_start': [177, 177, 177]}\n", "{'text': ['Santa Clara, California', \"Levi's Stadium\", \"Levi's Stadium in the San Francisco Bay Area at Santa Clara, California.\"], 'answer_start': [403, 355, 355]}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(raw_datasets[\"validation\"][0][\"answers\"])\n", "print(raw_datasets[\"validation\"][2][\"answers\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50.'\n", "# Le Super Bowl 50 était un match de football américain visant à déterminer le champion de la National Football League (NFL) pour la saison 2015. Les Denver Broncos, champions de la Conférence de football américain (AFC), ont battu les Carolina Panthers, champions de la Conférence nationale de football (NFC), 24 à 10, pour remporter leur troisième titre de Super Bowl. Le match s'est déroulé le 7 février 2016 au Levi\\'s Stadium, dans la baie de San Francisco, à Santa Clara, en Californie. Comme il s'agissait du 50e Super Bowl, la ligue a mis l'accent sur l'\" anniversaire doré \" avec diverses initiatives sur le thème de l'or, ainsi qu'en suspendant temporairement la tradition de nommer chaque match du Super Bowl avec des chiffres romains (en vertu de laquelle le match aurait été appelé \" Super Bowl L \"), afin que le logo puisse mettre en évidence les chiffres arabes 50.''\n", "'Where did Super Bowl 50 take place?' \n", "# Où a eu lieu le Super Bowl 50 ?" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(raw_datasets[\"validation\"][2][\"context\"])\n", "print(raw_datasets[\"validation\"][2][\"question\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "model_checkpoint = \"bert-base-cased\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.is_fast" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, '\n", "'the school has a Catholic character. Atop the Main Building\\'s gold dome is a golden statue of the Virgin '\n", "'Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms '\n", "'upraised with the legend \" Venite Ad Me Omnes \". Next to the Main Building is the Basilica of the Sacred '\n", "'Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a '\n", "'replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette '\n", "'Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues '\n", "'and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'\n", "\n", "'[CLS] A qui la Vierge Marie serait-elle apparue en 1858 à Lourdes en France ? [SEP] Architecturalement, '\n", "'l école a un caractère catholique. Au sommet du dôme doré du bâtiment principal se trouve une statue dorée de la Vierge '\n", "'Marie. Immédiatement devant le bâtiment principal et face à lui, se trouve une statue en cuivre du Christ, les bras '\n", "'levés avec la légende \" Venite Ad Me Omnes \". A côté du bâtiment principal se trouve la basilique du Sacré '\n", "'Cœur. Immédiatement derrière la basilique se trouve la Grotte, un lieu marial de prière et de réflexion. Il s'agit d'une '\n", "'réplique de la grotte de Lourdes, en France, où la Vierge Marie serait apparue à Sainte Bernadette '\n", "'Soubirous en 1858. Au bout de l'allée principale ( et en ligne directe qui passe par 3 statues '\n", "'et le Dôme d'or), se trouve une statue de Marie en pierre, simple et moderne. [SEP]'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "context = raw_datasets[\"train\"][0][\"context\"]\n", "question = raw_datasets[\"train\"][0][\"question\"]\n", "\n", "inputs = tokenizer(question, context)\n", "tokenizer.decode(inputs[\"input_ids\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \" Venite Ad Me Omnes \". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]'\n", "'[CLS] A qui la Vierge Marie serait-elle apparue en 1858 à Lourdes en France ? [SEP] Sur le plan architectural, l école a un caractère catholique. Au sommet du dôme doré du bâtiment principal se trouve une statue dorée de la Vierge Marie. Immédiatement devant le bâtiment principal et face à lui, se trouve une statue en cuivre du Christ, les bras levés, avec la légende \" Venite Ad Me Omnes \". À côté du bâtiment principal se trouve la basilique du Sacré-Cœur. Immédiatement derrière la basi [SEP]'\n", "\n", "'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \" Venite Ad Me Omnes \". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]'\n", "'[CLS] A qui la Vierge Marie serait-elle apparue en 1858 à Lourdes en France ? [SEP] le bâtiment principal et face à lui, une statue en cuivre du Christ aux bras levés avec la légende \" Venite Ad Me Omnes \". À côté du bâtiment principal se trouve la basilique du Sacré-Cœur. Immédiatement derrière la basilique se trouve la Grotte, un lieu marial de prière et de réflexion. Il s agit d'une réplique de la grotte de Lourdes, en France, où la Vierge [SEP]'\n", "\n", "'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 [SEP]'\n", "'[CLS] A qui la Vierge Marie serait-elle apparue en 1858 à Lourdes en France ? [SEP] A côté du bâtiment principal se trouve la basilique du Sacré-Cœur. Immédiatement derrière la basilique se trouve la Grotte, un lieu marial de prière et de réflexion. Il s agit d une réplique de la grotte de Lourdes, en France, où la Vierge Marie serait apparue à Sainte Bernadette Soubirous en 1858. Au bout de l allée principale ( et dans une ligne directe qui relie par 3 [SEP]'\n", "\n", "'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP]. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'\n", "'[CLS] A qui la Vierge Marie est-elle prétendument apparue en 1858 à Lourdes France ? [SEP]. Il s agit d une réplique de la grotte de Lourdes, en France, où la Vierge Marie serait apparue à Sainte Bernadette Soubirous en 1858. Au bout de l allée principale (et dans une ligne directe qui passe par 3 statues et le Dôme d or), se trouve une simple statue de pierre moderne de Marie. [SEP]'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs = tokenizer(\n", " question,\n", " context,\n", " max_length=100,\n", " truncation=\"only_second\",\n", " stride=50,\n", " return_overflowing_tokens=True,\n", ")\n", "\n", "for ids in inputs[\"input_ids\"]:\n", " print(tokenizer.decode(ids))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs = tokenizer(\n", " question,\n", " context,\n", " max_length=100,\n", " truncation=\"only_second\",\n", " stride=50,\n", " return_overflowing_tokens=True,\n", " return_offsets_mapping=True,\n", ")\n", "inputs.keys()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0, 0, 0, 0]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs[\"overflow_to_sample_mapping\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'The 4 examples gave 19 features.'\n", "'Here is where each comes from: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3].'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs = tokenizer(\n", " raw_datasets[\"train\"][2:6][\"question\"],\n", " raw_datasets[\"train\"][2:6][\"context\"],\n", " max_length=100,\n", " truncation=\"only_second\",\n", " stride=50,\n", " return_overflowing_tokens=True,\n", " return_offsets_mapping=True,\n", ")\n", "\n", "print(f\"The 4 examples gave {len(inputs['input_ids'])} features.\")\n", "print(f\"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "([83, 51, 19, 0, 0, 64, 27, 0, 34, 0, 0, 0, 67, 34, 0, 0, 0, 0, 0],\n", " [85, 53, 21, 0, 0, 70, 33, 0, 40, 0, 0, 0, 68, 35, 0, 0, 0, 0, 0])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "answers = raw_datasets[\"train\"][2:6][\"answers\"]\n", "start_positions = []\n", "end_positions = []\n", "\n", "for i, offset in enumerate(inputs[\"offset_mapping\"]):\n", " sample_idx = inputs[\"overflow_to_sample_mapping\"][i]\n", " answer = answers[sample_idx]\n", " start_char = answer[\"answer_start\"][0]\n", " end_char = answer[\"answer_start\"][0] + len(answer[\"text\"][0])\n", " sequence_ids = inputs.sequence_ids(i)\n", "\n", " # Trouver le début et la fin du contexte\n", " idx = 0\n", " while sequence_ids[idx] != 1:\n", " idx += 1\n", " context_start = idx\n", " while sequence_ids[idx] == 1:\n", " idx += 1\n", " context_end = idx - 1\n", "\n", " # Si la réponse n'est pas entièrement dans le contexte, l'étiquette est (0, 0)\n", " if offset[context_start][0] > start_char or offset[context_end][1] < end_char:\n", " start_positions.append(0)\n", " end_positions.append(0)\n", " else:\n", " # Sinon, ce sont les positions de début et de fin du token\n", " idx = context_start\n", " while idx <= context_end and offset[idx][0] <= start_char:\n", " idx += 1\n", " start_positions.append(idx - 1)\n", "\n", " idx = context_end\n", " while idx >= context_start and offset[idx][1] >= end_char:\n", " idx -= 1\n", " end_positions.append(idx + 1)\n", "\n", "start_positions, end_positions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Theoretical answer: the Main Building, labels give: the Main Building'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idx = 0\n", "sample_idx = inputs[\"overflow_to_sample_mapping\"][idx]\n", "answer = answers[sample_idx][\"text\"][0]\n", "\n", "start = start_positions[idx]\n", "end = end_positions[idx]\n", "labeled_answer = tokenizer.decode(inputs[\"input_ids\"][idx][start : end + 1])\n", "\n", "print(f\"Theoretical answer: {answer}, labels give: {labeled_answer}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Theoretical answer: a Marian place of prayer and reflection, decoded example: [CLS] What is the Grotto at Notre Dame? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \" Venite Ad Me Omnes \". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grot [SEP]'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idx = 4\n", "sample_idx = inputs[\"overflow_to_sample_mapping\"][idx]\n", "answer = answers[sample_idx][\"text\"][0]\n", "\n", "decoded_example = tokenizer.decode(inputs[\"input_ids\"][idx])\n", "print(f\"Theoretical answer: {answer}, decoded example: {decoded_example}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "max_length = 384\n", "stride = 128\n", "\n", "\n", "def preprocess_training_examples(examples):\n", " questions = [q.strip() for q in examples[\"question\"]]\n", " inputs = tokenizer(\n", " questions,\n", " examples[\"context\"],\n", " max_length=max_length,\n", " truncation=\"only_second\",\n", " stride=stride,\n", " return_overflowing_tokens=True,\n", " return_offsets_mapping=True,\n", " padding=\"max_length\",\n", " )\n", "\n", " offset_mapping = inputs.pop(\"offset_mapping\")\n", " sample_map = inputs.pop(\"overflow_to_sample_mapping\")\n", " answers = examples[\"answers\"]\n", " start_positions = []\n", " end_positions = []\n", "\n", " for i, offset in enumerate(offset_mapping):\n", " sample_idx = sample_map[i]\n", " answer = answers[sample_idx]\n", " start_char = answer[\"answer_start\"][0]\n", " end_char = answer[\"answer_start\"][0] + len(answer[\"text\"][0])\n", " sequence_ids = inputs.sequence_ids(i)\n", "\n", " # Trouver le début et la fin du contexte\n", " idx = 0\n", " while sequence_ids[idx] != 1:\n", " idx += 1\n", " context_start = idx\n", " while sequence_ids[idx] == 1:\n", " idx += 1\n", " context_end = idx - 1\n", "\n", " # Si la réponse n'est pas entièrement dans le contexte, l'étiquette est (0, 0)\n", " if offset[context_start][0] > start_char or offset[context_end][1] < end_char:\n", " start_positions.append(0)\n", " end_positions.append(0)\n", " else:\n", " # Sinon, ce sont les positions de début et de fin du token\n", " idx = context_start\n", " while idx <= context_end and offset[idx][0] <= start_char:\n", " idx += 1\n", " start_positions.append(idx - 1)\n", "\n", " idx = context_end\n", " while idx >= context_start and offset[idx][1] >= end_char:\n", " idx -= 1\n", " end_positions.append(idx + 1)\n", "\n", " inputs[\"start_positions\"] = start_positions\n", " inputs[\"end_positions\"] = end_positions\n", " return inputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(87599, 88729)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_dataset = raw_datasets[\"train\"].map(\n", " preprocess_training_examples,\n", " batched=True,\n", " remove_columns=raw_datasets[\"train\"].column_names,\n", ")\n", "len(raw_datasets[\"train\"]), len(train_dataset)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def preprocess_validation_examples(examples):\n", " questions = [q.strip() for q in examples[\"question\"]]\n", " inputs = tokenizer(\n", " questions,\n", " examples[\"context\"],\n", " max_length=max_length,\n", " truncation=\"only_second\",\n", " stride=stride,\n", " return_overflowing_tokens=True,\n", " return_offsets_mapping=True,\n", " padding=\"max_length\",\n", " )\n", "\n", " sample_map = inputs.pop(\"overflow_to_sample_mapping\")\n", " example_ids = []\n", "\n", " for i in range(len(inputs[\"input_ids\"])):\n", " sample_idx = sample_map[i]\n", " example_ids.append(examples[\"id\"][sample_idx])\n", "\n", " sequence_ids = inputs.sequence_ids(i)\n", " offset = inputs[\"offset_mapping\"][i]\n", " inputs[\"offset_mapping\"][i] = [\n", " o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)\n", " ]\n", "\n", " inputs[\"example_id\"] = example_ids\n", " return inputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(10570, 10822)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "validation_dataset = raw_datasets[\"validation\"].map(\n", " preprocess_validation_examples,\n", " batched=True,\n", " remove_columns=raw_datasets[\"validation\"].column_names,\n", ")\n", "len(raw_datasets[\"validation\"]), len(validation_dataset)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "small_eval_set = raw_datasets[\"validation\"].select(range(100))\n", "trained_checkpoint = \"distilbert-base-cased-distilled-squad\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)\n", "eval_set = small_eval_set.map(\n", " preprocess_validation_examples,\n", " batched=True,\n", " remove_columns=raw_datasets[\"validation\"].column_names,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import AutoModelForQuestionAnswering\n", "\n", "eval_set_for_model = eval_set.remove_columns([\"example_id\", \"offset_mapping\"])\n", "eval_set_for_model.set_format(\"torch\")\n", "\n", "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}\n", "trained_model = AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(\n", " device\n", ")\n", "\n", "with torch.no_grad():\n", " outputs = trained_model(**batch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "start_logits = outputs.start_logits.cpu().numpy()\n", "end_logits = outputs.end_logits.cpu().numpy()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import collections\n", "\n", "example_to_features = collections.defaultdict(list)\n", "for idx, feature in enumerate(eval_set):\n", " example_to_features[feature[\"example_id\"]].append(idx)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "n_best = 20\n", "max_answer_length = 30\n", "predicted_answers = []\n", "\n", "for example in small_eval_set:\n", " example_id = example[\"id\"]\n", " context = example[\"context\"]\n", " answers = []\n", "\n", " for feature_index in example_to_features[example_id]:\n", " start_logit = start_logits[feature_index]\n", " end_logit = end_logits[feature_index]\n", " offsets = eval_set[\"offset_mapping\"][feature_index]\n", "\n", " start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()\n", " end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()\n", " for start_index in start_indexes:\n", " for end_index in end_indexes:\n", " # Ignore les réponses qui ne sont pas entièrement dans le contexte\n", " if offsets[start_index] is None or offsets[end_index] is None:\n", " continue\n", " # Ignore les réponses dont la longueur est soit < 0 soit > max_answer_length\n", " if (\n", " end_index < start_index\n", " or end_index - start_index + 1 > max_answer_length\n", " ):\n", " continue\n", "\n", " answers.append(\n", " {\n", " \"text\": context[offsets[start_index][0] : offsets[end_index][1]],\n", " \"logit_score\": start_logit[start_index] + end_logit[end_index],\n", " }\n", " )\n", "\n", " best_answer = max(answers, key=lambda x: x[\"logit_score\"])\n", " predicted_answers.append({\"id\": example_id, \"prediction_text\": best_answer[\"text\"]})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import evaluate\n", "\n", "metric = evaluate.load(\"squad\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "theoretical_answers = [\n", " {\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in small_eval_set\n", "]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'id': '56be4db0acb8001400a502ec', 'prediction_text': 'Denver Broncos'}\n", "{'id': '56be4db0acb8001400a502ec', 'answers': {'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answer_start': [177, 177, 177]}}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "print(predicted_answers[0])\n", "print(theoretical_answers[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'exact_match': 83.0, 'f1': 88.25}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "metric.compute(predictions=predicted_answers, references=theoretical_answers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tqdm.auto import tqdm\n", "\n", "\n", "def compute_metrics(start_logits, end_logits, features, examples):\n", " example_to_features = collections.defaultdict(list)\n", " for idx, feature in enumerate(features):\n", " example_to_features[feature[\"example_id\"]].append(idx)\n", "\n", " predicted_answers = []\n", " for example in tqdm(examples):\n", " example_id = example[\"id\"]\n", " context = example[\"context\"]\n", " answers = []\n", "\n", " # Parcourir en boucle toutes les fonctionnalités associées à cet exemple\n", " for feature_index in example_to_features[example_id]:\n", " start_logit = start_logits[feature_index]\n", " end_logit = end_logits[feature_index]\n", " offsets = features[feature_index][\"offset_mapping\"]\n", "\n", " start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()\n", " end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()\n", " for start_index in start_indexes:\n", " for end_index in end_indexes:\n", " # Ignore les réponses qui ne sont pas entièrement dans le contexte\n", " if offsets[start_index] is None or offsets[end_index] is None:\n", " continue\n", " # Ignore les réponses dont la longueur est soit < 0, soit > max_answer_length\n", " if (\n", " end_index < start_index\n", " or end_index - start_index + 1 > max_answer_length\n", " ):\n", " continue\n", "\n", " answer = {\n", " \"text\": context[offsets[start_index][0] : offsets[end_index][1]],\n", " \"logit_score\": start_logit[start_index] + end_logit[end_index],\n", " }\n", " answers.append(answer)\n", "\n", " # Sélectionne la réponse avec le meilleur score\n", " if len(answers) > 0:\n", " best_answer = max(answers, key=lambda x: x[\"logit_score\"])\n", " predicted_answers.append(\n", " {\"id\": example_id, \"prediction_text\": best_answer[\"text\"]}\n", " )\n", " else:\n", " predicted_answers.append({\"id\": example_id, \"prediction_text\": \"\"})\n", "\n", " theoretical_answers = [{\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in examples]\n", " return metric.compute(predictions=predicted_answers, references=theoretical_answers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'exact_match': 83.0, 'f1': 88.25}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "compute_metrics(start_logits, end_logits, eval_set, small_eval_set)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import TrainingArguments\n", "\n", "args = TrainingArguments(\n", " \"bert-finetuned-squad\",\n", " evaluation_strategy=\"no\",\n", " save_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " num_train_epochs=3,\n", " weight_decay=0.01,\n", " fp16=True,\n", " push_to_hub=True,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import Trainer\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=args,\n", " train_dataset=train_dataset,\n", " eval_dataset=validation_dataset,\n", " tokenizer=tokenizer,\n", ")\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'exact_match': 81.18259224219489, 'f1': 88.67381321905516}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions, _ = trainer.predict(validation_dataset)\n", "start_logits, end_logits = predictions\n", "compute_metrics(start_logits, end_logits, validation_dataset, raw_datasets[\"validation\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'https://huggingface.co/sgugger/bert-finetuned-squad/commit/9dcee1fbc25946a6ed4bb32efb1bd71d5fa90b68'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.push_to_hub(commit_message=\"Training complete\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "from transformers import default_data_collator\n", "\n", "train_dataset.set_format(\"torch\")\n", "validation_set = validation_dataset.remove_columns([\"example_id\", \"offset_mapping\"])\n", "validation_set.set_format(\"torch\")\n", "\n", "train_dataloader = DataLoader(\n", " train_dataset,\n", " shuffle=True,\n", " collate_fn=default_data_collator,\n", " batch_size=8,\n", ")\n", "eval_dataloader = DataLoader(\n", " validation_set, collate_fn=default_data_collator, batch_size=8\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.optim import AdamW\n", "\n", "optimizer = AdamW(model.parameters(), lr=2e-5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from accelerate import Accelerator\n", "\n", "accelerator = Accelerator(fp16=True)\n", "model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n", " model, optimizer, train_dataloader, eval_dataloader\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import get_scheduler\n", "\n", "num_train_epochs = 3\n", "num_update_steps_per_epoch = len(train_dataloader)\n", "num_training_steps = num_train_epochs * num_update_steps_per_epoch\n", "\n", "lr_scheduler = get_scheduler(\n", " \"linear\",\n", " optimizer=optimizer,\n", " num_warmup_steps=0,\n", " num_training_steps=num_training_steps,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'sgugger/bert-finetuned-squad-accelerate'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from huggingface_hub import Repository, get_full_repo_name\n", "\n", "model_name = \"bert-finetuned-squad-accelerate\"\n", "repo_name = get_full_repo_name(model_name)\n", "repo_name" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "output_dir = \"bert-finetuned-squad-accelerate\"\n", "repo = Repository(output_dir, clone_from=repo_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tqdm.auto import tqdm\n", "import torch\n", "\n", "progress_bar = tqdm(range(num_training_steps))\n", "\n", "for epoch in range(num_train_epochs):\n", " # Entraînement\n", " model.train()\n", " for step, batch in enumerate(train_dataloader):\n", " outputs = model(**batch)\n", " loss = outputs.loss\n", " accelerator.backward(loss)\n", "\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", " progress_bar.update(1)\n", "\n", " # Evaluation\n", " model.eval()\n", " start_logits = []\n", " end_logits = []\n", " accelerator.print(\"Evaluation!\")\n", " for batch in tqdm(eval_dataloader):\n", " with torch.no_grad():\n", " outputs = model(**batch)\n", "\n", " start_logits.append(accelerator.gather(outputs.start_logits).cpu().numpy())\n", " end_logits.append(accelerator.gather(outputs.end_logits).cpu().numpy())\n", "\n", " start_logits = np.concatenate(start_logits)\n", " end_logits = np.concatenate(end_logits)\n", " start_logits = start_logits[: len(validation_dataset)]\n", " end_logits = end_logits[: len(validation_dataset)]\n", "\n", " metrics = compute_metrics(\n", " start_logits, end_logits, validation_dataset, raw_datasets[\"validation\"]\n", " )\n", " print(f\"epoch {epoch}:\", metrics)\n", "\n", " # Sauvegarder et télécharger\n", " accelerator.wait_for_everyone()\n", " unwrapped_model = accelerator.unwrap_model(model)\n", " unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)\n", " if accelerator.is_main_process:\n", " tokenizer.save_pretrained(output_dir)\n", " repo.push_to_hub(\n", " commit_message=f\"Training in progress epoch {epoch}\", blocking=False\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "accelerator.wait_for_everyone()\n", "unwrapped_model = accelerator.unwrap_model(model)\n", "unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'score': 0.9979003071784973,\n", " 'start': 78,\n", " 'end': 105,\n", " 'answer': 'Jax, PyTorch and TensorFlow'}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import pipeline\n", "\n", "# Remplacez par votre propre checkpoint\n", "model_checkpoint = \"huggingface-course/bert-finetuned-squad\"\n", "question_answerer = pipeline(\"question-answering\", model=model_checkpoint)\n", "\n", "context = \"\"\"\n", "🤗 Transformers is backed by the three most popular deep learning libraries — Jax, PyTorch and TensorFlow — with a seamless integration\n", "between them. It's straightforward to train your models with one before loading them for inference with the other.\n", "\"\"\"\n", "question = \"Which deep learning libraries back 🤗 Transformers?\"\n", "question_answerer(question=question, context=context)" ] } ], "metadata": { "colab": { "name": "Réponse aux questions (PyTorch)", "provenance": [] } }, "nbformat": 4, "nbformat_minor": 4 }