{"cells":[{"cell_type":"markdown","metadata":{"id":"Pa8a2g5mK-N9"},"source":["# Résumé (PyTorch)"]},{"cell_type":"markdown","metadata":{"id":"fJvNiRTdK-N_"},"source":["Installez les bibliothèques 🤗 *Datasets* et 🤗 *Transformers* pour exécuter ce *notebook*."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"hNHv1f6IK-OB"},"outputs":[],"source":["!pip install datasets transformers[sentencepiece]\n","!pip install accelerate\n","# Pour exécuter l'entraînement sur TPU, vous devez décommenter la ligne suivante :\n","# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl\n","!apt install git-lfs"]},{"cell_type":"markdown","metadata":{"id":"Xx-gfTuTK-OD"},"source":["Vous aurez besoin de configurer git, adaptez votre email et votre nom dans la cellule suivante."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"f9B3GDbhK-OE"},"outputs":[],"source":["!git config --global user.email \"you@example.com\"\n","!git config --global user.name \"Your Name\""]},{"cell_type":"markdown","metadata":{"id":"ur5BwnyKK-OF"},"source":["Vous devrez également être connecté au Hub d'Hugging Face. Exécutez ce qui suit et entrez vos informations d'identification."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"1WQHJadxK-OG"},"outputs":[],"source":["from huggingface_hub import notebook_login\n","\n","notebook_login()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"V10eZAq-K-OH"},"outputs":[],"source":["from datasets import load_dataset\n","\n","french_dataset = load_dataset(\"amazon_reviews_multi\", \"fr\")\n","english_dataset = load_dataset(\"amazon_reviews_multi\", \"en\")\n","french_dataset"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"gVHoJ0aNK-OJ"},"outputs":[],"source":["def show_samples(dataset, num_samples=3, seed=42):\n"," sample = dataset[\"train\"].shuffle(seed=seed).select(range(num_samples))\n"," for example in sample:\n"," print(f\"\\n'>> Title: {example['review_title']}'\")\n"," print(f\"'>> Review: {example['review_body']}'\")\n","\n","show_samples(french_dataset)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"1NPG9ebhK-OL"},"outputs":[],"source":["french_dataset.set_format(\"pandas\")\n","french_df = french_dataset[\"train\"][:]\n","# Afficher les comptes des 20 premiers produits\n","french_df[\"product_category\"].value_counts()[:20]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"BkFJ3vxdK-OM"},"outputs":[],"source":["def filter_books(example):\n"," return (\n"," example[\"product_category\"] == \"book\"\n"," or example[\"product_category\"] == \"digital_ebook_purchase\"\n"," )"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Wq7ako5PK-ON"},"outputs":[],"source":["french_dataset.reset_format()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"XQ3xI0fCK-OO"},"outputs":[],"source":["french_books = french_dataset.filter(filter_books)\n","english_books = english_dataset.filter(filter_books)\n","show_samples(french_dataset)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"4fdfxmRjK-OP"},"outputs":[],"source":["from datasets import concatenate_datasets, DatasetDict\n","\n","books_dataset = DatasetDict()\n","\n","for split in english_books.keys():\n"," books_dataset[split] = concatenate_datasets(\n"," [english_books[split], french_books[split]]\n"," )\n"," books_dataset[split] = books_dataset[split].shuffle(seed=42)\n","\n","# Quelques exemples\n","show_samples(books_dataset)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Cd-FUPHMK-OP"},"outputs":[],"source":["books_dataset = books_dataset.filter(lambda x: len(x[\"review_title\"].split()) > 2)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"93NzoKWPK-OQ"},"outputs":[],"source":["from transformers import AutoTokenizer\n","\n","model_checkpoint = \"google/mt5-small\"\n","tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"3u9KwyFNK-OR"},"outputs":[],"source":["inputs = tokenizer(\"J'ai adoré lire les Hunger Games !\")\n","inputs"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bZLmFI1vK-OS"},"outputs":[],"source":["tokenizer.convert_ids_to_tokens(inputs.input_ids)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"HCboaKIeK-OT"},"outputs":[],"source":["max_input_length = 512\n","max_target_length = 30\n","\n","\n","def preprocess_function(examples):\n"," model_inputs = tokenizer(\n"," examples[\"review_body\"], max_length=max_input_length, truncation=True\n"," )\n"," # Configurer le tokenizer pour les cibles\n"," with tokenizer.as_target_tokenizer():\n"," labels = tokenizer(\n"," examples[\"review_title\"], max_length=max_target_length, truncation=True\n"," )\n","\n"," model_inputs[\"labels\"] = labels[\"input_ids\"]\n"," return model_inputs"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"E2YvimJQK-OT"},"outputs":[],"source":["tokenized_datasets = books_dataset.map(preprocess_function, batched=True)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"DMtI8TCaK-OU"},"outputs":[],"source":["generated_summary = \"J'ai absolument adoré lire les Hunger Games\"\n","reference_summary = \"J'ai adoré lire les Hunger Games\""]},{"cell_type":"code","execution_count":null,"metadata":{"id":"4V2wHKLKK-OV"},"outputs":[],"source":["!pip install rouge_score"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"4kubvqkdK-OW"},"outputs":[],"source":["from datasets import load_metric\n","\n","rouge_score = load_metric(\"rouge\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"EEzmHqVlK-OW"},"outputs":[],"source":["scores = rouge_score.compute(\n"," predictions=[generated_summary], references=[reference_summary]\n",")\n","scores"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"n0J84AULK-OX"},"outputs":[],"source":["scores[\"rouge1\"].mid"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"e4kQhIjXK-OY"},"outputs":[],"source":["!pip install nltk"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"jPdNK5qoK-OY"},"outputs":[],"source":["import nltk\n","\n","nltk.download(\"punkt\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"3gFUAtOhK-OY"},"outputs":[],"source":["from nltk.tokenize import sent_tokenize\n","\n","\n","def three_sentence_summary(text):\n"," return \"\\n\".join(sent_tokenize(text)[:3])\n","\n","\n","print(three_sentence_summary(books_dataset[\"train\"][1][\"review_body\"]))"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"rtLaA9GrK-OZ"},"outputs":[],"source":["def evaluate_baseline(dataset, metric):\n"," summaries = [three_sentence_summary(text) for text in dataset[\"review_body\"]]\n"," return metric.compute(predictions=summaries, references=dataset[\"review_title\"])"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"XdtbG9A_K-OZ"},"outputs":[],"source":["import pandas as pd\n","\n","score = evaluate_baseline(books_dataset[\"validation\"], rouge_score)\n","rouge_names = [\"rouge1\", \"rouge2\", \"rougeL\", \"rougeLsum\"]\n","rouge_dict = dict((rn, round(score[rn].mid.fmeasure * 100, 2)) for rn in rouge_names)\n","rouge_dict"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"cIQlTsaoK-Oa"},"outputs":[],"source":["from transformers import AutoModelForSeq2SeqLM\n","\n","model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"SeVaDGwoK-Ob"},"outputs":[],"source":["from transformers import Seq2SeqTrainingArguments\n","\n","batch_size = 8\n","num_train_epochs = 8\n","# Montre la perte d'entraînement à chaque époque\n","logging_steps = len(tokenized_datasets[\"train\"]) // batch_size\n","model_name = model_checkpoint.split(\"/\")[-1]\n","\n","args = Seq2SeqTrainingArguments(\n"," output_dir=f\"{model_name}-finetuned-amazon-en-fr\",\n"," evaluation_strategy=\"epoch\",\n"," learning_rate=5.6e-5,\n"," per_device_train_batch_size=batch_size,\n"," per_device_eval_batch_size=batch_size,\n"," weight_decay=0.01,\n"," save_total_limit=3,\n"," num_train_epochs=num_train_epochs,\n"," predict_with_generate=True,\n"," logging_steps=logging_steps,\n"," push_to_hub=True,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"H_ddkeuUK-Oc"},"outputs":[],"source":["import numpy as np\n","\n","\n","def compute_metrics(eval_pred):\n"," predictions, labels = eval_pred\n"," # Décoder les résumés générés en texte\n"," decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n"," # Remplacer -100 dans les étiquettes car nous ne pouvons pas les décoder\n"," labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n"," # Décoder les résumés de référence en texte\n"," decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n"," # ROUGE attend une nouvelle ligne après chaque phrase\n"," decoded_preds = [\"\\n\".join(sent_tokenize(pred.strip())) for pred in decoded_preds]\n"," decoded_labels = [\"\\n\".join(sent_tokenize(label.strip())) for label in decoded_labels]\n"," # Calculer les scores ROUGE\n"," result = rouge_score.compute(\n"," predictions=decoded_preds, references=decoded_labels, use_stemmer=True\n"," )\n"," # Extraire les scores médians\n"," result = {key: value.mid.fmeasure * 100 for key, value in result.items()}\n"," return {k: round(v, 4) for k, v in result.items()}"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Of9WBMiIK-Od"},"outputs":[],"source":["from transformers import DataCollatorForSeq2Seq\n","\n","data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"qmhWCSNcK-Od"},"outputs":[],"source":["tokenized_datasets = tokenized_datasets.remove_columns(\n"," books_dataset[\"train\"].column_names\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"aPxkhz01K-Od"},"outputs":[],"source":["features = [tokenized_datasets[\"train\"][i] for i in range(2)]\n","data_collator(features)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"RVemQHl4K-Oe"},"outputs":[],"source":["from transformers import Seq2SeqTrainer\n","\n","trainer = Seq2SeqTrainer(\n"," model,\n"," args,\n"," train_dataset=tokenized_datasets[\"train\"],\n"," eval_dataset=tokenized_datasets[\"validation\"],\n"," data_collator=data_collator,\n"," tokenizer=tokenizer,\n"," compute_metrics=compute_metrics,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"LeovMinHK-Oe"},"outputs":[],"source":["trainer.train()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Qif2SSS1K-Oe"},"outputs":[],"source":["trainer.evaluate()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"IvF9-7WiK-Of"},"outputs":[],"source":["tokenized_datasets.set_format(\"torch\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"lDXXUDUSK-Of"},"outputs":[],"source":["model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"vr2uehAzK-Og"},"outputs":[],"source":["from torch.utils.data import DataLoader\n","\n","batch_size = 8\n","train_dataloader = DataLoader(\n"," tokenized_datasets[\"train\"],\n"," shuffle=True,\n"," collate_fn=data_collator,\n"," batch_size=batch_size,\n",")\n","eval_dataloader = DataLoader(\n"," tokenized_datasets[\"validation\"], collate_fn=data_collator, batch_size=batch_size\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Co3Kh6KkK-Og"},"outputs":[],"source":["from torch.optim import AdamW\n","\n","optimizer = AdamW(model.parameters(), lr=2e-5)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"wB8Vb3O4K-Og"},"outputs":[],"source":["from accelerate import Accelerator\n","\n","accelerator = Accelerator()\n","model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n"," model, optimizer, train_dataloader, eval_dataloader\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"rc-HG8orK-Oh"},"outputs":[],"source":["from transformers import get_scheduler\n","\n","num_train_epochs = 10\n","num_update_steps_per_epoch = len(train_dataloader)\n","num_training_steps = num_train_epochs * num_update_steps_per_epoch\n","\n","lr_scheduler = get_scheduler(\n"," \"linear\",\n"," optimizer=optimizer,\n"," num_warmup_steps=0,\n"," num_training_steps=num_training_steps,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"lcZUY4u1K-Oh"},"outputs":[],"source":["def postprocess_text(preds, labels):\n"," preds = [pred.strip() for pred in preds]\n"," labels = [label.strip() for label in labels]\n","\n"," # ROUGE attend une nouvelle ligne après chaque phrase\n"," preds = [\"\\n\".join(nltk.sent_tokenize(pred)) for pred in preds]\n"," labels = [\"\\n\".join(nltk.sent_tokenize(label)) for label in labels]\n","\n"," return preds, labels"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"MBsEsY2CK-Oi"},"outputs":[],"source":["from huggingface_hub import get_full_repo_name\n","\n","model_name = \"test-bert-finetuned-squad-accelerate\"\n","repo_name = get_full_repo_name(model_name)\n","repo_name"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"lBr2mNruK-Oi"},"outputs":[],"source":["from huggingface_hub import Repository\n","\n","output_dir = \"results-mt5-finetuned-squad-accelerate\"\n","repo = Repository(output_dir, clone_from=repo_name)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ZbkZgFlPK-Oj"},"outputs":[],"source":["from tqdm.auto import tqdm\n","import torch\n","import numpy as np\n","\n","progress_bar = tqdm(range(num_training_steps))\n","\n","for epoch in range(num_train_epochs):\n"," # Entraînement\n"," model.train()\n"," for step, batch in enumerate(train_dataloader):\n"," outputs = model(**batch)\n"," loss = outputs.loss\n"," accelerator.backward(loss)\n","\n"," optimizer.step()\n"," lr_scheduler.step()\n"," optimizer.zero_grad()\n"," progress_bar.update(1)\n","\n"," # Evaluation\n"," model.eval()\n"," for step, batch in enumerate(eval_dataloader):\n"," with torch.no_grad():\n"," generated_tokens = accelerator.unwrap_model(model).generate(\n"," batch[\"input_ids\"],\n"," attention_mask=batch[\"attention_mask\"],\n"," )\n","\n"," generated_tokens = accelerator.pad_across_processes(\n"," generated_tokens, dim=1, pad_index=tokenizer.pad_token_id\n"," )\n"," labels = batch[\"labels\"]\n","\n"," # Si nous n'avons pas rempli la longueur maximale, nous devons également remplir les étiquettes\n"," labels = accelerator.pad_across_processes(\n"," batch[\"labels\"], dim=1, pad_index=tokenizer.pad_token_id\n"," )\n","\n"," generated_tokens = accelerator.gather(generated_tokens).cpu().numpy()\n"," labels = accelerator.gather(labels).cpu().numpy()\n","\n"," # Remplacer -100 dans les étiquettes car nous ne pouvons pas les décoder\n"," labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n"," if isinstance(generated_tokens, tuple):\n"," generated_tokens = generated_tokens[0]\n"," decoded_preds = tokenizer.batch_decode(\n"," generated_tokens, skip_special_tokens=True\n"," )\n"," decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n","\n"," decoded_preds, decoded_labels = postprocess_text(\n"," decoded_preds, decoded_labels\n"," )\n","\n"," rouge_score.add_batch(predictions=decoded_preds, references=decoded_labels)\n","\n"," # Calculer les métriques\n"," result = rouge_score.compute()\n"," # Extraire les scores médians de ROUGE\n"," result = {key: value.mid.fmeasure * 100 for key, value in result.items()}\n"," result = {k: round(v, 4) for k, v in result.items()}\n"," print(f\"Epoch {epoch}:\", result)\n","\n"," # Sauvegarder et télécharger\n"," accelerator.wait_for_everyone()\n"," unwrapped_model = accelerator.unwrap_model(model)\n"," unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)\n"," if accelerator.is_main_process:\n"," tokenizer.save_pretrained(output_dir)\n"," repo.push_to_hub(\n"," commit_message=f\"Training in progress epoch {epoch}\", blocking=False\n"," )"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Sh71Ed8SK-Ok"},"outputs":[],"source":["from transformers import pipeline\n","\n","hub_model_id = \"huggingface-course/mt5-small-finetuned-amazon-en-fr\"\n","summarizer = pipeline(\"summarization\", model=hub_model_id)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"oobImt2RK-Ok"},"outputs":[],"source":["def print_summary(idx):\n"," review = books_dataset[\"test\"][idx][\"review_body\"]\n"," title = books_dataset[\"test\"][idx][\"review_title\"]\n"," summary = summarizer(books_dataset[\"test\"][idx][\"review_body\"])[0][\"summary_text\"]\n"," print(f\"'>>> Review: {review}'\")\n"," print(f\"\\n'>>> Title: {title}'\")\n"," print(f\"\\n'>>> Summary: {summary}'\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dDCLY4B-K-Ol"},"outputs":[],"source":["print_summary(100)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"wFLsEqzhK-Ol"},"outputs":[],"source":["print_summary(0)"]}],"metadata":{"colab":{"provenance":[],"collapsed_sections":[]},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.5"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0}