{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", " Model_Training_with_BERT.ipynb:\n", "

Use Text Extensions for Pandas to integrate BERT tokenization with model training for named entity recognition on Pandas.

\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Introduction\n", "\n", "This notebook shows how to use the open source library [Text Extensions for Pandas](https://github.com/CODAIT/text-extensions-for-pandas) to seamlessly integrate BERT tokenization and embeddings with model training for named entity recognition using [Pandas](https://pandas.pydata.org/) DataFrames.\n", "\n", "This example will build on the analysis of the [CoNLL-2003](https://www.clips.uantwerpen.be/conll2003/ner/) corpus done in [Analyze_Model_Outputs](./Analyze_Model_Outputs.ipynb) to train a new model for named entity recognition (NER) using state-of-the-art natural language understanding with BERT tokenization and embeddings. While the model used is rather simple and will only get modest scoring results, the purpose is to demonstrate how Text Extensions for Pandas integrates BERT from [Huggingface Transformers](https://huggingface.co/transformers/index.html) with the `TensorArray` extension for model training and scoring, all within Pandas DataFrames. See [Text_Extension_for_Pandas_Overview](./Text_Extension_for_Pandas_Overview.ipynb) for `TensorArray` specification and more example usage.\n", "\n", "The notebook is divided into the following steps:\n", "\n", "1. Retokenize the entire corpus using a \"BERT-compatible\" tokenizer, and map the token/entity labels from the original corpus on to the new tokenization.\n", "1. Generate BERT embeddings for every token in the entire corpus in one pass, and store those embeddings in a DataFrame column (of type TensorDtype) alongside the tokens and labels.\n", "1. Persist the DataFrame with computed BERT embeddings to disk as a checkpoint.\n", "1. Use the embeddings to train a multinomial logistic regression model to perform named entity recognition.\n", "1. Compute precision/recall for the model predictions on a test set.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Environment Setup\n", "\n", "This notebook requires a Python 3.7 or later environment with NumPy, Pandas, scikit-learn, PyTorch and Huggingface `transformers`. \n", "\n", "The notebook also requires the `text_extensions_for_pandas` library. You can satisfy this dependency in two ways:\n", "\n", "* Run `pip install text_extensions_for_pandas` before running this notebook. This command adds the library to your Python environment.\n", "* Run this notebook out of your local copy of the Text Extensions for Pandas project's [source tree](https://github.com/CODAIT/text-extensions-for-pandas). In this case, the notebook will use the version of Text Extensions for Pandas in your local source tree **if the package is not installed in your Python environment**." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import gc\n", "import os\n", "import sys\n", "from typing import *\n", "import numpy as np\n", "import pandas as pd\n", "import sklearn.pipeline\n", "import sklearn.linear_model\n", "import torch\n", "import transformers\n", "\n", "# And of course we need the text_extensions_for_pandas library itself.\n", "try:\n", " import text_extensions_for_pandas as tp\n", "except ModuleNotFoundError as e:\n", " # If we're running from within the project source tree and the parent Python\n", " # environment doesn't have the text_extensions_for_pandas package, use the\n", " # version in the local source tree.\n", " if not os.getcwd().endswith(\"notebooks\"):\n", " raise e\n", " if \"..\" not in sys.path:\n", " sys.path.insert(0, \"..\")\n", " import text_extensions_for_pandas as tp" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Named Entity Recognition with BERT on CoNLL-2003\n", "\n", "[CoNLL](https://www.conll.org/), the SIGNLL Conference on Computational Natural Language Learning, is an annual academic conference for natural language processing researchers. Each year's conference features a competition involving a challenging NLP task. The task for the 2003 competition involved identifying mentions of [named entities](https://en.wikipedia.org/wiki/Named-entity_recognition) in English and German news articles from the late 1990's. The corpus for this 2003 competition is one of the most widely-used benchmarks for the performance of named entity recognition models. Current [state-of-the-art results](https://paperswithcode.com/sota/named-entity-recognition-ner-on-conll-2003) on this corpus produce an F1 score (harmonic mean of precision and recall) of 0.93. The best F1 score in the original competition was 0.89.\n", "\n", "For more information about this data set, we recommend reading the conference paper about the competition results, [\"Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition,\"](https://www.aclweb.org/anthology/W03-0419/).\n", "\n", "**Note that the data set is licensed for research use only. Be sure to adhere to the terms of the license when using this data set!**\n", "\n", "The developers of the CoNLL-2003 corpus defined a file format for the corpus, based on the file format used in the earlier [Message Understanding Conference](https://en.wikipedia.org/wiki/Message_Understanding_Conference) competition. This format is generally known as \"CoNLL format\" or \"CoNLL-2003 format\".\n", "\n", "In the following cell, we use the facilities of Text Extensions for Pandas to download a copy of the CoNLL-2003 data set. Then we read the CoNLL-2003-format file containing the `test` fold of the corpus and translate the data into a collection of Pandas [DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) objects, one Dataframe per document. Finally, we display the Dataframe for the first document of the `test` fold of the corpus." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'train': 'outputs/eng.train',\n", " 'dev': 'outputs/eng.testa',\n", " 'test': 'outputs/eng.testb'}" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Download and cache the data set.\n", "# NOTE: This data set is licensed for research use only. Be sure to adhere\n", "# to the terms of the license when using this data set!\n", "data_set_info = tp.io.conll.maybe_download_conll_data(\"outputs\")\n", "data_set_info" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Show how to retokenize with a BERT tokenizer.\n", "\n", "The BERT model is originally from the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova. The model is pre-trained with masked language modeling and next sentence prediction objectives, which make it effective for masked token prediction and NLU. \n", "\n", "With the CoNLL-2003 corpus loaded, it will need to be retokenized using a \"BERT-compatible\" tokenizer. Then we can map the token/entity labels from the original corpus on to the new tokenization.\n", "\n", "We will start by showing the retokenizing process for a single document before doing the same on the entire corpus." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
spanent_iobent_typesentenceline_num
0[0, 10): '-DOCSTART-'ONone[0, 10): '-DOCSTART-'1469
1[11, 18): 'CRICKET'ONone[11, 62): 'CRICKET- PAKISTAN V NEW ZEALAND ONE...1471
2[18, 19): '-'ONone[11, 62): 'CRICKET- PAKISTAN V NEW ZEALAND ONE...1472
3[20, 28): 'PAKISTAN'BLOC[11, 62): 'CRICKET- PAKISTAN V NEW ZEALAND ONE...1473
4[29, 30): 'V'ONone[11, 62): 'CRICKET- PAKISTAN V NEW ZEALAND ONE...1474
..................
350[1620, 1621): '8'ONone[1590, 1634): 'Third one-day match: December 8...1865
351[1621, 1622): ','ONone[1590, 1634): 'Third one-day match: December 8...1866
352[1623, 1625): 'in'ONone[1590, 1634): 'Third one-day match: December 8...1867
353[1626, 1633): 'Karachi'BLOC[1590, 1634): 'Third one-day match: December 8...1868
354[1633, 1634): '.'ONone[1590, 1634): 'Third one-day match: December 8...1869
\n", "

355 rows × 5 columns

\n", "
" ], "text/plain": [ " span ent_iob ent_type \\\n", "0 [0, 10): '-DOCSTART-' O None \n", "1 [11, 18): 'CRICKET' O None \n", "2 [18, 19): '-' O None \n", "3 [20, 28): 'PAKISTAN' B LOC \n", "4 [29, 30): 'V' O None \n", ".. ... ... ... \n", "350 [1620, 1621): '8' O None \n", "351 [1621, 1622): ',' O None \n", "352 [1623, 1625): 'in' O None \n", "353 [1626, 1633): 'Karachi' B LOC \n", "354 [1633, 1634): '.' O None \n", "\n", " sentence line_num \n", "0 [0, 10): '-DOCSTART-' 1469 \n", "1 [11, 62): 'CRICKET- PAKISTAN V NEW ZEALAND ONE... 1471 \n", "2 [11, 62): 'CRICKET- PAKISTAN V NEW ZEALAND ONE... 1472 \n", "3 [11, 62): 'CRICKET- PAKISTAN V NEW ZEALAND ONE... 1473 \n", "4 [11, 62): 'CRICKET- PAKISTAN V NEW ZEALAND ONE... 1474 \n", ".. ... ... \n", "350 [1590, 1634): 'Third one-day match: December 8... 1865 \n", "351 [1590, 1634): 'Third one-day match: December 8... 1866 \n", "352 [1590, 1634): 'Third one-day match: December 8... 1867 \n", "353 [1590, 1634): 'Third one-day match: December 8... 1868 \n", "354 [1590, 1634): 'Third one-day match: December 8... 1869 \n", "\n", "[355 rows x 5 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Read in the corpus in its original tokenization.\n", "corpus_raw = {}\n", "for fold_name, file_name in data_set_info.items():\n", " df_list = tp.io.conll.conll_2003_to_dataframes(file_name, \n", " [\"pos\", \"phrase\", \"ent\"],\n", " [False, True, True])\n", " corpus_raw[fold_name] = [\n", " df.drop(columns=[\"pos\", \"phrase_iob\", \"phrase_type\"])\n", " for df in df_list\n", " ]\n", "\n", "test_raw = corpus_raw[\"test\"]\n", "\n", "# Pick out the dataframe for a single example document.\n", "example_df = test_raw[5]\n", "example_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `example_df` contains columns `span` and `sentence` of dtypes `SpanDtype` and `TokenSpanDtype`. These represent spans from the target text, and here they contain tokens of the text and the sentence containing that token. See the notebook [Text_Extension_for_Pandas_Overview](./Text_Extension_for_Pandas_Overview.ipynb) for more on `SpanArray` and `TokenSpanArray`." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "span SpanDtype\n", "ent_iob object\n", "ent_type object\n", "sentence TokenSpanDtype\n", "line_num int64\n", "dtype: object" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "example_df.dtypes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Convert IOB-Tagged Data to Lists of Entity Mentions\n", "\n", "The data we've looked at so far has been in [IOB2 format](https://en.wikipedia.org/wiki/Inside%E2%80%93outside%E2%80%93beginning_(tagging)). \n", "Each row of our DataFrame represents a token, and each token is tagged with an entity type (`ent_type`) and an IOB tag (`ent_iob`). The first token of each named entity mention is tagged `B`, while subsequent tokens are tagged `I`. Tokens that aren't part of any named entity are tagged `O`.\n", "\n", "IOB2 format is a convenient way to represent a corpus, but it is a less useful representation for analyzing the result quality of named entity recognition models. Most tokens in a typical NER corpus will be tagged `O`, any measure of error rate in terms of tokens will over-emphasizing the tokens that are part of entities. Token-level error rate implicitly assigns higher weight to named entity mentions that consist of multiple tokens, further unbalancing error metrics. And most crucially, a naive comparison of IOB tags can result in marking an incorrect answer as correct. Consider a case where the correct sequence of labels is `B, B, I` but the model has output `B, I, I`; in this case, last two tokens of model output are both incorrect (the model has assigned them to the same entity as the first token), but a naive token-level comparison will consider the last token to be correct.\n", "\n", "The CoNLL 2003 competition used the number of errors in extracting *entire* entity mentions to measure the result quality of the entries. We will use the same metric in this notebook. To compute entity-level errors, we convert the IOB-tagged tokens into pairs of ``. \n", "Text Extensions for Pandas includes a function `iob_to_spans()` that will handle this conversion for you." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
spanent_type
0[20, 28): 'PAKISTAN'LOC
1[31, 42): 'NEW ZEALAND'LOC
2[80, 83): 'GMT'MISC
3[85, 92): 'SIALKOT'LOC
4[94, 102): 'Pakistan'LOC
.........
69[1488, 1501): 'Shahid Afridi'PER
70[1512, 1523): 'Salim Malik'PER
71[1535, 1545): 'Ijaz Ahmad'PER
72[1565, 1573): 'Pakistan'LOC
73[1626, 1633): 'Karachi'LOC
\n", "

74 rows × 2 columns

\n", "
" ], "text/plain": [ " span ent_type\n", "0 [20, 28): 'PAKISTAN' LOC\n", "1 [31, 42): 'NEW ZEALAND' LOC\n", "2 [80, 83): 'GMT' MISC\n", "3 [85, 92): 'SIALKOT' LOC\n", "4 [94, 102): 'Pakistan' LOC\n", ".. ... ...\n", "69 [1488, 1501): 'Shahid Afridi' PER\n", "70 [1512, 1523): 'Salim Malik' PER\n", "71 [1535, 1545): 'Ijaz Ahmad' PER\n", "72 [1565, 1573): 'Pakistan' LOC\n", "73 [1626, 1633): 'Karachi' LOC\n", "\n", "[74 rows x 2 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Convert the corpus IOB2 tagged DataFrame to one with entity span and type columns.\n", "spans_df = tp.io.conll.iob_to_spans(example_df)\n", "spans_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Initialize our BERT Tokenizer and Model\n", "\n", "Here we configure and initialize the [Huggingface transformers BERT tokenizer and model](https://huggingface.co/transformers/model_doc/bert.html). Text Extensions for Pandas provides a `make_bert_tokens()` function that will use the tokenizer to create BERT tokens as a span column in a DataFrame, suitable to compute BERT embeddings with." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
token_idspaninput_idtoken_type_idattention_maskspecial_tokens_mask
00[0, 0): ''10101True
11[0, 1): '-'11801False
22[1, 2): 'D'14101False
33[2, 4): 'OC'924401False
44[4, 6): 'ST'927201False
.....................
684684[1621, 1622): ','11701False
685685[1623, 1625): 'in'110701False
686686[1626, 1633): 'Karachi'1623701False
687687[1633, 1634): '.'11901False
688688[0, 0): ''10201True
\n", "

689 rows × 6 columns

\n", "
" ], "text/plain": [ " token_id span input_id token_type_id \\\n", "0 0 [0, 0): '' 101 0 \n", "1 1 [0, 1): '-' 118 0 \n", "2 2 [1, 2): 'D' 141 0 \n", "3 3 [2, 4): 'OC' 9244 0 \n", "4 4 [4, 6): 'ST' 9272 0 \n", ".. ... ... ... ... \n", "684 684 [1621, 1622): ',' 117 0 \n", "685 685 [1623, 1625): 'in' 1107 0 \n", "686 686 [1626, 1633): 'Karachi' 16237 0 \n", "687 687 [1633, 1634): '.' 119 0 \n", "688 688 [0, 0): '' 102 0 \n", "\n", " attention_mask special_tokens_mask \n", "0 1 True \n", "1 1 False \n", "2 1 False \n", "3 1 False \n", "4 1 False \n", ".. ... ... \n", "684 1 False \n", "685 1 False \n", "686 1 False \n", "687 1 False \n", "688 1 True \n", "\n", "[689 rows x 6 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Huggingface transformers BERT Configuration.\n", "bert_model_name = \"dslim/bert-base-NER\"\n", "\n", "tokenizer = transformers.BertTokenizerFast.from_pretrained(bert_model_name, \n", " add_special_tokens=True)\n", "\n", "# Disable the warning about long sequences. We know what we're doing.\n", "# Different versions of transformers disable this warning differently,\n", "# so we need to do this twice.\n", "tokenizer.deprecation_warnings[\n", " \"sequence-length-is-longer-than-the-specified-maximum\"] = True\n", "tokenizer.model_max_length = 16384\n", "\n", "# Retokenize the document's text with the BERT tokenizer as a DataFrame \n", "# with a span column.\n", "bert_toks_df = tp.io.bert.make_bert_tokens(example_df[\"span\"].values[0].target_text, tokenizer)\n", "bert_toks_df" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
token_idspaninput_idtoken_type_idattention_maskspecial_tokens_mask
00[0, 0): ''10101True
688688[0, 0): ''10201True
\n", "
" ], "text/plain": [ " token_id span input_id token_type_id attention_mask \\\n", "0 0 [0, 0): '' 101 0 1 \n", "688 688 [0, 0): '' 102 0 1 \n", "\n", " special_tokens_mask \n", "0 True \n", "688 True " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# BERT tokenization includes special zero-length tokens.\n", "bert_toks_df[bert_toks_df[\"special_tokens_mask\"]]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
original_spanbert_spansent_type
0[20, 28): 'PAKISTAN'[20, 28): 'PAKISTAN'LOC
1[31, 42): 'NEW ZEALAND'[31, 42): 'NEW ZEALAND'LOC
2[80, 83): 'GMT'[80, 83): 'GMT'MISC
3[85, 92): 'SIALKOT'[85, 92): 'SIALKOT'LOC
4[94, 102): 'Pakistan'[94, 102): 'Pakistan'LOC
............
69[1488, 1501): 'Shahid Afridi'[1488, 1501): 'Shahid Afridi'PER
70[1512, 1523): 'Salim Malik'[1512, 1523): 'Salim Malik'PER
71[1535, 1545): 'Ijaz Ahmad'[1535, 1545): 'Ijaz Ahmad'PER
72[1565, 1573): 'Pakistan'[1565, 1573): 'Pakistan'LOC
73[1626, 1633): 'Karachi'[1626, 1633): 'Karachi'LOC
\n", "

74 rows × 3 columns

\n", "
" ], "text/plain": [ " original_span bert_spans ent_type\n", "0 [20, 28): 'PAKISTAN' [20, 28): 'PAKISTAN' LOC\n", "1 [31, 42): 'NEW ZEALAND' [31, 42): 'NEW ZEALAND' LOC\n", "2 [80, 83): 'GMT' [80, 83): 'GMT' MISC\n", "3 [85, 92): 'SIALKOT' [85, 92): 'SIALKOT' LOC\n", "4 [94, 102): 'Pakistan' [94, 102): 'Pakistan' LOC\n", ".. ... ... ...\n", "69 [1488, 1501): 'Shahid Afridi' [1488, 1501): 'Shahid Afridi' PER\n", "70 [1512, 1523): 'Salim Malik' [1512, 1523): 'Salim Malik' PER\n", "71 [1535, 1545): 'Ijaz Ahmad' [1535, 1545): 'Ijaz Ahmad' PER\n", "72 [1565, 1573): 'Pakistan' [1565, 1573): 'Pakistan' LOC\n", "73 [1626, 1633): 'Karachi' [1626, 1633): 'Karachi' LOC\n", "\n", "[74 rows x 3 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Align the BERT tokens with the original tokenization.\n", "bert_spans = tp.TokenSpanArray.align_to_tokens(bert_toks_df[\"span\"],\n", " spans_df[\"span\"])\n", "pd.DataFrame({\n", " \"original_span\": spans_df[\"span\"],\n", " \"bert_spans\": bert_spans,\n", " \"ent_type\": spans_df[\"ent_type\"]\n", "})" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
token_idspaninput_idtoken_type_idattention_maskspecial_tokens_maskent_iobent_type
1010[15, 17): 'KE'2244101FalseO<NA>
1111[17, 18): 'T'194201FalseO<NA>
1212[18, 19): '-'11801FalseO<NA>
1313[20, 22): 'PA'854401FalseBLOC
1414[22, 23): 'K'242801FalseILOC
1515[23, 25): 'IS'625801FalseILOC
1616[25, 27): 'TA'915901FalseILOC
1717[27, 28): 'N'224901FalseILOC
1818[29, 30): 'V'15901FalseO<NA>
1919[31, 33): 'NE'2654601FalseBLOC
\n", "
" ], "text/plain": [ " token_id span input_id token_type_id attention_mask \\\n", "10 10 [15, 17): 'KE' 22441 0 1 \n", "11 11 [17, 18): 'T' 1942 0 1 \n", "12 12 [18, 19): '-' 118 0 1 \n", "13 13 [20, 22): 'PA' 8544 0 1 \n", "14 14 [22, 23): 'K' 2428 0 1 \n", "15 15 [23, 25): 'IS' 6258 0 1 \n", "16 16 [25, 27): 'TA' 9159 0 1 \n", "17 17 [27, 28): 'N' 2249 0 1 \n", "18 18 [29, 30): 'V' 159 0 1 \n", "19 19 [31, 33): 'NE' 26546 0 1 \n", "\n", " special_tokens_mask ent_iob ent_type \n", "10 False O \n", "11 False O \n", "12 False O \n", "13 False B LOC \n", "14 False I LOC \n", "15 False I LOC \n", "16 False I LOC \n", "17 False I LOC \n", "18 False O \n", "19 False B LOC " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Generate IOB2 tags and entity labels that align with the BERT tokens.\n", "# See https://en.wikipedia.org/wiki/Inside%E2%80%93outside%E2%80%93beginning_(tagging)\n", "bert_toks_df[[\"ent_iob\", \"ent_type\"]] = tp.io.conll.spans_to_iob(bert_spans, \n", " spans_df[\"ent_type\"])\n", "bert_toks_df[10:20]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "CategoricalDtype(categories=['O', 'B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'I-LOC', 'I-MISC',\n", " 'I-ORG', 'I-PER'],\n", ", ordered=False)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a Pandas categorical type for consistent encoding of categories\n", "# across all documents.\n", "ENTITY_TYPES = [\"LOC\", \"MISC\", \"ORG\", \"PER\"]\n", "token_class_dtype, int_to_label, label_to_int = tp.io.conll.make_iob_tag_categories(ENTITY_TYPES)\n", "token_class_dtype" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
token_idspaninput_idtoken_type_idattention_maskspecial_tokens_maskent_iobent_typetoken_classtoken_class_id
00[0, 0): ''10101TrueO<NA>O0
11[0, 1): '-'11801FalseO<NA>O0
22[1, 2): 'D'14101FalseO<NA>O0
33[2, 4): 'OC'924401FalseO<NA>O0
44[4, 6): 'ST'927201FalseO<NA>O0
.................................
684684[1621, 1622): ','11701FalseO<NA>O0
685685[1623, 1625): 'in'110701FalseO<NA>O0
686686[1626, 1633): 'Karachi'1623701FalseBLOCB-LOC1
687687[1633, 1634): '.'11901FalseO<NA>O0
688688[0, 0): ''10201TrueO<NA>O0
\n", "

689 rows × 10 columns

\n", "
" ], "text/plain": [ " token_id span input_id token_type_id \\\n", "0 0 [0, 0): '' 101 0 \n", "1 1 [0, 1): '-' 118 0 \n", "2 2 [1, 2): 'D' 141 0 \n", "3 3 [2, 4): 'OC' 9244 0 \n", "4 4 [4, 6): 'ST' 9272 0 \n", ".. ... ... ... ... \n", "684 684 [1621, 1622): ',' 117 0 \n", "685 685 [1623, 1625): 'in' 1107 0 \n", "686 686 [1626, 1633): 'Karachi' 16237 0 \n", "687 687 [1633, 1634): '.' 119 0 \n", "688 688 [0, 0): '' 102 0 \n", "\n", " attention_mask special_tokens_mask ent_iob ent_type token_class \\\n", "0 1 True O O \n", "1 1 False O O \n", "2 1 False O O \n", "3 1 False O O \n", "4 1 False O O \n", ".. ... ... ... ... ... \n", "684 1 False O O \n", "685 1 False O O \n", "686 1 False B LOC B-LOC \n", "687 1 False O O \n", "688 1 True O O \n", "\n", " token_class_id \n", "0 0 \n", "1 0 \n", "2 0 \n", "3 0 \n", "4 0 \n", ".. ... \n", "684 0 \n", "685 0 \n", "686 1 \n", "687 0 \n", "688 0 \n", "\n", "[689 rows x 10 columns]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The traditional way to transform NER to token classification is to \n", "# treat each combination of {I,O,B} X {entity type} as a different\n", "# class. Generate class labels in that format.\n", "classes_df = tp.io.conll.add_token_classes(bert_toks_df, token_class_dtype)\n", "classes_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Show how to compute BERT embeddings\n", "\n", "We are going to use the BERT embeddings as the feature vector to train our model. First, we will show how they are computed " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
token_idspaninput_ident_iobent_typetoken_classembedding
1010[15, 17): 'KE'22441O<NA>O[ -0.19854169, -0.46898514, 0.7755601...
1111[17, 18): 'T'1942O<NA>O[ -0.24190396, -0.42399377, 0.9554063...
1212[18, 19): '-'118O<NA>O[ -0.20076752, -0.7481933, 1.302213...
1313[20, 22): 'PA'8544BLOCB-LOC[ 0.20202553, -0.26199815, 0.3297633...
1414[22, 23): 'K'2428ILOCI-LOC[ -0.5462168, -0.90924424, -0.0583674...
1515[23, 25): 'IS'6258ILOCI-LOC[ -0.37400252, -0.6890734, -0.1446257...
1616[25, 27): 'TA'9159ILOCI-LOC[ -0.46548516, -0.8717417, 0.3557479...
1717[27, 28): 'N'2249ILOCI-LOC[ -0.18682763, -0.90081865, 0.3601499...
1818[29, 30): 'V'159O<NA>O[ -0.16640103, -0.8363804, 0.8740610...
1919[31, 33): 'NE'26546BLOCB-LOC[ -0.30241105, -0.83826715, 1.105809...
\n", "
" ], "text/plain": [ " token_id span input_id ent_iob ent_type token_class \\\n", "10 10 [15, 17): 'KE' 22441 O O \n", "11 11 [17, 18): 'T' 1942 O O \n", "12 12 [18, 19): '-' 118 O O \n", "13 13 [20, 22): 'PA' 8544 B LOC B-LOC \n", "14 14 [22, 23): 'K' 2428 I LOC I-LOC \n", "15 15 [23, 25): 'IS' 6258 I LOC I-LOC \n", "16 16 [25, 27): 'TA' 9159 I LOC I-LOC \n", "17 17 [27, 28): 'N' 2249 I LOC I-LOC \n", "18 18 [29, 30): 'V' 159 O O \n", "19 19 [31, 33): 'NE' 26546 B LOC B-LOC \n", "\n", " embedding \n", "10 [ -0.19854169, -0.46898514, 0.7755601... \n", "11 [ -0.24190396, -0.42399377, 0.9554063... \n", "12 [ -0.20076752, -0.7481933, 1.302213... \n", "13 [ 0.20202553, -0.26199815, 0.3297633... \n", "14 [ -0.5462168, -0.90924424, -0.0583674... \n", "15 [ -0.37400252, -0.6890734, -0.1446257... \n", "16 [ -0.46548516, -0.8717417, 0.3557479... \n", "17 [ -0.18682763, -0.90081865, 0.3601499... \n", "18 [ -0.16640103, -0.8363804, 0.8740610... \n", "19 [ -0.30241105, -0.83826715, 1.105809... " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Initialize the BERT model that will be used to generate embeddings.\n", "bert = transformers.BertModel.from_pretrained(bert_model_name)\n", "\n", "# Force garbage collection in case this notebook is running on a low-RAM environment.\n", "gc.collect()\n", "\n", "# Compute BERT embeddings with the BERT model and add result to our example DataFrame.\n", "embeddings_df = tp.io.bert.add_embeddings(classes_df, bert)\n", "embeddings_df[[\"token_id\", \"span\", \"input_id\", \"ent_iob\", \"ent_type\", \"token_class\", \"embedding\"]].iloc[10:20]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
spanent_iobent_typeembedding
70[155, 168): 'international'O<NA>[ 0.23404993, -0.5534872, 0.9083986, ...
71[169, 176): 'between'O<NA>[ 0.27793035, -0.68538034, 1.1050361, ...
72[177, 185): 'Pakistan'BLOC[ 0.1971882, -0.4634109, 0.5182331, ...
73[186, 189): 'and'O<NA>[ 0.20423535, -0.63758826, 0.82874435, ...
74[190, 193): 'New'BLOC[ 0.2874066, -0.47174183, 0.7771955, ...
\n", "
" ], "text/plain": [ " span ent_iob ent_type \\\n", "70 [155, 168): 'international' O \n", "71 [169, 176): 'between' O \n", "72 [177, 185): 'Pakistan' B LOC \n", "73 [186, 189): 'and' O \n", "74 [190, 193): 'New' B LOC \n", "\n", " embedding \n", "70 [ 0.23404993, -0.5534872, 0.9083986, ... \n", "71 [ 0.27793035, -0.68538034, 1.1050361, ... \n", "72 [ 0.1971882, -0.4634109, 0.5182331, ... \n", "73 [ 0.20423535, -0.63758826, 0.82874435, ... \n", "74 [ 0.2874066, -0.47174183, 0.7771955, ... " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "embeddings_df[[\"span\", \"ent_iob\", \"ent_type\", \"embedding\"]].iloc[70:75]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# The `embedding` column is an extension type `TensorDtype` that holds a \n", "#`TensorArray` provided by Text Extensions for Pandas.\n", "embeddings_df[\"embedding\"].dtype" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A `TensorArray` can be constructed with a NumPy array of arbitrary dimensions, added to a DataFrame, then used with standard Pandas functionality. See the notebook [Text_Extension_for_Pandas_Overview](./Text_Extensions_for_Pandas.ipynb) for more on `TensorArray`." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(dtype('float32'), (689, 768))" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Zero-copy conversion to NumPy can be done by first unwrapping the\n", "# `TensorArray` with `.array` and calling `to_numpy()`.\n", "embeddings_arr = embeddings_df[\"embedding\"].array.to_numpy()\n", "embeddings_arr.dtype, embeddings_arr.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate BERT tokens and BERT embeddings for the entire corpus\n", "\n", "Text Extensions for Pandas has a convenience function that will combine the above cells to create BERT tokens and embeddings. We will use this to add embeddings to the entire corpus." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
token_idspaninput_idtoken_type_idattention_maskspecial_tokens_maskent_iobent_typetoken_classtoken_class_idembedding
00[0, 0): ''10101TrueO<NA>O0[ -0.08307081, -0.35959032, 1.015068...
11[0, 1): '-'11801FalseO<NA>O0[ -0.22862603, -0.49313632, 1.28423...
22[1, 2): 'D'14101FalseO<NA>O0[ 0.028480662, -0.17874284, 1.54320...
33[2, 4): 'OC'924401FalseO<NA>O0[ -0.4651753, -0.29836023, 1.073767...
44[4, 6): 'ST'927201FalseO<NA>O0[ -0.10730811, -0.33720982, 1.226979...
....................................
684684[1621, 1622): ','11701FalseO<NA>O0[ -0.1280663, -0.0023243837, 0.678132...
685685[1623, 1625): 'in'110701FalseO<NA>O0[ 0.3053407, -0.52625775, 0.8281702...
686686[1626, 1633): 'Karachi'1623701FalseBLOCB-LOC1[ -0.048738778, -0.33797324, -0.0583509...
687687[1633, 1634): '.'11901FalseO<NA>O0[ -0.005289644, -0.29743072, 0.716173...
688688[0, 0): ''10201TrueO<NA>O0[ -0.50302404, 0.36253828, 0.7314933...
\n", "

689 rows × 11 columns

\n", "
" ], "text/plain": [ " token_id span input_id token_type_id \\\n", "0 0 [0, 0): '' 101 0 \n", "1 1 [0, 1): '-' 118 0 \n", "2 2 [1, 2): 'D' 141 0 \n", "3 3 [2, 4): 'OC' 9244 0 \n", "4 4 [4, 6): 'ST' 9272 0 \n", ".. ... ... ... ... \n", "684 684 [1621, 1622): ',' 117 0 \n", "685 685 [1623, 1625): 'in' 1107 0 \n", "686 686 [1626, 1633): 'Karachi' 16237 0 \n", "687 687 [1633, 1634): '.' 119 0 \n", "688 688 [0, 0): '' 102 0 \n", "\n", " attention_mask special_tokens_mask ent_iob ent_type token_class \\\n", "0 1 True O O \n", "1 1 False O O \n", "2 1 False O O \n", "3 1 False O O \n", "4 1 False O O \n", ".. ... ... ... ... ... \n", "684 1 False O O \n", "685 1 False O O \n", "686 1 False B LOC B-LOC \n", "687 1 False O O \n", "688 1 True O O \n", "\n", " token_class_id embedding \n", "0 0 [ -0.08307081, -0.35959032, 1.015068... \n", "1 0 [ -0.22862603, -0.49313632, 1.28423... \n", "2 0 [ 0.028480662, -0.17874284, 1.54320... \n", "3 0 [ -0.4651753, -0.29836023, 1.073767... \n", "4 0 [ -0.10730811, -0.33720982, 1.226979... \n", ".. ... ... \n", "684 0 [ -0.1280663, -0.0023243837, 0.678132... \n", "685 0 [ 0.3053407, -0.52625775, 0.8281702... \n", "686 1 [ -0.048738778, -0.33797324, -0.0583509... \n", "687 0 [ -0.005289644, -0.29743072, 0.716173... \n", "688 0 [ -0.50302404, 0.36253828, 0.7314933... \n", "\n", "[689 rows x 11 columns]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Example usage of the convenience function to create BERT tokens and embeddings.\n", "tp.io.bert.conll_to_bert(example_df, tokenizer, bert, token_class_dtype)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When this notebook is running on a resource-constrained environment like [Binder](https://mybinder.org/),\n", "there may not be enough RAM available to hold all the embeddings in memory.\n", "So we use [Gaussian random projection](https://scikit-learn.org/stable/modules/random_projection.html#gaussian-random-projection) to reduce the size of the embeddings.\n", "The projection shrinks the embeddings by a factor of 3 at the expense of a small\n", "decrease in model accuracy.\n", "\n", "Change the constant `SHRINK_EMBEDDINGS` in the following cell to `False` if you want to disable this behavior." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "SHRINK_EMBEDDINGS = False\n", "PROJECTION_DIMS = 256\n", "RANDOM_SEED=42\n", "\n", "import sklearn.random_projection\n", "projection = sklearn.random_projection.GaussianRandomProjection(\n", " n_components=PROJECTION_DIMS, random_state=RANDOM_SEED)\n", "\n", "def maybe_shrink_embeddings(df):\n", " if SHRINK_EMBEDDINGS:\n", " df[\"embedding\"] = tp.TensorArray(projection.fit_transform(df[\"embedding\"]))\n", " return df" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Processing fold 'train'...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "733bd98d8a8f4959b5668020f1984a3c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=946, style=ProgressStyle(desc…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Processing fold 'dev'...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2e47f857a39a44d7ac5a3f34f60494cb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=216, style=ProgressStyle(desc…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Processing fold 'test'...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "31f95cf00b394566a13997459a76db17", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=231, style=ProgressStyle(desc…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
token_idspaninput_idtoken_type_idattention_maskspecial_tokens_maskent_iobent_typetoken_classtoken_class_idembedding
00[0, 0): ''10101TrueO<NA>O0[ -0.17669655, -0.3989963, 0.908887...
11[0, 1): '-'11801FalseO<NA>O0[ -0.3855382, -0.50232756, 1.173232...
22[1, 2): 'D'14101FalseO<NA>O0[ -0.11718995, -0.12701154, 1.38969...
33[2, 4): 'OC'924401FalseO<NA>O0[ -0.39025685, -0.25043246, 1.074507...
44[4, 6): 'ST'927201FalseO<NA>O0[ -0.27732754, -0.26160136, 1.078761...
....................................
21542154[5704, 5705): ')'11401FalseO<NA>O0[ 0.015393024, -0.040650737, 1.001185...
21552155[5706, 5708): '39'361401FalseO<NA>O0[ 0.075038865, 0.014400693, 1.043231...
21562156[5708, 5709): '.'11901FalseO<NA>O0[ -0.085796565, 0.05905571, 1.114640...
21572157[5709, 5711): '93'542901FalseO<NA>O0[ 0.0113782445, -0.26387203, 0.881803...
21582158[0, 0): ''10201TrueO<NA>O0[ 0.48513305, 1.5709875, 0.592935...
\n", "

2159 rows × 11 columns

\n", "
" ], "text/plain": [ " token_id span input_id token_type_id attention_mask \\\n", "0 0 [0, 0): '' 101 0 1 \n", "1 1 [0, 1): '-' 118 0 1 \n", "2 2 [1, 2): 'D' 141 0 1 \n", "3 3 [2, 4): 'OC' 9244 0 1 \n", "4 4 [4, 6): 'ST' 9272 0 1 \n", "... ... ... ... ... ... \n", "2154 2154 [5704, 5705): ')' 114 0 1 \n", "2155 2155 [5706, 5708): '39' 3614 0 1 \n", "2156 2156 [5708, 5709): '.' 119 0 1 \n", "2157 2157 [5709, 5711): '93' 5429 0 1 \n", "2158 2158 [0, 0): '' 102 0 1 \n", "\n", " special_tokens_mask ent_iob ent_type token_class token_class_id \\\n", "0 True O O 0 \n", "1 False O O 0 \n", "2 False O O 0 \n", "3 False O O 0 \n", "4 False O O 0 \n", "... ... ... ... ... ... \n", "2154 False O O 0 \n", "2155 False O O 0 \n", "2156 False O O 0 \n", "2157 False O O 0 \n", "2158 True O O 0 \n", "\n", " embedding \n", "0 [ -0.17669655, -0.3989963, 0.908887... \n", "1 [ -0.3855382, -0.50232756, 1.173232... \n", "2 [ -0.11718995, -0.12701154, 1.38969... \n", "3 [ -0.39025685, -0.25043246, 1.074507... \n", "4 [ -0.27732754, -0.26160136, 1.078761... \n", "... ... \n", "2154 [ 0.015393024, -0.040650737, 1.001185... \n", "2155 [ 0.075038865, 0.014400693, 1.043231... \n", "2156 [ -0.085796565, 0.05905571, 1.114640... \n", "2157 [ 0.0113782445, -0.26387203, 0.881803... \n", "2158 [ 0.48513305, 1.5709875, 0.592935... \n", "\n", "[2159 rows x 11 columns]" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Run the entire corpus through our processing pipeline.\n", "bert_toks_by_fold = {}\n", "for fold_name in corpus_raw.keys():\n", " print(f\"Processing fold '{fold_name}'...\")\n", " raw = corpus_raw[fold_name]\n", " with torch.inference_mode(): # This line cuts CPU usage by ~50%\n", " bert_toks_by_fold[fold_name] = tp.jupyter.run_with_progress_bar(\n", " len(raw), lambda i: maybe_shrink_embeddings(tp.io.bert.conll_to_bert(\n", " raw[i], tokenizer, bert, token_class_dtype)))\n", " \n", "bert_toks_by_fold[\"dev\"][20]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Collate the data structures we've generated so far" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
folddoc_numtoken_idspaninput_idtoken_type_idattention_maskspecial_tokens_maskent_iobent_typetoken_classtoken_class_idembedding
0train00[0, 0): ''10101TrueO<NA>O0[ -0.098505504, -0.4050192, 0.742888...
1train01[0, 1): '-'11801FalseO<NA>O0[ -0.057021566, -0.48112106, 0.989868...
2train02[1, 2): 'D'14101FalseO<NA>O0[ -0.04824192, -0.2532998, 1.16719...
3train03[2, 4): 'OC'924401FalseO<NA>O0[ -0.26682985, -0.31008705, 1.00747...
4train04[4, 6): 'ST'927201FalseO<NA>O0[ -0.22296886, -0.21308525, 0.933102...
..........................................
416536test230314[1386, 1393): 'brother'171101FalseO<NA>O0[ -0.02817309, -0.08062352, 0.9804888...
416537test230315[1393, 1394): ','11701FalseO<NA>O0[ 0.118173525, -0.07008511, 0.865484...
416538test230316[1395, 1400): 'Bobby'554501FalseBPERB-PER4[ -0.35689434, 0.31400475, 1.573854...
416539test230317[1400, 1401): '.'11901FalseO<NA>O0[ -0.18957116, -0.2458116, 0.66257...
416540test230318[0, 0): ''10201TrueO<NA>O0[ -0.4468915, -0.31665248, 0.779688...
\n", "

416541 rows × 13 columns

\n", "
" ], "text/plain": [ " fold doc_num token_id span input_id \\\n", "0 train 0 0 [0, 0): '' 101 \n", "1 train 0 1 [0, 1): '-' 118 \n", "2 train 0 2 [1, 2): 'D' 141 \n", "3 train 0 3 [2, 4): 'OC' 9244 \n", "4 train 0 4 [4, 6): 'ST' 9272 \n", "... ... ... ... ... ... \n", "416536 test 230 314 [1386, 1393): 'brother' 1711 \n", "416537 test 230 315 [1393, 1394): ',' 117 \n", "416538 test 230 316 [1395, 1400): 'Bobby' 5545 \n", "416539 test 230 317 [1400, 1401): '.' 119 \n", "416540 test 230 318 [0, 0): '' 102 \n", "\n", " token_type_id attention_mask special_tokens_mask ent_iob ent_type \\\n", "0 0 1 True O \n", "1 0 1 False O \n", "2 0 1 False O \n", "3 0 1 False O \n", "4 0 1 False O \n", "... ... ... ... ... ... \n", "416536 0 1 False O \n", "416537 0 1 False O \n", "416538 0 1 False B PER \n", "416539 0 1 False O \n", "416540 0 1 True O \n", "\n", " token_class token_class_id \\\n", "0 O 0 \n", "1 O 0 \n", "2 O 0 \n", "3 O 0 \n", "4 O 0 \n", "... ... ... \n", "416536 O 0 \n", "416537 O 0 \n", "416538 B-PER 4 \n", "416539 O 0 \n", "416540 O 0 \n", "\n", " embedding \n", "0 [ -0.098505504, -0.4050192, 0.742888... \n", "1 [ -0.057021566, -0.48112106, 0.989868... \n", "2 [ -0.04824192, -0.2532998, 1.16719... \n", "3 [ -0.26682985, -0.31008705, 1.00747... \n", "4 [ -0.22296886, -0.21308525, 0.933102... \n", "... ... \n", "416536 [ -0.02817309, -0.08062352, 0.9804888... \n", "416537 [ 0.118173525, -0.07008511, 0.865484... \n", "416538 [ -0.35689434, 0.31400475, 1.573854... \n", "416539 [ -0.18957116, -0.2458116, 0.66257... \n", "416540 [ -0.4468915, -0.31665248, 0.779688... \n", "\n", "[416541 rows x 13 columns]" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Create a single DataFrame with the entire corpus's embeddings.\n", "corpus_df = tp.io.conll.combine_folds(bert_toks_by_fold)\n", "corpus_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Checkpoint\n", "\n", "With the `TensorArray` from Text Extensions for Pandas, the computed embeddings can be persisted as a tensor along with the rest of the DataFrame using standard Pandas input/output methods. Since this is a costly operation and the embeddings are deterministic, it can save lots of time to checkpoint the data here and save the results to disk. This will allow us to continue working with model training without needing to re-compute the BERT embeddings again.\n", " \n", "### Save DataFrame with Embeddings Tensor" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# Write the tokenized corpus with embeddings to a Feather file.\n", "# We can't currently serialize span columns that cover multiple documents (see issue #73 https://github.com/CODAIT/text-extensions-for-pandas/issues/73),\n", "# so drop span columns from the contents we write to the Feather file.\n", "cols_to_drop = [c for c in corpus_df.columns if \"span\" in c]\n", "corpus_df.drop(columns=cols_to_drop).to_feather(\"outputs/corpus.feather\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load DataFrame with Previously Computed Embeddings" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
folddoc_numtoken_idinput_idtoken_type_idattention_maskspecial_tokens_maskent_iobent_typetoken_classtoken_class_idembedding
0train0010101TrueO<NA>O0[ -0.098505504, -0.4050192, 0.742888...
1train0111801FalseO<NA>O0[ -0.057021566, -0.48112106, 0.989868...
2train0214101FalseO<NA>O0[ -0.04824192, -0.2532998, 1.16719...
3train03924401FalseO<NA>O0[ -0.26682985, -0.31008705, 1.00747...
4train04927201FalseO<NA>O0[ -0.22296886, -0.21308525, 0.933102...
.......................................
416536test230314171101FalseO<NA>O0[ -0.02817309, -0.08062352, 0.9804888...
416537test23031511701FalseO<NA>O0[ 0.118173525, -0.07008511, 0.865484...
416538test230316554501FalseBPERB-PER4[ -0.35689434, 0.31400475, 1.573854...
416539test23031711901FalseO<NA>O0[ -0.18957116, -0.2458116, 0.66257...
416540test23031810201TrueO<NA>O0[ -0.4468915, -0.31665248, 0.779688...
\n", "

416541 rows × 12 columns

\n", "
" ], "text/plain": [ " fold doc_num token_id input_id token_type_id attention_mask \\\n", "0 train 0 0 101 0 1 \n", "1 train 0 1 118 0 1 \n", "2 train 0 2 141 0 1 \n", "3 train 0 3 9244 0 1 \n", "4 train 0 4 9272 0 1 \n", "... ... ... ... ... ... ... \n", "416536 test 230 314 1711 0 1 \n", "416537 test 230 315 117 0 1 \n", "416538 test 230 316 5545 0 1 \n", "416539 test 230 317 119 0 1 \n", "416540 test 230 318 102 0 1 \n", "\n", " special_tokens_mask ent_iob ent_type token_class token_class_id \\\n", "0 True O O 0 \n", "1 False O O 0 \n", "2 False O O 0 \n", "3 False O O 0 \n", "4 False O O 0 \n", "... ... ... ... ... ... \n", "416536 False O O 0 \n", "416537 False O O 0 \n", "416538 False B PER B-PER 4 \n", "416539 False O O 0 \n", "416540 True O O 0 \n", "\n", " embedding \n", "0 [ -0.098505504, -0.4050192, 0.742888... \n", "1 [ -0.057021566, -0.48112106, 0.989868... \n", "2 [ -0.04824192, -0.2532998, 1.16719... \n", "3 [ -0.26682985, -0.31008705, 1.00747... \n", "4 [ -0.22296886, -0.21308525, 0.933102... \n", "... ... \n", "416536 [ -0.02817309, -0.08062352, 0.9804888... \n", "416537 [ 0.118173525, -0.07008511, 0.865484... \n", "416538 [ -0.35689434, 0.31400475, 1.573854... \n", "416539 [ -0.18957116, -0.2458116, 0.66257... \n", "416540 [ -0.4468915, -0.31665248, 0.779688... \n", "\n", "[416541 rows x 12 columns]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Read the serialized embeddings back in so that you can rerun the model \n", "# training parts of this notebook (the cells from here onward) without \n", "# regenerating the embeddings.\n", "corpus_df = pd.read_feather(\"outputs/corpus.feather\")\n", "corpus_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training a model on the BERT embeddings\n", "\n", "Now we will use the loaded BERT embeddings to train a multinomial model to predict the token class from the embeddings tensor." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
folddoc_numtoken_idinput_idtoken_type_idattention_maskspecial_tokens_maskent_iobent_typetoken_classtoken_class_idembedding
0train0010101TrueO<NA>O0[ -0.098505504, -0.4050192, 0.742888...
1train0111801FalseO<NA>O0[ -0.057021566, -0.48112106, 0.989868...
2train0214101FalseO<NA>O0[ -0.04824192, -0.2532998, 1.16719...
3train03924401FalseO<NA>O0[ -0.26682985, -0.31008705, 1.00747...
4train04927201FalseO<NA>O0[ -0.22296886, -0.21308525, 0.933102...
.......................................
281104train945531705701FalseBORGB-ORG3[ 0.7556371, -0.91891253, -0.1403036...
281105train9455412201FalseO<NA>O0[ -0.11528473, -0.44492027, 0.4715562...
281106train94555461701FalseBORGB-ORG3[ 0.45602208, -0.8970848, 0.0678616...
281107train9455612301FalseO<NA>O0[ -0.19713743, -0.5427194, 0.294020...
281108train9455710201TrueO<NA>O0[ -0.57650733, -0.42160645, 0.994703...
\n", "

281109 rows × 12 columns

\n", "
" ], "text/plain": [ " fold doc_num token_id input_id token_type_id attention_mask \\\n", "0 train 0 0 101 0 1 \n", "1 train 0 1 118 0 1 \n", "2 train 0 2 141 0 1 \n", "3 train 0 3 9244 0 1 \n", "4 train 0 4 9272 0 1 \n", "... ... ... ... ... ... ... \n", "281104 train 945 53 17057 0 1 \n", "281105 train 945 54 122 0 1 \n", "281106 train 945 55 4617 0 1 \n", "281107 train 945 56 123 0 1 \n", "281108 train 945 57 102 0 1 \n", "\n", " special_tokens_mask ent_iob ent_type token_class token_class_id \\\n", "0 True O O 0 \n", "1 False O O 0 \n", "2 False O O 0 \n", "3 False O O 0 \n", "4 False O O 0 \n", "... ... ... ... ... ... \n", "281104 False B ORG B-ORG 3 \n", "281105 False O O 0 \n", "281106 False B ORG B-ORG 3 \n", "281107 False O O 0 \n", "281108 True O O 0 \n", "\n", " embedding \n", "0 [ -0.098505504, -0.4050192, 0.742888... \n", "1 [ -0.057021566, -0.48112106, 0.989868... \n", "2 [ -0.04824192, -0.2532998, 1.16719... \n", "3 [ -0.26682985, -0.31008705, 1.00747... \n", "4 [ -0.22296886, -0.21308525, 0.933102... \n", "... ... \n", "281104 [ 0.7556371, -0.91891253, -0.1403036... \n", "281105 [ -0.11528473, -0.44492027, 0.4715562... \n", "281106 [ 0.45602208, -0.8970848, 0.0678616... \n", "281107 [ -0.19713743, -0.5427194, 0.294020... \n", "281108 [ -0.57650733, -0.42160645, 0.994703... \n", "\n", "[281109 rows x 12 columns]" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Extract the training set DataFrame.\n", "train_df = corpus_df[corpus_df[\"fold\"] == \"train\"]\n", "train_df" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RUNNING THE L-BFGS-B CODE\n", "\n", " * * *\n", "\n", "Machine precision = 2.220D-16\n", " N = 6921 M = 10\n", "\n", "At X0 0 variables are exactly at the bounds\n", "\n", "At iterate 0 f= 6.17660D+05 |proj g|= 4.23293D+05\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " This problem is unconstrained.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "At iterate 50 f= 1.22005D+04 |proj g|= 2.48275D+02\n", "\n", "At iterate 100 f= 8.87639D+03 |proj g|= 1.72205D+02\n", "\n", "At iterate 150 f= 8.07946D+03 |proj g|= 1.28633D+02\n", "\n", "At iterate 200 f= 7.87840D+03 |proj g|= 6.20068D+01\n", "\n", "At iterate 250 f= 7.81730D+03 |proj g|= 9.11741D+00\n", "\n", "At iterate 300 f= 7.80144D+03 |proj g|= 6.86435D+00\n", "\n", "At iterate 350 f= 7.79623D+03 |proj g|= 7.21843D+00\n", "\n", "At iterate 400 f= 7.79451D+03 |proj g|= 5.64213D+00\n", "\n", "At iterate 450 f= 7.79356D+03 |proj g|= 2.47884D+00\n", "\n", "At iterate 500 f= 7.79273D+03 |proj g|= 2.32130D+00\n", "\n", "At iterate 550 f= 7.79141D+03 |proj g|= 1.03513D+01\n", "\n", "At iterate 600 f= 7.78944D+03 |proj g|= 4.39763D+00\n", "\n", "At iterate 650 f= 7.78798D+03 |proj g|= 2.72198D+00\n", "\n", "At iterate 700 f= 7.78721D+03 |proj g|= 2.49312D+00\n", "\n", "At iterate 750 f= 7.78691D+03 |proj g|= 2.09049D+00\n", "\n", "At iterate 800 f= 7.78678D+03 |proj g|= 1.56225D+00\n", "\n", "At iterate 850 f= 7.78669D+03 |proj g|= 9.61272D-01\n", "\n", "At iterate 900 f= 7.78660D+03 |proj g|= 1.88970D+00\n", "\n", "At iterate 950 f= 7.78644D+03 |proj g|= 1.39468D+00\n", "\n", "At iterate 1000 f= 7.78615D+03 |proj g|= 1.56165D+00\n", "\n", "At iterate 1050 f= 7.78593D+03 |proj g|= 1.81700D+00\n", "\n", "At iterate 1100 f= 7.78581D+03 |proj g|= 1.11273D+00\n", "\n", "At iterate 1150 f= 7.78577D+03 |proj g|= 4.10524D-01\n", "\n", "At iterate 1200 f= 7.78575D+03 |proj g|= 3.49336D-01\n", "\n", "At iterate 1250 f= 7.78574D+03 |proj g|= 8.20185D-01\n", "\n", "At iterate 1300 f= 7.78571D+03 |proj g|= 9.94495D-01\n", "\n", "At iterate 1350 f= 7.78567D+03 |proj g|= 7.14421D-01\n", "\n", "At iterate 1400 f= 7.78563D+03 |proj g|= 3.46513D-01\n", "\n", "At iterate 1450 f= 7.78561D+03 |proj g|= 1.15784D+00\n", "\n", "At iterate 1500 f= 7.78559D+03 |proj g|= 5.66811D-01\n", "\n", "At iterate 1550 f= 7.78559D+03 |proj g|= 1.43156D-01\n", "\n", "At iterate 1600 f= 7.78558D+03 |proj g|= 1.60595D-01\n", "\n", " * * *\n", "\n", "Tit = total number of iterations\n", "Tnf = total number of function evaluations\n", "Tnint = total number of segments explored during Cauchy searches\n", "Skip = number of BFGS updates skipped\n", "Nact = number of active bounds at final generalized Cauchy point\n", "Projg = norm of the final projected gradient\n", "F = final function value\n", "\n", " * * *\n", "\n", " N Tit Tnf Tnint Skip Nact Projg F\n", " 6921 1604 1694 1 0 0 4.829D-01 7.786D+03\n", " F = 7785.5829997825367 \n", "\n", "CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH \n", "CPU times: user 1h 34min 15s, sys: 6min 41s, total: 1h 40min 56s\n", "Wall time: 12min 44s\n" ] }, { "data": { "text/html": [ "
Pipeline(steps=[('mlogreg',\n",
       "                 LogisticRegression(C=0.1, max_iter=10000,\n",
       "                                    multi_class='multinomial', verbose=1))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "Pipeline(steps=[('mlogreg',\n", " LogisticRegression(C=0.1, max_iter=10000,\n", " multi_class='multinomial', verbose=1))])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "\n", "# Train a multinomial logistic regression model on the training set.\n", "MULTI_CLASS = \"multinomial\"\n", " \n", "# How many iterations to run the BGFS optimizer when fitting logistic\n", "# regression models. 100 ==> Fast; 10000 ==> Full convergence\n", "LBGFS_ITERATIONS = 10000\n", "_REGULARIZATION_COEFF = 1e-1 # Smaller values ==> more regularization\n", "\n", "base_pipeline = sklearn.pipeline.Pipeline([\n", " # Standard scaler. This only makes a difference for certain classes\n", " # of embeddings.\n", " #(\"scaler\", sklearn.preprocessing.StandardScaler()),\n", " (\"mlogreg\", sklearn.linear_model.LogisticRegression(\n", " multi_class=MULTI_CLASS,\n", " verbose=1,\n", " max_iter=LBGFS_ITERATIONS,\n", " C=_REGULARIZATION_COEFF\n", " ))\n", "])\n", "\n", "X_train = train_df[\"embedding\"].values\n", "Y_train = train_df[\"token_class_id\"]\n", "base_model = base_pipeline.fit(X_train, Y_train)\n", "base_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make Predictions on Token Class from BERT Embeddings\n", "\n", "Using our model, we can now predict the token class from the test set using the computed embeddings." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "# Define a function that will let us make predictions on a fold of the corpus.\n", "def predict_on_df(df: pd.DataFrame, id_to_class: Dict[int, str], predictor):\n", " \"\"\"\n", " Run a trained model on a DataFrame of tokens with embeddings.\n", "\n", " :param df: DataFrame of tokens for a document, containing a TokenSpan column\n", " called \"embedding\" for each token.\n", " :param id_to_class: Mapping from class ID to class name, as returned by\n", " :func:`text_extensions_for_pandas.make_iob_tag_categories`\n", " :param predictor: Python object with a `predict_proba` method that accepts\n", " a numpy array of embeddings.\n", " :returns: A copy of `df`, with the following additional columns:\n", " `predicted_id`, `predicted_class`, `predicted_iob`, `predicted_type`\n", " and `predicted_class_pr`.\n", " \"\"\"\n", " result_df = df.copy()\n", " embeddings = result_df[\"embedding\"].to_numpy()\n", " class_pr = tp.TensorArray(predictor.predict_proba(embeddings))\n", " result_df[\"predicted_id\"] = np.argmax(class_pr, axis=1)\n", " result_df[\"predicted_class\"] = [id_to_class[i]\n", " for i in result_df[\"predicted_id\"].values]\n", " iobs, types = tp.io.conll.decode_class_labels(result_df[\"predicted_class\"].values)\n", " result_df[\"predicted_iob\"] = iobs\n", " result_df[\"predicted_type\"] = types\n", " result_df[\"predicted_class_pr\"] = class_pr\n", " return result_df" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
folddoc_numtoken_idinput_idtoken_type_idattention_maskspecial_tokens_maskent_iobent_typetoken_classtoken_class_idembeddingpredicted_idpredicted_classpredicted_iobpredicted_typepredicted_class_pr
351001test0010101TrueO<NA>O0[ -0.19626583, -0.450937, 0.6775361...0OONone[ 0.9994774788863705, 1.9985127298723906e-0...
351002test0111801FalseO<NA>O0[ -0.3187211, -0.5074784, 1.046454...0OONone[ 0.9992964240340214, 3.7581023374440964e-0...
351003test0214101FalseO<NA>O0[ -0.080538824, -0.2477481, 1.356255...0OONone[ 0.998973288221842, 0.0004299715907382311...
351004test03924401FalseO<NA>O0[ -0.6878579, -0.30290246, 0.8842714...0OONone[ 0.9983217119367633, 4.888114850946988e-0...
351005test04927201FalseO<NA>O0[ -0.2963228, -0.23313177, 0.93988...0OONone[ 0.9999185106741023, 8.938753477308423e-0...
\n", "
" ], "text/plain": [ " fold doc_num token_id input_id token_type_id attention_mask \\\n", "351001 test 0 0 101 0 1 \n", "351002 test 0 1 118 0 1 \n", "351003 test 0 2 141 0 1 \n", "351004 test 0 3 9244 0 1 \n", "351005 test 0 4 9272 0 1 \n", "\n", " special_tokens_mask ent_iob ent_type token_class token_class_id \\\n", "351001 True O O 0 \n", "351002 False O O 0 \n", "351003 False O O 0 \n", "351004 False O O 0 \n", "351005 False O O 0 \n", "\n", " embedding predicted_id \\\n", "351001 [ -0.19626583, -0.450937, 0.6775361... 0 \n", "351002 [ -0.3187211, -0.5074784, 1.046454... 0 \n", "351003 [ -0.080538824, -0.2477481, 1.356255... 0 \n", "351004 [ -0.6878579, -0.30290246, 0.8842714... 0 \n", "351005 [ -0.2963228, -0.23313177, 0.93988... 0 \n", "\n", " predicted_class predicted_iob predicted_type \\\n", "351001 O O None \n", "351002 O O None \n", "351003 O O None \n", "351004 O O None \n", "351005 O O None \n", "\n", " predicted_class_pr \n", "351001 [ 0.9994774788863705, 1.9985127298723906e-0... \n", "351002 [ 0.9992964240340214, 3.7581023374440964e-0... \n", "351003 [ 0.998973288221842, 0.0004299715907382311... \n", "351004 [ 0.9983217119367633, 4.888114850946988e-0... \n", "351005 [ 0.9999185106741023, 8.938753477308423e-0... " ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Make predictions on the test set.\n", "test_results_df = predict_on_df(corpus_df[corpus_df[\"fold\"] == \"test\"], int_to_label, base_model)\n", "test_results_df.head()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
folddoc_numtoken_idinput_idtoken_type_idattention_maskspecial_tokens_maskent_iobent_typetoken_classtoken_class_idembeddingpredicted_idpredicted_classpredicted_iobpredicted_typepredicted_class_pr
351041test040330901FalseIPERI-PER8[ -0.21029201, -0.8535674, 0.0002756594...6I-MISCIMISC[ 0.0010111308810159478, 1.6209660863726316e-0...
351042test041130601FalseIPERI-PER8[ -0.23205486, -0.9290767, 0.3889118...6I-MISCIMISC[ 0.012755027203264928, 0.00554094580945546...
351043test042200101FalseIPERI-PER8[ 0.36844134, -0.68091154, -0.1059106...5I-LOCILOC[ 0.008349822538261149, 0.180904633782168...
351044test043118101FalseIPERI-PER8[ -0.30131084, -0.6546019, -0.1726912...8I-PERIPER[ 0.013398092974719904, 0.000889872066127380...
351045test044229301FalseIPERI-PER8[ -0.1611614, -0.69891113, 0.2342468...5I-LOCILOC[ 0.014927046511081343, 0.0209250472885050...
351046test0451858901FalseBLOCB-LOC1[ -0.058567554, -0.79558676, 0.3360603...1B-LOCBLOC[ 0.027281135850703336, 0.532249166723370...
351047test04611801FalseILOCI-LOC5[ 0.2037595, -0.73730904, -0.0888521...5I-LOCILOC[ 0.22512840995098554, 0.00379439656874946...
351048test0471901601FalseILOCI-LOC5[ -0.10341229, -0.33681834, 0.1738456...5I-LOCILOC[ 0.04472568023866835, 0.436126151622446...
351049test048224901FalseILOCI-LOC5[ -0.4054268, -0.6516522, 0.2469...5I-LOCILOC[ 0.0009405393288526446, 0.00244544190700176...
351050test04911701FalseO<NA>O0[ -0.16829254, -0.6475861, 0.8149025...0OONone[ 0.9999736550716568, 5.7005018158771435e-0...
\n", "
" ], "text/plain": [ " fold doc_num token_id input_id token_type_id attention_mask \\\n", "351041 test 0 40 3309 0 1 \n", "351042 test 0 41 1306 0 1 \n", "351043 test 0 42 2001 0 1 \n", "351044 test 0 43 1181 0 1 \n", "351045 test 0 44 2293 0 1 \n", "351046 test 0 45 18589 0 1 \n", "351047 test 0 46 118 0 1 \n", "351048 test 0 47 19016 0 1 \n", "351049 test 0 48 2249 0 1 \n", "351050 test 0 49 117 0 1 \n", "\n", " special_tokens_mask ent_iob ent_type token_class token_class_id \\\n", "351041 False I PER I-PER 8 \n", "351042 False I PER I-PER 8 \n", "351043 False I PER I-PER 8 \n", "351044 False I PER I-PER 8 \n", "351045 False I PER I-PER 8 \n", "351046 False B LOC B-LOC 1 \n", "351047 False I LOC I-LOC 5 \n", "351048 False I LOC I-LOC 5 \n", "351049 False I LOC I-LOC 5 \n", "351050 False O O 0 \n", "\n", " embedding predicted_id \\\n", "351041 [ -0.21029201, -0.8535674, 0.0002756594... 6 \n", "351042 [ -0.23205486, -0.9290767, 0.3889118... 6 \n", "351043 [ 0.36844134, -0.68091154, -0.1059106... 5 \n", "351044 [ -0.30131084, -0.6546019, -0.1726912... 8 \n", "351045 [ -0.1611614, -0.69891113, 0.2342468... 5 \n", "351046 [ -0.058567554, -0.79558676, 0.3360603... 1 \n", "351047 [ 0.2037595, -0.73730904, -0.0888521... 5 \n", "351048 [ -0.10341229, -0.33681834, 0.1738456... 5 \n", "351049 [ -0.4054268, -0.6516522, 0.2469... 5 \n", "351050 [ -0.16829254, -0.6475861, 0.8149025... 0 \n", "\n", " predicted_class predicted_iob predicted_type \\\n", "351041 I-MISC I MISC \n", "351042 I-MISC I MISC \n", "351043 I-LOC I LOC \n", "351044 I-PER I PER \n", "351045 I-LOC I LOC \n", "351046 B-LOC B LOC \n", "351047 I-LOC I LOC \n", "351048 I-LOC I LOC \n", "351049 I-LOC I LOC \n", "351050 O O None \n", "\n", " predicted_class_pr \n", "351041 [ 0.0010111308810159478, 1.6209660863726316e-0... \n", "351042 [ 0.012755027203264928, 0.00554094580945546... \n", "351043 [ 0.008349822538261149, 0.180904633782168... \n", "351044 [ 0.013398092974719904, 0.000889872066127380... \n", "351045 [ 0.014927046511081343, 0.0209250472885050... \n", "351046 [ 0.027281135850703336, 0.532249166723370... \n", "351047 [ 0.22512840995098554, 0.00379439656874946... \n", "351048 [ 0.04472568023866835, 0.436126151622446... \n", "351049 [ 0.0009405393288526446, 0.00244544190700176... \n", "351050 [ 0.9999736550716568, 5.7005018158771435e-0... " ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Take a slice to show a region with more entities.\n", "test_results_df.iloc[40:50]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute Precision and Recall\n", "\n", "With our model predictions on the test set, we can now compute precision and recall. To do this, we will use the following steps:\n", "\n", "1. Split up test set predictions by document, so we can work on the document level.\n", "1. Join the test predictions with token information into one DataFrame per document.\n", "1. Convert each DataFrame from IOB2 format to span, entity type pairs as done before.\n", "1. Compute accuracy for each document as a DataFrame.\n", "1. Aggregate per-document accuracy to get overal precision/recall." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
token_idspanent_iobent_typepredicted_iobpredicted_type
4040[68, 70): 'di'IPERIMISC
4141[70, 71): 'm'IPERIMISC
4242[72, 74): 'La'IPERILOC
4343[74, 75): 'd'IPERIPER
4444[75, 77): 'ki'IPERILOC
4545[78, 80): 'AL'BLOCBLOC
4646[80, 81): '-'ILOCILOC
4747[81, 83): 'AI'ILOCILOC
4848[83, 84): 'N'ILOCILOC
4949[84, 85): ','O<NA>ONone
5050[86, 92): 'United'BLOCBLOC
5151[93, 97): 'Arab'ILOCILOC
5252[98, 106): 'Emirates'ILOCILOC
5353[107, 111): '1996'O<NA>ONone
5454[111, 112): '-'O<NA>ONone
5555[112, 114): '12'O<NA>ONone
5656[114, 115): '-'O<NA>ONone
5757[115, 117): '06'O<NA>ONone
5858[118, 123): 'Japan'BLOCBLOC
5959[124, 129): 'began'O<NA>ONone
\n", "
" ], "text/plain": [ " token_id span ent_iob ent_type predicted_iob \\\n", "40 40 [68, 70): 'di' I PER I \n", "41 41 [70, 71): 'm' I PER I \n", "42 42 [72, 74): 'La' I PER I \n", "43 43 [74, 75): 'd' I PER I \n", "44 44 [75, 77): 'ki' I PER I \n", "45 45 [78, 80): 'AL' B LOC B \n", "46 46 [80, 81): '-' I LOC I \n", "47 47 [81, 83): 'AI' I LOC I \n", "48 48 [83, 84): 'N' I LOC I \n", "49 49 [84, 85): ',' O O \n", "50 50 [86, 92): 'United' B LOC B \n", "51 51 [93, 97): 'Arab' I LOC I \n", "52 52 [98, 106): 'Emirates' I LOC I \n", "53 53 [107, 111): '1996' O O \n", "54 54 [111, 112): '-' O O \n", "55 55 [112, 114): '12' O O \n", "56 56 [114, 115): '-' O O \n", "57 57 [115, 117): '06' O O \n", "58 58 [118, 123): 'Japan' B LOC B \n", "59 59 [124, 129): 'began' O O \n", "\n", " predicted_type \n", "40 MISC \n", "41 MISC \n", "42 LOC \n", "43 PER \n", "44 LOC \n", "45 LOC \n", "46 LOC \n", "47 LOC \n", "48 LOC \n", "49 None \n", "50 LOC \n", "51 LOC \n", "52 LOC \n", "53 None \n", "54 None \n", "55 None \n", "56 None \n", "57 None \n", "58 LOC \n", "59 None " ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Split model outputs for an entire fold back into documents and add\n", "# token information.\n", "\n", "# Get unique documents per fold.\n", "fold_and_doc = test_results_df[[\"fold\", \"doc_num\"]] \\\n", " .drop_duplicates() \\\n", " .to_records(index=False)\n", "\n", "# Index by fold, doc and token id, then make sure sorted.\n", "indexed_df = test_results_df \\\n", " .set_index([\"fold\", \"doc_num\", \"token_id\"], verify_integrity=True) \\\n", " .sort_index()\n", "\n", "# Join predictions with token information, for each document.\n", "test_results_by_doc = {}\n", "for collection, doc_num in fold_and_doc:\n", " doc_slice = indexed_df.loc[collection, doc_num].reset_index()\n", " doc_toks = bert_toks_by_fold[collection][doc_num][\n", " [\"token_id\", \"span\", \"ent_iob\", \"ent_type\"]\n", " ].rename(columns={\"id\": \"token_id\"})\n", " joined_df = doc_toks.copy().merge(\n", " doc_slice[[\"token_id\", \"predicted_iob\", \"predicted_type\"]])\n", " test_results_by_doc[(collection, doc_num)] = joined_df\n", " \n", "# Test results are now in one DataFrame per document.\n", "test_results_by_doc[(\"test\", 0)].iloc[40:60]" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
spanent_type
0[19, 24): 'JAPAN'PER
1[29, 34): 'LUCKY'LOC
2[40, 45): 'CHINA'ORG
3[66, 77): 'Nadim Ladki'LOC
4[78, 84): 'AL-AIN'LOC
\n", "
" ], "text/plain": [ " span ent_type\n", "0 [19, 24): 'JAPAN' PER\n", "1 [29, 34): 'LUCKY' LOC\n", "2 [40, 45): 'CHINA' ORG\n", "3 [66, 77): 'Nadim Ladki' LOC\n", "4 [78, 84): 'AL-AIN' LOC" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Convert IOB2 format to spans, entity type with `tp.io.conll.iob_to_spans()`.\n", "test_actual_spans = {k: tp.io.conll.iob_to_spans(v) for k, v in test_results_by_doc.items()}\n", "test_model_spans = {k:\n", " tp.io.conll.iob_to_spans(v, iob_col_name = \"predicted_iob\",\n", " entity_type_col_name = \"predicted_type\")\n", " .rename(columns={\"predicted_type\": \"ent_type\"})\n", " for k, v in test_results_by_doc.items()}\n", "\n", "test_model_spans[(\"test\", 0)].head()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
folddoc_numnum_true_positivesnum_extractednum_entitiesprecisionrecallF1
0test04147450.8723400.9111110.891304
1test14142440.9761900.9318180.953488
2test25254540.9629630.9629630.962963
3test34244440.9545450.9545450.954545
4test41819190.9473680.9473680.947368
...........................
226test2266770.8571430.8571430.857143
227test2271819210.9473680.8571430.900000
228test2282428270.8571430.8888890.872727
229test2292527270.9259260.9259260.925926
230test2302527280.9259260.8928570.909091
\n", "

231 rows × 8 columns

\n", "
" ], "text/plain": [ " fold doc_num num_true_positives num_extracted num_entities \\\n", "0 test 0 41 47 45 \n", "1 test 1 41 42 44 \n", "2 test 2 52 54 54 \n", "3 test 3 42 44 44 \n", "4 test 4 18 19 19 \n", ".. ... ... ... ... ... \n", "226 test 226 6 7 7 \n", "227 test 227 18 19 21 \n", "228 test 228 24 28 27 \n", "229 test 229 25 27 27 \n", "230 test 230 25 27 28 \n", "\n", " precision recall F1 \n", "0 0.872340 0.911111 0.891304 \n", "1 0.976190 0.931818 0.953488 \n", "2 0.962963 0.962963 0.962963 \n", "3 0.954545 0.954545 0.954545 \n", "4 0.947368 0.947368 0.947368 \n", ".. ... ... ... \n", "226 0.857143 0.857143 0.857143 \n", "227 0.947368 0.857143 0.900000 \n", "228 0.857143 0.888889 0.872727 \n", "229 0.925926 0.925926 0.925926 \n", "230 0.925926 0.892857 0.909091 \n", "\n", "[231 rows x 8 columns]" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Compute per-document statistics into a single DataFrame.\n", "test_stats_by_doc = tp.io.conll.compute_accuracy_by_document(test_actual_spans, test_model_spans)\n", "test_stats_by_doc" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'num_true_positives': 4881,\n", " 'num_entities': 5648,\n", " 'num_extracted': 5620,\n", " 'precision': 0.8685053380782918,\n", " 'recall': 0.8641997167138811,\n", " 'F1': 0.8663471778487754}" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Collection-wide precision and recall can be computed by aggregating\n", "# our DataFrame.\n", "tp.io.conll.compute_global_accuracy(test_stats_by_doc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Adjusting the BERT Model Output\n", "\n", "The above results aren't bad for a first shot, but taking a look a some of the predictions will show that sometimes the tokens have been split up into multiple entities. This is because the BERT tokenizer uses WordPiece to make subword tokens, see https://huggingface.co/transformers/tokenizer_summary.html and https://static.googleusercontent.com/media/research.google.com/ja//pubs/archive/37842.pdf for more information.\n", "\n", "This is going to cause a problem when computing precision/recall because we are comparing exact spans, and if the entity is split, it will be counted as a false negative _and_ possibly one or more false positives. Luckily we can fix up with Text Extension for Pandas.\n", "\n", "Let's drill down to see an example of the issue and how to correct it." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
spanent_type
0[11, 22): 'RUGBY UNION'ORG
1[24, 31): 'BRITISH'MISC
2[41, 47): 'LONDON'LOC
3[70, 77): 'British'MISC
4[111, 125): 'Pilkington Cup'MISC
5[139, 146): 'Reading'ORG
6[150, 151): 'W'ORG
7[151, 156): 'idnes'ORG
8[159, 166): 'English'MISC
9[180, 184): 'Bath'ORG
\n", "
" ], "text/plain": [ " span ent_type\n", "0 [11, 22): 'RUGBY UNION' ORG\n", "1 [24, 31): 'BRITISH' MISC\n", "2 [41, 47): 'LONDON' LOC\n", "3 [70, 77): 'British' MISC\n", "4 [111, 125): 'Pilkington Cup' MISC\n", "5 [139, 146): 'Reading' ORG\n", "6 [150, 151): 'W' ORG\n", "7 [151, 156): 'idnes' ORG\n", "8 [159, 166): 'English' MISC\n", "9 [180, 184): 'Bath' ORG" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Every once in a while, the BERT model will split a token in the original data\n", "# set into multiple entities. For example, look at document 202 of the test set:\n", "test_model_spans[(\"test\", 202)].head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice `[150, 151): 'W'` and `[151, 156): 'idnes'`. These outputs are part\n", "of the same original token, but have been split by the model." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
spancorpus_tokenent_type
0[11, 22): 'RUGBY UNION'[11, 16): 'RUGBY'ORG
1[11, 22): 'RUGBY UNION'[17, 22): 'UNION'ORG
2[24, 31): 'BRITISH'[24, 31): 'BRITISH'MISC
3[41, 47): 'LONDON'[41, 47): 'LONDON'LOC
4[70, 77): 'British'[70, 77): 'British'MISC
5[111, 125): 'Pilkington Cup'[111, 121): 'Pilkington'MISC
6[111, 125): 'Pilkington Cup'[122, 125): 'Cup'MISC
7[139, 146): 'Reading'[139, 146): 'Reading'ORG
8[150, 151): 'W'[150, 156): 'Widnes'ORG
9[151, 156): 'idnes'[150, 156): 'Widnes'ORG
\n", "
" ], "text/plain": [ " span corpus_token ent_type\n", "0 [11, 22): 'RUGBY UNION' [11, 16): 'RUGBY' ORG\n", "1 [11, 22): 'RUGBY UNION' [17, 22): 'UNION' ORG\n", "2 [24, 31): 'BRITISH' [24, 31): 'BRITISH' MISC\n", "3 [41, 47): 'LONDON' [41, 47): 'LONDON' LOC\n", "4 [70, 77): 'British' [70, 77): 'British' MISC\n", "5 [111, 125): 'Pilkington Cup' [111, 121): 'Pilkington' MISC\n", "6 [111, 125): 'Pilkington Cup' [122, 125): 'Cup' MISC\n", "7 [139, 146): 'Reading' [139, 146): 'Reading' ORG\n", "8 [150, 151): 'W' [150, 156): 'Widnes' ORG\n", "9 [151, 156): 'idnes' [150, 156): 'Widnes' ORG" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We can use spanner algebra in `tp.spanner.overlap_join()`\n", "# to fix up these outputs.\n", "spans_df = test_model_spans[(\"test\", 202)]\n", "toks_df = test_raw[202]\n", "\n", "# First, find which tokens the spans overlap with:\n", "overlaps_df = (\n", " tp.spanner.overlap_join(spans_df[\"span\"], toks_df[\"span\"],\n", " \"span\", \"corpus_token\")\n", " .merge(spans_df)\n", ")\n", "overlaps_df.head(10)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
spancorpus_tokenent_type
0[11, 22): 'RUGBY UNION'[11, 22): 'RUGBY UNION'ORG
1[24, 31): 'BRITISH'[24, 31): 'BRITISH'MISC
2[41, 47): 'LONDON'[41, 47): 'LONDON'LOC
3[70, 77): 'British'[70, 77): 'British'MISC
4[111, 125): 'Pilkington Cup'[111, 125): 'Pilkington Cup'MISC
5[139, 146): 'Reading'[139, 146): 'Reading'ORG
6[150, 151): 'W'[150, 156): 'Widnes'ORG
7[151, 156): 'idnes'[150, 156): 'Widnes'ORG
8[159, 166): 'English'[159, 166): 'English'MISC
9[180, 184): 'Bath'[180, 184): 'Bath'ORG
\n", "
" ], "text/plain": [ " span corpus_token ent_type\n", "0 [11, 22): 'RUGBY UNION' [11, 22): 'RUGBY UNION' ORG\n", "1 [24, 31): 'BRITISH' [24, 31): 'BRITISH' MISC\n", "2 [41, 47): 'LONDON' [41, 47): 'LONDON' LOC\n", "3 [70, 77): 'British' [70, 77): 'British' MISC\n", "4 [111, 125): 'Pilkington Cup' [111, 125): 'Pilkington Cup' MISC\n", "5 [139, 146): 'Reading' [139, 146): 'Reading' ORG\n", "6 [150, 151): 'W' [150, 156): 'Widnes' ORG\n", "7 [151, 156): 'idnes' [150, 156): 'Widnes' ORG\n", "8 [159, 166): 'English' [159, 166): 'English' MISC\n", "9 [180, 184): 'Bath' [180, 184): 'Bath' ORG" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Next, compute the minimum span that covers all the corpus tokens\n", "# that overlap with each entity span.\n", "agg_df = (\n", " overlaps_df\n", " .groupby(\"span\")\n", " .aggregate({\"corpus_token\": \"sum\", \"ent_type\": \"first\"})\n", " .reset_index()\n", ")\n", "agg_df.head(10)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
spanent_type
0[11, 22): 'RUGBY UNION'ORG
1[24, 31): 'BRITISH'MISC
2[41, 47): 'LONDON'LOC
3[70, 77): 'British'MISC
4[111, 125): 'Pilkington Cup'MISC
5[139, 146): 'Reading'ORG
6[150, 156): 'Widnes'ORG
8[159, 166): 'English'MISC
9[180, 184): 'Bath'ORG
10[188, 198): 'Harlequins'ORG
\n", "
" ], "text/plain": [ " span ent_type\n", "0 [11, 22): 'RUGBY UNION' ORG\n", "1 [24, 31): 'BRITISH' MISC\n", "2 [41, 47): 'LONDON' LOC\n", "3 [70, 77): 'British' MISC\n", "4 [111, 125): 'Pilkington Cup' MISC\n", "5 [139, 146): 'Reading' ORG\n", "6 [150, 156): 'Widnes' ORG\n", "8 [159, 166): 'English' MISC\n", "9 [180, 184): 'Bath' ORG\n", "10 [188, 198): 'Harlequins' ORG" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Finally, take unique values and covert character-based spans to token\n", "# spans in the corpus tokenization (since the new offsets might not match a\n", "# BERT tokenizer token boundary).\n", "cons_df = (\n", " tp.spanner.consolidate(agg_df, \"corpus_token\")[[\"corpus_token\", \"ent_type\"]]\n", " .rename(columns={\"corpus_token\": \"span\"})\n", ")\n", "cons_df[\"span\"] = tp.TokenSpanArray.align_to_tokens(toks_df[\"span\"],\n", " cons_df[\"span\"])\n", "cons_df.head(10)" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
spanent_type
0[11, 22): 'RUGBY UNION'ORG
1[24, 31): 'BRITISH'MISC
2[41, 47): 'LONDON'LOC
3[70, 77): 'British'MISC
4[111, 125): 'Pilkington Cup'MISC
5[139, 146): 'Reading'ORG
6[150, 156): 'Widnes'ORG
8[159, 166): 'English'MISC
9[180, 184): 'Bath'ORG
10[188, 198): 'Harlequins'ORG
\n", "
" ], "text/plain": [ " span ent_type\n", "0 [11, 22): 'RUGBY UNION' ORG\n", "1 [24, 31): 'BRITISH' MISC\n", "2 [41, 47): 'LONDON' LOC\n", "3 [70, 77): 'British' MISC\n", "4 [111, 125): 'Pilkington Cup' MISC\n", "5 [139, 146): 'Reading' ORG\n", "6 [150, 156): 'Widnes' ORG\n", "8 [159, 166): 'English' MISC\n", "9 [180, 184): 'Bath' ORG\n", "10 [188, 198): 'Harlequins' ORG" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Text Extensions for Pandas contains a single function that repeats the actions of the \n", "# previous 3 cells.\n", "tp.io.bert.align_bert_tokens_to_corpus_tokens(test_model_spans[(\"test\", 202)], test_raw[202]).head(10)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "68bcf2a129584acf9b7ce27ccd66302a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=231, style=ProgressStyle(desc…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
spanent_type
0[11, 22): 'RUGBY UNION'ORG
1[24, 31): 'BRITISH'MISC
2[41, 47): 'LONDON'LOC
3[70, 77): 'British'MISC
4[111, 125): 'Pilkington Cup'MISC
5[139, 146): 'Reading'ORG
6[150, 156): 'Widnes'ORG
8[159, 166): 'English'MISC
9[180, 184): 'Bath'ORG
10[188, 198): 'Harlequins'ORG
\n", "
" ], "text/plain": [ " span ent_type\n", "0 [11, 22): 'RUGBY UNION' ORG\n", "1 [24, 31): 'BRITISH' MISC\n", "2 [41, 47): 'LONDON' LOC\n", "3 [70, 77): 'British' MISC\n", "4 [111, 125): 'Pilkington Cup' MISC\n", "5 [139, 146): 'Reading' ORG\n", "6 [150, 156): 'Widnes' ORG\n", "8 [159, 166): 'English' MISC\n", "9 [180, 184): 'Bath' ORG\n", "10 [188, 198): 'Harlequins' ORG" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Run all of our DataFrames through `align_bert_tokens_to_corpus_tokens()`.\n", "keys = list(test_model_spans.keys())\n", "new_values = tp.jupyter.run_with_progress_bar(\n", " len(keys), \n", " lambda i: tp.io.bert.align_bert_tokens_to_corpus_tokens(test_model_spans[keys[i]], test_raw[keys[i][1]]))\n", "test_model_spans = {k: v for k, v in zip(keys, new_values)}\n", "test_model_spans[(\"test\", 202)].head(10)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
folddoc_numnum_true_positivesnum_extractednum_entitiesprecisionrecallF1
0test04247450.8936170.9333330.913043
1test14142440.9761900.9318180.953488
2test25254540.9629630.9629630.962963
3test34244440.9545450.9545450.954545
4test41819190.9473680.9473680.947368
...........................
226test2267771.0000001.0000001.000000
227test2271819210.9473680.8571430.900000
228test2282427270.8888890.8888890.888889
229test2292627270.9629630.9629630.962963
230test2302627280.9629630.9285710.945455
\n", "

231 rows × 8 columns

\n", "
" ], "text/plain": [ " fold doc_num num_true_positives num_extracted num_entities \\\n", "0 test 0 42 47 45 \n", "1 test 1 41 42 44 \n", "2 test 2 52 54 54 \n", "3 test 3 42 44 44 \n", "4 test 4 18 19 19 \n", ".. ... ... ... ... ... \n", "226 test 226 7 7 7 \n", "227 test 227 18 19 21 \n", "228 test 228 24 27 27 \n", "229 test 229 26 27 27 \n", "230 test 230 26 27 28 \n", "\n", " precision recall F1 \n", "0 0.893617 0.933333 0.913043 \n", "1 0.976190 0.931818 0.953488 \n", "2 0.962963 0.962963 0.962963 \n", "3 0.954545 0.954545 0.954545 \n", "4 0.947368 0.947368 0.947368 \n", ".. ... ... ... \n", "226 1.000000 1.000000 1.000000 \n", "227 0.947368 0.857143 0.900000 \n", "228 0.888889 0.888889 0.888889 \n", "229 0.962963 0.962963 0.962963 \n", "230 0.962963 0.928571 0.945455 \n", "\n", "[231 rows x 8 columns]" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Compute per-document statistics into a single DataFrame.\n", "test_stats_by_doc = tp.io.conll.compute_accuracy_by_document(test_actual_spans, test_model_spans)\n", "test_stats_by_doc" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'num_true_positives': 4971,\n", " 'num_entities': 5648,\n", " 'num_extracted': 5587,\n", " 'precision': 0.889744048684446,\n", " 'recall': 0.8801345609065155,\n", " 'F1': 0.8849132176234981}" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Collection-wide precision and recall can be computed by aggregating\n", "# our DataFrame.\n", "tp.io.conll.compute_global_accuracy(test_stats_by_doc)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These results are a bit better than before, and while the F1 score is not high compared to todays standards, it is decent enough for a simplistic model. More importantly, we did show it was fairly easy to create a model for named entity recognition and analyze the output by leveraging the functionalitiy of Pandas DataFrames along with [Text Extensions for Pandas](https://github.com/CODAIT/text-extensions-for-pandas) `SpanArray`, `TensorArray` and integration with BERT from [Huggingface Transformers](https://huggingface.co/transformers/index.html)." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.17" } }, "nbformat": 4, "nbformat_minor": 4 }