{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "X4cRE8IbIrIV" }, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. Uncomment the following cell and run it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "MOsHUjgdIrIW", "outputId": "f84a093e-147f-470e-aad9-80fb51193c8e" }, "outputs": [], "source": [ "#! pip install datasets transformers[sentencepiece] sacrebleu" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", "\n", "To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.\n", "\n", "First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then execute the following cell and input your username and password:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then you need to install Git-LFS. Uncomment the following instructions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# !apt install git-lfs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make sure your version of Transformers is at least 4.11.0 since the functionality was introduced in that version:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import transformers\n", "\n", "print(transformers.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "HFASsisvIrIb" }, "source": [ "You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/master/examples/seq2seq)." ] }, { "cell_type": "markdown", "metadata": { "id": "rEJBSTyZIrIb" }, "source": [ "# Fine-tuning a model on a translation task" ] }, { "cell_type": "markdown", "metadata": { "id": "kTCFado4IrIc" }, "source": [ "In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model for a translation task. We will use the [WMT dataset](http://www.statmt.org/wmt16/), a machine translation dataset composed from a collection of various sources, including news commentaries and parliament proceedings.\n", "\n", "\n", "\n", "We will see how to easily load the dataset for this task using 🤗 Datasets and how to fine-tune a model on it using the `Trainer` API." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_checkpoint = \"Helsinki-NLP/opus-mt-en-ro\"" ] }, { "cell_type": "markdown", "metadata": { "id": "4RRkXuteIrIh" }, "source": [ "This notebook is built to run with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a sequence-to-sequence version in the Transformers library. Here we picked the [`Helsinki-NLP/opus-mt-en-ro`](https://huggingface.co/Helsinki-NLP/opus-mt-en-ro) checkpoint. " ] }, { "cell_type": "markdown", "metadata": { "id": "whPRbBNbIrIl" }, "source": [ "## Loading the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "W7QYTpxXIrIl" }, "source": [ "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`. We use the English/Romanian part of the WMT dataset here." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IreSlFmlIrIm" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset wmt16 (/home/sgugger/.cache/huggingface/datasets/wmt16/ro-en/1.0.0/9dc00622c30446e99c4c63d12a484ea4fb653f2f37c867d6edcec839d7eae50f)\n" ] } ], "source": [ "from datasets import load_dataset, load_metric\n", "\n", "raw_datasets = load_dataset(\"wmt16\", \"ro-en\")\n", "metric = load_metric(\"sacrebleu\")" ] }, { "cell_type": "markdown", "metadata": { "id": "RzfPtOMoIrIu" }, "source": [ "The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GWiVUF0jIrIv", "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e" }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['translation'],\n", " num_rows: 610320\n", " })\n", " validation: Dataset({\n", " features: ['translation'],\n", " num_rows: 1999\n", " })\n", " test: Dataset({\n", " features: ['translation'],\n", " num_rows: 1999\n", " })\n", "})" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_datasets" ] }, { "cell_type": "markdown", "metadata": { "id": "u3EtYfeHIrIz" }, "source": [ "To access an actual element, you need to select a split first, then give an index:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X6HrpprwIrIz", "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95" }, "outputs": [ { "data": { "text/plain": [ "{'translation': {'en': 'Membership of Parliament: see Minutes',\n", " 'ro': 'Componenţa Parlamentului: a se vedea procesul-verbal'}}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "raw_datasets[\"train\"][0]" ] }, { "cell_type": "markdown", "metadata": { "id": "WHUmphG3IrI3" }, "source": [ "To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "i3j8APAoIrI3" }, "outputs": [], "source": [ "import datasets\n", "import random\n", "import pandas as pd\n", "from IPython.display import display, HTML\n", "\n", "def show_random_elements(dataset, num_examples=5):\n", " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n", " picks = []\n", " for _ in range(num_examples):\n", " pick = random.randint(0, len(dataset)-1)\n", " while pick in picks:\n", " pick = random.randint(0, len(dataset)-1)\n", " picks.append(pick)\n", " \n", " df = pd.DataFrame(dataset[picks])\n", " for column, typ in dataset.features.items():\n", " if isinstance(typ, datasets.ClassLabel):\n", " df[column] = df[column].transform(lambda i: typ.names[i])\n", " display(HTML(df.to_html()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SZy5tRB_IrI7", "outputId": "ba8f2124-e485-488f-8c0c-254f34f24f13" }, "outputs": [ { "data": { "text/html": [ "
| \n", " | translation | \n", "
|---|---|
| 0 | \n", "{'en': 'However, we must not forget that the law has not yet entered into force.', 'ro': 'Cu toate acestea, nu trebuie să uităm că legea respectivă nu a intrat încă în vigoare.'} | \n", "
| 1 | \n", "{'en': 'UniCredit Zagrebacka Banka, based in Bosnia and Herzegovina (BiH) has for the second time won Euromoney magazine's annual award for excellence.', 'ro': 'UniCredit Zagrebacka Banka cu sediul în Bosnia şi Herţegovina (BiH) a câştigat pentru a doua oară premiul anual pentru excelenţă al revistei Euromoney.'} | \n", "
| 2 | \n", "{'en': 'Measuring instruments for cold water meters for non-clean water, alcohol meters, certain weights, tyre pressure gauges and equipment to measure the standard mass of grain or the size of ship tanks have been replaced, in practice, by more modern digital equipment.', 'ro': 'Instrumentele de măsură pentru contoarele de apă rece pentru apa murdară, alcoolmetrele, anumite greutăți, manometrele pentru presiunea din pneuri și echipamentele de măsură pentru masa standard de cereale sau pentru dimensiunea rezervoarelor de nave au fost înlocuite, în practică, de echipamente digitale mai moderne.'} | \n", "
| 3 | \n", "{'en': 'The citizens are our most important allies in achieving our joint objectives.', 'ro': 'Cetăţenii sunt aliaţii noştri cei mai importanţi pentru atingerea obiectivelor noastre comune.'} | \n", "
| 4 | \n", "{'en': 'Nobody can ignore the farmers and our villages because we are not that small.\"', 'ro': 'Nimeni nu ne poate ignora, pe noi agricultorii şi satele noastre, pentru că nu suntem atât de mici”.'} | \n", "
| Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "Bleu | \n", "Gen Len | \n", "Runtime | \n", "Samples Per Second | \n", "
|---|---|---|---|---|---|---|
| 1 | \n", "0.740100 | \n", "1.290665 | \n", "28.059300 | \n", "34.051500 | \n", "135.611700 | \n", "14.741000 | \n", "
"
],
"text/plain": [
"