{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "X4cRE8IbIrIV" }, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. Uncomment the following cell and run it. We also use the `sacrebleu` and `sentencepiece` libraries - you may need to install these even if you already have 🤗 Transformers!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "MOsHUjgdIrIW", "outputId": "f84a093e-147f-470e-aad9-80fb51193c8e" }, "outputs": [], "source": [ "#! pip install transformers[sentencepiece] datasets\n", "#! pip install sacrebleu sentencepiece\n", "#! pip install huggingface_hub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", "\n", "To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.\n", "\n", "First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then uncomment the following cell and input your token:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then you need to install Git-LFS and setup Git if you haven't already. Uncomment the following instructions and adapt with your name and email:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# !apt install git-lfs\n", "# !git config --global user.email \"you@example.com\"\n", "# !git config --global user.name \"Your Name\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make sure your version of Transformers is at least 4.16.0 since some of the functionality we use was introduced in that version:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.21.0.dev0\n" ] } ], "source": [ "import transformers\n", "\n", "print(transformers.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "HFASsisvIrIb" }, "source": [ "You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/master/examples/seq2seq)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers.utils import send_example_telemetry\n", "\n", "send_example_telemetry(\"translation_notebook\", framework=\"tensorflow\")" ] }, { "cell_type": "markdown", "metadata": { "id": "rEJBSTyZIrIb" }, "source": [ "# Fine-tuning a model on a translation task" ] }, { "cell_type": "markdown", "metadata": { "id": "kTCFado4IrIc" }, "source": [ "In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model for a translation task. We will use the [WMT dataset](http://www.statmt.org/wmt16/), a machine translation dataset composed from a collection of various sources, including news commentaries and parliament proceedings.\n", "\n", "\n", "\n", "We will see how to easily load the dataset for this task using 🤗 Datasets and how to fine-tune a model on it using Keras." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "model_checkpoint = \"Helsinki-NLP/opus-mt-en-ROMANCE\"" ] }, { "cell_type": "markdown", "metadata": { "id": "4RRkXuteIrIh" }, "source": [ "This notebook is built to run with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a sequence-to-sequence version in the Transformers library. Here we picked the [`Helsinki-NLP/opus-mt-en-romance`](https://huggingface.co/Helsinki-NLP/opus-mt-en-ROMANCE) checkpoint. " ] }, { "cell_type": "markdown", "metadata": { "id": "whPRbBNbIrIl" }, "source": [ "## Loading the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "W7QYTpxXIrIl" }, "source": [ "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the `datasets` function `load_dataset` and the `evaluate` function `load`. We use the English/Romanian part of the WMT dataset here." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "IreSlFmlIrIm" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset wmt16 (/home/matt/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/28ebdf8cf22106c2f1e58b2083d4b103608acd7bfdb6b14313ccd9e5bc8c313a)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3cead726a3164d67847ade363d465233", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from datasets import load_dataset\n", "from evaluate import load\n", "\n", "raw_datasets = load_dataset(\"wmt16\", \"ro-en\")\n", "metric = load(\"sacrebleu\")" ] }, { "cell_type": "markdown", "metadata": { "id": "RzfPtOMoIrIu" }, "source": [ "The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "GWiVUF0jIrIv", "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e" }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['translation'],\n", " num_rows: 610320\n", " })\n", " validation: Dataset({\n", " features: ['translation'],\n", " num_rows: 1999\n", " })\n", " test: Dataset({\n", " features: ['translation'],\n", " num_rows: 1999\n", " })\n", "})" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_datasets" ] }, { "cell_type": "markdown", "metadata": { "id": "u3EtYfeHIrIz" }, "source": [ "To access an actual element, you need to select a split first, then give an index:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "X6HrpprwIrIz", "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95" }, "outputs": [ { "data": { "text/plain": [ "{'translation': {'en': 'Membership of Parliament: see Minutes',\n", " 'ro': 'Componenţa Parlamentului: a se vedea procesul-verbal'}}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_datasets[\"train\"][0]" ] }, { "cell_type": "markdown", "metadata": { "id": "WHUmphG3IrI3" }, "source": [ "To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "i3j8APAoIrI3" }, "outputs": [], "source": [ "import datasets\n", "import random\n", "import pandas as pd\n", "from IPython.display import display, HTML\n", "\n", "\n", "def show_random_elements(dataset, num_examples=5):\n", " assert num_examples <= len(\n", " dataset\n", " ), \"Can't pick more elements than there are in the dataset.\"\n", " picks = []\n", " for _ in range(num_examples):\n", " pick = random.randint(0, len(dataset) - 1)\n", " while pick in picks:\n", " pick = random.randint(0, len(dataset) - 1)\n", " picks.append(pick)\n", "\n", " df = pd.DataFrame(dataset[picks])\n", " for column, typ in dataset.features.items():\n", " if isinstance(typ, datasets.ClassLabel):\n", " df[column] = df[column].transform(lambda i: typ.names[i])\n", " display(HTML(df.to_html()))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "SZy5tRB_IrI7", "outputId": "ba8f2124-e485-488f-8c0c-254f34f24f13" }, "outputs": [ { "data": { "text/html": [ "
| \n", " | translation | \n", "
|---|---|
| 0 | \n", "{'en': '\"Kosovo does not have not a positive image mainly because (the media portrays) the Serbs living in ghettoes ... and NATO helping Albanians displace the Serbs from Kosovo,\" Chukov said.', 'ro': '\"Kosovo nu are o imagine pozitivă mai ales din cauza faptului că (presa arată) că sârbii trăiesc în ghetto-uri ... şi NATO îi ajută pe albanezi să strămute sârbii din Kosovo\", a spus Chukov.'} | \n", "
| 1 | \n", "{'en': 'They also signed a memorandum of understanding on diplomatic consultations.', 'ro': 'Aceştia au semnat de asemenea un protocol de acord cu privire la consultaţiile diplomatice.'} | \n", "
| 2 | \n", "{'en': 'EU Commissioner for Home Affairs Cecilia Malmstrom said on Monday (September 20th) that Albania has made significant progress in meeting requirements for visa-free travel.', 'ro': 'Comisarul UE pentru afaceri interne, Cecilia Malmstrom, a declarat luni (20 septembrie) că Albania a făcut progrese semnificative în întrunirea condiţiilor pentru liberalizarea vizelor.'} | \n", "
| 3 | \n", "{'en': '13.', 'ro': '13.'} | \n", "
| 4 | \n", "{'en': 'But, in principle, thank you very much for what was, for me, too, a very interesting debate, and all the best.', 'ro': 'Dar, în principiu, vă mulţumesc foarte mult pentru această dezbatere care a fost foarte interesantă pentru mine şi vă urez toate cele bune.'} | \n", "