{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Traduction (PyTorch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Install the Transformers, Datasets, and Evaluate libraries to run this notebook." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install datasets evaluate transformers[sentencepiece]\n", "!pip install accelerate\n", "# To run the training on TPU, you will need to uncomment the followin line:\n", "# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n", "!apt install git-lfs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You will need to setup git, adapt your email and name in the following cell." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!git config --global user.email \"you@example.com\"\n", "!git config --global user.name \"Your Name\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "raw_datasets = load_dataset(\"kde4\", lang1=\"en\", lang2=\"fr\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['id', 'translation'],\n", " num_rows: 210173\n", " })\n", "})" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['id', 'translation'],\n", " num_rows: 189155\n", " })\n", " test: Dataset({\n", " features: ['id', 'translation'],\n", " num_rows: 21018\n", " })\n", "})" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "split_datasets = raw_datasets[\"train\"].train_test_split(train_size=0.9, seed=20)\n", "split_datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "split_datasets[\"validation\"] = split_datasets.pop(\"test\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'en': 'Default to expanded threads',\n", " 'fr': 'Par défaut, développer les fils de discussion'}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "split_datasets[\"train\"][1][\"translation\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'translation_text': 'Par défaut pour les threads élargis'}]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import pipeline\n", "\n", "model_checkpoint = \"Helsinki-NLP/opus-mt-en-fr\"\n", "translator = pipeline(\"translation\", model=model_checkpoint)\n", "translator(\"Default to expanded threads\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'en': 'Unable to import %1 using the OFX importer plugin. This file is not the correct format.',\n", " 'fr': \"Impossible d'importer %1 en utilisant le module d'extension d'importation OFX. Ce fichier n'a pas un format correct.\"}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "split_datasets[\"train\"][172][\"translation\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'translation_text': \"Impossible d'importer %1 en utilisant le plugin d'importateur OFX. Ce fichier n'est pas le bon format.\"}]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "translator(\n", " \"Unable to import %1 using the OFX importer plugin. This file is not the correct format.\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "model_checkpoint = \"Helsinki-NLP/opus-mt-en-fr\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors=\"tf\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "en_sentence = split_datasets[\"train\"][1][\"translation\"][\"en\"]\n", "fr_sentence = split_datasets[\"train\"][1][\"translation\"][\"fr\"]\n", "\n", "inputs = tokenizer(en_sentence)\n", "with tokenizer.as_target_tokenizer():\n", " targets = tokenizer(fr_sentence)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['▁Par', '▁dé', 'f', 'aut', ',', '▁dé', 've', 'lop', 'per', '▁les', '▁fil', 's', '▁de', '▁discussion', '']\n", "['▁Par', '▁défaut', ',', '▁développer', '▁les', '▁fils', '▁de', '▁discussion', '']" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "wrong_targets = tokenizer(fr_sentence)\n", "print(tokenizer.convert_ids_to_tokens(wrong_targets[\"input_ids\"]))\n", "print(tokenizer.convert_ids_to_tokens(targets[\"input_ids\"]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "max_input_length = 128\n", "max_target_length = 128\n", "\n", "\n", "def preprocess_function(examples):\n", " inputs = [ex[\"en\"] for ex in examples[\"translation\"]]\n", " targets = [ex[\"fr\"] for ex in examples[\"translation\"]]\n", " model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)\n", "\n", " # Configurer le tokenizer pour les cibles.\n", " with tokenizer.as_target_tokenizer():\n", " labels = tokenizer(targets, max_length=max_target_length, truncation=True)\n", "\n", " model_inputs[\"labels\"] = labels[\"input_ids\"]\n", " return model_inputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tokenized_datasets = split_datasets.map(\n", " preprocess_function,\n", " batched=True,\n", " remove_columns=split_datasets[\"train\"].column_names,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForSeq2SeqLM\n", "\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import DataCollatorForSeq2Seq\n", "\n", "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['attention_mask', 'input_ids', 'labels', 'decoder_input_ids'])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch = data_collator([tokenized_datasets[\"train\"][i] for i in range(1, 3)])\n", "batch.keys()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[ 577, 5891, 2, 3184, 16, 2542, 5, 1710, 0, -100,\n", " -100, -100, -100, -100, -100, -100],\n", " [ 1211, 3, 49, 9409, 1211, 3, 29140, 817, 3124, 817,\n", " 550, 7032, 5821, 7907, 12649, 0]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch[\"labels\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[59513, 577, 5891, 2, 3184, 16, 2542, 5, 1710, 0,\n", " 59513, 59513, 59513, 59513, 59513, 59513],\n", " [59513, 1211, 3, 49, 9409, 1211, 3, 29140, 817, 3124,\n", " 817, 550, 7032, 5821, 7907, 12649]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch[\"decoder_input_ids\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[577, 5891, 2, 3184, 16, 2542, 5, 1710, 0]\n", "[1211, 3, 49, 9409, 1211, 3, 29140, 817, 3124, 817, 550, 7032, 5821, 7907, 12649, 0]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "for i in range(1, 3):\n", " print(tokenized_datasets[\"train\"][i][\"labels\"])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install sacrebleu" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import evaluate\n", "\n", "metric = evaluate.load(\"sacrebleu\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'score': 46.750469682990165,\n", " 'counts': [11, 6, 4, 3],\n", " 'totals': [12, 11, 10, 9],\n", " 'precisions': [91.67, 54.54, 40.0, 33.33],\n", " 'bp': 0.9200444146293233,\n", " 'sys_len': 12,\n", " 'ref_len': 13}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions = [\n", " \"This plugin lets you translate web pages between several languages automatically.\"\n", "]\n", "references = [\n", " [\n", " \"This plugin allows you to automatically translate web pages between several languages.\"\n", " ]\n", "]\n", "metric.compute(predictions=predictions, references=references)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'score': 1.683602693167689,\n", " 'counts': [1, 0, 0, 0],\n", " 'totals': [4, 3, 2, 1],\n", " 'precisions': [25.0, 16.67, 12.5, 12.5],\n", " 'bp': 0.10539922456186433,\n", " 'sys_len': 4,\n", " 'ref_len': 13}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions = [\"This This This This\"]\n", "references = [\n", " [\n", " \"This plugin allows you to automatically translate web pages between several languages.\"\n", " ]\n", "]\n", "metric.compute(predictions=predictions, references=references)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'score': 0.0,\n", " 'counts': [2, 1, 0, 0],\n", " 'totals': [2, 1, 0, 0],\n", " 'precisions': [100.0, 100.0, 0.0, 0.0],\n", " 'bp': 0.004086771438464067,\n", " 'sys_len': 2,\n", " 'ref_len': 13}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predictions = [\"This plugin\"]\n", "references = [\n", " [\n", " \"This plugin allows you to automatically translate web pages between several languages.\"\n", " ]\n", "]\n", "metric.compute(predictions=predictions, references=references)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "\n", "def compute_metrics(eval_preds):\n", " preds, labels = eval_preds\n", " # Dans le cas où le modèle retourne plus que les logits de prédiction\n", " if isinstance(preds, tuple):\n", " preds = preds[0]\n", "\n", " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n", "\n", " # Remplacer les -100 dans les étiquettes car nous ne pouvons pas les décoder\n", " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n", " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", "\n", " # Quelques post-traitements simples\n", " decoded_preds = [pred.strip() for pred in decoded_preds]\n", " decoded_labels = [[label.strip()] for label in decoded_labels]\n", "\n", " result = metric.compute(predictions=decoded_preds, references=decoded_labels)\n", " return {\"bleu\": result[\"score\"]}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import Seq2SeqTrainingArguments\n", "\n", "args = Seq2SeqTrainingArguments(\n", " f\"marian-finetuned-kde4-en-to-fr\",\n", " evaluation_strategy=\"no\",\n", " save_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=32,\n", " per_device_eval_batch_size=64,\n", " weight_decay=0.01,\n", " save_total_limit=3,\n", " num_train_epochs=3,\n", " predict_with_generate=True,\n", " fp16=True,\n", " push_to_hub=True,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import Seq2SeqTrainer\n", "\n", "trainer = Seq2SeqTrainer(\n", " model,\n", " args,\n", " train_dataset=tokenized_datasets[\"train\"],\n", " eval_dataset=tokenized_datasets[\"validation\"],\n", " data_collator=data_collator,\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'eval_loss': 1.6964408159255981,\n", " 'eval_bleu': 39.26865061007616,\n", " 'eval_runtime': 965.8884,\n", " 'eval_samples_per_second': 21.76,\n", " 'eval_steps_per_second': 0.341}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate(max_length=max_target_length)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'eval_loss': 0.8558505773544312,\n", " 'eval_bleu': 52.94161337775576,\n", " 'eval_runtime': 714.2576,\n", " 'eval_samples_per_second': 29.426,\n", " 'eval_steps_per_second': 0.461,\n", " 'epoch': 3.0}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate(max_length=max_target_length)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'https://huggingface.co/sgugger/marian-finetuned-kde4-en-to-fr/commit/3601d621e3baae2bc63d3311452535f8f58f6ef3'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.push_to_hub(tags=\"translation\", commit_message=\"Training complete\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "tokenized_datasets.set_format(\"torch\")\n", "train_dataloader = DataLoader(\n", " tokenized_datasets[\"train\"],\n", " shuffle=True,\n", " collate_fn=data_collator,\n", " batch_size=8,\n", ")\n", "eval_dataloader = DataLoader(\n", " tokenized_datasets[\"validation\"], collate_fn=data_collator, batch_size=8\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import AdamW\n", "\n", "optimizer = AdamW(model.parameters(), lr=2e-5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from accelerate import Accelerator\n", "\n", "accelerator = Accelerator()\n", "model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n", " model, optimizer, train_dataloader, eval_dataloader\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers import get_scheduler\n", "\n", "num_train_epochs = 3\n", "num_update_steps_per_epoch = len(train_dataloader)\n", "num_training_steps = num_train_epochs * num_update_steps_per_epoch\n", "\n", "lr_scheduler = get_scheduler(\n", " \"linear\",\n", " optimizer=optimizer,\n", " num_warmup_steps=0,\n", " num_training_steps=num_training_steps,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'sgugger/marian-finetuned-kde4-en-to-fr-accelerate'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from huggingface_hub import Repository, get_full_repo_name\n", "\n", "model_name = \"marian-finetuned-kde4-en-to-fr-accelerate\"\n", "repo_name = get_full_repo_name(model_name)\n", "repo_name" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "output_dir = \"marian-finetuned-kde4-en-to-fr-accelerate\"\n", "repo = Repository(output_dir, clone_from=repo_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def postprocess(predictions, labels):\n", " predictions = predictions.cpu().numpy()\n", " labels = labels.cpu().numpy()\n", "\n", " decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n", "\n", " # Remplace -100 dans les étiquettes car nous ne pouvons pas les décoder\n", " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n", " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", "\n", " # Quelques post-traitements simples\n", " decoded_preds = [pred.strip() for pred in decoded_preds]\n", " decoded_labels = [[label.strip()] for label in decoded_labels]\n", " return decoded_preds, decoded_labels" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "epoch 0, BLEU score: 53.47\n", "epoch 1, BLEU score: 54.24\n", "epoch 2, BLEU score: 54.44" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from tqdm.auto import tqdm\n", "import torch\n", "\n", "progress_bar = tqdm(range(num_training_steps))\n", "\n", "for epoch in range(num_train_epochs):\n", " # Entraînement\n", " model.train()\n", " for batch in train_dataloader:\n", " outputs = model(**batch)\n", " loss = outputs.loss\n", " accelerator.backward(loss)\n", "\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", " progress_bar.update(1)\n", "\n", " # Evaluation\n", " model.eval()\n", " for batch in tqdm(eval_dataloader):\n", " with torch.no_grad():\n", " generated_tokens = accelerator.unwrap_model(model).generate(\n", " batch[\"input_ids\"],\n", " attention_mask=batch[\"attention_mask\"],\n", " max_length=128,\n", " )\n", " labels = batch[\"labels\"]\n", "\n", " # Nécessaire pour rembourrer les prédictions et les étiquettes à rassembler\n", " generated_tokens = accelerator.pad_across_processes(\n", " generated_tokens, dim=1, pad_index=tokenizer.pad_token_id\n", " )\n", " labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)\n", "\n", " predictions_gathered = accelerator.gather(generated_tokens)\n", " labels_gathered = accelerator.gather(labels)\n", "\n", " decoded_preds, decoded_labels = postprocess(predictions_gathered, labels_gathered)\n", " metric.add_batch(predictions=decoded_preds, references=decoded_labels)\n", "\n", " results = metric.compute()\n", " print(f\"epoch {epoch}, BLEU score: {results['score']:.2f}\")\n", "\n", " # Sauvegarder et télécharger\n", " accelerator.wait_for_everyone()\n", " unwrapped_model = accelerator.unwrap_model(model)\n", " unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)\n", " if accelerator.is_main_process:\n", " tokenizer.save_pretrained(output_dir)\n", " repo.push_to_hub(\n", " commit_message=f\"Training in progress epoch {epoch}\", blocking=False\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'translation_text': 'Par défaut, développer les fils de discussion'}]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import pipeline\n", "\n", "# Remplacez ceci par votre propre checkpoint\n", "model_checkpoint = \"huggingface-course/marian-finetuned-kde4-en-to-fr\"\n", "translator = pipeline(\"translation\", model=model_checkpoint)\n", "translator(\"Default to expanded threads\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[{'translation_text': \"Impossible d'importer %1 en utilisant le module externe d'importation OFX. Ce fichier n'est pas le bon format.\"}]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "translator(\n", " \"Unable to import %1 using the OFX importer plugin. This file is not the correct format.\"\n", ")" ] } ], "metadata": { "colab": { "name": "Traduction (PyTorch)", "provenance": [] } }, "nbformat": 4, "nbformat_minor": 4 }