{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# CoNLL_4.ipynb\n", "\n", "This notebook contains the fourth part of the model training and analysis code from our CoNLL-2020 paper, [\"Identifying Incorrect Labels in the CoNLL-2003 Corpus\"](https://www.aclweb.org/anthology/2020.conll-1.16/).\n", "\n", "If you're new to the Text Extensions for Pandas library, we recommend that you start\n", "by reading through the notebook [`Analyze_Model_Outputs.ipynb`](https://github.com/CODAIT/text-extensions-for-pandas/blob/master/notebooks/Analyze_Model_Outputs.ipynb), which explains the \n", "portions of the library that we use in the notebooks in this directory.\n", "\n", "### Summary\n", "\n", "This notebook repeats the model training process from `CoNLL_3.ipynb`, but performs a 10-fold cross-validation. This process involves training a total of 170 models -- 10 groups of 17. Next, this notebook evaluates each group of models over the holdout set from the associated fold of the cross-validation. Then it aggregates together these outputs and uses the same techniques used in `CoNLL_2.ipynb` to flag potentially-incorrect labels. Finally, the notebook writes out CSV files containing ranked lists of potentially-incorrect labels.\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Libraries and constants" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "# Libraries\n", "import numpy as np\n", "import pandas as pd\n", "import os\n", "import sys\n", "import time\n", "import torch\n", "import transformers\n", "from typing import *\n", "import sklearn.model_selection\n", "import sklearn.pipeline\n", "import matplotlib.pyplot as plt\n", "import multiprocessing\n", "import gc\n", "\n", "# And of course we need the text_extensions_for_pandas library itself.\n", "try:\n", " import text_extensions_for_pandas as tp\n", "except ModuleNotFoundError as e:\n", " raise Exception(\"text_extensions_for_pandas package not found on the Jupyter \"\n", " \"kernel's path. Please either run:\\n\"\n", " \" ln -s ../../text_extensions_for_pandas .\\n\"\n", " \"from the directory containing this notebook, or use a Python \"\n", " \"environment on which you have used `pip` to install the package.\")\n", "\n", "from text_extensions_for_pandas import cleaning\n", " \n", "# BERT Configuration\n", "# Keep this in sync with `CoNLL_3.ipynb`.\n", "#bert_model_name = \"bert-base-uncased\"\n", "#bert_model_name = \"bert-large-uncased\"\n", "bert_model_name = \"dslim/bert-base-NER\"\n", "tokenizer = transformers.BertTokenizerFast.from_pretrained(bert_model_name, \n", " add_special_tokens=True)\n", "bert = transformers.BertModel.from_pretrained(bert_model_name)\n", "\n", "# If False, use cached values, provided those values are present on disk\n", "_REGENERATE_EMBEDDINGS = True\n", "_REGENERATE_MODELS = True\n", "\n", "# Number of dimensions that we reduce the BERT embeddings down to when\n", "# training reduced-quality models.\n", "#_REDUCED_DIMS = [8, 16, 32, 64, 128, 256]\n", "_REDUCED_DIMS = [32, 64, 128, 256]\n", "\n", "# How many models we train at each level of dimensionality reduction\n", "_MODELS_AT_DIM = [4] * len(_REDUCED_DIMS)\n", "\n", "# Consistent set of random seeds to use when generating dimension-reduced\n", "# models. Index is [index into _REDUCED_DIMS, model number], and there are\n", "# lots of extra entries so we don't need to resize this matrix.\n", "from numpy.random import default_rng\n", "_MASTER_SEED = 42\n", "rng = default_rng(_MASTER_SEED)\n", "_MODEL_RANDOM_SEEDS = rng.integers(0, 1e6, size=(8, 8))\n", "\n", "# Create a Pandas categorical type for consistent encoding of categories\n", "# across all documents.\n", "_ENTITY_TYPES = [\"LOC\", \"MISC\", \"ORG\", \"PER\"]\n", "token_class_dtype, int_to_label, label_to_int = tp.io.conll.make_iob_tag_categories(_ENTITY_TYPES)\n", "\n", "# Parameters for splitting the corpus into folds\n", "_KFOLD_RANDOM_SEED = _MASTER_SEED\n", "_KFOLD_NUM_FOLDS = 10\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Read inputs\n", "\n", "Read in the corpus, retokenize it with the BERT tokenizer, add BERT embeddings, and convert\n", "to a single dataframe." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'train': 'outputs/eng.train',\n", " 'dev': 'outputs/eng.testa',\n", " 'test': 'outputs/eng.testb'}" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Download and cache the data set.\n", "# NOTE: This data set is licensed for research use only. Be sure to adhere\n", "# to the terms of the license when using this data set!\n", "data_set_info = tp.io.conll.maybe_download_conll_data(\"outputs\")\n", "data_set_info" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# The raw dataset in its original tokenization\n", "corpus_raw = {}\n", "for fold_name, file_name in data_set_info.items():\n", " df_list = tp.io.conll.conll_2003_to_dataframes(file_name, \n", " [\"pos\", \"phrase\", \"ent\"],\n", " [False, True, True])\n", " corpus_raw[fold_name] = [\n", " df.drop(columns=[\"pos\", \"phrase_iob\", \"phrase_type\"])\n", " for df in df_list\n", " ]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "preprocessing fold train\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3e030cae2b3b4c898bfc57ba945740ab", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=946, style=ProgressStyle(desc…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Token indices sequence length is longer than the specified maximum sequence length for this model (559 > 512). Running this sequence through the model will result in indexing errors\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "preprocessing fold dev\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "55361e25bc62408082d864b813e9ab61", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=216, style=ProgressStyle(desc…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "preprocessing fold test\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f9753feeace94e65a24961292d4420db", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=231, style=ProgressStyle(desc…" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Retokenize with the BERT tokenizer and regenerate embeddings.\n", "corpus_df,token_class_dtype, int_to_label, label_to_int = cleaning.preprocess.preprocess_documents(corpus_raw,'ent_type',True,carry_cols=['line_num'],iob_col='ent_iob')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Prepare folds for a 10-fold cross-validation\n", "\n", "We divide the documents of the corpus into 10 random samples." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>fold</th>\n", " <th>doc_num</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>train</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>train</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>train</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>train</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>train</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>1388</th>\n", " <td>test</td>\n", " <td>226</td>\n", " </tr>\n", " <tr>\n", " <th>1389</th>\n", " <td>test</td>\n", " <td>227</td>\n", " </tr>\n", " <tr>\n", " <th>1390</th>\n", " <td>test</td>\n", " <td>228</td>\n", " </tr>\n", " <tr>\n", " <th>1391</th>\n", " <td>test</td>\n", " <td>229</td>\n", " </tr>\n", " <tr>\n", " <th>1392</th>\n", " <td>test</td>\n", " <td>230</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>1393 rows × 2 columns</p>\n", "</div>" ], "text/plain": [ " fold doc_num\n", "0 train 0\n", "1 train 1\n", "2 train 2\n", "3 train 3\n", "4 train 4\n", "... ... ...\n", "1388 test 226\n", "1389 test 227\n", "1390 test 228\n", "1391 test 229\n", "1392 test 230\n", "\n", "[1393 rows x 2 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# IDs for each of the keys\n", "doc_keys = corpus_df[[\"fold\", \"doc_num\"]].drop_duplicates().reset_index(drop=True)\n", "doc_keys" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>fold</th>\n", " <th>doc_num</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>146</th>\n", " <td>train</td>\n", " <td>146</td>\n", " </tr>\n", " <tr>\n", " <th>1164</th>\n", " <td>test</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>483</th>\n", " <td>train</td>\n", " <td>483</td>\n", " </tr>\n", " <tr>\n", " <th>1190</th>\n", " <td>test</td>\n", " <td>28</td>\n", " </tr>\n", " <tr>\n", " <th>20</th>\n", " <td>train</td>\n", " <td>20</td>\n", " </tr>\n", " <tr>\n", " <th>237</th>\n", " <td>train</td>\n", " <td>237</td>\n", " </tr>\n", " <tr>\n", " <th>86</th>\n", " <td>train</td>\n", " <td>86</td>\n", " </tr>\n", " <tr>\n", " <th>408</th>\n", " <td>train</td>\n", " <td>408</td>\n", " </tr>\n", " <tr>\n", " <th>1252</th>\n", " <td>test</td>\n", " <td>90</td>\n", " </tr>\n", " <tr>\n", " <th>1213</th>\n", " <td>test</td>\n", " <td>51</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " fold doc_num\n", "146 train 146\n", "1164 test 2\n", "483 train 483\n", "1190 test 28\n", "20 train 20\n", "237 train 237\n", "86 train 86\n", "408 train 408\n", "1252 test 90\n", "1213 test 51" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# We want to split the documents randomly into _NUM_FOLDS sets, then\n", "# for each stage of cross-validation train a model on the union of\n", "# (_NUM_FOLDS - 1) of them while testing on the remaining fold.\n", "# sklearn.model_selection doesn't implement this approach directly,\n", "# but we can piece it together with some help from Numpy.\n", "#from numpy.random import default_rng\n", "rng = np.random.default_rng(seed=_KFOLD_RANDOM_SEED)\n", "iloc_order = rng.permutation(len(doc_keys.index))\n", "kf = sklearn.model_selection.KFold(n_splits=_KFOLD_NUM_FOLDS)\n", "\n", "train_keys = []\n", "test_keys = []\n", "for train_ix, test_ix in kf.split(iloc_order):\n", " # sklearn.model_selection.KFold gives us a partitioning of the\n", " # numbers from 0 to len(iloc_order). Use that partitioning to \n", " # choose elements from iloc_order, then use those elements to \n", " # index into doc_keys.\n", " train_iloc = iloc_order[train_ix]\n", " test_iloc = iloc_order[test_ix]\n", " train_keys.append(doc_keys.iloc[train_iloc])\n", " test_keys.append(doc_keys.iloc[test_iloc])\n", "\n", "train_keys[1].head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Dry run: Train and evaluate models on the first fold\n", "\n", "Train models on the first of our 10 folds and manually examine some of the \n", "model outputs." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>fold</th>\n", " <th>doc_num</th>\n", " <th>token_id</th>\n", " <th>span</th>\n", " <th>input_id</th>\n", " <th>token_type_id</th>\n", " <th>attention_mask</th>\n", " <th>special_tokens_mask</th>\n", " <th>raw_span</th>\n", " <th>line_num</th>\n", " <th>raw_span_id</th>\n", " <th>ent_iob</th>\n", " <th>ent_type</th>\n", " <th>embedding</th>\n", " <th>token_class</th>\n", " <th>token_class_id</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>train</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>[0, 0): ''</td>\n", " <td>101</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>True</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.098505184, -0.4050192, 0.7428884...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>train</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>[0, 1): '-'</td>\n", " <td>118</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[0, 10): '-DOCSTART-'</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.057021223, -0.48112097, 0.989868...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>train</td>\n", " <td>0</td>\n", " <td>2</td>\n", " <td>[1, 2): 'D'</td>\n", " <td>141</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[0, 10): '-DOCSTART-'</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.04824195, -0.25330004, 1.167191...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>train</td>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>[2, 4): 'OC'</td>\n", " <td>9244</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[0, 10): '-DOCSTART-'</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.26682988, -0.31008753, 1.007472...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>train</td>\n", " <td>0</td>\n", " <td>4</td>\n", " <td>[4, 6): 'ST'</td>\n", " <td>9272</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[0, 10): '-DOCSTART-'</td>\n", " <td>0.0</td>\n", " <td>0.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.22296889, -0.21308492, 0.9331016...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>371472</th>\n", " <td>test</td>\n", " <td>230</td>\n", " <td>314</td>\n", " <td>[1386, 1393): 'brother'</td>\n", " <td>1711</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[1386, 1393): 'brother'</td>\n", " <td>50345.0</td>\n", " <td>267.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.028172785, -0.08062388, 0.9804888...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>371473</th>\n", " <td>test</td>\n", " <td>230</td>\n", " <td>315</td>\n", " <td>[1393, 1394): ','</td>\n", " <td>117</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[1393, 1394): ','</td>\n", " <td>50346.0</td>\n", " <td>268.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ 0.11817408, -0.07008513, 0.865484...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>371474</th>\n", " <td>test</td>\n", " <td>230</td>\n", " <td>316</td>\n", " <td>[1395, 1400): 'Bobby'</td>\n", " <td>5545</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[1395, 1400): 'Bobby'</td>\n", " <td>50347.0</td>\n", " <td>269.0</td>\n", " <td>B</td>\n", " <td>PER</td>\n", " <td>[ -0.35689482, 0.31400457, 1.573853...</td>\n", " <td>B-PER</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>371475</th>\n", " <td>test</td>\n", " <td>230</td>\n", " <td>317</td>\n", " <td>[1400, 1401): '.'</td>\n", " <td>119</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[1400, 1401): '.'</td>\n", " <td>50348.0</td>\n", " <td>270.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.18957126, -0.24581163, 0.66257...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>371476</th>\n", " <td>test</td>\n", " <td>230</td>\n", " <td>318</td>\n", " <td>[0, 0): ''</td>\n", " <td>102</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>True</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.44689128, -0.31665266, 0.779688...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>371477 rows × 16 columns</p>\n", "</div>" ], "text/plain": [ " fold doc_num token_id span input_id \\\n", "0 train 0 0 [0, 0): '' 101 \n", "1 train 0 1 [0, 1): '-' 118 \n", "2 train 0 2 [1, 2): 'D' 141 \n", "3 train 0 3 [2, 4): 'OC' 9244 \n", "4 train 0 4 [4, 6): 'ST' 9272 \n", "... ... ... ... ... ... \n", "371472 test 230 314 [1386, 1393): 'brother' 1711 \n", "371473 test 230 315 [1393, 1394): ',' 117 \n", "371474 test 230 316 [1395, 1400): 'Bobby' 5545 \n", "371475 test 230 317 [1400, 1401): '.' 119 \n", "371476 test 230 318 [0, 0): '' 102 \n", "\n", " token_type_id attention_mask special_tokens_mask \\\n", "0 0 1 True \n", "1 0 1 False \n", "2 0 1 False \n", "3 0 1 False \n", "4 0 1 False \n", "... ... ... ... \n", "371472 0 1 False \n", "371473 0 1 False \n", "371474 0 1 False \n", "371475 0 1 False \n", "371476 0 1 True \n", "\n", " raw_span line_num raw_span_id ent_iob ent_type \\\n", "0 NaN NaN NaN O <NA> \n", "1 [0, 10): '-DOCSTART-' 0.0 0.0 O <NA> \n", "2 [0, 10): '-DOCSTART-' 0.0 0.0 O <NA> \n", "3 [0, 10): '-DOCSTART-' 0.0 0.0 O <NA> \n", "4 [0, 10): '-DOCSTART-' 0.0 0.0 O <NA> \n", "... ... ... ... ... ... \n", "371472 [1386, 1393): 'brother' 50345.0 267.0 O <NA> \n", "371473 [1393, 1394): ',' 50346.0 268.0 O <NA> \n", "371474 [1395, 1400): 'Bobby' 50347.0 269.0 B PER \n", "371475 [1400, 1401): '.' 50348.0 270.0 O <NA> \n", "371476 NaN NaN NaN O <NA> \n", "\n", " embedding token_class \\\n", "0 [ -0.098505184, -0.4050192, 0.7428884... O \n", "1 [ -0.057021223, -0.48112097, 0.989868... O \n", "2 [ -0.04824195, -0.25330004, 1.167191... O \n", "3 [ -0.26682988, -0.31008753, 1.007472... O \n", "4 [ -0.22296889, -0.21308492, 0.9331016... O \n", "... ... ... \n", "371472 [ -0.028172785, -0.08062388, 0.9804888... O \n", "371473 [ 0.11817408, -0.07008513, 0.865484... O \n", "371474 [ -0.35689482, 0.31400457, 1.573853... B-PER \n", "371475 [ -0.18957126, -0.24581163, 0.66257... O \n", "371476 [ -0.44689128, -0.31665266, 0.779688... O \n", "\n", " token_class_id \n", "0 0 \n", "1 0 \n", "2 0 \n", "3 0 \n", "4 0 \n", "... ... \n", "371472 0 \n", "371473 0 \n", "371474 3 \n", "371475 0 \n", "371476 0 \n", "\n", "[371477 rows x 16 columns]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Gather the training set together by joining our list of documents\n", "# with the entire corpus on the composite key <fold, doc_num>\n", "train_inputs_df = corpus_df.merge(train_keys[0])\n", "train_inputs_df" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>fold</th>\n", " <th>doc_num</th>\n", " <th>token_id</th>\n", " <th>span</th>\n", " <th>input_id</th>\n", " <th>token_type_id</th>\n", " <th>attention_mask</th>\n", " <th>special_tokens_mask</th>\n", " <th>raw_span</th>\n", " <th>line_num</th>\n", " <th>raw_span_id</th>\n", " <th>ent_iob</th>\n", " <th>ent_type</th>\n", " <th>embedding</th>\n", " <th>token_class</th>\n", " <th>token_class_id</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>train</td>\n", " <td>12</td>\n", " <td>0</td>\n", " <td>[0, 0): ''</td>\n", " <td>101</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>True</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.101977676, -0.42442498, 0.8440171...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>train</td>\n", " <td>12</td>\n", " <td>1</td>\n", " <td>[0, 1): '-'</td>\n", " <td>118</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[0, 10): '-DOCSTART-'</td>\n", " <td>2664.0</td>\n", " <td>0.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.09124618, -0.47710702, 1.120292...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>train</td>\n", " <td>12</td>\n", " <td>2</td>\n", " <td>[1, 2): 'D'</td>\n", " <td>141</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[0, 10): '-DOCSTART-'</td>\n", " <td>2664.0</td>\n", " <td>0.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.1695277, -0.27063507, 1.209566...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>train</td>\n", " <td>12</td>\n", " <td>3</td>\n", " <td>[2, 4): 'OC'</td>\n", " <td>9244</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[0, 10): '-DOCSTART-'</td>\n", " <td>2664.0</td>\n", " <td>0.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.27648172, -0.3675844, 1.092024...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>train</td>\n", " <td>12</td>\n", " <td>4</td>\n", " <td>[4, 6): 'ST'</td>\n", " <td>9272</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[0, 10): '-DOCSTART-'</td>\n", " <td>2664.0</td>\n", " <td>0.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.24050614, -0.24247544, 1.07511...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>45059</th>\n", " <td>test</td>\n", " <td>225</td>\n", " <td>75</td>\n", " <td>[208, 213): 'fight'</td>\n", " <td>2147</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[208, 213): 'fight'</td>\n", " <td>49418.0</td>\n", " <td>29.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.09621397, -0.48016888, 0.510937...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>45060</th>\n", " <td>test</td>\n", " <td>225</td>\n", " <td>76</td>\n", " <td>[214, 216): 'on'</td>\n", " <td>1113</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[214, 216): 'on'</td>\n", " <td>49419.0</td>\n", " <td>30.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.0858628, -0.2341724, 0.832928...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>45061</th>\n", " <td>test</td>\n", " <td>225</td>\n", " <td>77</td>\n", " <td>[217, 225): 'Saturday'</td>\n", " <td>4306</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[217, 225): 'Saturday'</td>\n", " <td>49420.0</td>\n", " <td>31.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.012238501, -0.4282664, 0.619483...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>45062</th>\n", " <td>test</td>\n", " <td>225</td>\n", " <td>78</td>\n", " <td>[225, 226): '.'</td>\n", " <td>119</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>False</td>\n", " <td>[225, 226): '.'</td>\n", " <td>49421.0</td>\n", " <td>32.0</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.042955935, -0.36315423, 0.660203...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>45063</th>\n", " <td>test</td>\n", " <td>225</td>\n", " <td>79</td>\n", " <td>[0, 0): ''</td>\n", " <td>102</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>True</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>NaN</td>\n", " <td>O</td>\n", " <td><NA></td>\n", " <td>[ -0.9504192, 0.012983555, 0.7374987...</td>\n", " <td>O</td>\n", " <td>0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>45064 rows × 16 columns</p>\n", "</div>" ], "text/plain": [ " fold doc_num token_id span input_id \\\n", "0 train 12 0 [0, 0): '' 101 \n", "1 train 12 1 [0, 1): '-' 118 \n", "2 train 12 2 [1, 2): 'D' 141 \n", "3 train 12 3 [2, 4): 'OC' 9244 \n", "4 train 12 4 [4, 6): 'ST' 9272 \n", "... ... ... ... ... ... \n", "45059 test 225 75 [208, 213): 'fight' 2147 \n", "45060 test 225 76 [214, 216): 'on' 1113 \n", "45061 test 225 77 [217, 225): 'Saturday' 4306 \n", "45062 test 225 78 [225, 226): '.' 119 \n", "45063 test 225 79 [0, 0): '' 102 \n", "\n", " token_type_id attention_mask special_tokens_mask \\\n", "0 0 1 True \n", "1 0 1 False \n", "2 0 1 False \n", "3 0 1 False \n", "4 0 1 False \n", "... ... ... ... \n", "45059 0 1 False \n", "45060 0 1 False \n", "45061 0 1 False \n", "45062 0 1 False \n", "45063 0 1 True \n", "\n", " raw_span line_num raw_span_id ent_iob ent_type \\\n", "0 NaN NaN NaN O <NA> \n", "1 [0, 10): '-DOCSTART-' 2664.0 0.0 O <NA> \n", "2 [0, 10): '-DOCSTART-' 2664.0 0.0 O <NA> \n", "3 [0, 10): '-DOCSTART-' 2664.0 0.0 O <NA> \n", "4 [0, 10): '-DOCSTART-' 2664.0 0.0 O <NA> \n", "... ... ... ... ... ... \n", "45059 [208, 213): 'fight' 49418.0 29.0 O <NA> \n", "45060 [214, 216): 'on' 49419.0 30.0 O <NA> \n", "45061 [217, 225): 'Saturday' 49420.0 31.0 O <NA> \n", "45062 [225, 226): '.' 49421.0 32.0 O <NA> \n", "45063 NaN NaN NaN O <NA> \n", "\n", " embedding token_class \\\n", "0 [ -0.101977676, -0.42442498, 0.8440171... O \n", "1 [ -0.09124618, -0.47710702, 1.120292... O \n", "2 [ -0.1695277, -0.27063507, 1.209566... O \n", "3 [ -0.27648172, -0.3675844, 1.092024... O \n", "4 [ -0.24050614, -0.24247544, 1.07511... O \n", "... ... ... \n", "45059 [ -0.09621397, -0.48016888, 0.510937... O \n", "45060 [ -0.0858628, -0.2341724, 0.832928... O \n", "45061 [ -0.012238501, -0.4282664, 0.619483... O \n", "45062 [ -0.042955935, -0.36315423, 0.660203... O \n", "45063 [ -0.9504192, 0.012983555, 0.7374987... O \n", "\n", " token_class_id \n", "0 0 \n", "1 0 \n", "2 0 \n", "3 0 \n", "4 0 \n", "... ... \n", "45059 0 \n", "45060 0 \n", "45061 0 \n", "45062 0 \n", "45063 0 \n", "\n", "[45064 rows x 16 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Repeat the same process for the test set\n", "test_inputs_df = corpus_df.merge(test_keys[0])\n", "test_inputs_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train an ensemble of models" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2021-07-12 18:16:33,117\tINFO services.py:1267 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265\u001b[39m\u001b[22m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training model using all of 768-dimension embeddings.\n", "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n", "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n", "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n", "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n", "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n", "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n", "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n", "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n", "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n", "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n", "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n", "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n", "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n", "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n", "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n", "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n", "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=256 and seed=781567.\n", "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=256 and seed=402414.\n", "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=128 and seed=839748.\n", "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=256 and seed=643865.\n", "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=32 and seed=438878.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n", "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=32 and seed=654571.\n", "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=32 and seed=773956.\n", "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=128 and seed=128113.\n", "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=64 and seed=526478.\n", "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=64 and seed=94177.\n", "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=64 and seed=975622.\n", "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=64 and seed=201469.\n", "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=128 and seed=513226.\n", "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=128 and seed=450385.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n", "Trained 17 models.\n", "Model names after loading or training: 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64_2, 64_3, 64_4, 128_1, 128_2, 128_3, 128_4, 256_1, 256_2, 256_3, 256_4\n" ] } ], "source": [ "import importlib\n", "import sklearn.linear_model\n", "import ray\n", "ray.init()\n", "\n", "# Wrap train_reduced_model in a Ray task\n", "@ray.remote\n", "def train_reduced_model_task(\n", " x_values: np.ndarray, y_values: np.ndarray, n_components: int,\n", " seed: int, max_iter: int = 10000) -> sklearn.base.BaseEstimator:\n", " return cleaning.ensemble.train_reduced_model(x_values, y_values, n_components, seed, max_iter)\n", "\n", "# Ray task that trains a model using the entire embedding\n", "@ray.remote\n", "def train_full_model_task(x_values: np.ndarray, y_values: np.ndarray, \n", " max_iter: int = 10000) -> sklearn.base.BaseEstimator:\n", " return (\n", " sklearn.linear_model.LogisticRegression(\n", " multi_class=\"multinomial\", max_iter=max_iter\n", " )\n", " .fit(x_values, y_values)\n", " )\n", "\n", "def train_models(train_df: pd.DataFrame) \\\n", " -> Dict[str, sklearn.base.BaseEstimator]:\n", " \"\"\"\n", " Train an ensemble of models with different levels of noise.\n", " \n", " :param train_df: DataFrame of labeled training documents, with one\n", " row per token. Must contain the columns \"embedding\" (precomputed \n", " BERT embeddings) and \"token_class_id\" (integer ID of token type)\n", " \n", " :returns: A mapping from mnemonic model name to trained model\n", " \"\"\"\n", " X = train_df[\"embedding\"].values\n", " Y = train_df[\"token_class_id\"]\n", " \n", "\n", " # Push the X and Y values to Plasma so that our tasks can share them.\n", " X_id = ray.put(X.to_numpy().copy())\n", " Y_id = ray.put(Y.to_numpy().copy())\n", " \n", " names_list = []\n", " futures_list = []\n", " \n", " print(f\"Training model using all of \"\n", " f\"{X._tensor.shape[1]}-dimension embeddings.\")\n", " names_list.append(f\"{X._tensor.shape[1]}_1\")\n", " futures_list.append(train_full_model_task.remote(X_id, Y_id)) \n", " \n", " for i in range(len(_REDUCED_DIMS)):\n", " num_dims = _REDUCED_DIMS[i]\n", " num_models = _MODELS_AT_DIM[i]\n", " for j in range(num_models):\n", " model_name = f\"{num_dims}_{j + 1}\"\n", " seed = _MODEL_RANDOM_SEEDS[i, j]\n", " print(f\"Training model '{model_name}' (#{j + 1} \"\n", " f\"at {num_dims} dimensions) with seed {seed}\")\n", " names_list.append(model_name)\n", " futures_list.append(train_reduced_model_task.remote(X_id, Y_id, \n", " num_dims, seed))\n", " \n", " # Block until all training tasks have completed and fetch the resulting models.\n", " models_list = ray.get(futures_list)\n", " models = {\n", " n: m for n, m in zip(names_list, models_list)\n", " }\n", " return models\n", "\n", "def maybe_train_models(train_df: pd.DataFrame, fold_num: int):\n", " import pickle\n", " _CACHED_MODELS_FILE = f\"outputs/fold_{fold_num}_models.pickle\"\n", " if _REGENERATE_MODELS or not os.path.exists(_CACHED_MODELS_FILE):\n", " m = train_models(train_df)\n", " print(f\"Trained {len(m)} models.\")\n", " with open(_CACHED_MODELS_FILE, \"wb\") as f:\n", " pickle.dump(m, f)\n", " else:\n", " # Use a cached model when using cached embeddings\n", " with open(_CACHED_MODELS_FILE, \"rb\") as f:\n", " m = pickle.load(f)\n", " print(f\"Loaded {len(m)} models from {_CACHED_MODELS_FILE}.\")\n", " return m\n", "\n", "models = maybe_train_models(train_inputs_df, 0)\n", "print(f\"Model names after loading or training: {', '.join(models.keys())}\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Uncomment this code if you need to have the cells that follow ignore\n", "# some of the models saved to disk.\n", "# _MODEL_SIZES_TO_KEEP = [32, 64, 128, 256]\n", "# _RUNS_TO_KEEP = [4] * len(_MODEL_SIZES_TO_KEEP)\n", "# _OTHER_MODELS_TO_KEEP = [\"768_1\"]\n", "\n", "# to_keep = _OTHER_MODELS_TO_KEEP.copy()\n", "# for size in _MODEL_SIZES_TO_KEEP:\n", "# for num_runs in _RUNS_TO_KEEP:\n", "# for i in range(num_runs):\n", "# to_keep.append(f\"{size}_{i+1}\")\n", "\n", "# models = {k: v for k, v in models.items() if k in to_keep}\n", "\n", "# print(f\"Model names after filtering: {', '.join(models.keys())}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluate the models on this fold's test set" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6eb2c9cdd5f242ffb017fc96bab994b5", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>span</th>\n", " <th>ent_type</th>\n", " <th>fold</th>\n", " <th>doc_num</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>[11, 16): 'Saudi'</td>\n", " <td>MISC</td>\n", " <td>train</td>\n", " <td>12</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>[59, 65): 'MANAMA'</td>\n", " <td>LOC</td>\n", " <td>train</td>\n", " <td>12</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>[86, 91): 'Saudi'</td>\n", " <td>MISC</td>\n", " <td>train</td>\n", " <td>12</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>[259, 264): 'Saudi'</td>\n", " <td>MISC</td>\n", " <td>train</td>\n", " <td>12</td>\n", " </tr>\n", " <tr>\n", " <th>0</th>\n", " <td>[55, 65): 'MONTGOMERY'</td>\n", " <td>LOC</td>\n", " <td>train</td>\n", " <td>20</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " span ent_type fold doc_num\n", "0 [11, 16): 'Saudi' MISC train 12\n", "1 [59, 65): 'MANAMA' LOC train 12\n", "2 [86, 91): 'Saudi' MISC train 12\n", "3 [259, 264): 'Saudi' MISC train 12\n", "0 [55, 65): 'MONTGOMERY' LOC train 20" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def eval_models(models: Dict[str, sklearn.base.BaseEstimator],\n", " test_df: pd.DataFrame):\n", " \"\"\"\n", " Bulk-evaluate an ensemble of models generated by :func:`train_models`.\n", " \n", " :param models: Output of :func:`train_models`\n", " :param test_df: DataFrame of labeled test documents, with one\n", " row per token. Must contain the columns \"embedding\" (precomputed \n", " BERT embeddings) and \"token_class_id\" (integer ID of token type)\n", " \n", " :returns: A dictionary from model name to results of \n", " :func:`util.analyze_model`\n", " \"\"\"\n", " todo = [(name, model) for name, model in models.items()]\n", " results = tp.jupyter.run_with_progress_bar(\n", " len(todo),\n", " lambda i: cleaning.infer_and_extract_entities_iob(test_df,corpus_raw, int_to_label, todo[i][1]),\n", " \"model\"\n", " )\n", " return {t[0]: result for t, result in zip(todo, results)}\n", "\n", "evals = eval_models(models, test_inputs_df)\n", "# display one of the results\n", "evals[list(evals.keys())[0]].head()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>precision</th>\n", " <th>recall</th>\n", " <th>f1-score</th>\n", " <th>dims</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>768_1</th>\n", " <td>0.947149</td>\n", " <td>0.938839</td>\n", " <td>0.942976</td>\n", " <td>768</td>\n", " </tr>\n", " <tr>\n", " <th>32_1</th>\n", " <td>0.924075</td>\n", " <td>0.863742</td>\n", " <td>0.892890</td>\n", " <td>32</td>\n", " </tr>\n", " <tr>\n", " <th>32_2</th>\n", " <td>0.924755</td>\n", " <td>0.875355</td>\n", " <td>0.899377</td>\n", " <td>32</td>\n", " </tr>\n", " <tr>\n", " <th>32_3</th>\n", " <td>0.925028</td>\n", " <td>0.866065</td>\n", " <td>0.894576</td>\n", " <td>32</td>\n", " </tr>\n", " <tr>\n", " <th>32_4</th>\n", " <td>0.932949</td>\n", " <td>0.876129</td>\n", " <td>0.903647</td>\n", " <td>32</td>\n", " </tr>\n", " <tr>\n", " <th>64_1</th>\n", " <td>0.940086</td>\n", " <td>0.902968</td>\n", " <td>0.921153</td>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>64_2</th>\n", " <td>0.938321</td>\n", " <td>0.902968</td>\n", " <td>0.920305</td>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>64_3</th>\n", " <td>0.936808</td>\n", " <td>0.895226</td>\n", " <td>0.915545</td>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>64_4</th>\n", " <td>0.940828</td>\n", " <td>0.902710</td>\n", " <td>0.921375</td>\n", " <td>64</td>\n", " </tr>\n", " <tr>\n", " <th>128_1</th>\n", " <td>0.944401</td>\n", " <td>0.924903</td>\n", " <td>0.934550</td>\n", " <td>128</td>\n", " </tr>\n", " <tr>\n", " <th>128_2</th>\n", " <td>0.947577</td>\n", " <td>0.923613</td>\n", " <td>0.935442</td>\n", " <td>128</td>\n", " </tr>\n", " <tr>\n", " <th>128_3</th>\n", " <td>0.943212</td>\n", " <td>0.921548</td>\n", " <td>0.932254</td>\n", " <td>128</td>\n", " </tr>\n", " <tr>\n", " <th>128_4</th>\n", " <td>0.940991</td>\n", " <td>0.921806</td>\n", " <td>0.931300</td>\n", " <td>128</td>\n", " </tr>\n", " <tr>\n", " <th>256_1</th>\n", " <td>0.949201</td>\n", " <td>0.935484</td>\n", " <td>0.942293</td>\n", " <td>256</td>\n", " </tr>\n", " <tr>\n", " <th>256_2</th>\n", " <td>0.943396</td>\n", " <td>0.929032</td>\n", " <td>0.936159</td>\n", " <td>256</td>\n", " </tr>\n", " <tr>\n", " <th>256_3</th>\n", " <td>0.945478</td>\n", " <td>0.930839</td>\n", " <td>0.938101</td>\n", " <td>256</td>\n", " </tr>\n", " <tr>\n", " <th>256_4</th>\n", " <td>0.945055</td>\n", " <td>0.932129</td>\n", " <td>0.938547</td>\n", " <td>256</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " precision recall f1-score dims\n", "768_1 0.947149 0.938839 0.942976 768\n", "32_1 0.924075 0.863742 0.892890 32\n", "32_2 0.924755 0.875355 0.899377 32\n", "32_3 0.925028 0.866065 0.894576 32\n", "32_4 0.932949 0.876129 0.903647 32\n", "64_1 0.940086 0.902968 0.921153 64\n", "64_2 0.938321 0.902968 0.920305 64\n", "64_3 0.936808 0.895226 0.915545 64\n", "64_4 0.940828 0.902710 0.921375 64\n", "128_1 0.944401 0.924903 0.934550 128\n", "128_2 0.947577 0.923613 0.935442 128\n", "128_3 0.943212 0.921548 0.932254 128\n", "128_4 0.940991 0.921806 0.931300 128\n", "256_1 0.949201 0.935484 0.942293 256\n", "256_2 0.943396 0.929032 0.936159 256\n", "256_3 0.945478 0.930839 0.938101 256\n", "256_4 0.945055 0.932129 0.938547 256" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Summarize how each of the models does on the test set.\n", "gold_elts = cleaning.preprocess.combine_raw_spans_docs_to_match(corpus_raw,evals[list(evals.keys())[0]],label_col = 'ent_type')\n", "def make_summary_df(evals_df: pd.DataFrame) -> pd.DataFrame:\n", " gold_elts = cleaning.preprocess.combine_raw_spans_docs_to_match(corpus_raw, evals['256_4'], label_col = 'ent_type')\n", " summary_df= cleaning.analysis.create_f1_report_ensemble_iob(evals,gold_elts)\n", " summary_df['dims'] = [int(name.split('_')[0]) for name in evals.keys()]\n", " return summary_df\n", "\n", "summary_df = make_summary_df(evals)\n", "summary_df" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAAEGCAYAAADFdkirAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAYQ0lEQVR4nO3df5RdZX3v8feHZIKjgIMkZZGEmwBiaGzB0DFCpU2uWhK6bPihbUlZFaxXrldRrE1uybILNV0s6g22/ijXVargpbqkiGmM2DpgDNbSipkwhBBxQoxaMqF1LAy2OkJ+fO8f+znJyZCZnJk5+zxnzvm81pqVvZ+99znfPSf55Nl7n/1sRQRmZjkcl7sAM2tfDiAzy8YBZGbZOIDMLBsHkJllMz13AfUyc+bMmD9/fu4yzNrS1q1bfxwRs8a7XcsE0Pz58+nt7c1dhllbkvTDiWznQzAzy8YBZGbZOIDMLBsHkJll4wAys2xa5iqYmdXPhr4B1vX0s3domNldnaxetoDLFs2p+/s4gMzsCBv6BlizfjvD+w4AMDA0zJr12wHqHkI+BDOzI6zr6T8UPhXD+w6wrqe/7u/lHlCLa1RX2lrH3qHhcbVPhntALazSlR4YGiY43JXe0DeQuzRrYrO7OsfVPhkOoBbWyK60tY7VyxbQ2THtiLbOjmmsXrag7u/lQ7AW1siutLWOyiG6r4LZpMzu6mTgKGFTRlfaWstli+Y05FyhD8FaWCO70mYT4R5QC2tkV9psIhxALa5RXWmzifAhmJllU2oASVouqV/SLkk3HGX5PEmbJD0q6QFJc0csP0nSHkl/WWadZpZHaQEkaRpwK3AJsBBYKWnhiNVuAe6MiHOBtcDNI5b/KfCPZdVoZnmV2QNaDOyKiN0R8TxwF3DpiHUWAl9P05url0v6FeBU4L4SazSzjMoMoDnAk1Xze1JbtW3AFWn6cuBESadIOg74CLBqrDeQdK2kXkm9g4ODdSrbzBol90noVcASSX3AEmAAOAC8E/j7iNgz1sYRcVtEdEdE96xZ434iiJllVuZl+AHg9Kr5uantkIjYS+oBSToBeFNEDEm6EPg1Se8ETgBmSPqviHjBiWwzm7rKDKAtwNmSzqAIniuB36teQdJM4OmIOAisAW4HiIirqta5Buh2+EyMh+OwZlbaIVhE7AeuA3qAx4G7I2KHpLWSVqTVlgL9knZSnHC+qax62tGGvgFW37PtiOE4Vt+zzcNxWNNQROSuoS66u7vDT0Y90qK19/HMz/a9oP3kF3fQd+PFGSqyViVpa0R0j3e73CehrURHC5+x2s0azQFkZtk4gFpYV2fHuNrNGs0B1MI+uOKVdBynI9o6jhMfXPHKTBWZHcnDcbQwjwdkzc4B1OI8HpA1MwdQi/MXEa2ZOYBaWCMfsWs2ET4J3cL8XDBrdu4BlaBZDnv8XDBrdu4B1VkzPQ65kY/YNZsIB1CdNdNhj58LZs3Oh2B11kyHPf4ekDU7B1Cddb2446g3e3a9OM/tD/4ekDUzH4LV2Wijm7TIqCdmdeUAqrNnh48+1MVo7WbtzAFUZ77yZFY7B1Cd+cqTWe18ErrOfOXJrHYOoBL4ypNZbXwIZmbZOIDMLBsfgpWgWW5GNWt2DqA68xg8ZrXzIVidNdPNqGbNzgFUZ810M6pZs3MA1dlLR3nm1mjtZu3MAVRn0vjazdqZA6jOhkZ57vpo7WbtzAFUZ74Z1ax2pQaQpOWS+iXtknTDUZbPk7RJ0qOSHpA0t6r9YUmPSNoh6R1l1llPvhnVrHalfQ9I0jTgVuA3gD3AFkkbI+I7VavdAtwZEf9P0uuAm4HfB54CLoyI5ySdADyWtt1bVr314ptRzWpX5hcRFwO7ImI3gKS7gEuB6gBaCLwvTW8GNgBExPNV6xzPFDtU9M2oZrUp8x/2HODJqvk9qa3aNuCKNH05cKKkUwAknS7p0fQaHz5a70fStZJ6JfUODg7WfQfMrFy5exargCWS+oAlwABwACAinoyIc4GXA1dLOnXkxhFxW0R0R0T3rFmzGlm3mdVBmQE0AJxeNT83tR0SEXsj4oqIWAS8P7UNjVwHeAz4tRJrNbMMygygLcDZks6QNAO4EthYvYKkmZIqNawBbk/tcyV1pumTgYsA30xl1mJKC6CI2A9cB/QAjwN3R8QOSWslrUirLQX6Je0ETgVuSu2/CDwkaRvwDeCWiNheVq1mloeiRR5Y1d3dHb29vbnLMGtLkrZGRPd4t8t9EtrM2pgDyMyycQCZWTYOIDPLxgFkZtk4gMwsGweQmWXjADKzbBxAZpaNA8jMsnEAmVk2DiAzy8YBZGbZOIDMLBsHkJll4wAys2wcQGaWjQPIzLJxAJlZNg4gM8vGAWRm2TiAzCybmgJI0kWS3pqmZ0k6o9yyzKwdHDOAJH0A+GOKJ5cCdACfLbMoM2sPtfSALgdWAD+FQ89qP7HMosysPUyvYZ3nIyIkBYCkl5RcUzYb+gZY19PP3qFhZnd1snrZAi5bNCd3WWYtq5YAulvSXwFdkt4O/AHw1+WW1Xgb+gZYfc829h0oHlU9MDTM6nu2ATiEzEoy5iGYJAF/C9wDfBFYANwYEZ9oQG0N9aEv7zgUPhX7DgQf+vKOTBWZtb4xe0Dp0OvvI+KXgfsbVFMWz/xs37jazWzyajkJ/bCkV5deiZm1nVrOAb0GuErSDymuhImic3RuqZU1WFdnB0PDL+ztdHV2ZKjGrD3U0gNaBpwFvA74LeCN6c9jkrRcUr+kXZJuOMryeZI2SXpU0gOS5qb2V0n6F0k70rLfrX2XJuaN5502rnYzm7xjBlBE/BDoogid3wK6UtuYJE0DbgUuARYCKyUtHLHaLcCdqTe1Frg5tf8MeEtEvBJYDnxUUlctOzRRm787OK52M5u8Wr4JfT3wOeAX0s9nJb27htdeDOyKiN0R8TxwF3DpiHUWAl9P05sryyNiZ0Q8kab3Aj8CZtXwnhO2d2h4XO1mNnm1HIK9DXhNRNwYETcCFwBvr2G7OcCTVfN7Ulu1bcAVafpy4ERJp1SvIGkxMAP43sg3kHStpF5JvYODk+upzO7qHFe7mU1eLQEk4EDV/IHUVg+rgCWS+oAlwED1e0k6Dfgb4K0RcXDkxhFxW0R0R0T3rFmT6yCtXraAzo5pR7R1dkxj9bIFk3pdMxtdLVfB7gAekvR3af4y4NM1bDcAnF41Pze1HZIOr64AkHQC8KaIGErzJwFfAd4fEd+q4f0mpfJtZ9+KYdY4iohjrySdD1yUZr8ZEX01bDMd2Am8niJ4tgC/FxE7qtaZCTwdEQcl3QQciIgbJc0A/gH4ckR8tJYd6e7ujt7e3lpWNbM6k7Q1IrrHu90xe0CSLgB2RMTDaf4kSa+JiIfG2i4i9ku6DugBpgG3R8QOSWuB3ojYCCwFbk43uv4j8K60+e8Avw6cIuma1HZNRDwy3h00s+Z1zB5QOj9zfqQVJR1HESDnN6C+mrkHZJbPRHtANZ2EjqqUSieDazl3ZGY2ploCaLek90jqSD/XA7vLLszMWl8tAfQO4FcpTiQPUNwbdm2ZRZlZezjmoVRE/Ai4sgG1mFmbGbUHJOntks5O05J0u6Rn082hTXUC2symprEOwa4HfpCmVwLnAWcC7wM+Vm5ZZtYOxgqg/RFRGSDnjRR3rf9HRHwNaNmB6c2sccYKoIOSTpP0IopvM3+tapnv0DSzSRvrJPSNQC/Ft5g3Vm6hkLQEX4Y3szoYNYAi4l5J84ATI+KZqkW9QOkjFObg54KZNdaxnoqxH3hmRNtPS60okw19A6xZv53hfcVoIANDw6xZvx3wc8HMylLLFxHbwrqe/kPhUzG87wDrevozVWTW+hxAiYdkNWu8CQWQpHPqXUhuHpLVrPEm2gO6r65VNAEPyWrWeKOehJb08dEWUTymp6V4SFazxhvrKthbgT8CnjvKspXllJPXZYvmOHDMGmisANoCPBYR/zxygaQPllaRmbWNsQLozcDPj7YgIs4opxwzaydjnYQ+ISJ+1rBKzKztjBVAGyoTkr5Yfilm1m7GCqDqp5+eWXYhZtZ+xgqgGGXazKwuxjoJfZ6kn1D0hDrTNGk+IuKk0qszs5Y21nAc00ZbZmZWD74Z1cyycQCZWTYOIDPLxgFkZtk4gMwsm1IDSNJySf2Sdkm64SjL50nalJ62+oCkuVXLvippSNK9ZdZoZvmUFkCSpgG3ApcAC4GVkhaOWO0WigcengusBW6uWrYO+P2y6jOz/MrsAS0GdkXE7oh4HrgLuHTEOguBr6fpzdXLI2IT8J8l1mdmmZUZQHOAJ6vm96S2atuAK9L05cCJkk6p9Q0kXSupV1Lv4ODgpIo1s8bLfRJ6FbBEUh+wBBgADoy9yWERcVtEdEdE96xZs8qq0cxKMuaDCSdpADi9an5uajskIvaSekCSTgDeFBFDJdZkZk2kzB7QFuBsSWdImgFcCWysXkHSTEmVGtYAt5dYj5k1mdICKD3W+TqgB3gcuDsidkhaK2lFWm0p0C9pJ3AqcFNle0nfBL4AvF7SHknLyqrVzPJQRGsM9dPd3R29vb25yzBrS5K2RkT3eLfLfRLazNqYA8jMsnEAmVk2DiAzy8YBZGbZOIDMLBsHkJll4wAys2wcQGaWjQPIzLJxAJlZNg4gM8vGAWRm2TiAzCwbB5CZZeMAMrNsyhwTuuls6BtgXU8/e4eGmd3VyeplC7hs0cgHdZhZo7RNAG3oG2DN+u0M7yseujEwNMya9dsBHEJmmbTNIdi6nv5D4VMxvO8A63r6M1VkZm0TQHuHhsfVbmbla5sAmt3VOa52Mytf2wTQ6mUL6OyYdkRbZ8c0Vi9bkKkiM2ubk9CVE82+CmbWPNomgKAIIQeOWfNom0MwM2s+DiAzy8YBZGbZOIDMLBsHkJll4wAys2wcQGaWTakBJGm5pH5JuyTdcJTl8yRtkvSopAckza1adrWkJ9LP1WXWaWZ5lBZAkqYBtwKXAAuBlZIWjljtFuDOiDgXWAvcnLZ9GfAB4DXAYuADkk4uq1Yzy6PMb0IvBnZFxG4ASXcBlwLfqVpnIfC+NL0Z2JCmlwH3R8TTadv7geXA5ydTkAckM2suZR6CzQGerJrfk9qqbQOuSNOXAydKOqXGbZF0raReSb2Dg4NjFlMZkGxgaJjg8IBkG/oGxrVTZlY/uU9CrwKWSOoDlgADwIGxNzksIm6LiO6I6J41a9aY63pAMrPmU+Yh2ABwetX83NR2SETsJfWAJJ0AvCkihiQNAEtHbPvAZIrxgGRmzafMHtAW4GxJZ0iaAVwJbKxeQdJMSZUa1gC3p+ke4GJJJ6eTzxentgnzgGRmzae0AIqI/cB1FMHxOHB3ROyQtFbSirTaUqBf0k7gVOCmtO3TwJ9ShNgWYG3lhPREeUAys+ajiMhdQ110d3dHb2/vmOv4KphZOSRtjYju8W7nAcnMLJvcV8HMrI05gMwsGweQmWXjADKzbBxAZpaNA8jMsnEAmVk2DiAzy8YBZGbZOIDMLBsHkJll01b3gvlmVLPm0jYBVBmStTIqYmVIVsAhZJZJ2xyCeUhWs+bTNgHkIVnNmk/bBJCHZDVrPm0TQB6S1az5tM1J6MqJZl8FM2sebdMDMrPm0zY9IF+GN2s+bdMD8mV4s+bTNgHky/BmzadtAsiX4c2aT9sEkC/DmzWftjkJ7cvwZs2nbQII/GRUs2bTNodgZtZ8HEBmlo0DyMyycQCZWTYOIDPLRhGRu4a6kDQI/BCYCfw4czll8b5NTe2wb/MiYtZ4N26ZAKqQ1BsR3bnrKIP3bWryvo3Oh2Bmlo0DyMyyacUAui13ASXyvk1N3rdRtNw5IDObOlqxB2RmU4QDyMyyaZkAkrRcUr+kXZJuyF3PeEk6XdJmSd+RtEPS9an9ZZLul/RE+vPk1C5JH0/7+6ik8/PuwbFJmiapT9K9af4MSQ+lffhbSTNS+/FpfldaPj9r4TWQ1CXpHknflfS4pAtb5bOT9Ifp7+Rjkj4v6UX1+uxaIoAkTQNuBS4BFgIrJS3MW9W47Qf+KCIWAhcA70r7cAOwKSLOBjaleSj29ez0cy3wycaXPG7XA49XzX8Y+IuIeDnwDPC21P424JnU/hdpvWb3MeCrEXEOcB7Ffk75z07SHOA9QHdE/BIwDbiSen12ETHlf4ALgZ6q+TXAmtx1TXKfvgT8BtAPnJbaTgP60/RfASur1j+0XjP+AHMp/hG+DrgXEMU3aKeP/AyBHuDCND09rafc+zDGvr0U+P7IGlvhswPmAE8CL0ufxb3Asnp9di3RA+LwL6liT2qbklK3dRHwEHBqRDyVFv0bcGqanmr7/FHgfwMH0/wpwFBE7E/z1fUf2re0/Nm0frM6AxgE7kiHmJ+S9BJa4LOLiAHgFuBfgacoPout1Omza5UAahmSTgC+CLw3In5SvSyK/1am3PcmJL0R+FFEbM1dS0mmA+cDn4yIRcBPOXy4BUzpz+5k4FKKkJ0NvARYXq/Xb5UAGgBOr5qfm9qmFEkdFOHzuYhYn5r/XdJpaflpwI9S+1Ta59cCKyT9ALiL4jDsY0CXpMqwwNX1H9q3tPylwH80suBx2gPsiYiH0vw9FIHUCp/dG4DvR8RgROwD1lN8nnX57FolgLYAZ6cz8zMoTpJtzFzTuEgS8Gng8Yj486pFG4Gr0/TVFOeGKu1vSVdULgCereruN5WIWBMRcyNiPsVn8/WIuArYDLw5rTZy3yr7/Oa0ftP2HiLi34AnJVUesfJ64Du0wGdHceh1gaQXp7+jlX2rz2eX+yRXHU+W/SawE/ge8P7c9Uyg/osouuiPAo+kn9+kOH7eBDwBfA14WVpfFFf+vgdsp7hKkX0/atjPpcC9afpM4NvALuALwPGp/UVpfldafmbuumvYr1cBvenz2wCc3CqfHfAh4LvAY8DfAMfX67PzrRhmlk2rHIKZ2RTkADKzbBxAZpaNA8jMsnEAmVk2DqApQlJI+kjV/CpJH6zTa39G0puPveak3+e3053im0e0z5c0nG5jeFzStyVdU7V8RY4RDiTNlnRPo9+3nUw/9irWJJ4DrpB0c0Q0zSNeJE2Pw/cEHcvbgLdHxD8dZdn3oriNAUlnAuslKSLuiIiNZPhiaUTs5fCX7awE7gFNHfspxt/9w5ELRvZgJP1X+nOppG9I+pKk3ZL+TNJVqYexXdJZVS/zBkm9kname7cq4/esk7QljVvzP6te95uSNlJ8K3ZkPSvT6z8m6cOp7UaKL1t+WtK6sXY0InYD76MYBgJJ10j6y6p9/aSkb6V9Wirp9tRz+kxVDRdL+hdJD0v6QrrHDkk/kPSh1L5d0jmpfYmkR9JPn6QTU8/ssbT8RZLuSNv0SfrvVbWtl/RVFeP+/J+q391n0u9gu6QXfG7mHtBUcyvwaOUveY3OA34ReBrYDXwqIharGPDs3cB703rzgcXAWcBmSS8H3kJxm8CrJR0PPCjpvrT++cAvRcT3q99M0myKMWB+hWKcmPskXRYRayW9DlgVEb011P0wcM4oy06mGAJiBUXP6LXA/wC2SHoVxb1ZfwK8ISJ+KumPKQJtbdr+xxFxvqR3AqvStquAd0XEgymsfj7iPd9FcU/pL6fQuk/SK9KyV1GMXvAc0C/pE8AvAHOiGEMHSV017HPbcQ9oConi7vg7ST2DGm2JiKci4jmKr/5XAmQ7RehU3B0RByPiCYqgOge4mOKepUcohgY5hWIQLYBvjwyf5NXAA1HcvLgf+Bzw6+Oot0JjLPtyFF/h3w78e0Rsj4iDwI60TxdQDEz3YKr9amBe1faVG323cvh38CDw55LeA3Qd5bDyIuCzABHxXYqn8FYCaFNEPBsRP6foEc6j+B2eKekTkpYDP8FewD2gqeejFL2DO6ra9pP+M5F0HDCjatlzVdMHq+YPcuTnP/KenKAIgXdHRE/1AklLKYacKNMijhw9sVr1Pozcv+nAAeD+iFh5jO0PpPWJiD+T9BWK++8elLSMF/aCRlNdwwGKgbqekXQexeBd7wB+B/iDGl+vbbgHNMVExNPA3RweAhPgBxSHPFAclnRM4KV/W9Jx6bzQmRSj9PUA/0vFMCFIeoWKgbbG8m1giaSZKobKXQl8YzyFqBiQ7RbgE+Pch4pvAa9Nh5FIeknV4dJo73lW6kl9mGJ0hZGHf98ErkrrvgL4bxS/o9FebyZwXER8keJwsKnHfc7FPaCp6SPAdVXzfw18SdI24KtMrHfyrxThcRLwjoj4uaRPURyiPCxJFKP+XTbWi0TEU+mS+WaKHtRXIuJLY22TnCWpj+Ju6v8EPh4Rn5nAfhARgyou438+nbuCIgR2jrHZe9OJ5cqh3D9QDKNa8X+BT0raTtHjvCYinit+LUc1h2KExMp/8msmsi+tznfDm1k2PgQzs2wcQGaWjQPIzLJxAJlZNg4gM8vGAWRm2TiAzCyb/w/cTs4P76EQxgAAAABJRU5ErkJggg==\n", "text/plain": [ "<Figure size 288x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# Plot the tradeoff between dimensionality and F1 score\n", "x = summary_df[\"dims\"]\n", "y = summary_df[\"f1-score\"]\n", "\n", "plt.figure(figsize=(4,4))\n", "plt.scatter(x, y)\n", "#plt.yscale(\"log\")\n", "#plt.xscale(\"log\")\n", "plt.xlabel(\"Number of Dimensions\")\n", "plt.ylabel(\"F1 Score\")\n", "\n", "# Also dump the raw data to a local file.\n", "pd.DataFrame({\"num_dims\": x, \"f1_score\": y}).to_csv(\"outputs/dims_vs_f1_score_xval.csv\",\n", " index=False)\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Aggregate the model results and compare with the gold standard" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>fold</th>\n", " <th>doc_num</th>\n", " <th>span</th>\n", " <th>ent_type</th>\n", " <th>in_gold</th>\n", " <th>count</th>\n", " <th>models</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>4927</th>\n", " <td>train</td>\n", " <td>907</td>\n", " <td>[590, 598): 'Gorleben'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " <td>[GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...</td>\n", " </tr>\n", " <tr>\n", " <th>4925</th>\n", " <td>train</td>\n", " <td>907</td>\n", " <td>[63, 67): 'BONN'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " <td>[GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...</td>\n", " </tr>\n", " <tr>\n", " <th>4924</th>\n", " <td>train</td>\n", " <td>907</td>\n", " <td>[11, 17): 'German'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " <td>[GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...</td>\n", " </tr>\n", " <tr>\n", " <th>4923</th>\n", " <td>train</td>\n", " <td>896</td>\n", " <td>[523, 528): 'China'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " <td>[GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...</td>\n", " </tr>\n", " <tr>\n", " <th>4922</th>\n", " <td>train</td>\n", " <td>896</td>\n", " <td>[512, 518): 'Mexico'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " <td>[GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>374</th>\n", " <td>dev</td>\n", " <td>149</td>\n", " <td>[81, 93): 'Major League'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " <td>[GOLD]</td>\n", " </tr>\n", " <tr>\n", " <th>246</th>\n", " <td>dev</td>\n", " <td>120</td>\n", " <td>[63, 70): 'English'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " <td>[GOLD]</td>\n", " </tr>\n", " <tr>\n", " <th>78</th>\n", " <td>dev</td>\n", " <td>64</td>\n", " <td>[2571, 2575): 'AIDS'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " <td>[GOLD]</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>dev</td>\n", " <td>21</td>\n", " <td>[86, 90): 'UEFA'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " <td>[GOLD]</td>\n", " </tr>\n", " <tr>\n", " <th>0</th>\n", " <td>dev</td>\n", " <td>21</td>\n", " <td>[25, 39): 'STANDARD LIEGE'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " <td>[GOLD]</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>4928 rows × 7 columns</p>\n", "</div>" ], "text/plain": [ " fold doc_num span ent_type in_gold count \\\n", "4927 train 907 [590, 598): 'Gorleben' LOC True 17 \n", "4925 train 907 [63, 67): 'BONN' LOC True 17 \n", "4924 train 907 [11, 17): 'German' MISC True 17 \n", "4923 train 896 [523, 528): 'China' LOC True 17 \n", "4922 train 896 [512, 518): 'Mexico' LOC True 17 \n", "... ... ... ... ... ... ... \n", "374 dev 149 [81, 93): 'Major League' MISC True 0 \n", "246 dev 120 [63, 70): 'English' MISC True 0 \n", "78 dev 64 [2571, 2575): 'AIDS' MISC True 0 \n", "3 dev 21 [86, 90): 'UEFA' ORG True 0 \n", "0 dev 21 [25, 39): 'STANDARD LIEGE' ORG True 0 \n", "\n", " models \n", "4927 [GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64... \n", "4925 [GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64... \n", "4924 [GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64... \n", "4923 [GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64... \n", "4922 [GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64... \n", "... ... \n", "374 [GOLD] \n", "246 [GOLD] \n", "78 [GOLD] \n", "3 [GOLD] \n", "0 [GOLD] \n", "\n", "[4928 rows x 7 columns]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "full_results = cleaning.flag_suspicious_labels(evals,'ent_type','ent_type',label_name='ent_type',gold_feats=gold_elts,align_over_cols=['fold','doc_num','span'],keep_cols=[],split_doc=False)\n", "full_results" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>fold</th>\n", " <th>doc_num</th>\n", " <th>span</th>\n", " <th>ent_type</th>\n", " <th>in_gold</th>\n", " <th>count</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>4927</th>\n", " <td>train</td>\n", " <td>907</td>\n", " <td>[590, 598): 'Gorleben'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>4925</th>\n", " <td>train</td>\n", " <td>907</td>\n", " <td>[63, 67): 'BONN'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>4924</th>\n", " <td>train</td>\n", " <td>907</td>\n", " <td>[11, 17): 'German'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>4923</th>\n", " <td>train</td>\n", " <td>896</td>\n", " <td>[523, 528): 'China'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>4922</th>\n", " <td>train</td>\n", " <td>896</td>\n", " <td>[512, 518): 'Mexico'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>374</th>\n", " <td>dev</td>\n", " <td>149</td>\n", " <td>[81, 93): 'Major League'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>246</th>\n", " <td>dev</td>\n", " <td>120</td>\n", " <td>[63, 70): 'English'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>78</th>\n", " <td>dev</td>\n", " <td>64</td>\n", " <td>[2571, 2575): 'AIDS'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>dev</td>\n", " <td>21</td>\n", " <td>[86, 90): 'UEFA'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>0</th>\n", " <td>dev</td>\n", " <td>21</td>\n", " <td>[25, 39): 'STANDARD LIEGE'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>4928 rows × 6 columns</p>\n", "</div>" ], "text/plain": [ " fold doc_num span ent_type in_gold count\n", "4927 train 907 [590, 598): 'Gorleben' LOC True 17\n", "4925 train 907 [63, 67): 'BONN' LOC True 17\n", "4924 train 907 [11, 17): 'German' MISC True 17\n", "4923 train 896 [523, 528): 'China' LOC True 17\n", "4922 train 896 [512, 518): 'Mexico' LOC True 17\n", "... ... ... ... ... ... ...\n", "374 dev 149 [81, 93): 'Major League' MISC True 0\n", "246 dev 120 [63, 70): 'English' MISC True 0\n", "78 dev 64 [2571, 2575): 'AIDS' MISC True 0\n", "3 dev 21 [86, 90): 'UEFA' ORG True 0\n", "0 dev 21 [25, 39): 'STANDARD LIEGE' ORG True 0\n", "\n", "[4928 rows x 6 columns]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Drop Boolean columns for now\n", "results = full_results[[\"fold\", \"doc_num\", \"span\", \"ent_type\", \"in_gold\", \"count\"]]\n", "results" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>num_ents</th>\n", " </tr>\n", " <tr>\n", " <th>count</th>\n", " <th></th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>115</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>31</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>23</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>20</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>18</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>23</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>23</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>19</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>29</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>28</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>41</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>48</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>62</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>75</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>115</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>248</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>2940</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " num_ents\n", "count \n", "0 115\n", "1 31\n", "2 23\n", "3 20\n", "4 17\n", "5 18\n", "6 23\n", "7 23\n", "8 19\n", "9 29\n", "10 28\n", "11 41\n", "12 48\n", "13 62\n", "14 75\n", "15 115\n", "16 248\n", "17 2940" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(results[results[\"in_gold\"] == True][[\"count\", \"span\"]]\n", " .groupby(\"count\").count()\n", " .rename(columns={\"span\": \"num_ents\"}))" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>num_ents</th>\n", " </tr>\n", " <tr>\n", " <th>count</th>\n", " <th></th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>1</th>\n", " <td>468</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>174</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>94</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>61</td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>52</td>\n", " </tr>\n", " <tr>\n", " <th>6</th>\n", " <td>26</td>\n", " </tr>\n", " <tr>\n", " <th>7</th>\n", " <td>36</td>\n", " </tr>\n", " <tr>\n", " <th>8</th>\n", " <td>16</td>\n", " </tr>\n", " <tr>\n", " <th>9</th>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>10</th>\n", " <td>12</td>\n", " </tr>\n", " <tr>\n", " <th>11</th>\n", " <td>9</td>\n", " </tr>\n", " <tr>\n", " <th>12</th>\n", " <td>9</td>\n", " </tr>\n", " <tr>\n", " <th>13</th>\n", " <td>8</td>\n", " </tr>\n", " <tr>\n", " <th>14</th>\n", " <td>11</td>\n", " </tr>\n", " <tr>\n", " <th>15</th>\n", " <td>14</td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>15</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>31</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " num_ents\n", "count \n", "1 468\n", "2 174\n", "3 94\n", "4 61\n", "5 52\n", "6 26\n", "7 36\n", "8 16\n", "9 17\n", "10 12\n", "11 9\n", "12 9\n", "13 8\n", "14 11\n", "15 14\n", "16 15\n", "17 31" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(results[results[\"in_gold\"] == False][[\"count\", \"span\"]]\n", " .groupby(\"count\").count()\n", " .rename(columns={\"span\": \"num_ents\"}))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>fold</th>\n", " <th>doc_num</th>\n", " <th>span</th>\n", " <th>ent_type</th>\n", " <th>in_gold</th>\n", " <th>count</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>3</th>\n", " <td>dev</td>\n", " <td>21</td>\n", " <td>[86, 90): 'UEFA'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>0</th>\n", " <td>dev</td>\n", " <td>21</td>\n", " <td>[25, 39): 'STANDARD LIEGE'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>78</th>\n", " <td>dev</td>\n", " <td>64</td>\n", " <td>[2571, 2575): 'AIDS'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>246</th>\n", " <td>dev</td>\n", " <td>120</td>\n", " <td>[63, 70): 'English'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>374</th>\n", " <td>dev</td>\n", " <td>149</td>\n", " <td>[81, 93): 'Major League'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>498</th>\n", " <td>dev</td>\n", " <td>182</td>\n", " <td>[2173, 2177): 'Ruch'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>462</th>\n", " <td>dev</td>\n", " <td>182</td>\n", " <td>[662, 670): 'division'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>512</th>\n", " <td>dev</td>\n", " <td>203</td>\n", " <td>[879, 881): '90'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>622</th>\n", " <td>dev</td>\n", " <td>214</td>\n", " <td>[1689, 1705): 'Schindler's List'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>621</th>\n", " <td>dev</td>\n", " <td>214</td>\n", " <td>[1643, 1648): 'Oscar'</td>\n", " <td>PER</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>583</th>\n", " <td>dev</td>\n", " <td>214</td>\n", " <td>[285, 305): 'Venice Film Festival'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>569</th>\n", " <td>dev</td>\n", " <td>214</td>\n", " <td>[187, 202): 'Michael Collins'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>802</th>\n", " <td>test</td>\n", " <td>15</td>\n", " <td>[44, 56): 'WORLD SERIES'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>801</th>\n", " <td>test</td>\n", " <td>15</td>\n", " <td>[32, 43): 'WEST INDIES'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>942</th>\n", " <td>test</td>\n", " <td>21</td>\n", " <td>[719, 725): 'Wijaya'</td>\n", " <td>PER</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>896</th>\n", " <td>test</td>\n", " <td>21</td>\n", " <td>[22, 38): 'WORLD GRAND PRIX'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1057</th>\n", " <td>test</td>\n", " <td>23</td>\n", " <td>[1117, 1127): 'NY RANGERS'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1052</th>\n", " <td>test</td>\n", " <td>23</td>\n", " <td>[1106, 1113): 'TORONTO'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1025</th>\n", " <td>test</td>\n", " <td>23</td>\n", " <td>[673, 689): 'CENTRAL DIVISION'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1016</th>\n", " <td>test</td>\n", " <td>23</td>\n", " <td>[599, 611): 'NY ISLANDERS'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " fold doc_num span ent_type in_gold \\\n", "3 dev 21 [86, 90): 'UEFA' ORG True \n", "0 dev 21 [25, 39): 'STANDARD LIEGE' ORG True \n", "78 dev 64 [2571, 2575): 'AIDS' MISC True \n", "246 dev 120 [63, 70): 'English' MISC True \n", "374 dev 149 [81, 93): 'Major League' MISC True \n", "498 dev 182 [2173, 2177): 'Ruch' ORG True \n", "462 dev 182 [662, 670): 'division' MISC True \n", "512 dev 203 [879, 881): '90' LOC True \n", "622 dev 214 [1689, 1705): 'Schindler's List' MISC True \n", "621 dev 214 [1643, 1648): 'Oscar' PER True \n", "583 dev 214 [285, 305): 'Venice Film Festival' MISC True \n", "569 dev 214 [187, 202): 'Michael Collins' MISC True \n", "802 test 15 [44, 56): 'WORLD SERIES' MISC True \n", "801 test 15 [32, 43): 'WEST INDIES' LOC True \n", "942 test 21 [719, 725): 'Wijaya' PER True \n", "896 test 21 [22, 38): 'WORLD GRAND PRIX' MISC True \n", "1057 test 23 [1117, 1127): 'NY RANGERS' ORG True \n", "1052 test 23 [1106, 1113): 'TORONTO' ORG True \n", "1025 test 23 [673, 689): 'CENTRAL DIVISION' MISC True \n", "1016 test 23 [599, 611): 'NY ISLANDERS' ORG True \n", "\n", " count \n", "3 0 \n", "0 0 \n", "78 0 \n", "246 0 \n", "374 0 \n", "498 0 \n", "462 0 \n", "512 0 \n", "622 0 \n", "621 0 \n", "583 0 \n", "569 0 \n", "802 0 \n", "801 0 \n", "942 0 \n", "896 0 \n", "1057 0 \n", "1052 0 \n", "1025 0 \n", "1016 0 " ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Pull out some hard-to-find examples, sorting by document to make labeling easier\n", "hard_to_get = results[results[\"in_gold\"]].sort_values([\"count\", \"fold\", \"doc_num\"]).head(20)\n", "hard_to_get" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### TODO: Relabel the above 20 examples with a Markdown table (copy from CSV)\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>fold</th>\n", " <th>doc_num</th>\n", " <th>span</th>\n", " <th>ent_type</th>\n", " <th>in_gold</th>\n", " <th>count</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>373</th>\n", " <td>dev</td>\n", " <td>149</td>\n", " <td>[81, 102): 'Major League Baseball'</td>\n", " <td>MISC</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>570</th>\n", " <td>dev</td>\n", " <td>214</td>\n", " <td>[187, 202): 'Michael Collins'</td>\n", " <td>PER</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>983</th>\n", " <td>test</td>\n", " <td>23</td>\n", " <td>[94, 116): 'National Hockey League'</td>\n", " <td>MISC</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1110</th>\n", " <td>test</td>\n", " <td>25</td>\n", " <td>[856, 864): 'NFC East'</td>\n", " <td>MISC</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1109</th>\n", " <td>test</td>\n", " <td>25</td>\n", " <td>[823, 835): 'Philadelphia'</td>\n", " <td>ORG</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1184</th>\n", " <td>test</td>\n", " <td>41</td>\n", " <td>[674, 688): 'Sporting Gijon'</td>\n", " <td>ORG</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1323</th>\n", " <td>test</td>\n", " <td>114</td>\n", " <td>[51, 61): 'sales-USDA'</td>\n", " <td>ORG</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1367</th>\n", " <td>test</td>\n", " <td>118</td>\n", " <td>[776, 791): 'mid-Mississippi'</td>\n", " <td>LOC</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1362</th>\n", " <td>test</td>\n", " <td>118</td>\n", " <td>[535, 550): 'mid-Mississippi'</td>\n", " <td>LOC</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1509</th>\n", " <td>test</td>\n", " <td>178</td>\n", " <td>[1787, 1800): 'Uruguay Round'</td>\n", " <td>MISC</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1560</th>\n", " <td>test</td>\n", " <td>180</td>\n", " <td>[588, 592): 'BILO'</td>\n", " <td>ORG</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1558</th>\n", " <td>test</td>\n", " <td>180</td>\n", " <td>[579, 583): 'TOPS'</td>\n", " <td>ORG</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1550</th>\n", " <td>test</td>\n", " <td>180</td>\n", " <td>[395, 399): 'BILO'</td>\n", " <td>ORG</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1544</th>\n", " <td>test</td>\n", " <td>180</td>\n", " <td>[286, 293): 'Malysia'</td>\n", " <td>ORG</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1542</th>\n", " <td>test</td>\n", " <td>180</td>\n", " <td>[259, 263): 'BILO'</td>\n", " <td>ORG</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1649</th>\n", " <td>test</td>\n", " <td>207</td>\n", " <td>[1041, 1047): 'Oxford'</td>\n", " <td>ORG</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1786</th>\n", " <td>test</td>\n", " <td>219</td>\n", " <td>[368, 381): 'Koo Jeon Woon'</td>\n", " <td>PER</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1807</th>\n", " <td>test</td>\n", " <td>222</td>\n", " <td>[218, 225): 'EASTERN'</td>\n", " <td>MISC</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>1805</th>\n", " <td>test</td>\n", " <td>222</td>\n", " <td>[92, 114): 'National Hockey League'</td>\n", " <td>MISC</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>2054</th>\n", " <td>train</td>\n", " <td>48</td>\n", " <td>[885, 899): 'Sjeng Schalken'</td>\n", " <td>ORG</td>\n", " <td>False</td>\n", " <td>17</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " fold doc_num span ent_type in_gold \\\n", "373 dev 149 [81, 102): 'Major League Baseball' MISC False \n", "570 dev 214 [187, 202): 'Michael Collins' PER False \n", "983 test 23 [94, 116): 'National Hockey League' MISC False \n", "1110 test 25 [856, 864): 'NFC East' MISC False \n", "1109 test 25 [823, 835): 'Philadelphia' ORG False \n", "1184 test 41 [674, 688): 'Sporting Gijon' ORG False \n", "1323 test 114 [51, 61): 'sales-USDA' ORG False \n", "1367 test 118 [776, 791): 'mid-Mississippi' LOC False \n", "1362 test 118 [535, 550): 'mid-Mississippi' LOC False \n", "1509 test 178 [1787, 1800): 'Uruguay Round' MISC False \n", "1560 test 180 [588, 592): 'BILO' ORG False \n", "1558 test 180 [579, 583): 'TOPS' ORG False \n", "1550 test 180 [395, 399): 'BILO' ORG False \n", "1544 test 180 [286, 293): 'Malysia' ORG False \n", "1542 test 180 [259, 263): 'BILO' ORG False \n", "1649 test 207 [1041, 1047): 'Oxford' ORG False \n", "1786 test 219 [368, 381): 'Koo Jeon Woon' PER False \n", "1807 test 222 [218, 225): 'EASTERN' MISC False \n", "1805 test 222 [92, 114): 'National Hockey League' MISC False \n", "2054 train 48 [885, 899): 'Sjeng Schalken' ORG False \n", "\n", " count \n", "373 17 \n", "570 17 \n", "983 17 \n", "1110 17 \n", "1109 17 \n", "1184 17 \n", "1323 17 \n", "1367 17 \n", "1362 17 \n", "1509 17 \n", "1560 17 \n", "1558 17 \n", "1550 17 \n", "1544 17 \n", "1542 17 \n", "1649 17 \n", "1786 17 \n", "1807 17 \n", "1805 17 \n", "2054 17 " ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Hardest results not in the gold standard for models to avoid\n", "hard_to_avoid = results[~results[\"in_gold\"]].sort_values([\"count\", \"fold\", \"doc_num\"], ascending=[False, True, True]).head(20)\n", "hard_to_avoid" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### TODO: Relabel the above 20 examples (copy from CSV)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Remainder of Experiment\n", "\n", "For each of the 10 folds, train a model on the fold's training set and run\n", "analysis on the fold's test set." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Starting fold 1.\n", "Training model using all of 768-dimension embeddings.\n", "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n", "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n", "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n", "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n", "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n", "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n", "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n", "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n", "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n", "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n", "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n", "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n", "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n", "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n", "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n", "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n", "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=32 and seed=438878.\n", "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=32 and seed=654571.\n", "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=64 and seed=975622.\n", "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=32 and seed=773956.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n", "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=256 and seed=643865.\n", "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=256 and seed=781567.\n", "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=64 and seed=201469.\n", "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=128 and seed=839748.\n", "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=128 and seed=513226.\n", "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=128 and seed=450385.\n", "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=128 and seed=128113.\n", "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=64 and seed=526478.\n", "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=64 and seed=94177.\n", "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=256 and seed=402414.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n", "Trained 17 models.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8e6ab8836b6246edb1d3e635fda3ad7f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done with fold 1.\n", "Starting fold 2.\n", "Training model using all of 768-dimension embeddings.\n", "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n", "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n", "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n", "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n", "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n", "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n", "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n", "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n", "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n", "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n", "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n", "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n", "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n", "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n", "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n", "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n", "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=256 and seed=643865.\n", "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=256 and seed=781567.\n", "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=128 and seed=450385.\n", "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=256 and seed=402414.\n", "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=32 and seed=654571.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n", "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=32 and seed=438878.\n", "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=32 and seed=773956.\n", "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=128 and seed=128113.\n", "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=64 and seed=526478.\n", "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=64 and seed=975622.\n", "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=64 and seed=94177.\n", "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=64 and seed=201469.\n", "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=128 and seed=839748.\n", "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=128 and seed=513226.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n", "Trained 17 models.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "51af4ff8c89d4dca999d510eccb7e575", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done with fold 2.\n", "Starting fold 3.\n", "Training model using all of 768-dimension embeddings.\n", "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n", "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n", "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n", "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n", "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n", "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n", "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n", "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n", "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n", "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n", "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n", "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n", "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n", "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n", "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n", "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n", "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=32 and seed=773956.\n", "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=32 and seed=654571.\n", "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=64 and seed=975622.\n", "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=32 and seed=438878.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n", "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=256 and seed=781567.\n", "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=256 and seed=643865.\n", "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=64 and seed=201469.\n", "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=128 and seed=839748.\n", "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=128 and seed=450385.\n", "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=128 and seed=513226.\n", "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=128 and seed=128113.\n", "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=64 and seed=94177.\n", "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=64 and seed=526478.\n", "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=256 and seed=402414.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n", "Trained 17 models.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d7b85c9d3d9f4c7dbad63f5a7b00ce03", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done with fold 3.\n", "Starting fold 4.\n", "Training model using all of 768-dimension embeddings.\n", "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n", "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n", "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n", "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n", "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n", "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n", "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n", "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n", "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n", "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n", "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n", "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n", "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n", "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n", "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n", "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n", "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=256 and seed=643865.\n", "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=128 and seed=450385.\n", "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=256 and seed=781567.\n", "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=32 and seed=438878.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n", "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=32 and seed=654571.\n", "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=32 and seed=773956.\n", "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=128 and seed=839748.\n", "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=64 and seed=975622.\n", "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=64 and seed=94177.\n", "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=64 and seed=526478.\n", "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=64 and seed=201469.\n", "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=128 and seed=513226.\n", "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=128 and seed=128113.\n", "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=256 and seed=402414.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n", "Trained 17 models.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "01b0a7edf3714a63b0237cf286fd041d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done with fold 4.\n", "Starting fold 5.\n", "Training model using all of 768-dimension embeddings.\n", "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n", "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n", "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n", "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n", "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n", "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n", "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n", "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n", "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n", "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n", "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n", "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n", "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n", "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n", "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n", "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n", "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=32 and seed=438878.\n", "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=32 and seed=773956.\n", "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=64 and seed=526478.\n", "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=32 and seed=654571.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n", "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=64 and seed=94177.\n", "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=128 and seed=513226.\n", "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=64 and seed=975622.\n", "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=64 and seed=201469.\n", "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=128 and seed=128113.\n", "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=128 and seed=839748.\n", "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=128 and seed=450385.\n", "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=256 and seed=781567.\n", "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=256 and seed=643865.\n", "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=256 and seed=402414.\n", "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=256 and seed=822761.\n", "Trained 17 models.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c003ad7644b041b18315f34e8bf1f85c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done with fold 5.\n", "Starting fold 6.\n", "Training model using all of 768-dimension embeddings.\n", "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n", "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n", "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n", "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n", "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n", "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n", "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n", "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n", "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n", "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n", "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n", "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n", "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n", "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n", "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n", "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n", "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=256 and seed=781567.\n", "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=32 and seed=89250.\n", "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=128 and seed=839748.\n", "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=32 and seed=438878.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=643865.\n", "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=32 and seed=654571.\n", "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=32 and seed=773956.\n", "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=128 and seed=513226.\n", "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=64 and seed=94177.\n", "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=64 and seed=526478.\n", "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=64 and seed=975622.\n", "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=64 and seed=201469.\n", "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=128 and seed=450385.\n", "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=128 and seed=128113.\n", "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=256 and seed=402414.\n", "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=256 and seed=822761.\n", "Trained 17 models.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "79e940db371644e9bf2e8c945e0806bf", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done with fold 6.\n", "Starting fold 7.\n", "Training model using all of 768-dimension embeddings.\n", "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n", "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n", "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n", "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n", "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n", "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n", "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n", "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n", "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n", "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n", "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n", "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n", "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n", "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n", "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n", "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n", "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=32 and seed=654571.\n", "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=64 and seed=975622.\n", "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=32 and seed=438878.\n", "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=256 and seed=781567.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=773956.\n", "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=32 and seed=89250.\n", "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=256 and seed=402414.\n", "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=64 and seed=94177.\n", "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=128 and seed=513226.\n", "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=128 and seed=128113.\n", "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=128 and seed=450385.\n", "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=128 and seed=839748.\n", "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=64 and seed=526478.\n", "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=64 and seed=201469.\n", "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=256 and seed=643865.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n", "Trained 17 models.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "04a013f28dcd46d1831ff6ca1be8266a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done with fold 7.\n", "Starting fold 8.\n", "Training model using all of 768-dimension embeddings.\n", "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n", "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n", "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n", "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n", "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n", "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n", "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n", "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n", "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n", "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n", "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n", "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n", "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n", "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n", "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n", "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n", "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=256 and seed=781567.\n", "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=32 and seed=773956.\n", "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=128 and seed=450385.\n", "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=256 and seed=402414.\n", "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=32 and seed=438878.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n", "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=256 and seed=643865.\n", "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=32 and seed=654571.\n", "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=128 and seed=513226.\n", "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=64 and seed=94177.\n", "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=64 and seed=201469.\n", "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=64 and seed=975622.\n", "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=64 and seed=526478.\n", "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=128 and seed=839748.\n", "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=128 and seed=128113.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n", "Trained 17 models.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4a17afd5ec08439a8956146853fd9679", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done with fold 8.\n", "Starting fold 9.\n", "Training model using all of 768-dimension embeddings.\n", "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n", "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n", "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n", "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n", "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n", "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n", "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n", "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n", "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n", "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n", "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n", "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n", "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n", "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n", "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n", "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n", "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=32 and seed=773956.\n", "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=64 and seed=975622.\n", "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=32 and seed=438878.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n", "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=32 and seed=654571.\n", "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=64 and seed=526478.\n", "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=128 and seed=513226.\n", "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=64 and seed=94177.\n", "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=64 and seed=201469.\n", "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=128 and seed=128113.\n", "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=128 and seed=839748.\n", "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=128 and seed=450385.\n", "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=256 and seed=781567.\n", "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=256 and seed=643865.\n", "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=256 and seed=402414.\n", "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n", "Trained 17 models.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "038d5be9065f4ac6837021045d963b06", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Done with fold 9.\n" ] } ], "source": [ "def handle_fold(fold_ix: int) -> Dict[str, Any]:\n", " \"\"\"\n", " The per-fold processing of the previous section's cells, collapsed into \n", " a single function.\n", " \n", " :param fold_ix: 0-based index of fold\n", " \n", " :returns: a dictionary that maps data structure name to data structure\n", " \"\"\"\n", " # To avoid accidentally picking up leftover data from a previous cell,\n", " # variables local to this function are named with a leading underscore\n", " _train_inputs_df = corpus_df.merge(train_keys[fold_ix])\n", " _test_inputs_df = corpus_df.merge(test_keys[fold_ix])\n", " _models = maybe_train_models(_train_inputs_df, fold_ix)\n", " _evals = eval_models(_models, _test_inputs_df)\n", " _summary_df = make_summary_df(_evals)\n", " _gold_elts = cleaning.preprocess.combine_raw_spans_docs_to_match(corpus_raw,_evals[list(evals.keys())[0]])\n", " _full_results = cleaning.flag_suspicious_labels(_evals,'ent_type','ent_type',\n", " label_name='ent_type',\n", " gold_feats=_gold_elts,\n", " align_over_cols=['fold','doc_num','span'],\n", " keep_cols=[],split_doc=False)\n", " _results = _full_results[[\"fold\", \"doc_num\", \"span\", \n", " \"ent_type\", \"in_gold\", \"count\"]]\n", " return {\n", " \"models\": _models,\n", " \"summary_df\": _summary_df,\n", " \"full_results\": _full_results,\n", " \"results\": _results\n", " }\n", "\n", "# Start with the (already computed) results for fold 0\n", "results_by_fold = [\n", " {\n", " \"models\": models,\n", " \"summary_df\": summary_df,\n", " \"full_results\": full_results,\n", " \"results\": results\n", " }\n", "]\n", "\n", "for fold in range(1, _KFOLD_NUM_FOLDS):\n", " print(f\"Starting fold {fold}.\")\n", " results_by_fold.append(handle_fold(fold))\n", " print(f\"Done with fold {fold}.\")\n", " " ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>fold</th>\n", " <th>doc_num</th>\n", " <th>span</th>\n", " <th>ent_type</th>\n", " <th>in_gold</th>\n", " <th>count</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>4927</th>\n", " <td>train</td>\n", " <td>907</td>\n", " <td>[590, 598): 'Gorleben'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>4925</th>\n", " <td>train</td>\n", " <td>907</td>\n", " <td>[63, 67): 'BONN'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>4924</th>\n", " <td>train</td>\n", " <td>907</td>\n", " <td>[11, 17): 'German'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>4923</th>\n", " <td>train</td>\n", " <td>896</td>\n", " <td>[523, 528): 'China'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>4922</th>\n", " <td>train</td>\n", " <td>896</td>\n", " <td>[512, 518): 'Mexico'</td>\n", " <td>LOC</td>\n", " <td>True</td>\n", " <td>17</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>271</th>\n", " <td>dev</td>\n", " <td>93</td>\n", " <td>[469, 481): 'JAKARTA POST'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>183</th>\n", " <td>dev</td>\n", " <td>76</td>\n", " <td>[1285, 1312): 'Chicago Purchasing Managers'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>126</th>\n", " <td>dev</td>\n", " <td>49</td>\n", " <td>[1920, 1925): 'Tajik'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>25</th>\n", " <td>dev</td>\n", " <td>15</td>\n", " <td>[109, 133): 'National Football League'</td>\n", " <td>ORG</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>17</th>\n", " <td>dev</td>\n", " <td>15</td>\n", " <td>[15, 40): 'AMERICAN FOOTBALL-RANDALL'</td>\n", " <td>MISC</td>\n", " <td>True</td>\n", " <td>0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>44802 rows × 6 columns</p>\n", "</div>" ], "text/plain": [ " fold doc_num span ent_type \\\n", "4927 train 907 [590, 598): 'Gorleben' LOC \n", "4925 train 907 [63, 67): 'BONN' LOC \n", "4924 train 907 [11, 17): 'German' MISC \n", "4923 train 896 [523, 528): 'China' LOC \n", "4922 train 896 [512, 518): 'Mexico' LOC \n", "... ... ... ... ... \n", "271 dev 93 [469, 481): 'JAKARTA POST' ORG \n", "183 dev 76 [1285, 1312): 'Chicago Purchasing Managers' ORG \n", "126 dev 49 [1920, 1925): 'Tajik' MISC \n", "25 dev 15 [109, 133): 'National Football League' ORG \n", "17 dev 15 [15, 40): 'AMERICAN FOOTBALL-RANDALL' MISC \n", "\n", " in_gold count \n", "4927 True 17 \n", "4925 True 17 \n", "4924 True 17 \n", "4923 True 17 \n", "4922 True 17 \n", "... ... ... \n", "271 True 0 \n", "183 True 0 \n", "126 True 0 \n", "25 True 0 \n", "17 True 0 \n", "\n", "[44802 rows x 6 columns]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Combine all the results into a single dataframe for the entire corpus\n", "all_results = pd.concat([r[\"results\"] for r in results_by_fold])\n", "all_results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate CSV files for manual labeling" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>count</th>\n", " <th>fold</th>\n", " <th>doc_offset</th>\n", " <th>corpus_span</th>\n", " <th>corpus_ent_type</th>\n", " <th>error_type</th>\n", " <th>correct_span</th>\n", " <th>correct_ent_type</th>\n", " <th>notes</th>\n", " <th>time_started</th>\n", " <th>time_stopped</th>\n", " <th>time_elapsed</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>30</th>\n", " <td>0</td>\n", " <td>dev</td>\n", " <td>2</td>\n", " <td>[760, 765): 'Leeds'</td>\n", " <td>ORG</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>21</th>\n", " <td>0</td>\n", " <td>dev</td>\n", " <td>2</td>\n", " <td>[614, 634): 'Duke of Norfolk's XI'</td>\n", " <td>ORG</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>5</th>\n", " <td>0</td>\n", " <td>dev</td>\n", " <td>2</td>\n", " <td>[189, 218): 'Test and County Cricket Board'</td>\n", " <td>ORG</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>0</td>\n", " <td>dev</td>\n", " <td>2</td>\n", " <td>[87, 92): 'Ashes'</td>\n", " <td>MISC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>0</th>\n", " <td>0</td>\n", " <td>dev</td>\n", " <td>2</td>\n", " <td>[25, 30): 'ASHES'</td>\n", " <td>MISC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>1738</th>\n", " <td>17</td>\n", " <td>test</td>\n", " <td>230</td>\n", " <td>[230, 238): 'Charlton'</td>\n", " <td>PER</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1737</th>\n", " <td>17</td>\n", " <td>test</td>\n", " <td>230</td>\n", " <td>[177, 187): 'Englishman'</td>\n", " <td>MISC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1736</th>\n", " <td>17</td>\n", " <td>test</td>\n", " <td>230</td>\n", " <td>[135, 142): 'Ireland'</td>\n", " <td>LOC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1735</th>\n", " <td>17</td>\n", " <td>test</td>\n", " <td>230</td>\n", " <td>[87, 100): 'Jack Charlton'</td>\n", " <td>PER</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1734</th>\n", " <td>17</td>\n", " <td>test</td>\n", " <td>230</td>\n", " <td>[69, 75): 'DUBLIN'</td>\n", " <td>LOC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>11590 rows × 12 columns</p>\n", "</div>" ], "text/plain": [ " count fold doc_offset corpus_span \\\n", "30 0 dev 2 [760, 765): 'Leeds' \n", "21 0 dev 2 [614, 634): 'Duke of Norfolk's XI' \n", "5 0 dev 2 [189, 218): 'Test and County Cricket Board' \n", "3 0 dev 2 [87, 92): 'Ashes' \n", "0 0 dev 2 [25, 30): 'ASHES' \n", "... ... ... ... ... \n", "1738 17 test 230 [230, 238): 'Charlton' \n", "1737 17 test 230 [177, 187): 'Englishman' \n", "1736 17 test 230 [135, 142): 'Ireland' \n", "1735 17 test 230 [87, 100): 'Jack Charlton' \n", "1734 17 test 230 [69, 75): 'DUBLIN' \n", "\n", " corpus_ent_type error_type correct_span correct_ent_type notes \\\n", "30 ORG \n", "21 ORG \n", "5 ORG \n", "3 MISC \n", "0 MISC \n", "... ... ... ... ... ... \n", "1738 PER \n", "1737 MISC \n", "1736 LOC \n", "1735 PER \n", "1734 LOC \n", "\n", " time_started time_stopped time_elapsed \n", "30 \n", "21 \n", "5 \n", "3 \n", "0 \n", "... ... ... ... \n", "1738 \n", "1737 \n", "1736 \n", "1735 \n", "1734 \n", "\n", "[11590 rows x 12 columns]" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Reformat for output\n", "dev_and_test_results = all_results[all_results[\"fold\"].isin([\"dev\", \"test\"])]\n", "in_gold_to_write, not_in_gold_to_write = cleaning.analysis.csv_prep(dev_and_test_results, \"count\")\n", "in_gold_to_write" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>count</th>\n", " <th>fold</th>\n", " <th>doc_offset</th>\n", " <th>model_span</th>\n", " <th>model_ent_type</th>\n", " <th>error_type</th>\n", " <th>corpus_span</th>\n", " <th>corpus_ent_type</th>\n", " <th>correct_span</th>\n", " <th>correct_ent_type</th>\n", " <th>notes</th>\n", " <th>time_started</th>\n", " <th>time_stopped</th>\n", " <th>time_elapsed</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>29</th>\n", " <td>17</td>\n", " <td>dev</td>\n", " <td>2</td>\n", " <td>[760, 765): 'Leeds'</td>\n", " <td>LOC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>25</th>\n", " <td>17</td>\n", " <td>dev</td>\n", " <td>6</td>\n", " <td>[567, 572): 'Rotor'</td>\n", " <td>PER</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>20</th>\n", " <td>17</td>\n", " <td>dev</td>\n", " <td>6</td>\n", " <td>[399, 404): 'Rotor'</td>\n", " <td>PER</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>16</th>\n", " <td>17</td>\n", " <td>dev</td>\n", " <td>6</td>\n", " <td>[262, 267): 'Rotor'</td>\n", " <td>PER</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>142</th>\n", " <td>17</td>\n", " <td>dev</td>\n", " <td>11</td>\n", " <td>[1961, 1975): 'Czech Republic'</td>\n", " <td>LOC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>1708</th>\n", " <td>1</td>\n", " <td>test</td>\n", " <td>228</td>\n", " <td>[771, 784): 'De Graafschap'</td>\n", " <td>ORG</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1690</th>\n", " <td>1</td>\n", " <td>test</td>\n", " <td>228</td>\n", " <td>[269, 287): 'Brazilian defender'</td>\n", " <td>MISC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1679</th>\n", " <td>1</td>\n", " <td>test</td>\n", " <td>228</td>\n", " <td>[40, 43): 'SIX'</td>\n", " <td>ORG</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1724</th>\n", " <td>1</td>\n", " <td>test</td>\n", " <td>230</td>\n", " <td>[19, 29): 'ENGLISHMAN'</td>\n", " <td>LOC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1727</th>\n", " <td>1</td>\n", " <td>test</td>\n", " <td>230</td>\n", " <td>[19, 38): 'ENGLISHMAN CHARLTON'</td>\n", " <td>PER</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>4366 rows × 14 columns</p>\n", "</div>" ], "text/plain": [ " count fold doc_offset model_span \\\n", "29 17 dev 2 [760, 765): 'Leeds' \n", "25 17 dev 6 [567, 572): 'Rotor' \n", "20 17 dev 6 [399, 404): 'Rotor' \n", "16 17 dev 6 [262, 267): 'Rotor' \n", "142 17 dev 11 [1961, 1975): 'Czech Republic' \n", "... ... ... ... ... \n", "1708 1 test 228 [771, 784): 'De Graafschap' \n", "1690 1 test 228 [269, 287): 'Brazilian defender' \n", "1679 1 test 228 [40, 43): 'SIX' \n", "1724 1 test 230 [19, 29): 'ENGLISHMAN' \n", "1727 1 test 230 [19, 38): 'ENGLISHMAN CHARLTON' \n", "\n", " model_ent_type error_type corpus_span corpus_ent_type correct_span \\\n", "29 LOC \n", "25 PER \n", "20 PER \n", "16 PER \n", "142 LOC \n", "... ... ... ... ... ... \n", "1708 ORG \n", "1690 MISC \n", "1679 ORG \n", "1724 LOC \n", "1727 PER \n", "\n", " correct_ent_type notes time_started time_stopped time_elapsed \n", "29 \n", "25 \n", "20 \n", "16 \n", "142 \n", "... ... ... ... ... ... \n", "1708 \n", "1690 \n", "1679 \n", "1724 \n", "1727 \n", "\n", "[4366 rows x 14 columns]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "not_in_gold_to_write" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "in_gold_to_write.to_csv(\"outputs/CoNLL_4_in_gold.csv\", index=False)\n", "not_in_gold_to_write.to_csv(\"outputs/CoNLL_4_not_in_gold.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>count</th>\n", " <th>fold</th>\n", " <th>doc_offset</th>\n", " <th>corpus_span</th>\n", " <th>corpus_ent_type</th>\n", " <th>error_type</th>\n", " <th>correct_span</th>\n", " <th>correct_ent_type</th>\n", " <th>notes</th>\n", " <th>time_started</th>\n", " <th>time_stopped</th>\n", " <th>time_elapsed</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>1486</th>\n", " <td>0</td>\n", " <td>train</td>\n", " <td>6</td>\n", " <td>[121, 137): 'Toronto Dominion'</td>\n", " <td>PER</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1358</th>\n", " <td>0</td>\n", " <td>train</td>\n", " <td>24</td>\n", " <td>[384, 388): 'FLNC'</td>\n", " <td>ORG</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1355</th>\n", " <td>0</td>\n", " <td>train</td>\n", " <td>24</td>\n", " <td>[161, 169): 'Africans'</td>\n", " <td>MISC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1965</th>\n", " <td>0</td>\n", " <td>train</td>\n", " <td>25</td>\n", " <td>[141, 151): 'mid-Norway'</td>\n", " <td>MISC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1383</th>\n", " <td>0</td>\n", " <td>train</td>\n", " <td>28</td>\n", " <td>[1133, 1135): 'EU'</td>\n", " <td>ORG</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>4132</th>\n", " <td>17</td>\n", " <td>train</td>\n", " <td>945</td>\n", " <td>[130, 137): 'Preston'</td>\n", " <td>ORG</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>4131</th>\n", " <td>17</td>\n", " <td>train</td>\n", " <td>945</td>\n", " <td>[119, 127): 'Plymouth'</td>\n", " <td>ORG</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>4130</th>\n", " <td>17</td>\n", " <td>train</td>\n", " <td>945</td>\n", " <td>[72, 79): 'English'</td>\n", " <td>MISC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>4129</th>\n", " <td>17</td>\n", " <td>train</td>\n", " <td>945</td>\n", " <td>[43, 49): 'LONDON'</td>\n", " <td>LOC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>4128</th>\n", " <td>17</td>\n", " <td>train</td>\n", " <td>945</td>\n", " <td>[19, 26): 'ENGLISH'</td>\n", " <td>MISC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>23499 rows × 12 columns</p>\n", "</div>" ], "text/plain": [ " count fold doc_offset corpus_span \\\n", "1486 0 train 6 [121, 137): 'Toronto Dominion' \n", "1358 0 train 24 [384, 388): 'FLNC' \n", "1355 0 train 24 [161, 169): 'Africans' \n", "1965 0 train 25 [141, 151): 'mid-Norway' \n", "1383 0 train 28 [1133, 1135): 'EU' \n", "... ... ... ... ... \n", "4132 17 train 945 [130, 137): 'Preston' \n", "4131 17 train 945 [119, 127): 'Plymouth' \n", "4130 17 train 945 [72, 79): 'English' \n", "4129 17 train 945 [43, 49): 'LONDON' \n", "4128 17 train 945 [19, 26): 'ENGLISH' \n", "\n", " corpus_ent_type error_type correct_span correct_ent_type notes \\\n", "1486 PER \n", "1358 ORG \n", "1355 MISC \n", "1965 MISC \n", "1383 ORG \n", "... ... ... ... ... ... \n", "4132 ORG \n", "4131 ORG \n", "4130 MISC \n", "4129 LOC \n", "4128 MISC \n", "\n", " time_started time_stopped time_elapsed \n", "1486 \n", "1358 \n", "1355 \n", "1965 \n", "1383 \n", "... ... ... ... \n", "4132 \n", "4131 \n", "4130 \n", "4129 \n", "4128 \n", "\n", "[23499 rows x 12 columns]" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Repeat for the contents of the original training set\n", "train_results = all_results[all_results[\"fold\"] == \"train\"]\n", "in_gold_to_write, not_in_gold_to_write = cleaning.analysis.csv_prep(train_results, \"count\")\n", "in_gold_to_write" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>count</th>\n", " <th>fold</th>\n", " <th>doc_offset</th>\n", " <th>model_span</th>\n", " <th>model_ent_type</th>\n", " <th>error_type</th>\n", " <th>corpus_span</th>\n", " <th>corpus_ent_type</th>\n", " <th>correct_span</th>\n", " <th>correct_ent_type</th>\n", " <th>notes</th>\n", " <th>time_started</th>\n", " <th>time_stopped</th>\n", " <th>time_elapsed</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>1738</th>\n", " <td>17</td>\n", " <td>train</td>\n", " <td>3</td>\n", " <td>[0, 10): '-DOCSTART-'</td>\n", " <td>LOC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1485</th>\n", " <td>17</td>\n", " <td>train</td>\n", " <td>6</td>\n", " <td>[121, 137): 'Toronto Dominion'</td>\n", " <td>LOC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1964</th>\n", " <td>17</td>\n", " <td>train</td>\n", " <td>25</td>\n", " <td>[141, 151): 'mid-Norway'</td>\n", " <td>LOC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>2022</th>\n", " <td>17</td>\n", " <td>train</td>\n", " <td>29</td>\n", " <td>[762, 774): 'Mark O'Meara'</td>\n", " <td>PER</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>1996</th>\n", " <td>17</td>\n", " <td>train</td>\n", " <td>29</td>\n", " <td>[454, 468): 'Phil Mickelson'</td>\n", " <td>PER</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>4416</th>\n", " <td>1</td>\n", " <td>train</td>\n", " <td>943</td>\n", " <td>[25, 46): 'SAN MARINO GRAND PRIX'</td>\n", " <td>PER</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>4461</th>\n", " <td>1</td>\n", " <td>train</td>\n", " <td>944</td>\n", " <td>[25, 32): 'MASTERS'</td>\n", " <td>MISC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>4462</th>\n", " <td>1</td>\n", " <td>train</td>\n", " <td>944</td>\n", " <td>[25, 32): 'MASTERS'</td>\n", " <td>PER</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>4463</th>\n", " <td>1</td>\n", " <td>train</td>\n", " <td>944</td>\n", " <td>[17, 32): 'BRITISH MASTERS'</td>\n", " <td>LOC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " <tr>\n", " <th>4458</th>\n", " <td>1</td>\n", " <td>train</td>\n", " <td>944</td>\n", " <td>[11, 15): 'GOLF'</td>\n", " <td>LOC</td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " <td></td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>5347 rows × 14 columns</p>\n", "</div>" ], "text/plain": [ " count fold doc_offset model_span \\\n", "1738 17 train 3 [0, 10): '-DOCSTART-' \n", "1485 17 train 6 [121, 137): 'Toronto Dominion' \n", "1964 17 train 25 [141, 151): 'mid-Norway' \n", "2022 17 train 29 [762, 774): 'Mark O'Meara' \n", "1996 17 train 29 [454, 468): 'Phil Mickelson' \n", "... ... ... ... ... \n", "4416 1 train 943 [25, 46): 'SAN MARINO GRAND PRIX' \n", "4461 1 train 944 [25, 32): 'MASTERS' \n", "4462 1 train 944 [25, 32): 'MASTERS' \n", "4463 1 train 944 [17, 32): 'BRITISH MASTERS' \n", "4458 1 train 944 [11, 15): 'GOLF' \n", "\n", " model_ent_type error_type corpus_span corpus_ent_type correct_span \\\n", "1738 LOC \n", "1485 LOC \n", "1964 LOC \n", "2022 PER \n", "1996 PER \n", "... ... ... ... ... ... \n", "4416 PER \n", "4461 MISC \n", "4462 PER \n", "4463 LOC \n", "4458 LOC \n", "\n", " correct_ent_type notes time_started time_stopped time_elapsed \n", "1738 \n", "1485 \n", "1964 \n", "2022 \n", "1996 \n", "... ... ... ... ... ... \n", "4416 \n", "4461 \n", "4462 \n", "4463 \n", "4458 \n", "\n", "[5347 rows x 14 columns]" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "not_in_gold_to_write" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "in_gold_to_write.to_csv(\"outputs/CoNLL_4_train_in_gold.csv\", index=False)\n", "not_in_gold_to_write.to_csv(\"outputs/CoNLL_4_train_not_in_gold.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 4 }