{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "X4cRE8IbIrIV" }, "source": [ "If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets. We will also use the `seqeval` library to compute some evaluation metrics. Uncomment the following cell and run it." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "MOsHUjgdIrIW", "outputId": "f84a093e-147f-470e-aad9-80fb51193c8e" }, "outputs": [], "source": [ "#! pip install transformers\n", "#! pip install datasets\n", "#! pip install seqeval\n", "#! pip install huggingface_hub" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.\n", "\n", "To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.\n", "\n", "First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then uncomment the following cell and input it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import notebook_login\n", "\n", "notebook_login()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then you need to install Git-LFS and setup Git if you haven't already. Uncomment the following instructions and adapt with your name and email:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# !apt install git-lfs\n", "# !git config --global user.email \"you@example.com\"\n", "# !git config --global user.name \"Your Name\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make sure your version of Transformers is at least 4.16.0 since some of the functionality we use was introduced in that version:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.21.0.dev0\n" ] } ], "source": [ "import transformers\n", "\n", "print(transformers.__version__)" ] }, { "cell_type": "markdown", "metadata": { "id": "HFASsisvIrIb" }, "source": [ "You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/master/examples/token-classification)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from transformers.utils import send_example_telemetry\n", "\n", "send_example_telemetry(\"token_classification_notebook\", framework=\"tensorflow\")" ] }, { "cell_type": "markdown", "metadata": { "id": "rEJBSTyZIrIb" }, "source": [ "# Fine-tuning a model on a token classification task" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model to a token classification task, which is the task of predicting a label for each token.\n", "\n", "\n", "\n", "The most common token classification tasks are:\n", "\n", "- NER (Named-entity recognition) Classify the entities in the text (person, organization, location...).\n", "- POS (Part-of-speech tagging) Grammatically classify the tokens (noun, verb, adjective...)\n", "- Chunk (Chunking) Grammatically classify the tokens and group them into \"chunks\" that go together\n", "\n", "We will see how to easily load a dataset for these kinds of tasks and use Keras to fine-tune a model on it." ] }, { "cell_type": "markdown", "metadata": { "id": "4RRkXuteIrIh" }, "source": [ "This notebook is built to run on any token classification task, with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a version with a token classification head and a fast tokenizer (check on [this table](https://huggingface.co/transformers/index.html#bigtable) if this is the case). It might just need some small adjustments if you decide to use a different dataset than the one used here. Depending on you model and the GPU you are using, you might need to adjust the batch size to avoid out-of-memory errors. Set those three parameters, then the rest of the notebook should run smoothly:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "zVvslsfMIrIh" }, "outputs": [], "source": [ "task = \"ner\" # Should be one of \"ner\", \"pos\" or \"chunk\"\n", "model_checkpoint = \"distilbert-base-uncased\"\n", "batch_size = 16" ] }, { "cell_type": "markdown", "metadata": { "id": "whPRbBNbIrIl" }, "source": [ "## Loading the dataset" ] }, { "cell_type": "markdown", "metadata": { "id": "W7QYTpxXIrIl" }, "source": [ "We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "IreSlFmlIrIm" }, "outputs": [], "source": [ "from datasets import load_dataset, load_metric" ] }, { "cell_type": "markdown", "metadata": { "id": "CKx2zKs5IrIq" }, "source": [ "For our example here, we'll use the [CONLL 2003 dataset](https://www.aclweb.org/anthology/W03-0419.pdf). The notebook should work with any token classification dataset provided by the 🤗 Datasets library. If you're using your own dataset defined from a JSON or csv file (see the [Datasets documentation](https://huggingface.co/docs/datasets/loading_datasets.html#from-local-files) on how to load them), it might need some adjustments in the names of the columns used." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 270, "referenced_widgets": [ "69caab03d6264fef9fc5649bffff5e20", "3f74532faa86412293d90d3952f38c4a", "50615aa59c7247c4804ca5cbc7945bd7", "fe962391292a413ca55dc932c4279fa7", "299f4b4c07654e53a25f8192bd1d7bbd", "ad04ed1038154081bbb0c1444784dcc2", "7c667ad22b5740d5a6319f1b1e3a8097", "46c2b043c0f84806978784a45a4e203b", "80e2943be35f46eeb24c8ab13faa6578", "de5956b5008d4fdba807bae57509c393", "931db1f7a42f4b46b7ff8c2e1262b994", "6c1db72efff5476e842c1386fadbbdba", "ccd2f37647c547abb4c719b75a26f2de", "d30a66df5c0145e79693e09789d96b81", "5fa26fc336274073abbd1d550542ee33", "2b34de08115d49d285def9269a53f484", "d426be871b424affb455aeb7db5e822e", "160bf88485f44f5cb6eaeecba5e0901f", "745c0d47d672477b9bb0dae77b926364", "d22ab78269cd4ccfbcf70c707057c31b", "d298eb19eeff453cba51c2804629d3f4", "a7204ade36314c86907c562e0a2158b8", "e35d42b2d352498ca3fc8530393786b2", "75103f83538d44abada79b51a1cec09e", "f6253931d90543e9b5fd0bb2d615f73a", "051aa783ff9e47e28d1f9584043815f5", "0984b2a14115454bbb009df71c1cf36f", "8ab9dfce29854049912178941ef1b289", "c9de740e007141958545e269372780a4", "cbea68b25d6d4ba09b2ce0f27b1726d5", "5781fc45cf8d486cb06ed68853b2c644", "d2a92143a08a4951b55bab9bc0a6d0d3", "a14c3e40e5254d61ba146f6ec88eae25", "c4ffe6f624ce4e978a0d9b864544941a", "1aca01c1d8c940dfadd3e7144bb35718", "9fbbaae50e6743f2aa19342152398186", "fea27ca6c9504fc896181bc1ff5730e5", "940d00556cb849b3a689d56e274041c2", "5cdf9ed939fb42d4bf77301c80b8afca", "94b39ccfef0b4b08bf2fb61bb0a657c1", "9a55087c85b74ea08b3e952ac1d73cbe", "2361ab124daf47cc885ff61f2899b2af", "1a65887eb37747ddb75dc4a40f7285f2", "3c946e2260704e6c98593136bd32d921", "50d325cdb9844f62a9ecc98e768cb5af", "aa781f0cfe454e9da5b53b93e9baabd8", "6bb68d3887ef43809eb23feb467f9723", "7e29a8b952cf4f4ea42833c8bf55342f", "dd5997d01d8947e4b1c211433969b89b", "2ace4dc78e2f4f1492a181bcd63304e7", "bbee008c2791443d8610371d1f16b62b", "31b1c8a2e3334b72b45b083688c1a20c", "7fb7c36adc624f7dbbcb4a831c1e4f63", "0b7c8f1939074794b3d9221244b1344d", "a71908883b064e1fbdddb547a8c41743", "2f5223f26c8541fc87e91d2205c39995" ] }, "id": "s_AY1ATSIrIq", "outputId": "fd0578d1-8895-443d-b56f-5908de9f1b6b" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset conll2003 (/home/matt/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c5d18992de2649609cc82185d781412d", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "datasets = load_dataset(\"conll2003\")" ] }, { "cell_type": "markdown", "metadata": { "id": "RzfPtOMoIrIu" }, "source": [ "The `datasets` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "GWiVUF0jIrIv", "outputId": "35e3ea43-f397-4a54-c90c-f2cf8d36873e" }, "outputs": [ { "data": { "text/plain": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n", " num_rows: 14041\n", " })\n", " validation: Dataset({\n", " features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n", " num_rows: 3250\n", " })\n", " test: Dataset({\n", " features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n", " num_rows: 3453\n", " })\n", "})" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see the training, validation and test sets all have a column for the tokens (the input texts split into words) and one column of labels for each kind of task we introduced before." ] }, { "cell_type": "markdown", "metadata": { "id": "u3EtYfeHIrIz" }, "source": [ "To access an actual element, you need to select a split first, then give an index:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "X6HrpprwIrIz", "outputId": "d7670bc0-42e4-4c09-8a6a-5c018ded7d95" }, "outputs": [ { "data": { "text/plain": [ "{'id': '0',\n", " 'tokens': ['EU',\n", " 'rejects',\n", " 'German',\n", " 'call',\n", " 'to',\n", " 'boycott',\n", " 'British',\n", " 'lamb',\n", " '.'],\n", " 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],\n", " 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],\n", " 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets[\"train\"][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The labels are already coded as integer ids to be easily usable by our model, but the correspondence with the actual categories is stored in the `features` of the dataset:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequence(feature=ClassLabel(num_classes=9, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "datasets[\"train\"].features[f\"ner_tags\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So for the NER tags, 0 corresponds to 'O', 1 to 'B-PER' etc... On top of the 'O' (which means no special entity), there are four labels for NER here, each prefixed with 'B-' (for beginning) or 'I-' (for intermediate), that indicate if the token is the first one for the current group with the label or not:\n", "- 'PER' for person\n", "- 'ORG' for organization\n", "- 'LOC' for location\n", "- 'MISC' for miscellaneous" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since the labels are lists of `ClassLabel`, the actual names of the labels are nested in the `feature` attribute of the object above:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "label_list = datasets[\"train\"].features[f\"{task}_tags\"].feature.names\n", "label_list" ] }, { "cell_type": "markdown", "metadata": { "id": "WHUmphG3IrI3" }, "source": [ "To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset (automatically decoding the labels in passing)." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "i3j8APAoIrI3" }, "outputs": [], "source": [ "from datasets import ClassLabel, Sequence\n", "import random\n", "import pandas as pd\n", "from IPython.display import display, HTML\n", "\n", "\n", "def show_random_elements(dataset, num_examples=10):\n", " assert num_examples <= len(\n", " dataset\n", " ), \"Can't pick more elements than there are in the dataset.\"\n", " picks = []\n", " for _ in range(num_examples):\n", " pick = random.randint(0, len(dataset) - 1)\n", " while pick in picks:\n", " pick = random.randint(0, len(dataset) - 1)\n", " picks.append(pick)\n", "\n", " df = pd.DataFrame(dataset[picks])\n", " for column, typ in dataset.features.items():\n", " if isinstance(typ, ClassLabel):\n", " df[column] = df[column].transform(lambda i: typ.names[i])\n", " elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):\n", " df[column] = df[column].transform(\n", " lambda x: [typ.feature.names[i] for i in x]\n", " )\n", " display(HTML(df.to_html()))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "SZy5tRB_IrI7", "outputId": "ba8f2124-e485-488f-8c0c-254f34f24f13", "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
| \n", " | id | \n", "tokens | \n", "pos_tags | \n", "chunk_tags | \n", "ner_tags | \n", "
|---|---|---|---|---|---|
| 0 | \n", "6496 | \n", "[Total, (, for, one, wicket, ), 48] | \n", "[JJ, (, IN, CD, NN, ), CD] | \n", "[B-NP, O, B-PP, B-NP, I-NP, O, B-NP] | \n", "[O, O, O, O, O, O, O] | \n", "
| 1 | \n", "11665 | \n", "[The, BOJ, sought, to, put, the, best, face, on, the, data, which, defied, economists, ', predictions, of, improving, sentiment, and, was, the, first, decline, in, business, sentiment, in, a, year, .] | \n", "[DT, NNP, VBD, TO, VB, DT, JJS, NN, IN, DT, NNS, WDT, VBD, NNS, POS, NNS, IN, VBG, NN, CC, VBD, DT, JJ, NN, IN, NN, NN, IN, DT, NN, .] | \n", "[B-NP, I-NP, B-VP, I-VP, I-VP, B-NP, I-NP, I-NP, B-PP, B-NP, I-NP, B-NP, B-VP, B-NP, B-NP, I-NP, B-PP, B-NP, I-NP, O, B-VP, B-NP, I-NP, I-NP, B-PP, B-NP, I-NP, B-PP, B-NP, I-NP, O] | \n", "[O, B-ORG, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O] | \n", "
| 2 | \n", "9882 | \n", "[Palestinians, to, strike, over, Jerusalem, demolition, .] | \n", "[NNPS, TO, VB, RP, NNP, NN, .] | \n", "[B-NP, B-VP, I-VP, B-PRT, B-NP, I-NP, O] | \n", "[B-MISC, O, O, O, B-LOC, O, O] | \n", "
| 3 | \n", "7824 | \n", "[REPRICING, OF, THE, BALANCE, OF, THE, BONDS, IN, THE, ACCOUNT, .] | \n", "[VBG, IN, DT, NN, IN, DT, NNS, IN, DT, NN, .] | \n", "[B-VP, B-PP, B-NP, I-NP, B-PP, B-NP, I-NP, B-PP, B-NP, I-NP, O] | \n", "[O, O, O, O, O, O, O, O, O, O, O] | \n", "
| 4 | \n", "13690 | \n", "[Hong, Kong, Financial, Secretary, Donald, Tsang, said, on, Thursday, he, expected, the, territory, 's, economy, to, keep, growing, at, around, five, percent, but, with, some, fluctuations, from, year, to, year, .] | \n", "[NNP, NNP, NNP, NNP, NNP, NNP, VBD, IN, NNP, PRP, VBD, DT, NN, POS, NN, TO, VB, VBG, IN, IN, CD, NN, CC, IN, DT, NNS, IN, NN, TO, NN, .] | \n", "[B-NP, I-NP, I-NP, I-NP, I-NP, I-NP, B-VP, B-PP, B-NP, B-NP, B-VP, B-NP, I-NP, B-NP, I-NP, B-VP, I-VP, I-VP, B-PP, B-NP, I-NP, I-NP, O, B-PP, B-NP, I-NP, B-PP, B-NP, B-PP, B-NP, O] | \n", "[B-LOC, I-LOC, O, O, B-PER, I-PER, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O] | \n", "
| 5 | \n", "6749 | \n", "[Atlante, 1, Atlas, 1] | \n", "[NNP, CD, NNP, CD] | \n", "[B-NP, I-NP, I-NP, I-NP] | \n", "[B-ORG, O, B-ORG, O] | \n", "
| 6 | \n", "9964 | \n", "[1., Osmond, Ezinwa, (, Nigeria, ), 10.13, seconds] | \n", "[CD, NNP, NNP, (, NNP, ), CD, NNS] | \n", "[B-NP, I-NP, I-NP, O, B-NP, O, B-NP, I-NP] | \n", "[O, B-PER, I-PER, O, B-LOC, O, O, O] | \n", "
| 7 | \n", "13665 | \n", "[In, an, interview, following, its, first-half, results, ,, which, included, a, less, optimistic, forecast, for, the, second, half, of, this, year, than, it, had, made, in, the, past, ,, Sir, Colin, Hope, said, T&N, had, taken, defensive, action, to, protect, it, from, patchy, markets, .] | \n", "[IN, DT, NN, VBG, PRP$, JJ, NNS, ,, WDT, VBD, DT, RBR, JJ, NN, IN, DT, JJ, NN, IN, DT, NN, IN, PRP, VBD, VBN, IN, DT, NN, ,, NNP, NNP, NNP, VBD, NNP, VBD, VBN, JJ, NN, TO, VB, PRP, IN, JJ, NNS, .] | \n", "[B-PP, B-NP, I-NP, B-PP, B-NP, I-NP, I-NP, O, B-NP, B-VP, B-NP, I-NP, I-NP, I-NP, B-PP, B-NP, I-NP, I-NP, B-PP, B-NP, I-NP, B-SBAR, B-NP, B-VP, I-VP, B-PP, B-NP, I-NP, O, B-NP, I-NP, I-NP, B-VP, B-NP, B-VP, I-VP, B-NP, I-NP, B-VP, I-VP, B-NP, B-PP, B-NP, I-NP, O] | \n", "[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, B-PER, I-PER, O, B-ORG, O, O, O, O, O, O, O, O, O, O, O] | \n", "
| 8 | \n", "6232 | \n", "[The, rand, was, last, bid, at, 4.5350, against, the, dollar, .] | \n", "[DT, NN, VBD, JJ, NN, IN, CD, IN, DT, NN, .] | \n", "[B-NP, I-NP, B-VP, B-NP, I-NP, B-PP, B-NP, B-PP, B-NP, I-NP, O] | \n", "[O, O, O, O, O, O, O, O, O, O, O] | \n", "
| 9 | \n", "13340 | \n", "[Liam, Gallagher, ,, singer, of, Britain, 's, top, rock, group, Oasis, ,, flew, out, on, Thursday, to, join, the, band, three, days, after, the, start, of, its, U.S., tour, .] | \n", "[NNP, NNP, ,, NN, IN, NNP, POS, JJ, NN, NN, NNP, ,, VBD, RP, IN, NNP, TO, VB, DT, NN, CD, NNS, IN, DT, NN, IN, PRP$, NNP, NN, .] | \n", "[B-NP, I-NP, O, B-NP, B-PP, B-NP, B-NP, I-NP, I-NP, I-NP, I-NP, O, B-VP, B-PRT, B-PP, B-NP, B-VP, I-VP, B-NP, I-NP, B-NP, I-NP, B-PP, B-NP, I-NP, B-PP, B-NP, I-NP, I-NP, O] | \n", "[B-PER, I-PER, O, O, O, B-LOC, O, O, O, O, B-ORG, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, B-LOC, O, O] | \n", "