{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "X4cRE8IbIrIV" }, "source": [ "If you're opening this Notebook on colab, you will probably need to install the most recent versions of 🤗 Transformers and 🤗 Datasets. We will also need `scipy` and `scikit-learn` for some of the metrics. Uncomment the following cell and run it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "MOsHUjgdIrIW", "outputId": "f84a093e-147f-470e-aad9-80fb51193c8e" }, "outputs": [], "source": [ "#! pip install transformers\n", "#! pip install datasets\n", "#! pip install scipy sklearn\n", "#! pip install huggingface_hub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the latest version of those libraries.\n", "\n", "To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow:\n", "\n", "First you have to create an access token on the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then uncomment the following cell and input your token." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then you need to install Git-LFS and setup Git if you haven't already. Uncomment the following instructions and adapt with your name and email:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# !apt install git-lfs\n", "# !git config --global user.email \"you@example.com\"\n", "# !git config --global user.name \"Your Name\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make sure your version of Transformers is at least 4.16.0 since the functionality was introduced in that version:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.22.0.dev0\n" ] } ], "source": [ "import transformers\n", "\n", "print(transformers.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "HFASsisvIrIb" }, "source": [ "You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/master/examples/text-classification)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers.utils import send_example_telemetry\n", "\n", "send_example_telemetry(\"text_classification_notebook\", framework=\"tensorflow\")" ] }, { "cell_type": "markdown", "metadata": { "id": "rEJBSTyZIrIb" }, "source": [ "# Fine-tuning a model on a text classification task" ] }, { "cell_type": "markdown", "metadata": { "id": "kTCFado4IrIc" }, "source": [ "In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model to a text classification task of the [GLUE Benchmark](https://gluebenchmark.com/).\n", "\n", "\n", "\n", "The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences which are:\n", "\n", "- [CoLA](https://nyu-mll.github.io/CoLA/) (Corpus of Linguistic Acceptability) Determine if a sentence is grammatically correct or not.is a dataset containing sentences labeled grammatically correct or not.\n", "- [MNLI](https://arxiv.org/abs/1704.05426) (Multi-Genre Natural Language Inference) Determine if a sentence entails, contradicts or is unrelated to a given hypothesis. (This dataset has two versions, one with the validation and test set coming from the same distribution, another called mismatched where the validation and test use out-of-domain data.)\n", "- [MRPC](https://www.microsoft.com/en-us/download/details.aspx?id=52398) (Microsoft Research Paraphrase Corpus) Determine if two sentences are paraphrases from one another or not.\n", "- [QNLI](https://rajpurkar.github.io/SQuAD-explorer/) (Question-answering Natural Language Inference) Determine if the answer to a question is in the second sentence or not. (This dataset is built from the SQuAD dataset.)\n", "- [QQP](https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs) (Quora Question Pairs2) Determine if two questions are semantically equivalent or not.\n", "- [RTE](https://aclweb.org/aclwiki/Recognizing_Textual_Entailment) (Recognizing Textual Entailment) Determine if a sentence entails a given hypothesis or not.\n", "- [SST-2](https://nlp.stanford.edu/sentiment/index.html) (Stanford Sentiment Treebank) Determine if the sentence has a positive or negative sentiment.\n", "- [STS-B](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark) (Semantic Textual Similarity Benchmark) Determine the similarity of two sentences with a score from 1 to 5.\n", "- [WNLI](https://cs.nyu.edu/faculty/davise/papers/WinogradSchemas/WS.html) (Winograd Natural Language Inference) Determine if a sentence with an anonymous pronoun and a sentence with this pronoun replaced are entailed or not. (This dataset is built from the Winograd Schema Challenge dataset.)\n", "\n", "We will see how to easily load the dataset for each one of those tasks and use Keras to fine-tune a model on it. Each task is named by its acronym, with `mnli-mm` standing for the mismatched version of MNLI (a task with the same training set as `mnli` but different validation and test sets):" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "YZbiBDuGIrId" }, "outputs": [], "source": [ "GLUE_TASKS = [\n", " \"cola\",\n", " \"mnli\",\n", " \"mnli-mm\",\n", " \"mrpc\",\n", " \"qnli\",\n", " \"qqp\",\n", " \"rte\",\n", " \"sst2\",\n", " \"stsb\",\n", " \"wnli\",\n", "]" ] }, { "cell_type": "markdown", "metadata": { "id": "4RRkXuteIrIh" }, "source": [ "This notebook is built to run on any of the tasks in the list above, with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a classification head. Depending on your model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set these three parameters, then the rest of the notebook should run smoothly:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "zVvslsfMIrIh" }, "outputs": [], "source": [ "task = \"cola\"\n", "model_checkpoint = \"distilbert-base-uncased\"\n", "batch_size = 16" ] }, { "cell_type": "markdown", "metadata": { "id": "whPRbBNbIrIl" }, "source": [ "## Loading the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "W7QYTpxXIrIl" }, "source": [ "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and the [🤗 Evaluate](https://github.com/huggingface/datasets) library to get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the `load_dataset` function from `datasets` and and the `load` function from `evaluate`. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "IreSlFmlIrIm" }, "outputs": [], "source": [ "from datasets import load_dataset\n", "from evaluate import load" ] }, { "cell_type": "markdown", "metadata": { "id": "CKx2zKs5IrIq" }, "source": [ "With the exception of `mnli-mm`, we can directly pass our task name to those functions. `load_dataset` will cache the dataset to avoid downloading it again the next time you run this cell." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 270, "referenced_widgets": [ "69caab03d6264fef9fc5649bffff5e20", "3f74532faa86412293d90d3952f38c4a", "50615aa59c7247c4804ca5cbc7945bd7", "fe962391292a413ca55dc932c4279fa7", "299f4b4c07654e53a25f8192bd1d7bbd", "ad04ed1038154081bbb0c1444784dcc2", "7c667ad22b5740d5a6319f1b1e3a8097", "46c2b043c0f84806978784a45a4e203b", "80e2943be35f46eeb24c8ab13faa6578", "de5956b5008d4fdba807bae57509c393", "931db1f7a42f4b46b7ff8c2e1262b994", "6c1db72efff5476e842c1386fadbbdba", "ccd2f37647c547abb4c719b75a26f2de", "d30a66df5c0145e79693e09789d96b81", "5fa26fc336274073abbd1d550542ee33", "2b34de08115d49d285def9269a53f484", "d426be871b424affb455aeb7db5e822e", "160bf88485f44f5cb6eaeecba5e0901f", "745c0d47d672477b9bb0dae77b926364", "d22ab78269cd4ccfbcf70c707057c31b", "d298eb19eeff453cba51c2804629d3f4", "a7204ade36314c86907c562e0a2158b8", "e35d42b2d352498ca3fc8530393786b2", "75103f83538d44abada79b51a1cec09e", "f6253931d90543e9b5fd0bb2d615f73a", "051aa783ff9e47e28d1f9584043815f5", "0984b2a14115454bbb009df71c1cf36f", "8ab9dfce29854049912178941ef1b289", "c9de740e007141958545e269372780a4", "cbea68b25d6d4ba09b2ce0f27b1726d5", "5781fc45cf8d486cb06ed68853b2c644", "d2a92143a08a4951b55bab9bc0a6d0d3", "a14c3e40e5254d61ba146f6ec88eae25", "c4ffe6f624ce4e978a0d9b864544941a", "1aca01c1d8c940dfadd3e7144bb35718", "9fbbaae50e6743f2aa19342152398186", "fea27ca6c9504fc896181bc1ff5730e5", "940d00556cb849b3a689d56e274041c2", "5cdf9ed939fb42d4bf77301c80b8afca", "94b39ccfef0b4b08bf2fb61bb0a657c1", "9a55087c85b74ea08b3e952ac1d73cbe", "2361ab124daf47cc885ff61f2899b2af", "1a65887eb37747ddb75dc4a40f7285f2", "3c946e2260704e6c98593136bd32d921", "50d325cdb9844f62a9ecc98e768cb5af", "aa781f0cfe454e9da5b53b93e9baabd8", "6bb68d3887ef43809eb23feb467f9723", "7e29a8b952cf4f4ea42833c8bf55342f", "dd5997d01d8947e4b1c211433969b89b", "2ace4dc78e2f4f1492a181bcd63304e7", "bbee008c2791443d8610371d1f16b62b", "31b1c8a2e3334b72b45b083688c1a20c", "7fb7c36adc624f7dbbcb4a831c1e4f63", "0b7c8f1939074794b3d9221244b1344d", "a71908883b064e1fbdddb547a8c41743", "2f5223f26c8541fc87e91d2205c39995" ] }, "id": "s_AY1ATSIrIq", "outputId": "fd0578d1-8895-443d-b56f-5908de9f1b6b" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset glue (/home/matt/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e63b825755a14c68bc7d9c99ca39b31a", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "actual_task = \"mnli\" if task == \"mnli-mm\" else task\n", "dataset = load_dataset(\"glue\", actual_task)\n", "metric = load(\"glue\", actual_task)" ] }, { "cell_type": "markdown", "metadata": { "id": "RzfPtOMoIrIu" }, "source": [ "The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set (with more keys for the mismatched validation and test set in the special case of `mnli`)." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "GWiVUF0jIrIv", "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e" }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['sentence', 'label', 'idx'],\n", " num_rows: 8551\n", " })\n", " validation: Dataset({\n", " features: ['sentence', 'label', 'idx'],\n", " num_rows: 1043\n", " })\n", " test: Dataset({\n", " features: ['sentence', 'label', 'idx'],\n", " num_rows: 1063\n", " })\n", "})" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "u3EtYfeHIrIz" }, "source": [ "To access an actual element, you need to select a split first, then give an index:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "X6HrpprwIrIz", "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95" }, "outputs": [ { "data": { "text/plain": [ "{'sentence': \"Our friends won't buy this analysis, let alone the next one we propose.\",\n", " 'label': 1,\n", " 'idx': 0}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[\"train\"][0]" ] }, { "cell_type": "markdown", "metadata": { "id": "WHUmphG3IrI3" }, "source": [ "To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "i3j8APAoIrI3" }, "outputs": [], "source": [ "import datasets\n", "import random\n", "import pandas as pd\n", "from IPython.display import display, HTML\n", "\n", "\n", "def show_random_elements(dataset, num_examples=10):\n", " assert num_examples <= len(\n", " dataset\n", " ), \"Can't pick more elements than there are in the dataset.\"\n", " picks = []\n", " for _ in range(num_examples):\n", " pick = random.randint(0, len(dataset) - 1)\n", " while pick in picks:\n", " pick = random.randint(0, len(dataset) - 1)\n", " picks.append(pick)\n", "\n", " df = pd.DataFrame(dataset[picks])\n", " for column, typ in dataset.features.items():\n", " if isinstance(typ, datasets.ClassLabel):\n", " df[column] = df[column].transform(lambda i: typ.names[i])\n", " display(HTML(df.to_html()))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "SZy5tRB_IrI7", "outputId": "ba8f2124-e485-488f-8c0c-254f34f24f13" }, "outputs": [ { "data": { "text/html": [ "
| \n", " | sentence | \n", "label | \n", "idx | \n", "
|---|---|---|---|
| 0 | \n", "I have Tom's feeling that the company will squander the money. | \n", "unacceptable | \n", "1245 | \n", "
| 1 | \n", "Who is it obvious that Plato loves. | \n", "acceptable | \n", "7749 | \n", "
| 2 | \n", "She her will put a picture of Bill on your desk before tomorrow. | \n", "unacceptable | \n", "7139 | \n", "
| 3 | \n", "Aphrodite wanted Hera to persuade Athena to leave. | \n", "acceptable | \n", "8446 | \n", "
| 4 | \n", "That Plato lived in the city of Athens was well-known. | \n", "acceptable | \n", "8101 | \n", "
| 5 | \n", "How many did you buy of those pies at the fair? | \n", "acceptable | \n", "6514 | \n", "
| 6 | \n", "Jason intended for PRO to learn magic. | \n", "unacceptable | \n", "8126 | \n", "
| 7 | \n", "Would they have been walking for hours? | \n", "acceptable | \n", "7153 | \n", "
| 8 | \n", "There tried to be riots in Seoul. | \n", "unacceptable | \n", "4283 | \n", "
| 9 | \n", "The men were arrested all. | \n", "unacceptable | \n", "715 | \n", "