{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "X4cRE8IbIrIV" }, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. Uncomment the following cell and run it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "MOsHUjgdIrIW", "outputId": "f84a093e-147f-470e-aad9-80fb51193c8e" }, "outputs": [], "source": [ "#! pip install datasets transformers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", "\n", "To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.\n", "\n", "First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then execute the following cell and input your username and password:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then you need to install Git-LFS. Uncomment the following instructions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# !apt install git-lfs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make sure your version of Transformers is at least 4.11.0 since the functionality was introduced in that version:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import transformers\n", "\n", "print(transformers.__version__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers.utils import send_example_telemetry\n", "\n", "send_example_telemetry(\"multiple_choice_notebook\", framework=\"pytorch\")" ] }, { "cell_type": "markdown", "metadata": { "id": "rEJBSTyZIrIb" }, "source": [ "# Fine-tuning a model on a multiple choice task" ] }, { "cell_type": "markdown", "metadata": { "id": "kTCFado4IrIc" }, "source": [ "In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model to a multiple choice task, which is the task of selecting the most plausible inputs in a given selection. The dataset used here is [SWAG](https://www.aclweb.org/anthology/D18-1009/) but you can adapt the pre-processing to any other multiple choice dataset you like, or your own data. SWAG is a dataset about commonsense reasoning, where each example describes a situation then proposes four options that could go after it. \n", "\n", "This notebook is built to run with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a mutiple choice head. Depending on you model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those two parameters, then the rest of the notebook should run smoothly:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zVvslsfMIrIh" }, "outputs": [], "source": [ "model_checkpoint = \"bert-base-uncased\"\n", "batch_size = 16" ] }, { "cell_type": "markdown", "metadata": { "id": "whPRbBNbIrIl" }, "source": [ "## Loading the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "W7QYTpxXIrIl" }, "source": [ "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data. This can be easily done with the functions `load_dataset`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IreSlFmlIrIm" }, "outputs": [], "source": [ "from datasets import load_dataset, load_metric" ] }, { "cell_type": "markdown", "metadata": { "id": "CKx2zKs5IrIq" }, "source": [ "`load_dataset` will cache the dataset to avoid downloading it again the next time you run this cell." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 270, "referenced_widgets": [ "69caab03d6264fef9fc5649bffff5e20", "3f74532faa86412293d90d3952f38c4a", "50615aa59c7247c4804ca5cbc7945bd7", "fe962391292a413ca55dc932c4279fa7", "299f4b4c07654e53a25f8192bd1d7bbd", "ad04ed1038154081bbb0c1444784dcc2", "7c667ad22b5740d5a6319f1b1e3a8097", "46c2b043c0f84806978784a45a4e203b", "80e2943be35f46eeb24c8ab13faa6578", "de5956b5008d4fdba807bae57509c393", "931db1f7a42f4b46b7ff8c2e1262b994", "6c1db72efff5476e842c1386fadbbdba", "ccd2f37647c547abb4c719b75a26f2de", "d30a66df5c0145e79693e09789d96b81", "5fa26fc336274073abbd1d550542ee33", "2b34de08115d49d285def9269a53f484", "d426be871b424affb455aeb7db5e822e", "160bf88485f44f5cb6eaeecba5e0901f", "745c0d47d672477b9bb0dae77b926364", "d22ab78269cd4ccfbcf70c707057c31b", "d298eb19eeff453cba51c2804629d3f4", "a7204ade36314c86907c562e0a2158b8", "e35d42b2d352498ca3fc8530393786b2", "75103f83538d44abada79b51a1cec09e", "f6253931d90543e9b5fd0bb2d615f73a", "051aa783ff9e47e28d1f9584043815f5", "0984b2a14115454bbb009df71c1cf36f", "8ab9dfce29854049912178941ef1b289", "c9de740e007141958545e269372780a4", "cbea68b25d6d4ba09b2ce0f27b1726d5", "5781fc45cf8d486cb06ed68853b2c644", "d2a92143a08a4951b55bab9bc0a6d0d3", "a14c3e40e5254d61ba146f6ec88eae25", "c4ffe6f624ce4e978a0d9b864544941a", "1aca01c1d8c940dfadd3e7144bb35718", "9fbbaae50e6743f2aa19342152398186", "fea27ca6c9504fc896181bc1ff5730e5", "940d00556cb849b3a689d56e274041c2", "5cdf9ed939fb42d4bf77301c80b8afca", "94b39ccfef0b4b08bf2fb61bb0a657c1", "9a55087c85b74ea08b3e952ac1d73cbe", "2361ab124daf47cc885ff61f2899b2af", "1a65887eb37747ddb75dc4a40f7285f2", "3c946e2260704e6c98593136bd32d921", "50d325cdb9844f62a9ecc98e768cb5af", "aa781f0cfe454e9da5b53b93e9baabd8", "6bb68d3887ef43809eb23feb467f9723", "7e29a8b952cf4f4ea42833c8bf55342f", "dd5997d01d8947e4b1c211433969b89b", "2ace4dc78e2f4f1492a181bcd63304e7", "bbee008c2791443d8610371d1f16b62b", "31b1c8a2e3334b72b45b083688c1a20c", "7fb7c36adc624f7dbbcb4a831c1e4f63", "0b7c8f1939074794b3d9221244b1344d", "a71908883b064e1fbdddb547a8c41743", "2f5223f26c8541fc87e91d2205c39995" ] }, "id": "s_AY1ATSIrIq", "outputId": "fd0578d1-8895-443d-b56f-5908de9f1b6b" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset swag (/home/sgugger/.cache/huggingface/datasets/swag/regular/0.0.0/f9784740e0964a3c799d68cec0d992cc267d3fe94f3e048175eca69d739b980d)\n" ] } ], "source": [ "datasets = load_dataset(\"swag\", \"regular\")" ] }, { "cell_type": "markdown", "metadata": { "id": "RzfPtOMoIrIu" }, "source": [ "The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set (with more keys for the mismatched validation and test set in the special case of `mnli`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GWiVUF0jIrIv", "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e" }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label'],\n", " num_rows: 73546\n", " })\n", " validation: Dataset({\n", " features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label'],\n", " num_rows: 20006\n", " })\n", " test: Dataset({\n", " features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label'],\n", " num_rows: 20005\n", " })\n", "})" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets" ] }, { "cell_type": "markdown", "metadata": { "id": "u3EtYfeHIrIz" }, "source": [ "To access an actual element, you need to select a split first, then give an index:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X6HrpprwIrIz", "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95" }, "outputs": [ { "data": { "text/plain": [ "{'ending0': 'passes by walking down the street playing their instruments.',\n", " 'ending1': 'has heard approaching them.',\n", " 'ending2': \"arrives and they're outside dancing and asleep.\",\n", " 'ending3': 'turns the lead singer watches the performance.',\n", " 'fold-ind': '3416',\n", " 'gold-source': 'gold',\n", " 'label': 0,\n", " 'sent1': 'Members of the procession walk down the street holding small horn brass instruments.',\n", " 'sent2': 'A drum line',\n", " 'startphrase': 'Members of the procession walk down the street holding small horn brass instruments. A drum line',\n", " 'video-id': 'anetv_jkn6uvmqwh4'}" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets[\"train\"][0]" ] }, { "cell_type": "markdown", "metadata": { "id": "WHUmphG3IrI3" }, "source": [ "To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "i3j8APAoIrI3" }, "outputs": [], "source": [ "from datasets import ClassLabel\n", "import random\n", "import pandas as pd\n", "from IPython.display import display, HTML\n", "\n", "def show_random_elements(dataset, num_examples=10):\n", " assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n", " picks = []\n", " for _ in range(num_examples):\n", " pick = random.randint(0, len(dataset)-1)\n", " while pick in picks:\n", " pick = random.randint(0, len(dataset)-1)\n", " picks.append(pick)\n", " \n", " df = pd.DataFrame(dataset[picks])\n", " for column, typ in dataset.features.items():\n", " if isinstance(typ, ClassLabel):\n", " df[column] = df[column].transform(lambda i: typ.names[i])\n", " display(HTML(df.to_html()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SZy5tRB_IrI7", "outputId": "ba8f2124-e485-488f-8c0c-254f34f24f13" }, "outputs": [ { "data": { "text/html": [ "
\n", " | ending0 | \n", "ending1 | \n", "ending2 | \n", "ending3 | \n", "fold-ind | \n", "gold-source | \n", "label | \n", "sent1 | \n", "sent2 | \n", "startphrase | \n", "video-id | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "are seated on a field. | \n", "are skiing down the slope. | \n", "are in a lift. | \n", "are pouring out in a man. | \n", "16668 | \n", "gold | \n", "1 | \n", "A man is wiping the skiboard. | \n", "Group of people | \n", "A man is wiping the skiboard. Group of people | \n", "anetv_JmL6BiuXr_g | \n", "
1 | \n", "performs stunts inside a gym. | \n", "shows several shopping in the water. | \n", "continues his skateboard while talking. | \n", "is putting a black bike close. | \n", "11424 | \n", "gold | \n", "0 | \n", "The credits of the video are shown. | \n", "A lady | \n", "The credits of the video are shown. A lady | \n", "anetv_dWyE0o2NetQ | \n", "
2 | \n", "is emerging into the hospital. | \n", "are strewn under water at some wreckage. | \n", "tosses the wand together and saunters into the marketplace. | \n", "swats him upside down. | \n", "15023 | \n", "gen | \n", "1 | \n", "Through his binoculars, someone watches a handful of surfers being rolled up into the wave. | \n", "Someone | \n", "Through his binoculars, someone watches a handful of surfers being rolled up into the wave. Someone | \n", "lsmdc3016_CHASING_MAVERICKS-6791 | \n", "
3 | \n", "spies someone sitting below. | \n", "opens the fridge and checks out the photo. | \n", "puts a little sheepishly. | \n", "staggers up to him. | \n", "5475 | \n", "gold | \n", "3 | \n", "He tips it upside down, and its little umbrella falls to the floor. | \n", "Back inside, someone | \n", "He tips it upside down, and its little umbrella falls to the floor. Back inside, someone | \n", "lsmdc1008_Spider-Man2-75503 | \n", "
4 | \n", "carries her to the grave. | \n", "laughs as someone styles her hair. | \n", "sets down his glass. | \n", "stares after her then trudges back up into the street. | \n", "6904 | \n", "gen | \n", "1 | \n", "Someone kisses her smiling daughter on the cheek and beams back at the camera. | \n", "Someone | \n", "Someone kisses her smiling daughter on the cheek and beams back at the camera. Someone | \n", "lsmdc1028_No_Reservations-83242 | \n", "
5 | \n", "stops someone and sweeps all the way back from the lower deck to join them. | \n", "is being dragged towards the monstrous animation. | \n", "beats out many events at the touch of the sword, crawling it. | \n", "reaches into a pocket and yanks open the door. | \n", "14089 | \n", "gen | \n", "1 | \n", "But before he can use his wand, he accidentally rams it up the troll's nostril. | \n", "The angry troll | \n", "But before he can use his wand, he accidentally rams it up the troll's nostril. The angry troll | \n", "lsmdc1053_Harry_Potter_and_the_philosophers_stone-95867 | \n", "
6 | \n", "sees someone's name in the photo. | \n", "gives a surprised look. | \n", "kneels down and touches his ripped specs. | \n", "spies on someone's clock. | \n", "8407 | \n", "gen | \n", "1 | \n", "Someone keeps his tired eyes on the road. | \n", "Glancing over, he | \n", "Someone keeps his tired eyes on the road. Glancing over, he | \n", "lsmdc1024_Identity_Thief-82693 | \n", "
7 | \n", "stops as someone speaks into the camera. | \n", "notices how blue his eyes are. | \n", "is flung out of the door and knocks the boy over. | \n", "flies through the air, its a fireball. | \n", "4523 | \n", "gold | \n", "1 | \n", "Both people are knocked back a few steps from the force of the collision. | \n", "She | \n", "Both people are knocked back a few steps from the force of the collision. She | \n", "lsmdc0043_Thelma_and_Luise-68271 | \n", "
8 | \n", "sits close to the river. | \n", "have pet's supplies and pets. | \n", "pops parked outside the dirt facility, sending up a car highway to catch control. | \n", "displays all kinds of power tools and website. | \n", "8112 | \n", "gold | \n", "1 | \n", "A guy waits in the waiting room with his pet. | \n", "A pet store and its van | \n", "A guy waits in the waiting room with his pet. A pet store and its van | \n", "anetv_9VWoQpg9wqE | \n", "
9 | \n", "the slender someone, someone turns on the light. | \n", ", someone gives them to her boss then dumps some alcohol into dough. | \n", "liquids from a bowl, she slams them drunk. | \n", "wags his tail as someone returns to the hotel room. | \n", "10867 | \n", "gold | \n", "3 | \n", "Inside a convenience store, she opens a freezer case. | \n", "Dolce | \n", "Inside a convenience store, she opens a freezer case. Dolce | \n", "lsmdc3090_YOUNG_ADULT-43871 | \n", "
Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "Accuracy | \n", "
---|---|---|---|
1 | \n", "0.154598 | \n", "0.828017 | \n", "0.766520 | \n", "
2 | \n", "0.296633 | \n", "0.667454 | \n", "0.786814 | \n", "
3 | \n", "0.111786 | \n", "0.994927 | \n", "0.789363 | \n", "
"
],
"text/plain": [
"