{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "X4cRE8IbIrIV" }, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. Uncomment the following cell and run it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "MOsHUjgdIrIW", "outputId": "f84a093e-147f-470e-aad9-80fb51193c8e" }, "outputs": [], "source": [ "#! pip install transformers datasets huggingface_hub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", "\n", "To be able to share your model with the community, there are a few more steps to follow.\n", "\n", "First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then uncomment the following cell and input your token:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then you need to install Git-LFS and setup Git if you haven't already. Uncomment the following instructions and adapt with your name and email:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# !apt install git-lfs\n", "# !git config --global user.email \"you@example.com\"\n", "# !git config --global user.name \"Your Name\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make sure your version of Transformers is at least 4.16.0 since some of the functionality we use was introduced in that version:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.21.0.dev0\n" ] } ], "source": [ "import transformers\n", "\n", "print(transformers.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "rEJBSTyZIrIb" }, "source": [ "# Fine-tuning a model on a multiple choice task" ] }, { "cell_type": "markdown", "metadata": { "id": "kTCFado4IrIc" }, "source": [ "In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model on a multiple-choice task. In a multiple-choice task, multiple answers or continuations are provided for each input, and the model must guess which is most plausible. The dataset used here is [SWAG](https://www.aclweb.org/anthology/D18-1009/) but you can adapt the pre-processing to any other multiple choice dataset you like, or your own data. SWAG is a dataset about commonsense reasoning, where each example describes a situation and proposes four continuations that could follow it. \n", "\n", "This notebook is built to run with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a mutiple choice head. Depending on your model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those two parameters, then the rest of the notebook should run smoothly." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "zVvslsfMIrIh" }, "outputs": [], "source": [ "model_checkpoint = \"bert-base-cased\"\n", "batch_size = 16" ] }, { "cell_type": "markdown", "metadata": { "id": "whPRbBNbIrIl" }, "source": [ "## Loading the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "W7QYTpxXIrIl" }, "source": [ "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data. This can be easily done with the `load_dataset` function. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "IreSlFmlIrIm" }, "outputs": [], "source": [ "from datasets import load_dataset, load_metric" ] }, { "cell_type": "markdown", "metadata": { "id": "CKx2zKs5IrIq" }, "source": [ "`load_dataset` will cache the dataset to avoid downloading it again the next time you run this cell." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 270, "referenced_widgets": [ "69caab03d6264fef9fc5649bffff5e20", "3f74532faa86412293d90d3952f38c4a", "50615aa59c7247c4804ca5cbc7945bd7", "fe962391292a413ca55dc932c4279fa7", "299f4b4c07654e53a25f8192bd1d7bbd", "ad04ed1038154081bbb0c1444784dcc2", "7c667ad22b5740d5a6319f1b1e3a8097", "46c2b043c0f84806978784a45a4e203b", "80e2943be35f46eeb24c8ab13faa6578", "de5956b5008d4fdba807bae57509c393", "931db1f7a42f4b46b7ff8c2e1262b994", "6c1db72efff5476e842c1386fadbbdba", "ccd2f37647c547abb4c719b75a26f2de", "d30a66df5c0145e79693e09789d96b81", "5fa26fc336274073abbd1d550542ee33", "2b34de08115d49d285def9269a53f484", "d426be871b424affb455aeb7db5e822e", "160bf88485f44f5cb6eaeecba5e0901f", "745c0d47d672477b9bb0dae77b926364", "d22ab78269cd4ccfbcf70c707057c31b", "d298eb19eeff453cba51c2804629d3f4", "a7204ade36314c86907c562e0a2158b8", "e35d42b2d352498ca3fc8530393786b2", "75103f83538d44abada79b51a1cec09e", "f6253931d90543e9b5fd0bb2d615f73a", "051aa783ff9e47e28d1f9584043815f5", "0984b2a14115454bbb009df71c1cf36f", "8ab9dfce29854049912178941ef1b289", "c9de740e007141958545e269372780a4", "cbea68b25d6d4ba09b2ce0f27b1726d5", "5781fc45cf8d486cb06ed68853b2c644", "d2a92143a08a4951b55bab9bc0a6d0d3", "a14c3e40e5254d61ba146f6ec88eae25", "c4ffe6f624ce4e978a0d9b864544941a", "1aca01c1d8c940dfadd3e7144bb35718", "9fbbaae50e6743f2aa19342152398186", "fea27ca6c9504fc896181bc1ff5730e5", "940d00556cb849b3a689d56e274041c2", "5cdf9ed939fb42d4bf77301c80b8afca", "94b39ccfef0b4b08bf2fb61bb0a657c1", "9a55087c85b74ea08b3e952ac1d73cbe", "2361ab124daf47cc885ff61f2899b2af", "1a65887eb37747ddb75dc4a40f7285f2", "3c946e2260704e6c98593136bd32d921", "50d325cdb9844f62a9ecc98e768cb5af", "aa781f0cfe454e9da5b53b93e9baabd8", "6bb68d3887ef43809eb23feb467f9723", "7e29a8b952cf4f4ea42833c8bf55342f", "dd5997d01d8947e4b1c211433969b89b", "2ace4dc78e2f4f1492a181bcd63304e7", "bbee008c2791443d8610371d1f16b62b", "31b1c8a2e3334b72b45b083688c1a20c", "7fb7c36adc624f7dbbcb4a831c1e4f63", "0b7c8f1939074794b3d9221244b1344d", "a71908883b064e1fbdddb547a8c41743", "2f5223f26c8541fc87e91d2205c39995" ] }, "id": "s_AY1ATSIrIq", "outputId": "fd0578d1-8895-443d-b56f-5908de9f1b6b" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset swag (/home/matt/.cache/huggingface/datasets/swag/regular/0.0.0/9640de08cdba6a1469ed3834fcab4b8ad8e38caf5d1ba5e7436d8b1fd067ad4c)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f73b371cb74c4e399faea25052830c5f", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "datasets = load_dataset(\"swag\", \"regular\")" ] }, { "cell_type": "markdown", "metadata": { "id": "RzfPtOMoIrIu" }, "source": [ "The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "GWiVUF0jIrIv", "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e" }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label'],\n", " num_rows: 73546\n", " })\n", " validation: Dataset({\n", " features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label'],\n", " num_rows: 20006\n", " })\n", " test: Dataset({\n", " features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label'],\n", " num_rows: 20005\n", " })\n", "})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets" ] }, { "cell_type": "markdown", "metadata": { "id": "u3EtYfeHIrIz" }, "source": [ "To access an actual element, you need to select a split first, then give an index:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "X6HrpprwIrIz", "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95" }, "outputs": [ { "data": { "text/plain": [ "{'video-id': 'anetv_jkn6uvmqwh4',\n", " 'fold-ind': '3416',\n", " 'startphrase': 'Members of the procession walk down the street holding small horn brass instruments. A drum line',\n", " 'sent1': 'Members of the procession walk down the street holding small horn brass instruments.',\n", " 'sent2': 'A drum line',\n", " 'gold-source': 'gold',\n", " 'ending0': 'passes by walking down the street playing their instruments.',\n", " 'ending1': 'has heard approaching them.',\n", " 'ending2': \"arrives and they're outside dancing and asleep.\",\n", " 'ending3': 'turns the lead singer watches the performance.',\n", " 'label': 0}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets[\"train\"][0]" ] }, { "cell_type": "markdown", "metadata": { "id": "WHUmphG3IrI3" }, "source": [ "To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "i3j8APAoIrI3" }, "outputs": [], "source": [ "from datasets import ClassLabel\n", "import random\n", "import pandas as pd\n", "from IPython.display import display, HTML\n", "\n", "\n", "def show_random_elements(dataset, num_examples=10):\n", " assert num_examples <= len(\n", " dataset\n", " ), \"Can't pick more elements than there are in the dataset.\"\n", " picks = []\n", " for _ in range(num_examples):\n", " pick = random.randint(0, len(dataset) - 1)\n", " while pick in picks:\n", " pick = random.randint(0, len(dataset) - 1)\n", " picks.append(pick)\n", "\n", " df = pd.DataFrame(dataset[picks])\n", " for column, typ in dataset.features.items():\n", " if isinstance(typ, ClassLabel):\n", " df[column] = df[column].transform(lambda i: typ.names[i])\n", " display(HTML(df.to_html()))" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "SZy5tRB_IrI7", "outputId": "ba8f2124-e485-488f-8c0c-254f34f24f13" }, "outputs": [ { "data": { "text/html": [ "
\n", " | video-id | \n", "fold-ind | \n", "startphrase | \n", "sent1 | \n", "sent2 | \n", "gold-source | \n", "ending0 | \n", "ending1 | \n", "ending2 | \n", "ending3 | \n", "label | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "lsmdc0005_Chinatown-48562 | \n", "957 | \n", "In a moment Cross can be seen, looking toward camera. He | \n", "In a moment Cross can be seen, looking toward camera. | \n", "He | \n", "gen | \n", "closes the pantry door a little bit more. | \n", "crosses to a door to a thick dark case that closes clear. | \n", "looks up at an empty window in the club, then hangs up, embarrassed on the cold. | \n", "is carrying a trail, which wraps around his ankle. | \n", "0 | \n", "
1 | \n", "anetv_W6y6Vmk5edg | \n", "11587 | \n", "A gymnast is seen leaning across a long beam and begins performing a gymnastics routine. The girl | \n", "A gymnast is seen leaning across a long beam and begins performing a gymnastics routine. | \n", "The girl | \n", "gen | \n", "performs several kickboxing while moving her arms in and down. | \n", "does a gymnastic routine on the balance beam and herself on to the balance beam. | \n", "move through several flips and flips on the beam. | \n", "begins jumping up and down on the mat. | \n", "2 | \n", "
2 | \n", "lsmdc3042_KARATE_KID-19795 | \n", "12868 | \n", "He stares off, vacantly. Someone | \n", "He stares off, vacantly. | \n", "Someone | \n", "gold | \n", "puts a three - slim guard into his mouth and kisses him. | \n", "is in slow motion, his knees still scarred by the battle. | \n", "flips across the dark to the car. | \n", "paces toward him and glances at the wreck. | \n", "3 | \n", "
3 | \n", "lsmdc0026_The_Big_Fish-62739 | \n", "18560 | \n", "Someone walks into the river, up to his knees. He | \n", "Someone walks into the river, up to his knees. | \n", "He | \n", "gold | \n", "is pulled away, and his shoulders upset. | \n", "finds his hand seething, standing around the cab, raising his head. | \n", "turns back so his father can face the crowd. | \n", "moves vigorously he then turns and does a little push the bubbles into his mouth, leaving his mouth open, breathing ringed. | \n", "2 | \n", "
4 | \n", "anetv_ZPVrC5185NM | \n", "11852 | \n", "She pushes the baby back and forth on a swing. The baby | \n", "She pushes the baby back and forth on a swing. | \n", "The baby | \n", "gold | \n", "laughs and smiles as she swings. | \n", "comes back and takes a few flips. | \n", "hops backwards into the swing. | \n", "walks back inside and throws it out. | \n", "0 | \n", "
5 | \n", "lsmdc3024_EASY_A-11617 | \n", "11559 | \n", "She unbuckles her seat belt and they share a tight embrace. Someone | \n", "She unbuckles her seat belt and they share a tight embrace. | \n", "Someone | \n", "gen | \n", "pulls a knapsack from a shelf by someone's coat pocket as she opens one of the boxes. | \n", "sits in the limo. | \n", "slips his arm around her waist. | \n", "blinks, then leans forward and presses his lips to her chin. | \n", "3 | \n", "
6 | \n", "lsmdc0050_Indiana_Jones_and_the_last_crusade-70715 | \n", "3232 | \n", "Now he sees the pendulum has been guarding a small corridor which turns a corner to the left fifty yards ahead. Wooden wheels | \n", "Now he sees the pendulum has been guarding a small corridor which turns a corner to the left fifty yards ahead. | \n", "Wooden wheels | \n", "gold | \n", "turn - - the mechanism controlling the spinning blades. | \n", "follow through the rising foam. | \n", "cranks back in an electronic system. | \n", "press against a gold and metal gate. | \n", "0 | \n", "
7 | \n", "anetv_fZQS02Ypca4 | \n", "847 | \n", "A young girl is seen sitting and speaking to the camera while using a brush to powder her face. A color pallet is held up next to her and she | \n", "A young girl is seen sitting and speaking to the camera while using a brush to powder her face. | \n", "A color pallet is held up next to her and she | \n", "gold | \n", "begins rubbing the powder all over her eyes. | \n", "uses the brush on brush drys on her hair while continuing to speak to the camera. | \n", "applies the sides and pans far away. | \n", "begins brushing her teeth and showing an image with the brush in her long hair. | \n", "0 | \n", "
8 | \n", "lsmdc0051_Men_in_black-70855 | \n", "14113 | \n", "Someone barks a few orders to them. He | \n", "Someone barks a few orders to them. | \n", "He | \n", "gen | \n", "has a hair mustache. | \n", "pretends to sit back. | \n", "puts his slacks back on. | \n", "sends aiming a heft torch. | \n", "1 | \n", "
9 | \n", "anetv_76RoR_LbIzQ | \n", "16378 | \n", "Pictures of an office is shown. A woman | \n", "Pictures of an office is shown. | \n", "A woman | \n", "gold | \n", "is explaining how she is wiping the car drive. | \n", "is sitting on a long swing holding a stick. | \n", "is doing another woman's hair. | \n", "is seen playing a song on a stage. | \n", "2 | \n", "