{"cells":[{"cell_type":"markdown","metadata":{"id":"lQwwCPeVK7lU"},"source":["# Traduction (TensorFlow)"]},{"cell_type":"markdown","metadata":{"id":"jID8fKanK7lX"},"source":["Installez les bibliothรจques ๐Ÿค— *Datasets* et ๐Ÿค— *Transformers* pour exรฉcuter ce *notebook*."]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"background_save":true,"base_uri":"https://localhost:8080/"},"id":"n_9ZNCn0K7lZ"},"outputs":[{"name":"stdout","output_type":"stream","text":["Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting datasets\n"," Downloading datasets-2.6.1-py3-none-any.whl (441 kB)\n","\u001b[K |โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 441 kB 5.2 MB/s \n","\u001b[?25hCollecting transformers[sentencepiece]\n"," Downloading transformers-4.23.1-py3-none-any.whl (5.3 MB)\n","\u001b[K |โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 5.3 MB 38.1 MB/s \n","\u001b[?25hRequirement already satisfied: dill\u003c0.3.6 in /usr/local/lib/python3.7/dist-packages (from datasets) (0.3.5.1)\n","Collecting responses\u003c0.19\n"," Downloading responses-0.18.0-py3-none-any.whl (38 kB)\n","Requirement already satisfied: pyarrow\u003e=6.0.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0.1)\n","Requirement already satisfied: numpy\u003e=1.17 in /usr/local/lib/python3.7/dist-packages (from datasets) (1.21.6)\n","Requirement already satisfied: fsspec[http]\u003e=2021.11.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (2022.8.2)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.7/dist-packages (from datasets) (3.8.3)\n","Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from datasets) (4.13.0)\n","Collecting xxhash\n"," Downloading xxhash-3.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)\n","\u001b[K |โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 212 kB 45.6 MB/s \n","\u001b[?25hCollecting multiprocess\n"," Downloading multiprocess-0.70.14-py37-none-any.whl (115 kB)\n","\u001b[K |โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 115 kB 47.7 MB/s \n","\u001b[?25hRequirement already satisfied: pyyaml\u003e=5.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (6.0)\n","Collecting huggingface-hub\u003c1.0.0,\u003e=0.2.0\n"," Downloading huggingface_hub-0.10.1-py3-none-any.whl (163 kB)\n","\u001b[K |โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 163 kB 24.9 MB/s \n","\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from datasets) (1.3.5)\n","Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from datasets) (21.3)\n","Requirement already satisfied: tqdm\u003e=4.62.1 in /usr/local/lib/python3.7/dist-packages (from datasets) (4.64.1)\n","Requirement already satisfied: requests\u003e=2.19.0 in /usr/local/lib/python3.7/dist-packages (from datasets) (2.23.0)\n","Requirement already satisfied: charset-normalizer\u003c3.0,\u003e=2.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp-\u003edatasets) (2.1.1)\n","Requirement already satisfied: aiosignal\u003e=1.1.2 in /usr/local/lib/python3.7/dist-packages (from aiohttp-\u003edatasets) (1.2.0)\n","Requirement already satisfied: async-timeout\u003c5.0,\u003e=4.0.0a3 in /usr/local/lib/python3.7/dist-packages (from aiohttp-\u003edatasets) (4.0.2)\n","Requirement already satisfied: yarl\u003c2.0,\u003e=1.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp-\u003edatasets) (1.8.1)\n","Requirement already satisfied: asynctest==0.13.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp-\u003edatasets) (0.13.0)\n","Requirement already satisfied: attrs\u003e=17.3.0 in /usr/local/lib/python3.7/dist-packages (from aiohttp-\u003edatasets) (22.1.0)\n","Requirement already satisfied: multidict\u003c7.0,\u003e=4.5 in /usr/local/lib/python3.7/dist-packages (from aiohttp-\u003edatasets) (6.0.2)\n","Requirement already satisfied: typing-extensions\u003e=3.7.4 in /usr/local/lib/python3.7/dist-packages (from aiohttp-\u003edatasets) (4.1.1)\n","Requirement already satisfied: frozenlist\u003e=1.1.1 in /usr/local/lib/python3.7/dist-packages (from aiohttp-\u003edatasets) (1.3.1)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from huggingface-hub\u003c1.0.0,\u003e=0.2.0-\u003edatasets) (3.8.0)\n","Requirement already satisfied: pyparsing!=3.0.5,\u003e=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging-\u003edatasets) (3.0.9)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,\u003c1.26,\u003e=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests\u003e=2.19.0-\u003edatasets) (1.24.3)\n","Requirement already satisfied: chardet\u003c4,\u003e=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests\u003e=2.19.0-\u003edatasets) (3.0.4)\n","Requirement already satisfied: certifi\u003e=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests\u003e=2.19.0-\u003edatasets) (2022.9.24)\n","Requirement already satisfied: idna\u003c3,\u003e=2.5 in /usr/local/lib/python3.7/dist-packages (from requests\u003e=2.19.0-\u003edatasets) (2.10)\n","Collecting urllib3!=1.25.0,!=1.25.1,\u003c1.26,\u003e=1.21.1\n"," Downloading urllib3-1.25.11-py2.py3-none-any.whl (127 kB)\n","\u001b[K |โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 127 kB 14.0 MB/s \n","\u001b[?25hRequirement already satisfied: zipp\u003e=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata-\u003edatasets) (3.9.0)\n","Collecting multiprocess\n"," Downloading multiprocess-0.70.13-py37-none-any.whl (115 kB)\n","\u001b[K |โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 115 kB 37.5 MB/s \n","\u001b[?25hRequirement already satisfied: python-dateutil\u003e=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas-\u003edatasets) (2.8.2)\n","Requirement already satisfied: pytz\u003e=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas-\u003edatasets) (2022.4)\n","Requirement already satisfied: six\u003e=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil\u003e=2.7.3-\u003epandas-\u003edatasets) (1.15.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers[sentencepiece]) (2022.6.2)\n","Collecting tokenizers!=0.11.3,\u003c0.14,\u003e=0.11.1\n"," Downloading tokenizers-0.13.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)\n","\u001b[K |โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 7.6 MB 39.5 MB/s \n","\u001b[?25hCollecting sentencepiece!=0.1.92,\u003e=0.1.91\n"," Downloading sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n","\u001b[K |โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 1.3 MB 44.7 MB/s \n","\u001b[?25hRequirement already satisfied: protobuf\u003c=3.20.2 in /usr/local/lib/python3.7/dist-packages (from transformers[sentencepiece]) (3.17.3)\n","Installing collected packages: urllib3, tokenizers, huggingface-hub, xxhash, transformers, sentencepiece, responses, multiprocess, datasets\n"," Attempting uninstall: urllib3\n"," Found existing installation: urllib3 1.24.3\n"," Uninstalling urllib3-1.24.3:\n"," Successfully uninstalled urllib3-1.24.3\n","Successfully installed datasets-2.6.1 huggingface-hub-0.10.1 multiprocess-0.70.13 responses-0.18.0 sentencepiece-0.1.97 tokenizers-0.13.1 transformers-4.23.1 urllib3-1.25.11 xxhash-3.1.0\n","Reading package lists... Done\n","Building dependency tree \n","Reading state information... Done\n","git-lfs is already the newest version (2.3.4-1).\n","The following package was automatically installed and is no longer required:\n"," libnvidia-common-460\n","Use 'apt autoremove' to remove it.\n","0 upgraded, 0 newly installed, 0 to remove and 27 not upgraded.\n"]}],"source":["!pip install datasets transformers[sentencepiece]\n","!apt install git-lfs"]},{"cell_type":"markdown","metadata":{"id":"WxFtihYXK7lc"},"source":["Vous aurez besoin de configurer git, adaptez votre email et votre nom dans la cellule suivante."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"woP6GDVCK7le"},"outputs":[],"source":["!git config --global user.email \"you@example.com\"\n","!git config --global user.name \"Your Name\""]},{"cell_type":"markdown","metadata":{"id":"MLMwx807K7lf"},"source":["Vous devrez รฉgalement รชtre connectรฉ au Hub d'Hugging Face. Exรฉcutez ce qui suit et entrez vos informations d'identification."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"3P9UXPJsK7lf"},"outputs":[],"source":["from huggingface_hub import notebook_login\n","\n","notebook_login()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"W0YziRB-K7lg"},"outputs":[],"source":["from datasets import load_dataset, load_metric\n","\n","raw_datasets = load_dataset(\"kde4\", lang1=\"en\", lang2=\"fr\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"tFCg2ntAK7li"},"outputs":[],"source":["raw_datasets"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"V09kRyLyK7ll"},"outputs":[],"source":["split_datasets = raw_datasets[\"train\"].train_test_split(train_size=0.9, seed=20)\n","split_datasets"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"JvvrBODbK7lm"},"outputs":[],"source":["split_datasets[\"validation\"] = split_datasets.pop(\"test\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"_GROMRExK7ln"},"outputs":[],"source":["split_datasets[\"train\"][1][\"translation\"]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"mpAXVL-EK7lo"},"outputs":[],"source":["from transformers import pipeline\n","\n","model_checkpoint = \"Helsinki-NLP/opus-mt-en-fr\"\n","translator = pipeline(\"translation\", model=model_checkpoint)\n","translator(\"Default to expanded threads\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"epLNmY_HK7lp"},"outputs":[],"source":["split_datasets[\"train\"][172][\"translation\"]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"bY-XZVJCK7lp"},"outputs":[],"source":["translator(\n"," \"Unable to import %1 using the OFX importer plugin. This file is not the correct format.\"\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"tGpi3GlmK7lq"},"outputs":[],"source":["from transformers import AutoTokenizer\n","\n","model_checkpoint = \"Helsinki-NLP/opus-mt-en-fr\"\n","tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, return_tensors=\"tf\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Wtw_yJF2K7lr"},"outputs":[],"source":["en_sentence = split_datasets[\"train\"][1][\"translation\"][\"en\"]\n","fr_sentence = split_datasets[\"train\"][1][\"translation\"][\"fr\"]\n","\n","inputs = tokenizer(en_sentence)\n","with tokenizer.as_target_tokenizer():\n"," targets = tokenizer(fr_sentence)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"2ZV7SMUWK7ls"},"outputs":[],"source":["wrong_targets = tokenizer(fr_sentence)\n","print(tokenizer.convert_ids_to_tokens(wrong_targets[\"input_ids\"]))\n","print(tokenizer.convert_ids_to_tokens(targets[\"input_ids\"]))"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"j2YcgzcaK7lt"},"outputs":[],"source":["max_input_length = 128\n","max_target_length = 128\n","\n","\n","def preprocess_function(examples):\n"," inputs = [ex[\"en\"] for ex in examples[\"translation\"]]\n"," targets = [ex[\"fr\"] for ex in examples[\"translation\"]]\n"," model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)\n","\n"," # Configurer le tokenizer pour les cibles\n"," with tokenizer.as_target_tokenizer():\n"," labels = tokenizer(targets, max_length=max_target_length, truncation=True)\n","\n"," model_inputs[\"labels\"] = labels[\"input_ids\"]\n"," return model_inputs"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"HRpke7NdK7lt"},"outputs":[],"source":["tokenized_datasets = split_datasets.map(\n"," preprocess_function,\n"," batched=True,\n"," remove_columns=split_datasets[\"train\"].column_names,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Me0u8VVvK7lu"},"outputs":[],"source":["from transformers import TFAutoModelForSeq2SeqLM\n","\n","model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_pt=True)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"X6lZL_LWK7lu"},"outputs":[],"source":["from transformers import DataCollatorForSeq2Seq\n","\n","data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors=\"tf\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"uf9FgnBFK7lv"},"outputs":[],"source":["batch = data_collator([tokenized_datasets[\"train\"][i] for i in range(1, 3)])\n","batch.keys()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Axr6br0pK7lv"},"outputs":[],"source":["batch[\"labels\"]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Nw1IhmWRK7lw"},"outputs":[],"source":["batch[\"decoder_input_ids\"]"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"q0Q725XgK7lw"},"outputs":[],"source":["for i in range(1, 3):\n"," print(tokenized_datasets[\"train\"][i][\"labels\"])"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"l8uOxCo-K7ly"},"outputs":[],"source":["tf_train_dataset = model.prepare_tf_dataset(\n"," tokenized_datasets[\"train\"],\n"," collate_fn=data_collator,\n"," shuffle=True,\n"," batch_size=32,\n",")\n","\n","tf_eval_dataset = model.prepare_tf_dataset(\n"," tokenized_datasets[\"validation\"],\n"," collate_fn=data_collator,\n"," shuffle=False,\n"," batch_size=16,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"X6SFfWWHK7ly"},"outputs":[],"source":["!pip install sacrebleu"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"hDadfs1JK7ly"},"outputs":[],"source":["from datasets import load_metric\n","\n","metric = load_metric(\"sacrebleu\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"u9Moa_SuK7lz"},"outputs":[],"source":["predictions = [\n"," \"This plugin lets you translate web pages between several languages automatically.\"\n","]\n","references = [\n"," [\n"," \"This plugin allows you to automatically translate web pages between several languages.\"\n"," ]\n","]\n","metric.compute(predictions=predictions, references=references)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ozUUl2TxK7lz"},"outputs":[],"source":["predictions = [\"This This This This\"]\n","references = [\n"," [\n"," \"This plugin allows you to automatically translate web pages between several languages.\"\n"," ]\n","]\n","metric.compute(predictions=predictions, references=references)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"dszazssQK7l0"},"outputs":[],"source":["predictions = [\"This plugin\"]\n","references = [\n"," [\n"," \"This plugin allows you to automatically translate web pages between several languages.\"\n"," ]\n","]\n","metric.compute(predictions=predictions, references=references)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ZmHOtSV4K7l0"},"outputs":[],"source":["import numpy as np\n","import tensorflow as tf\n","from tqdm import tqdm\n","generation_data_collator = DataCollatorForSeq2Seq(\n"," tokenizer, model=model, return_tensors=\"tf\", pad_to_multiple_of=128\n",")\n","tf_generate_dataset = model.prepare_tf_dataset(\n"," tokenized_datasets[\"validation\"],\n"," collate_fn=generation_data_collator,\n"," shuffle=False,\n"," batch_size=8,\n",")\n","@tf.function(jit_compile=True)\n","def generate_with_xla(batch):\n"," return model.generate(\n"," input_ids=batch[\"input_ids\"],\n"," attention_mask=batch[\"attention_mask\"],\n"," max_new_tokens=128,\n"," )\n","\n","def compute_metrics():\n"," all_preds = []\n"," all_labels = []\n","\n"," for batch, labels in tqdm(tf_generate_dataset):\n"," predictions = generate_with_xla(batch)\n","\n"," for batch in tf_generate_dataset:\n"," predictions = model.generate(\n"," input_ids=batch[\"input_ids\"], attention_mask=batch[\"attention_mask\"]\n"," )\n"," labels = labels.numpy()\n"," labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n"," decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n"," decoded_preds = [pred.strip() for pred in decoded_preds]\n"," decoded_labels = [[label.strip()] for label in decoded_labels]\n"," all_preds.extend(decoded_preds)\n"," all_labels.extend(decoded_labels)\n"," result = metric.compute(predictions=all_preds, references=all_labels)\n"," return {\"bleu\": result[\"score\"]}"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"X11nbLUtK7l1"},"outputs":[],"source":["from huggingface_hub import notebook_login\n","\n","notebook_login()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"hmSjXUvBK7l2"},"outputs":[],"source":["print(compute_metrics())"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"-7ciXHiMK7l2"},"outputs":[],"source":["from transformers import create_optimizer\n","from transformers.keras_callbacks import PushToHubCallback\n","import tensorflow as tf\n","\n","# Le nombre d'รฉtapes d'entraรฎnement est le nombre d'รฉchantillons dans le jeu de donnรฉes, divisรฉ par la taille du batch puis multipliรฉ\n","# par le nombre total d'รฉpoques. Notez que le jeu de donnรฉes tf_train_dataset est ici un lot de donnรฉes tf.data.Dataset,\n","# pas le jeu de donnรฉes original Hugging Face, donc son len() est dรฉjร  num_samples // batch_size.\n","num_epochs = 3\n","num_train_steps = len(tf_train_dataset) * num_epochs\n","\n","optimizer, schedule = create_optimizer(\n"," init_lr=5e-5,\n"," num_warmup_steps=0,\n"," num_train_steps=num_train_steps,\n"," weight_decay_rate=0.01,\n",")\n","model.compile(optimizer=optimizer)\n","\n","# Entraรฎner en mixed-precision float16\n","tf.keras.mixed_precision.set_global_policy(\"mixed_float16\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Zpd9e2R1K7l3"},"outputs":[],"source":["from transformers.keras_callbacks import PushToHubCallback\n","\n","callback = PushToHubCallback(\n"," output_dir=\"marian-finetuned-kde4-en-to-fr\", tokenizer=tokenizer\n",")\n","\n","model.fit(\n"," tf_train_dataset,\n"," validation_data=tf_eval_dataset,\n"," callbacks=[callback],\n"," epochs=num_epochs,\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"VGGvxYbUK7l3"},"outputs":[],"source":["print(compute_metrics())"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"XBfYvd5qK7l4"},"outputs":[],"source":["from transformers import pipeline\n","\n","# Remplacer par votre propre checkpoint\n","model_checkpoint = \"huggingface-course/marian-finetuned-kde4-en-to-fr\"\n","translator = pipeline(\"translation\", model=model_checkpoint)\n","translator(\"Default to expanded threads\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"18mI4K4PK7l5"},"outputs":[],"source":["translator(\n"," \"Unable to import %1 using the OFX importer plugin. This file is not the correct format.\"\n",")"]}],"metadata":{"accelerator":"GPU","colab":{"collapsed_sections":[],"name":"","version":""},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.5"}},"nbformat":4,"nbformat_minor":0}