{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CoNLL_4.ipynb\n",
    "\n",
    "This notebook contains the fourth part of the model training and analysis code from our CoNLL-2020 paper, [\"Identifying Incorrect Labels in the CoNLL-2003 Corpus\"](https://www.aclweb.org/anthology/2020.conll-1.16/).\n",
    "\n",
    "If you're new to the Text Extensions for Pandas library, we recommend that you start\n",
    "by reading through the notebook [`Analyze_Model_Outputs.ipynb`](https://github.com/CODAIT/text-extensions-for-pandas/blob/master/notebooks/Analyze_Model_Outputs.ipynb), which explains the \n",
    "portions of the library that we use in the notebooks in this directory.\n",
    "\n",
    "### Summary\n",
    "\n",
    "This notebook repeats the model training process from `CoNLL_3.ipynb`, but performs a 10-fold cross-validation. This process involves training a total of 170 models -- 10 groups of 17. Next, this notebook evaluates each group of models over the holdout set from the associated fold of the cross-validation. Then it aggregates together these outputs and uses the same techniques used in `CoNLL_2.ipynb` to flag potentially-incorrect labels. Finally, the notebook writes out CSV files containing ranked lists of potentially-incorrect labels.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Libraries and constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "# Libraries\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import os\n",
    "import sys\n",
    "import time\n",
    "import torch\n",
    "import transformers\n",
    "from typing import *\n",
    "import sklearn.model_selection\n",
    "import sklearn.pipeline\n",
    "import matplotlib.pyplot as plt\n",
    "import multiprocessing\n",
    "import gc\n",
    "\n",
    "# And of course we need the text_extensions_for_pandas library itself.\n",
    "try:\n",
    "    import text_extensions_for_pandas as tp\n",
    "except ModuleNotFoundError as e:\n",
    "    raise Exception(\"text_extensions_for_pandas package not found on the Jupyter \"\n",
    "                    \"kernel's path. Please either run:\\n\"\n",
    "                    \"   ln -s ../../text_extensions_for_pandas .\\n\"\n",
    "                    \"from the directory containing this notebook, or use a Python \"\n",
    "                    \"environment on which you have used `pip` to install the package.\")\n",
    "\n",
    "from text_extensions_for_pandas import cleaning\n",
    "    \n",
    "# BERT Configuration\n",
    "# Keep this in sync with `CoNLL_3.ipynb`.\n",
    "#bert_model_name = \"bert-base-uncased\"\n",
    "#bert_model_name = \"bert-large-uncased\"\n",
    "bert_model_name = \"dslim/bert-base-NER\"\n",
    "tokenizer = transformers.BertTokenizerFast.from_pretrained(bert_model_name, \n",
    "                                                           add_special_tokens=True)\n",
    "bert = transformers.BertModel.from_pretrained(bert_model_name)\n",
    "\n",
    "# If False, use cached values, provided those values are present on disk\n",
    "_REGENERATE_EMBEDDINGS = True\n",
    "_REGENERATE_MODELS = True\n",
    "\n",
    "# Number of dimensions that we reduce the BERT embeddings down to when\n",
    "# training reduced-quality models.\n",
    "#_REDUCED_DIMS = [8, 16, 32, 64, 128, 256]\n",
    "_REDUCED_DIMS = [32, 64, 128, 256]\n",
    "\n",
    "# How many models we train at each level of dimensionality reduction\n",
    "_MODELS_AT_DIM = [4] * len(_REDUCED_DIMS)\n",
    "\n",
    "# Consistent set of random seeds to use when generating dimension-reduced\n",
    "# models. Index is [index into _REDUCED_DIMS, model number], and there are\n",
    "# lots of extra entries so we don't need to resize this matrix.\n",
    "from numpy.random import default_rng\n",
    "_MASTER_SEED = 42\n",
    "rng = default_rng(_MASTER_SEED)\n",
    "_MODEL_RANDOM_SEEDS = rng.integers(0, 1e6, size=(8, 8))\n",
    "\n",
    "# Create a Pandas categorical type for consistent encoding of categories\n",
    "# across all documents.\n",
    "_ENTITY_TYPES = [\"LOC\", \"MISC\", \"ORG\", \"PER\"]\n",
    "token_class_dtype, int_to_label, label_to_int = tp.io.conll.make_iob_tag_categories(_ENTITY_TYPES)\n",
    "\n",
    "# Parameters for splitting the corpus into folds\n",
    "_KFOLD_RANDOM_SEED = _MASTER_SEED\n",
    "_KFOLD_NUM_FOLDS = 10\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read inputs\n",
    "\n",
    "Read in the corpus, retokenize it with the BERT tokenizer, add BERT embeddings, and convert\n",
    "to a single dataframe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'train': 'outputs/eng.train',\n",
       " 'dev': 'outputs/eng.testa',\n",
       " 'test': 'outputs/eng.testb'}"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Download and cache the data set.\n",
    "# NOTE: This data set is licensed for research use only. Be sure to adhere\n",
    "#  to the terms of the license when using this data set!\n",
    "data_set_info = tp.io.conll.maybe_download_conll_data(\"outputs\")\n",
    "data_set_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The raw dataset in its original tokenization\n",
    "corpus_raw = {}\n",
    "for fold_name, file_name in data_set_info.items():\n",
    "    df_list = tp.io.conll.conll_2003_to_dataframes(file_name, \n",
    "                                          [\"pos\", \"phrase\", \"ent\"],\n",
    "                                          [False, True, True])\n",
    "    corpus_raw[fold_name] = [\n",
    "        df.drop(columns=[\"pos\", \"phrase_iob\", \"phrase_type\"])\n",
    "        for df in df_list\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertModel: ['classifier.weight', 'classifier.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "preprocessing fold train\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3e030cae2b3b4c898bfc57ba945740ab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=946, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Token indices sequence length is longer than the specified maximum sequence length for this model (559 > 512). Running this sequence through the model will result in indexing errors\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "preprocessing fold dev\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "55361e25bc62408082d864b813e9ab61",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=216, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "preprocessing fold test\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f9753feeace94e65a24961292d4420db",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=231, style=ProgressStyle(desc…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Retokenize with the BERT tokenizer and regenerate embeddings.\n",
    "corpus_df,token_class_dtype, int_to_label, label_to_int = cleaning.preprocess.preprocess_documents(corpus_raw,'ent_type',True,carry_cols=['line_num'],iob_col='ent_iob')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare folds for a 10-fold cross-validation\n",
    "\n",
    "We divide the documents of the corpus into 10 random samples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_num</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>train</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>train</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>train</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>train</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1388</th>\n",
       "      <td>test</td>\n",
       "      <td>226</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1389</th>\n",
       "      <td>test</td>\n",
       "      <td>227</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1390</th>\n",
       "      <td>test</td>\n",
       "      <td>228</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1391</th>\n",
       "      <td>test</td>\n",
       "      <td>229</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1392</th>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1393 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       fold  doc_num\n",
       "0     train        0\n",
       "1     train        1\n",
       "2     train        2\n",
       "3     train        3\n",
       "4     train        4\n",
       "...     ...      ...\n",
       "1388   test      226\n",
       "1389   test      227\n",
       "1390   test      228\n",
       "1391   test      229\n",
       "1392   test      230\n",
       "\n",
       "[1393 rows x 2 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# IDs for each of the keys\n",
    "doc_keys = corpus_df[[\"fold\", \"doc_num\"]].drop_duplicates().reset_index(drop=True)\n",
    "doc_keys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_num</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>146</th>\n",
       "      <td>train</td>\n",
       "      <td>146</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1164</th>\n",
       "      <td>test</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>483</th>\n",
       "      <td>train</td>\n",
       "      <td>483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1190</th>\n",
       "      <td>test</td>\n",
       "      <td>28</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>train</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>237</th>\n",
       "      <td>train</td>\n",
       "      <td>237</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>86</th>\n",
       "      <td>train</td>\n",
       "      <td>86</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>408</th>\n",
       "      <td>train</td>\n",
       "      <td>408</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1252</th>\n",
       "      <td>test</td>\n",
       "      <td>90</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1213</th>\n",
       "      <td>test</td>\n",
       "      <td>51</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       fold  doc_num\n",
       "146   train      146\n",
       "1164   test        2\n",
       "483   train      483\n",
       "1190   test       28\n",
       "20    train       20\n",
       "237   train      237\n",
       "86    train       86\n",
       "408   train      408\n",
       "1252   test       90\n",
       "1213   test       51"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# We want to split the documents randomly into _NUM_FOLDS sets, then\n",
    "# for each stage of cross-validation train a model on the union of\n",
    "# (_NUM_FOLDS - 1) of them while testing on the remaining fold.\n",
    "# sklearn.model_selection doesn't implement this approach directly,\n",
    "# but we can piece it together with some help from Numpy.\n",
    "#from numpy.random import default_rng\n",
    "rng = np.random.default_rng(seed=_KFOLD_RANDOM_SEED)\n",
    "iloc_order = rng.permutation(len(doc_keys.index))\n",
    "kf = sklearn.model_selection.KFold(n_splits=_KFOLD_NUM_FOLDS)\n",
    "\n",
    "train_keys = []\n",
    "test_keys = []\n",
    "for train_ix, test_ix in kf.split(iloc_order):\n",
    "    # sklearn.model_selection.KFold gives us a partitioning of the\n",
    "    # numbers from 0 to len(iloc_order). Use that partitioning to \n",
    "    # choose elements from iloc_order, then use those elements to \n",
    "    # index into doc_keys.\n",
    "    train_iloc = iloc_order[train_ix]\n",
    "    test_iloc = iloc_order[test_ix]\n",
    "    train_keys.append(doc_keys.iloc[train_iloc])\n",
    "    test_keys.append(doc_keys.iloc[test_iloc])\n",
    "\n",
    "train_keys[1].head(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Dry run: Train and evaluate models on the first fold\n",
    "\n",
    "Train models on the first of our 10 folds and manually examine some of the \n",
    "model outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_num</th>\n",
       "      <th>token_id</th>\n",
       "      <th>span</th>\n",
       "      <th>input_id</th>\n",
       "      <th>token_type_id</th>\n",
       "      <th>attention_mask</th>\n",
       "      <th>special_tokens_mask</th>\n",
       "      <th>raw_span</th>\n",
       "      <th>line_num</th>\n",
       "      <th>raw_span_id</th>\n",
       "      <th>ent_iob</th>\n",
       "      <th>ent_type</th>\n",
       "      <th>embedding</th>\n",
       "      <th>token_class</th>\n",
       "      <th>token_class_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>[0, 0): ''</td>\n",
       "      <td>101</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[  -0.098505184,     -0.4050192,     0.7428884...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>[0, 1): '-'</td>\n",
       "      <td>118</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 10): '-DOCSTART-'</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[  -0.057021223,    -0.48112097,      0.989868...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>[1, 2): 'D'</td>\n",
       "      <td>141</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 10): '-DOCSTART-'</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[   -0.04824195,    -0.25330004,      1.167191...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>[2, 4): 'OC'</td>\n",
       "      <td>9244</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 10): '-DOCSTART-'</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[   -0.26682988,    -0.31008753,      1.007472...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>train</td>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>[4, 6): 'ST'</td>\n",
       "      <td>9272</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 10): '-DOCSTART-'</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[   -0.22296889,    -0.21308492,     0.9331016...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>371472</th>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "      <td>314</td>\n",
       "      <td>[1386, 1393): 'brother'</td>\n",
       "      <td>1711</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[1386, 1393): 'brother'</td>\n",
       "      <td>50345.0</td>\n",
       "      <td>267.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[  -0.028172785,    -0.08062388,     0.9804888...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>371473</th>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "      <td>315</td>\n",
       "      <td>[1393, 1394): ','</td>\n",
       "      <td>117</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[1393, 1394): ','</td>\n",
       "      <td>50346.0</td>\n",
       "      <td>268.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[    0.11817408,    -0.07008513,      0.865484...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>371474</th>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "      <td>316</td>\n",
       "      <td>[1395, 1400): 'Bobby'</td>\n",
       "      <td>5545</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[1395, 1400): 'Bobby'</td>\n",
       "      <td>50347.0</td>\n",
       "      <td>269.0</td>\n",
       "      <td>B</td>\n",
       "      <td>PER</td>\n",
       "      <td>[   -0.35689482,     0.31400457,      1.573853...</td>\n",
       "      <td>B-PER</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>371475</th>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "      <td>317</td>\n",
       "      <td>[1400, 1401): '.'</td>\n",
       "      <td>119</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[1400, 1401): '.'</td>\n",
       "      <td>50348.0</td>\n",
       "      <td>270.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[   -0.18957126,    -0.24581163,       0.66257...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>371476</th>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "      <td>318</td>\n",
       "      <td>[0, 0): ''</td>\n",
       "      <td>102</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[   -0.44689128,    -0.31665266,      0.779688...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>371477 rows × 16 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         fold  doc_num  token_id                     span  input_id  \\\n",
       "0       train        0         0               [0, 0): ''       101   \n",
       "1       train        0         1              [0, 1): '-'       118   \n",
       "2       train        0         2              [1, 2): 'D'       141   \n",
       "3       train        0         3             [2, 4): 'OC'      9244   \n",
       "4       train        0         4             [4, 6): 'ST'      9272   \n",
       "...       ...      ...       ...                      ...       ...   \n",
       "371472   test      230       314  [1386, 1393): 'brother'      1711   \n",
       "371473   test      230       315        [1393, 1394): ','       117   \n",
       "371474   test      230       316    [1395, 1400): 'Bobby'      5545   \n",
       "371475   test      230       317        [1400, 1401): '.'       119   \n",
       "371476   test      230       318               [0, 0): ''       102   \n",
       "\n",
       "        token_type_id  attention_mask  special_tokens_mask  \\\n",
       "0                   0               1                 True   \n",
       "1                   0               1                False   \n",
       "2                   0               1                False   \n",
       "3                   0               1                False   \n",
       "4                   0               1                False   \n",
       "...               ...             ...                  ...   \n",
       "371472              0               1                False   \n",
       "371473              0               1                False   \n",
       "371474              0               1                False   \n",
       "371475              0               1                False   \n",
       "371476              0               1                 True   \n",
       "\n",
       "                       raw_span  line_num  raw_span_id ent_iob ent_type  \\\n",
       "0                           NaN       NaN          NaN       O     <NA>   \n",
       "1         [0, 10): '-DOCSTART-'       0.0          0.0       O     <NA>   \n",
       "2         [0, 10): '-DOCSTART-'       0.0          0.0       O     <NA>   \n",
       "3         [0, 10): '-DOCSTART-'       0.0          0.0       O     <NA>   \n",
       "4         [0, 10): '-DOCSTART-'       0.0          0.0       O     <NA>   \n",
       "...                         ...       ...          ...     ...      ...   \n",
       "371472  [1386, 1393): 'brother'   50345.0        267.0       O     <NA>   \n",
       "371473        [1393, 1394): ','   50346.0        268.0       O     <NA>   \n",
       "371474    [1395, 1400): 'Bobby'   50347.0        269.0       B      PER   \n",
       "371475        [1400, 1401): '.'   50348.0        270.0       O     <NA>   \n",
       "371476                      NaN       NaN          NaN       O     <NA>   \n",
       "\n",
       "                                                embedding token_class  \\\n",
       "0       [  -0.098505184,     -0.4050192,     0.7428884...           O   \n",
       "1       [  -0.057021223,    -0.48112097,      0.989868...           O   \n",
       "2       [   -0.04824195,    -0.25330004,      1.167191...           O   \n",
       "3       [   -0.26682988,    -0.31008753,      1.007472...           O   \n",
       "4       [   -0.22296889,    -0.21308492,     0.9331016...           O   \n",
       "...                                                   ...         ...   \n",
       "371472  [  -0.028172785,    -0.08062388,     0.9804888...           O   \n",
       "371473  [    0.11817408,    -0.07008513,      0.865484...           O   \n",
       "371474  [   -0.35689482,     0.31400457,      1.573853...       B-PER   \n",
       "371475  [   -0.18957126,    -0.24581163,       0.66257...           O   \n",
       "371476  [   -0.44689128,    -0.31665266,      0.779688...           O   \n",
       "\n",
       "        token_class_id  \n",
       "0                    0  \n",
       "1                    0  \n",
       "2                    0  \n",
       "3                    0  \n",
       "4                    0  \n",
       "...                ...  \n",
       "371472               0  \n",
       "371473               0  \n",
       "371474               3  \n",
       "371475               0  \n",
       "371476               0  \n",
       "\n",
       "[371477 rows x 16 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Gather the training set together by joining our list of documents\n",
    "# with the entire corpus on the composite key <fold, doc_num>\n",
    "train_inputs_df = corpus_df.merge(train_keys[0])\n",
    "train_inputs_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_num</th>\n",
       "      <th>token_id</th>\n",
       "      <th>span</th>\n",
       "      <th>input_id</th>\n",
       "      <th>token_type_id</th>\n",
       "      <th>attention_mask</th>\n",
       "      <th>special_tokens_mask</th>\n",
       "      <th>raw_span</th>\n",
       "      <th>line_num</th>\n",
       "      <th>raw_span_id</th>\n",
       "      <th>ent_iob</th>\n",
       "      <th>ent_type</th>\n",
       "      <th>embedding</th>\n",
       "      <th>token_class</th>\n",
       "      <th>token_class_id</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>train</td>\n",
       "      <td>12</td>\n",
       "      <td>0</td>\n",
       "      <td>[0, 0): ''</td>\n",
       "      <td>101</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[  -0.101977676,    -0.42442498,     0.8440171...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>train</td>\n",
       "      <td>12</td>\n",
       "      <td>1</td>\n",
       "      <td>[0, 1): '-'</td>\n",
       "      <td>118</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 10): '-DOCSTART-'</td>\n",
       "      <td>2664.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[   -0.09124618,    -0.47710702,      1.120292...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>train</td>\n",
       "      <td>12</td>\n",
       "      <td>2</td>\n",
       "      <td>[1, 2): 'D'</td>\n",
       "      <td>141</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 10): '-DOCSTART-'</td>\n",
       "      <td>2664.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[    -0.1695277,    -0.27063507,      1.209566...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>train</td>\n",
       "      <td>12</td>\n",
       "      <td>3</td>\n",
       "      <td>[2, 4): 'OC'</td>\n",
       "      <td>9244</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 10): '-DOCSTART-'</td>\n",
       "      <td>2664.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[   -0.27648172,     -0.3675844,      1.092024...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>train</td>\n",
       "      <td>12</td>\n",
       "      <td>4</td>\n",
       "      <td>[4, 6): 'ST'</td>\n",
       "      <td>9272</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[0, 10): '-DOCSTART-'</td>\n",
       "      <td>2664.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[   -0.24050614,    -0.24247544,       1.07511...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45059</th>\n",
       "      <td>test</td>\n",
       "      <td>225</td>\n",
       "      <td>75</td>\n",
       "      <td>[208, 213): 'fight'</td>\n",
       "      <td>2147</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[208, 213): 'fight'</td>\n",
       "      <td>49418.0</td>\n",
       "      <td>29.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[   -0.09621397,    -0.48016888,      0.510937...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45060</th>\n",
       "      <td>test</td>\n",
       "      <td>225</td>\n",
       "      <td>76</td>\n",
       "      <td>[214, 216): 'on'</td>\n",
       "      <td>1113</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[214, 216): 'on'</td>\n",
       "      <td>49419.0</td>\n",
       "      <td>30.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[    -0.0858628,     -0.2341724,      0.832928...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45061</th>\n",
       "      <td>test</td>\n",
       "      <td>225</td>\n",
       "      <td>77</td>\n",
       "      <td>[217, 225): 'Saturday'</td>\n",
       "      <td>4306</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[217, 225): 'Saturday'</td>\n",
       "      <td>49420.0</td>\n",
       "      <td>31.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[  -0.012238501,     -0.4282664,      0.619483...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45062</th>\n",
       "      <td>test</td>\n",
       "      <td>225</td>\n",
       "      <td>78</td>\n",
       "      <td>[225, 226): '.'</td>\n",
       "      <td>119</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>False</td>\n",
       "      <td>[225, 226): '.'</td>\n",
       "      <td>49421.0</td>\n",
       "      <td>32.0</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[  -0.042955935,    -0.36315423,      0.660203...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>45063</th>\n",
       "      <td>test</td>\n",
       "      <td>225</td>\n",
       "      <td>79</td>\n",
       "      <td>[0, 0): ''</td>\n",
       "      <td>102</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>True</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>O</td>\n",
       "      <td>&lt;NA&gt;</td>\n",
       "      <td>[    -0.9504192,    0.012983555,     0.7374987...</td>\n",
       "      <td>O</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>45064 rows × 16 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        fold  doc_num  token_id                    span  input_id  \\\n",
       "0      train       12         0              [0, 0): ''       101   \n",
       "1      train       12         1             [0, 1): '-'       118   \n",
       "2      train       12         2             [1, 2): 'D'       141   \n",
       "3      train       12         3            [2, 4): 'OC'      9244   \n",
       "4      train       12         4            [4, 6): 'ST'      9272   \n",
       "...      ...      ...       ...                     ...       ...   \n",
       "45059   test      225        75     [208, 213): 'fight'      2147   \n",
       "45060   test      225        76        [214, 216): 'on'      1113   \n",
       "45061   test      225        77  [217, 225): 'Saturday'      4306   \n",
       "45062   test      225        78         [225, 226): '.'       119   \n",
       "45063   test      225        79              [0, 0): ''       102   \n",
       "\n",
       "       token_type_id  attention_mask  special_tokens_mask  \\\n",
       "0                  0               1                 True   \n",
       "1                  0               1                False   \n",
       "2                  0               1                False   \n",
       "3                  0               1                False   \n",
       "4                  0               1                False   \n",
       "...              ...             ...                  ...   \n",
       "45059              0               1                False   \n",
       "45060              0               1                False   \n",
       "45061              0               1                False   \n",
       "45062              0               1                False   \n",
       "45063              0               1                 True   \n",
       "\n",
       "                     raw_span  line_num  raw_span_id ent_iob ent_type  \\\n",
       "0                         NaN       NaN          NaN       O     <NA>   \n",
       "1       [0, 10): '-DOCSTART-'    2664.0          0.0       O     <NA>   \n",
       "2       [0, 10): '-DOCSTART-'    2664.0          0.0       O     <NA>   \n",
       "3       [0, 10): '-DOCSTART-'    2664.0          0.0       O     <NA>   \n",
       "4       [0, 10): '-DOCSTART-'    2664.0          0.0       O     <NA>   \n",
       "...                       ...       ...          ...     ...      ...   \n",
       "45059     [208, 213): 'fight'   49418.0         29.0       O     <NA>   \n",
       "45060        [214, 216): 'on'   49419.0         30.0       O     <NA>   \n",
       "45061  [217, 225): 'Saturday'   49420.0         31.0       O     <NA>   \n",
       "45062         [225, 226): '.'   49421.0         32.0       O     <NA>   \n",
       "45063                     NaN       NaN          NaN       O     <NA>   \n",
       "\n",
       "                                               embedding token_class  \\\n",
       "0      [  -0.101977676,    -0.42442498,     0.8440171...           O   \n",
       "1      [   -0.09124618,    -0.47710702,      1.120292...           O   \n",
       "2      [    -0.1695277,    -0.27063507,      1.209566...           O   \n",
       "3      [   -0.27648172,     -0.3675844,      1.092024...           O   \n",
       "4      [   -0.24050614,    -0.24247544,       1.07511...           O   \n",
       "...                                                  ...         ...   \n",
       "45059  [   -0.09621397,    -0.48016888,      0.510937...           O   \n",
       "45060  [    -0.0858628,     -0.2341724,      0.832928...           O   \n",
       "45061  [  -0.012238501,     -0.4282664,      0.619483...           O   \n",
       "45062  [  -0.042955935,    -0.36315423,      0.660203...           O   \n",
       "45063  [    -0.9504192,    0.012983555,     0.7374987...           O   \n",
       "\n",
       "       token_class_id  \n",
       "0                   0  \n",
       "1                   0  \n",
       "2                   0  \n",
       "3                   0  \n",
       "4                   0  \n",
       "...               ...  \n",
       "45059               0  \n",
       "45060               0  \n",
       "45061               0  \n",
       "45062               0  \n",
       "45063               0  \n",
       "\n",
       "[45064 rows x 16 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Repeat the same process for the test set\n",
    "test_inputs_df = corpus_df.merge(test_keys[0])\n",
    "test_inputs_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train an ensemble of models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2021-07-12 18:16:33,117\tINFO services.py:1267 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265\u001b[39m\u001b[22m\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training model using all of 768-dimension embeddings.\n",
      "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n",
      "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n",
      "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n",
      "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n",
      "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n",
      "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n",
      "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n",
      "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n",
      "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n",
      "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n",
      "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n",
      "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n",
      "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n",
      "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n",
      "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n",
      "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n",
      "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=256 and seed=781567.\n",
      "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=256 and seed=402414.\n",
      "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=128 and seed=839748.\n",
      "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=256 and seed=643865.\n",
      "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=32 and seed=438878.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n",
      "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=32 and seed=654571.\n",
      "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=32 and seed=773956.\n",
      "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=128 and seed=128113.\n",
      "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=64 and seed=526478.\n",
      "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=64 and seed=94177.\n",
      "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=64 and seed=975622.\n",
      "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=64 and seed=201469.\n",
      "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=128 and seed=513226.\n",
      "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=128 and seed=450385.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n",
      "Trained 17 models.\n",
      "Model names after loading or training: 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64_2, 64_3, 64_4, 128_1, 128_2, 128_3, 128_4, 256_1, 256_2, 256_3, 256_4\n"
     ]
    }
   ],
   "source": [
    "import importlib\n",
    "import sklearn.linear_model\n",
    "import ray\n",
    "ray.init()\n",
    "\n",
    "# Wrap train_reduced_model in a Ray task\n",
    "@ray.remote\n",
    "def train_reduced_model_task(\n",
    "        x_values: np.ndarray, y_values: np.ndarray, n_components: int,\n",
    "        seed: int, max_iter: int = 10000) -> sklearn.base.BaseEstimator:\n",
    "    return cleaning.ensemble.train_reduced_model(x_values, y_values, n_components, seed, max_iter)\n",
    "\n",
    "# Ray task that trains a model using the entire embedding\n",
    "@ray.remote\n",
    "def train_full_model_task(x_values: np.ndarray, y_values: np.ndarray, \n",
    "                          max_iter: int = 10000) -> sklearn.base.BaseEstimator:\n",
    "    return (\n",
    "        sklearn.linear_model.LogisticRegression(\n",
    "            multi_class=\"multinomial\", max_iter=max_iter\n",
    "        )\n",
    "        .fit(x_values, y_values)\n",
    "    )\n",
    "\n",
    "def train_models(train_df: pd.DataFrame) \\\n",
    "        -> Dict[str, sklearn.base.BaseEstimator]:\n",
    "    \"\"\"\n",
    "    Train an ensemble of models with different levels of noise.\n",
    "    \n",
    "    :param train_df: DataFrame of labeled training documents, with one\n",
    "     row per token. Must contain the columns \"embedding\" (precomputed \n",
    "     BERT embeddings) and \"token_class_id\" (integer ID of token type)\n",
    "     \n",
    "    :returns: A mapping from mnemonic model name to trained model\n",
    "    \"\"\"\n",
    "    X = train_df[\"embedding\"].values\n",
    "    Y = train_df[\"token_class_id\"]\n",
    "    \n",
    "\n",
    "    # Push the X and Y values to Plasma so that our tasks can share them.\n",
    "    X_id = ray.put(X.to_numpy().copy())\n",
    "    Y_id = ray.put(Y.to_numpy().copy())\n",
    "    \n",
    "    names_list = []\n",
    "    futures_list = []\n",
    "    \n",
    "    print(f\"Training model using all of \"\n",
    "          f\"{X._tensor.shape[1]}-dimension embeddings.\")\n",
    "    names_list.append(f\"{X._tensor.shape[1]}_1\")\n",
    "    futures_list.append(train_full_model_task.remote(X_id, Y_id))   \n",
    "    \n",
    "    for i in range(len(_REDUCED_DIMS)):\n",
    "        num_dims = _REDUCED_DIMS[i]\n",
    "        num_models = _MODELS_AT_DIM[i]\n",
    "        for j in range(num_models):\n",
    "            model_name = f\"{num_dims}_{j + 1}\"\n",
    "            seed = _MODEL_RANDOM_SEEDS[i, j]\n",
    "            print(f\"Training model '{model_name}' (#{j + 1} \"\n",
    "                  f\"at {num_dims} dimensions) with seed {seed}\")\n",
    "            names_list.append(model_name)\n",
    "            futures_list.append(train_reduced_model_task.remote(X_id, Y_id, \n",
    "                                                                num_dims, seed))\n",
    "    \n",
    "    # Block until all training tasks have completed and fetch the resulting models.\n",
    "    models_list = ray.get(futures_list)\n",
    "    models = {\n",
    "        n: m for n, m in zip(names_list, models_list)\n",
    "    }\n",
    "    return models\n",
    "\n",
    "def maybe_train_models(train_df: pd.DataFrame, fold_num: int):\n",
    "    import pickle\n",
    "    _CACHED_MODELS_FILE = f\"outputs/fold_{fold_num}_models.pickle\"\n",
    "    if _REGENERATE_MODELS or not os.path.exists(_CACHED_MODELS_FILE):\n",
    "        m = train_models(train_df)\n",
    "        print(f\"Trained {len(m)} models.\")\n",
    "        with open(_CACHED_MODELS_FILE, \"wb\") as f:\n",
    "            pickle.dump(m, f)\n",
    "    else:\n",
    "        # Use a cached model when using cached embeddings\n",
    "        with open(_CACHED_MODELS_FILE, \"rb\") as f:\n",
    "            m = pickle.load(f)\n",
    "            print(f\"Loaded {len(m)} models from {_CACHED_MODELS_FILE}.\")\n",
    "    return m\n",
    "\n",
    "models = maybe_train_models(train_inputs_df, 0)\n",
    "print(f\"Model names after loading or training: {', '.join(models.keys())}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Uncomment this code if you need to have the cells that follow ignore\n",
    "# some of the models saved to disk.\n",
    "# _MODEL_SIZES_TO_KEEP = [32, 64, 128, 256]\n",
    "# _RUNS_TO_KEEP = [4] * len(_MODEL_SIZES_TO_KEEP)\n",
    "# _OTHER_MODELS_TO_KEEP = [\"768_1\"]\n",
    "\n",
    "# to_keep = _OTHER_MODELS_TO_KEEP.copy()\n",
    "# for size in _MODEL_SIZES_TO_KEEP:\n",
    "#     for num_runs in _RUNS_TO_KEEP:\n",
    "#         for i in range(num_runs):\n",
    "#             to_keep.append(f\"{size}_{i+1}\")\n",
    "\n",
    "# models = {k: v for k, v in models.items() if k in to_keep}\n",
    "\n",
    "# print(f\"Model names after filtering: {', '.join(models.keys())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate the models on this fold's test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6eb2c9cdd5f242ffb017fc96bab994b5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>span</th>\n",
       "      <th>ent_type</th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_num</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[11, 16): 'Saudi'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>train</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[59, 65): 'MANAMA'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>train</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[86, 91): 'Saudi'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>train</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[259, 264): 'Saudi'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>train</td>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[55, 65): 'MONTGOMERY'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>train</td>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                     span ent_type   fold  doc_num\n",
       "0       [11, 16): 'Saudi'     MISC  train       12\n",
       "1      [59, 65): 'MANAMA'      LOC  train       12\n",
       "2       [86, 91): 'Saudi'     MISC  train       12\n",
       "3     [259, 264): 'Saudi'     MISC  train       12\n",
       "0  [55, 65): 'MONTGOMERY'      LOC  train       20"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def eval_models(models: Dict[str, sklearn.base.BaseEstimator],\n",
    "                test_df: pd.DataFrame):\n",
    "    \"\"\"\n",
    "    Bulk-evaluate an ensemble of models generated by :func:`train_models`.\n",
    "    \n",
    "    :param models: Output of :func:`train_models`\n",
    "    :param test_df:  DataFrame of labeled test documents, with one\n",
    "     row per token. Must contain the columns \"embedding\" (precomputed \n",
    "     BERT embeddings) and \"token_class_id\" (integer ID of token type)\n",
    "    \n",
    "    :returns: A dictionary from model name to results of \n",
    "     :func:`util.analyze_model`\n",
    "    \"\"\"\n",
    "    todo = [(name, model) for name, model in models.items()]\n",
    "    results = tp.jupyter.run_with_progress_bar(\n",
    "        len(todo),\n",
    "        lambda i: cleaning.infer_and_extract_entities_iob(test_df,corpus_raw, int_to_label, todo[i][1]),\n",
    "        \"model\"\n",
    "    )\n",
    "    return {t[0]: result for t, result in zip(todo, results)}\n",
    "\n",
    "evals = eval_models(models, test_inputs_df)\n",
    "# display one of the results\n",
    "evals[list(evals.keys())[0]].head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>precision</th>\n",
       "      <th>recall</th>\n",
       "      <th>f1-score</th>\n",
       "      <th>dims</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>768_1</th>\n",
       "      <td>0.947149</td>\n",
       "      <td>0.938839</td>\n",
       "      <td>0.942976</td>\n",
       "      <td>768</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32_1</th>\n",
       "      <td>0.924075</td>\n",
       "      <td>0.863742</td>\n",
       "      <td>0.892890</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32_2</th>\n",
       "      <td>0.924755</td>\n",
       "      <td>0.875355</td>\n",
       "      <td>0.899377</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32_3</th>\n",
       "      <td>0.925028</td>\n",
       "      <td>0.866065</td>\n",
       "      <td>0.894576</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32_4</th>\n",
       "      <td>0.932949</td>\n",
       "      <td>0.876129</td>\n",
       "      <td>0.903647</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64_1</th>\n",
       "      <td>0.940086</td>\n",
       "      <td>0.902968</td>\n",
       "      <td>0.921153</td>\n",
       "      <td>64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64_2</th>\n",
       "      <td>0.938321</td>\n",
       "      <td>0.902968</td>\n",
       "      <td>0.920305</td>\n",
       "      <td>64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64_3</th>\n",
       "      <td>0.936808</td>\n",
       "      <td>0.895226</td>\n",
       "      <td>0.915545</td>\n",
       "      <td>64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64_4</th>\n",
       "      <td>0.940828</td>\n",
       "      <td>0.902710</td>\n",
       "      <td>0.921375</td>\n",
       "      <td>64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>128_1</th>\n",
       "      <td>0.944401</td>\n",
       "      <td>0.924903</td>\n",
       "      <td>0.934550</td>\n",
       "      <td>128</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>128_2</th>\n",
       "      <td>0.947577</td>\n",
       "      <td>0.923613</td>\n",
       "      <td>0.935442</td>\n",
       "      <td>128</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>128_3</th>\n",
       "      <td>0.943212</td>\n",
       "      <td>0.921548</td>\n",
       "      <td>0.932254</td>\n",
       "      <td>128</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>128_4</th>\n",
       "      <td>0.940991</td>\n",
       "      <td>0.921806</td>\n",
       "      <td>0.931300</td>\n",
       "      <td>128</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>256_1</th>\n",
       "      <td>0.949201</td>\n",
       "      <td>0.935484</td>\n",
       "      <td>0.942293</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>256_2</th>\n",
       "      <td>0.943396</td>\n",
       "      <td>0.929032</td>\n",
       "      <td>0.936159</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>256_3</th>\n",
       "      <td>0.945478</td>\n",
       "      <td>0.930839</td>\n",
       "      <td>0.938101</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>256_4</th>\n",
       "      <td>0.945055</td>\n",
       "      <td>0.932129</td>\n",
       "      <td>0.938547</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       precision    recall  f1-score  dims\n",
       "768_1   0.947149  0.938839  0.942976   768\n",
       "32_1    0.924075  0.863742  0.892890    32\n",
       "32_2    0.924755  0.875355  0.899377    32\n",
       "32_3    0.925028  0.866065  0.894576    32\n",
       "32_4    0.932949  0.876129  0.903647    32\n",
       "64_1    0.940086  0.902968  0.921153    64\n",
       "64_2    0.938321  0.902968  0.920305    64\n",
       "64_3    0.936808  0.895226  0.915545    64\n",
       "64_4    0.940828  0.902710  0.921375    64\n",
       "128_1   0.944401  0.924903  0.934550   128\n",
       "128_2   0.947577  0.923613  0.935442   128\n",
       "128_3   0.943212  0.921548  0.932254   128\n",
       "128_4   0.940991  0.921806  0.931300   128\n",
       "256_1   0.949201  0.935484  0.942293   256\n",
       "256_2   0.943396  0.929032  0.936159   256\n",
       "256_3   0.945478  0.930839  0.938101   256\n",
       "256_4   0.945055  0.932129  0.938547   256"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Summarize how each of the models does on the test set.\n",
    "gold_elts = cleaning.preprocess.combine_raw_spans_docs_to_match(corpus_raw,evals[list(evals.keys())[0]],label_col = 'ent_type')\n",
    "def make_summary_df(evals_df: pd.DataFrame) -> pd.DataFrame:\n",
    "    gold_elts = cleaning.preprocess.combine_raw_spans_docs_to_match(corpus_raw, evals['256_4'], label_col = 'ent_type')\n",
    "    summary_df= cleaning.analysis.create_f1_report_ensemble_iob(evals,gold_elts)\n",
    "    summary_df['dims'] = [int(name.split('_')[0]) for name in evals.keys()]\n",
    "    return summary_df\n",
    "\n",
    "summary_df = make_summary_df(evals)\n",
    "summary_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 288x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Plot the tradeoff between dimensionality and F1 score\n",
    "x = summary_df[\"dims\"]\n",
    "y = summary_df[\"f1-score\"]\n",
    "\n",
    "plt.figure(figsize=(4,4))\n",
    "plt.scatter(x, y)\n",
    "#plt.yscale(\"log\")\n",
    "#plt.xscale(\"log\")\n",
    "plt.xlabel(\"Number of Dimensions\")\n",
    "plt.ylabel(\"F1 Score\")\n",
    "\n",
    "# Also dump the raw data to a local file.\n",
    "pd.DataFrame({\"num_dims\": x, \"f1_score\": y}).to_csv(\"outputs/dims_vs_f1_score_xval.csv\",\n",
    "                                                    index=False)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aggregate the model results and compare with the gold standard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_num</th>\n",
       "      <th>span</th>\n",
       "      <th>ent_type</th>\n",
       "      <th>in_gold</th>\n",
       "      <th>count</th>\n",
       "      <th>models</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>4927</th>\n",
       "      <td>train</td>\n",
       "      <td>907</td>\n",
       "      <td>[590, 598): 'Gorleben'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "      <td>[GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4925</th>\n",
       "      <td>train</td>\n",
       "      <td>907</td>\n",
       "      <td>[63, 67): 'BONN'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "      <td>[GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4924</th>\n",
       "      <td>train</td>\n",
       "      <td>907</td>\n",
       "      <td>[11, 17): 'German'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "      <td>[GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4923</th>\n",
       "      <td>train</td>\n",
       "      <td>896</td>\n",
       "      <td>[523, 528): 'China'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "      <td>[GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4922</th>\n",
       "      <td>train</td>\n",
       "      <td>896</td>\n",
       "      <td>[512, 518): 'Mexico'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "      <td>[GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>374</th>\n",
       "      <td>dev</td>\n",
       "      <td>149</td>\n",
       "      <td>[81, 93): 'Major League'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "      <td>[GOLD]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>246</th>\n",
       "      <td>dev</td>\n",
       "      <td>120</td>\n",
       "      <td>[63, 70): 'English'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "      <td>[GOLD]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>78</th>\n",
       "      <td>dev</td>\n",
       "      <td>64</td>\n",
       "      <td>[2571, 2575): 'AIDS'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "      <td>[GOLD]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>dev</td>\n",
       "      <td>21</td>\n",
       "      <td>[86, 90): 'UEFA'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "      <td>[GOLD]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>dev</td>\n",
       "      <td>21</td>\n",
       "      <td>[25, 39): 'STANDARD LIEGE'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "      <td>[GOLD]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>4928 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       fold  doc_num                        span ent_type  in_gold  count  \\\n",
       "4927  train      907      [590, 598): 'Gorleben'      LOC     True     17   \n",
       "4925  train      907            [63, 67): 'BONN'      LOC     True     17   \n",
       "4924  train      907          [11, 17): 'German'     MISC     True     17   \n",
       "4923  train      896         [523, 528): 'China'      LOC     True     17   \n",
       "4922  train      896        [512, 518): 'Mexico'      LOC     True     17   \n",
       "...     ...      ...                         ...      ...      ...    ...   \n",
       "374     dev      149    [81, 93): 'Major League'     MISC     True      0   \n",
       "246     dev      120         [63, 70): 'English'     MISC     True      0   \n",
       "78      dev       64        [2571, 2575): 'AIDS'     MISC     True      0   \n",
       "3       dev       21            [86, 90): 'UEFA'      ORG     True      0   \n",
       "0       dev       21  [25, 39): 'STANDARD LIEGE'      ORG     True      0   \n",
       "\n",
       "                                                 models  \n",
       "4927  [GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...  \n",
       "4925  [GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...  \n",
       "4924  [GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...  \n",
       "4923  [GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...  \n",
       "4922  [GOLD, 768_1, 32_1, 32_2, 32_3, 32_4, 64_1, 64...  \n",
       "...                                                 ...  \n",
       "374                                              [GOLD]  \n",
       "246                                              [GOLD]  \n",
       "78                                               [GOLD]  \n",
       "3                                                [GOLD]  \n",
       "0                                                [GOLD]  \n",
       "\n",
       "[4928 rows x 7 columns]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "full_results = cleaning.flag_suspicious_labels(evals,'ent_type','ent_type',label_name='ent_type',gold_feats=gold_elts,align_over_cols=['fold','doc_num','span'],keep_cols=[],split_doc=False)\n",
    "full_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_num</th>\n",
       "      <th>span</th>\n",
       "      <th>ent_type</th>\n",
       "      <th>in_gold</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>4927</th>\n",
       "      <td>train</td>\n",
       "      <td>907</td>\n",
       "      <td>[590, 598): 'Gorleben'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4925</th>\n",
       "      <td>train</td>\n",
       "      <td>907</td>\n",
       "      <td>[63, 67): 'BONN'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4924</th>\n",
       "      <td>train</td>\n",
       "      <td>907</td>\n",
       "      <td>[11, 17): 'German'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4923</th>\n",
       "      <td>train</td>\n",
       "      <td>896</td>\n",
       "      <td>[523, 528): 'China'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4922</th>\n",
       "      <td>train</td>\n",
       "      <td>896</td>\n",
       "      <td>[512, 518): 'Mexico'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>374</th>\n",
       "      <td>dev</td>\n",
       "      <td>149</td>\n",
       "      <td>[81, 93): 'Major League'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>246</th>\n",
       "      <td>dev</td>\n",
       "      <td>120</td>\n",
       "      <td>[63, 70): 'English'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>78</th>\n",
       "      <td>dev</td>\n",
       "      <td>64</td>\n",
       "      <td>[2571, 2575): 'AIDS'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>dev</td>\n",
       "      <td>21</td>\n",
       "      <td>[86, 90): 'UEFA'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>dev</td>\n",
       "      <td>21</td>\n",
       "      <td>[25, 39): 'STANDARD LIEGE'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>4928 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       fold  doc_num                        span ent_type  in_gold  count\n",
       "4927  train      907      [590, 598): 'Gorleben'      LOC     True     17\n",
       "4925  train      907            [63, 67): 'BONN'      LOC     True     17\n",
       "4924  train      907          [11, 17): 'German'     MISC     True     17\n",
       "4923  train      896         [523, 528): 'China'      LOC     True     17\n",
       "4922  train      896        [512, 518): 'Mexico'      LOC     True     17\n",
       "...     ...      ...                         ...      ...      ...    ...\n",
       "374     dev      149    [81, 93): 'Major League'     MISC     True      0\n",
       "246     dev      120         [63, 70): 'English'     MISC     True      0\n",
       "78      dev       64        [2571, 2575): 'AIDS'     MISC     True      0\n",
       "3       dev       21            [86, 90): 'UEFA'      ORG     True      0\n",
       "0       dev       21  [25, 39): 'STANDARD LIEGE'      ORG     True      0\n",
       "\n",
       "[4928 rows x 6 columns]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Drop Boolean columns for now\n",
    "results = full_results[[\"fold\", \"doc_num\", \"span\", \"ent_type\", \"in_gold\", \"count\"]]\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>num_ents</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>115</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>31</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>23</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>18</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>23</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>23</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>19</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>29</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>28</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>41</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>48</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>62</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>75</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>115</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>248</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>2940</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       num_ents\n",
       "count          \n",
       "0           115\n",
       "1            31\n",
       "2            23\n",
       "3            20\n",
       "4            17\n",
       "5            18\n",
       "6            23\n",
       "7            23\n",
       "8            19\n",
       "9            29\n",
       "10           28\n",
       "11           41\n",
       "12           48\n",
       "13           62\n",
       "14           75\n",
       "15          115\n",
       "16          248\n",
       "17         2940"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(results[results[\"in_gold\"] == True][[\"count\", \"span\"]]\n",
    " .groupby(\"count\").count()\n",
    " .rename(columns={\"span\": \"num_ents\"}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>num_ents</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>468</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>174</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>94</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>61</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>52</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>26</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>36</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>9</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>14</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>15</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>31</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       num_ents\n",
       "count          \n",
       "1           468\n",
       "2           174\n",
       "3            94\n",
       "4            61\n",
       "5            52\n",
       "6            26\n",
       "7            36\n",
       "8            16\n",
       "9            17\n",
       "10           12\n",
       "11            9\n",
       "12            9\n",
       "13            8\n",
       "14           11\n",
       "15           14\n",
       "16           15\n",
       "17           31"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(results[results[\"in_gold\"] == False][[\"count\", \"span\"]]\n",
    " .groupby(\"count\").count()\n",
    " .rename(columns={\"span\": \"num_ents\"}))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_num</th>\n",
       "      <th>span</th>\n",
       "      <th>ent_type</th>\n",
       "      <th>in_gold</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>dev</td>\n",
       "      <td>21</td>\n",
       "      <td>[86, 90): 'UEFA'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>dev</td>\n",
       "      <td>21</td>\n",
       "      <td>[25, 39): 'STANDARD LIEGE'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>78</th>\n",
       "      <td>dev</td>\n",
       "      <td>64</td>\n",
       "      <td>[2571, 2575): 'AIDS'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>246</th>\n",
       "      <td>dev</td>\n",
       "      <td>120</td>\n",
       "      <td>[63, 70): 'English'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>374</th>\n",
       "      <td>dev</td>\n",
       "      <td>149</td>\n",
       "      <td>[81, 93): 'Major League'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>498</th>\n",
       "      <td>dev</td>\n",
       "      <td>182</td>\n",
       "      <td>[2173, 2177): 'Ruch'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>462</th>\n",
       "      <td>dev</td>\n",
       "      <td>182</td>\n",
       "      <td>[662, 670): 'division'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>512</th>\n",
       "      <td>dev</td>\n",
       "      <td>203</td>\n",
       "      <td>[879, 881): '90'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>622</th>\n",
       "      <td>dev</td>\n",
       "      <td>214</td>\n",
       "      <td>[1689, 1705): 'Schindler's List'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>621</th>\n",
       "      <td>dev</td>\n",
       "      <td>214</td>\n",
       "      <td>[1643, 1648): 'Oscar'</td>\n",
       "      <td>PER</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>583</th>\n",
       "      <td>dev</td>\n",
       "      <td>214</td>\n",
       "      <td>[285, 305): 'Venice Film Festival'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>569</th>\n",
       "      <td>dev</td>\n",
       "      <td>214</td>\n",
       "      <td>[187, 202): 'Michael Collins'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>802</th>\n",
       "      <td>test</td>\n",
       "      <td>15</td>\n",
       "      <td>[44, 56): 'WORLD SERIES'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>801</th>\n",
       "      <td>test</td>\n",
       "      <td>15</td>\n",
       "      <td>[32, 43): 'WEST INDIES'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>942</th>\n",
       "      <td>test</td>\n",
       "      <td>21</td>\n",
       "      <td>[719, 725): 'Wijaya'</td>\n",
       "      <td>PER</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>896</th>\n",
       "      <td>test</td>\n",
       "      <td>21</td>\n",
       "      <td>[22, 38): 'WORLD GRAND PRIX'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1057</th>\n",
       "      <td>test</td>\n",
       "      <td>23</td>\n",
       "      <td>[1117, 1127): 'NY RANGERS'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1052</th>\n",
       "      <td>test</td>\n",
       "      <td>23</td>\n",
       "      <td>[1106, 1113): 'TORONTO'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1025</th>\n",
       "      <td>test</td>\n",
       "      <td>23</td>\n",
       "      <td>[673, 689): 'CENTRAL DIVISION'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1016</th>\n",
       "      <td>test</td>\n",
       "      <td>23</td>\n",
       "      <td>[599, 611): 'NY ISLANDERS'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      fold  doc_num                                span ent_type  in_gold  \\\n",
       "3      dev       21                    [86, 90): 'UEFA'      ORG     True   \n",
       "0      dev       21          [25, 39): 'STANDARD LIEGE'      ORG     True   \n",
       "78     dev       64                [2571, 2575): 'AIDS'     MISC     True   \n",
       "246    dev      120                 [63, 70): 'English'     MISC     True   \n",
       "374    dev      149            [81, 93): 'Major League'     MISC     True   \n",
       "498    dev      182                [2173, 2177): 'Ruch'      ORG     True   \n",
       "462    dev      182              [662, 670): 'division'     MISC     True   \n",
       "512    dev      203                    [879, 881): '90'      LOC     True   \n",
       "622    dev      214    [1689, 1705): 'Schindler's List'     MISC     True   \n",
       "621    dev      214               [1643, 1648): 'Oscar'      PER     True   \n",
       "583    dev      214  [285, 305): 'Venice Film Festival'     MISC     True   \n",
       "569    dev      214       [187, 202): 'Michael Collins'     MISC     True   \n",
       "802   test       15            [44, 56): 'WORLD SERIES'     MISC     True   \n",
       "801   test       15             [32, 43): 'WEST INDIES'      LOC     True   \n",
       "942   test       21                [719, 725): 'Wijaya'      PER     True   \n",
       "896   test       21        [22, 38): 'WORLD GRAND PRIX'     MISC     True   \n",
       "1057  test       23          [1117, 1127): 'NY RANGERS'      ORG     True   \n",
       "1052  test       23             [1106, 1113): 'TORONTO'      ORG     True   \n",
       "1025  test       23      [673, 689): 'CENTRAL DIVISION'     MISC     True   \n",
       "1016  test       23          [599, 611): 'NY ISLANDERS'      ORG     True   \n",
       "\n",
       "      count  \n",
       "3         0  \n",
       "0         0  \n",
       "78        0  \n",
       "246       0  \n",
       "374       0  \n",
       "498       0  \n",
       "462       0  \n",
       "512       0  \n",
       "622       0  \n",
       "621       0  \n",
       "583       0  \n",
       "569       0  \n",
       "802       0  \n",
       "801       0  \n",
       "942       0  \n",
       "896       0  \n",
       "1057      0  \n",
       "1052      0  \n",
       "1025      0  \n",
       "1016      0  "
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Pull out some hard-to-find examples, sorting by document to make labeling easier\n",
    "hard_to_get = results[results[\"in_gold\"]].sort_values([\"count\", \"fold\", \"doc_num\"]).head(20)\n",
    "hard_to_get"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TODO: Relabel the above 20 examples with a Markdown table (copy from CSV)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_num</th>\n",
       "      <th>span</th>\n",
       "      <th>ent_type</th>\n",
       "      <th>in_gold</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>373</th>\n",
       "      <td>dev</td>\n",
       "      <td>149</td>\n",
       "      <td>[81, 102): 'Major League Baseball'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>570</th>\n",
       "      <td>dev</td>\n",
       "      <td>214</td>\n",
       "      <td>[187, 202): 'Michael Collins'</td>\n",
       "      <td>PER</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>983</th>\n",
       "      <td>test</td>\n",
       "      <td>23</td>\n",
       "      <td>[94, 116): 'National Hockey League'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1110</th>\n",
       "      <td>test</td>\n",
       "      <td>25</td>\n",
       "      <td>[856, 864): 'NFC East'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1109</th>\n",
       "      <td>test</td>\n",
       "      <td>25</td>\n",
       "      <td>[823, 835): 'Philadelphia'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1184</th>\n",
       "      <td>test</td>\n",
       "      <td>41</td>\n",
       "      <td>[674, 688): 'Sporting Gijon'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1323</th>\n",
       "      <td>test</td>\n",
       "      <td>114</td>\n",
       "      <td>[51, 61): 'sales-USDA'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1367</th>\n",
       "      <td>test</td>\n",
       "      <td>118</td>\n",
       "      <td>[776, 791): 'mid-Mississippi'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1362</th>\n",
       "      <td>test</td>\n",
       "      <td>118</td>\n",
       "      <td>[535, 550): 'mid-Mississippi'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1509</th>\n",
       "      <td>test</td>\n",
       "      <td>178</td>\n",
       "      <td>[1787, 1800): 'Uruguay Round'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1560</th>\n",
       "      <td>test</td>\n",
       "      <td>180</td>\n",
       "      <td>[588, 592): 'BILO'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1558</th>\n",
       "      <td>test</td>\n",
       "      <td>180</td>\n",
       "      <td>[579, 583): 'TOPS'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1550</th>\n",
       "      <td>test</td>\n",
       "      <td>180</td>\n",
       "      <td>[395, 399): 'BILO'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1544</th>\n",
       "      <td>test</td>\n",
       "      <td>180</td>\n",
       "      <td>[286, 293): 'Malysia'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1542</th>\n",
       "      <td>test</td>\n",
       "      <td>180</td>\n",
       "      <td>[259, 263): 'BILO'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1649</th>\n",
       "      <td>test</td>\n",
       "      <td>207</td>\n",
       "      <td>[1041, 1047): 'Oxford'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1786</th>\n",
       "      <td>test</td>\n",
       "      <td>219</td>\n",
       "      <td>[368, 381): 'Koo Jeon Woon'</td>\n",
       "      <td>PER</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1807</th>\n",
       "      <td>test</td>\n",
       "      <td>222</td>\n",
       "      <td>[218, 225): 'EASTERN'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1805</th>\n",
       "      <td>test</td>\n",
       "      <td>222</td>\n",
       "      <td>[92, 114): 'National Hockey League'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2054</th>\n",
       "      <td>train</td>\n",
       "      <td>48</td>\n",
       "      <td>[885, 899): 'Sjeng Schalken'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>False</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       fold  doc_num                                 span ent_type  in_gold  \\\n",
       "373     dev      149   [81, 102): 'Major League Baseball'     MISC    False   \n",
       "570     dev      214        [187, 202): 'Michael Collins'      PER    False   \n",
       "983    test       23  [94, 116): 'National Hockey League'     MISC    False   \n",
       "1110   test       25               [856, 864): 'NFC East'     MISC    False   \n",
       "1109   test       25           [823, 835): 'Philadelphia'      ORG    False   \n",
       "1184   test       41         [674, 688): 'Sporting Gijon'      ORG    False   \n",
       "1323   test      114               [51, 61): 'sales-USDA'      ORG    False   \n",
       "1367   test      118        [776, 791): 'mid-Mississippi'      LOC    False   \n",
       "1362   test      118        [535, 550): 'mid-Mississippi'      LOC    False   \n",
       "1509   test      178        [1787, 1800): 'Uruguay Round'     MISC    False   \n",
       "1560   test      180                   [588, 592): 'BILO'      ORG    False   \n",
       "1558   test      180                   [579, 583): 'TOPS'      ORG    False   \n",
       "1550   test      180                   [395, 399): 'BILO'      ORG    False   \n",
       "1544   test      180                [286, 293): 'Malysia'      ORG    False   \n",
       "1542   test      180                   [259, 263): 'BILO'      ORG    False   \n",
       "1649   test      207               [1041, 1047): 'Oxford'      ORG    False   \n",
       "1786   test      219          [368, 381): 'Koo Jeon Woon'      PER    False   \n",
       "1807   test      222                [218, 225): 'EASTERN'     MISC    False   \n",
       "1805   test      222  [92, 114): 'National Hockey League'     MISC    False   \n",
       "2054  train       48         [885, 899): 'Sjeng Schalken'      ORG    False   \n",
       "\n",
       "      count  \n",
       "373      17  \n",
       "570      17  \n",
       "983      17  \n",
       "1110     17  \n",
       "1109     17  \n",
       "1184     17  \n",
       "1323     17  \n",
       "1367     17  \n",
       "1362     17  \n",
       "1509     17  \n",
       "1560     17  \n",
       "1558     17  \n",
       "1550     17  \n",
       "1544     17  \n",
       "1542     17  \n",
       "1649     17  \n",
       "1786     17  \n",
       "1807     17  \n",
       "1805     17  \n",
       "2054     17  "
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Hardest results not in the gold standard for models to avoid\n",
    "hard_to_avoid = results[~results[\"in_gold\"]].sort_values([\"count\", \"fold\", \"doc_num\"], ascending=[False, True, True]).head(20)\n",
    "hard_to_avoid"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TODO: Relabel the above 20 examples (copy from CSV)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Remainder of Experiment\n",
    "\n",
    "For each of the 10 folds, train a model on the fold's training set and run\n",
    "analysis on the fold's test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting fold 1.\n",
      "Training model using all of 768-dimension embeddings.\n",
      "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n",
      "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n",
      "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n",
      "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n",
      "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n",
      "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n",
      "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n",
      "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n",
      "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n",
      "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n",
      "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n",
      "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n",
      "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n",
      "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n",
      "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n",
      "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n",
      "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=32 and seed=438878.\n",
      "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=32 and seed=654571.\n",
      "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=64 and seed=975622.\n",
      "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=32 and seed=773956.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n",
      "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=256 and seed=643865.\n",
      "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=256 and seed=781567.\n",
      "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=64 and seed=201469.\n",
      "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=128 and seed=839748.\n",
      "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=128 and seed=513226.\n",
      "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=128 and seed=450385.\n",
      "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=128 and seed=128113.\n",
      "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=64 and seed=526478.\n",
      "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=64 and seed=94177.\n",
      "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=256 and seed=402414.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n",
      "Trained 17 models.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8e6ab8836b6246edb1d3e635fda3ad7f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with fold 1.\n",
      "Starting fold 2.\n",
      "Training model using all of 768-dimension embeddings.\n",
      "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n",
      "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n",
      "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n",
      "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n",
      "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n",
      "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n",
      "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n",
      "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n",
      "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n",
      "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n",
      "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n",
      "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n",
      "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n",
      "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n",
      "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n",
      "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n",
      "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=256 and seed=643865.\n",
      "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=256 and seed=781567.\n",
      "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=128 and seed=450385.\n",
      "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=256 and seed=402414.\n",
      "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=32 and seed=654571.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n",
      "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=32 and seed=438878.\n",
      "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=32 and seed=773956.\n",
      "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=128 and seed=128113.\n",
      "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=64 and seed=526478.\n",
      "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=64 and seed=975622.\n",
      "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=64 and seed=94177.\n",
      "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=64 and seed=201469.\n",
      "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=128 and seed=839748.\n",
      "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=128 and seed=513226.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n",
      "Trained 17 models.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "51af4ff8c89d4dca999d510eccb7e575",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with fold 2.\n",
      "Starting fold 3.\n",
      "Training model using all of 768-dimension embeddings.\n",
      "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n",
      "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n",
      "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n",
      "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n",
      "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n",
      "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n",
      "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n",
      "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n",
      "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n",
      "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n",
      "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n",
      "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n",
      "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n",
      "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n",
      "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n",
      "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n",
      "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=32 and seed=773956.\n",
      "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=32 and seed=654571.\n",
      "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=64 and seed=975622.\n",
      "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=32 and seed=438878.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n",
      "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=256 and seed=781567.\n",
      "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=256 and seed=643865.\n",
      "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=64 and seed=201469.\n",
      "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=128 and seed=839748.\n",
      "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=128 and seed=450385.\n",
      "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=128 and seed=513226.\n",
      "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=128 and seed=128113.\n",
      "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=64 and seed=94177.\n",
      "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=64 and seed=526478.\n",
      "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=256 and seed=402414.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n",
      "Trained 17 models.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d7b85c9d3d9f4c7dbad63f5a7b00ce03",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with fold 3.\n",
      "Starting fold 4.\n",
      "Training model using all of 768-dimension embeddings.\n",
      "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n",
      "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n",
      "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n",
      "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n",
      "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n",
      "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n",
      "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n",
      "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n",
      "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n",
      "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n",
      "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n",
      "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n",
      "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n",
      "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n",
      "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n",
      "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n",
      "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=256 and seed=643865.\n",
      "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=128 and seed=450385.\n",
      "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=256 and seed=781567.\n",
      "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=32 and seed=438878.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n",
      "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=32 and seed=654571.\n",
      "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=32 and seed=773956.\n",
      "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=128 and seed=839748.\n",
      "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=64 and seed=975622.\n",
      "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=64 and seed=94177.\n",
      "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=64 and seed=526478.\n",
      "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=64 and seed=201469.\n",
      "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=128 and seed=513226.\n",
      "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=128 and seed=128113.\n",
      "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=256 and seed=402414.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n",
      "Trained 17 models.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "01b0a7edf3714a63b0237cf286fd041d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with fold 4.\n",
      "Starting fold 5.\n",
      "Training model using all of 768-dimension embeddings.\n",
      "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n",
      "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n",
      "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n",
      "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n",
      "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n",
      "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n",
      "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n",
      "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n",
      "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n",
      "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n",
      "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n",
      "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n",
      "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n",
      "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n",
      "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n",
      "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n",
      "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=32 and seed=438878.\n",
      "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=32 and seed=773956.\n",
      "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=64 and seed=526478.\n",
      "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=32 and seed=654571.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n",
      "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=64 and seed=94177.\n",
      "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=128 and seed=513226.\n",
      "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=64 and seed=975622.\n",
      "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=64 and seed=201469.\n",
      "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=128 and seed=128113.\n",
      "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=128 and seed=839748.\n",
      "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=128 and seed=450385.\n",
      "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=256 and seed=781567.\n",
      "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=256 and seed=643865.\n",
      "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=256 and seed=402414.\n",
      "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=256 and seed=822761.\n",
      "Trained 17 models.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c003ad7644b041b18315f34e8bf1f85c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with fold 5.\n",
      "Starting fold 6.\n",
      "Training model using all of 768-dimension embeddings.\n",
      "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n",
      "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n",
      "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n",
      "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n",
      "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n",
      "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n",
      "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n",
      "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n",
      "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n",
      "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n",
      "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n",
      "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n",
      "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n",
      "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n",
      "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n",
      "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n",
      "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=256 and seed=781567.\n",
      "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=32 and seed=89250.\n",
      "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=128 and seed=839748.\n",
      "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=32 and seed=438878.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=643865.\n",
      "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=32 and seed=654571.\n",
      "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=32 and seed=773956.\n",
      "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=128 and seed=513226.\n",
      "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=64 and seed=94177.\n",
      "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=64 and seed=526478.\n",
      "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=64 and seed=975622.\n",
      "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=64 and seed=201469.\n",
      "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=128 and seed=450385.\n",
      "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=128 and seed=128113.\n",
      "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=256 and seed=402414.\n",
      "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=256 and seed=822761.\n",
      "Trained 17 models.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "79e940db371644e9bf2e8c945e0806bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with fold 6.\n",
      "Starting fold 7.\n",
      "Training model using all of 768-dimension embeddings.\n",
      "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n",
      "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n",
      "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n",
      "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n",
      "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n",
      "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n",
      "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n",
      "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n",
      "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n",
      "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n",
      "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n",
      "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n",
      "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n",
      "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n",
      "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n",
      "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n",
      "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=32 and seed=654571.\n",
      "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=64 and seed=975622.\n",
      "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=32 and seed=438878.\n",
      "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=256 and seed=781567.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=773956.\n",
      "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=32 and seed=89250.\n",
      "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=256 and seed=402414.\n",
      "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=64 and seed=94177.\n",
      "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=128 and seed=513226.\n",
      "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=128 and seed=128113.\n",
      "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=128 and seed=450385.\n",
      "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=128 and seed=839748.\n",
      "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=64 and seed=526478.\n",
      "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=64 and seed=201469.\n",
      "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=256 and seed=643865.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n",
      "Trained 17 models.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "04a013f28dcd46d1831ff6ca1be8266a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with fold 7.\n",
      "Starting fold 8.\n",
      "Training model using all of 768-dimension embeddings.\n",
      "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n",
      "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n",
      "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n",
      "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n",
      "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n",
      "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n",
      "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n",
      "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n",
      "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n",
      "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n",
      "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n",
      "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n",
      "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n",
      "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n",
      "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n",
      "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n",
      "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=256 and seed=781567.\n",
      "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=32 and seed=773956.\n",
      "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=128 and seed=450385.\n",
      "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=256 and seed=402414.\n",
      "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=32 and seed=438878.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n",
      "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=256 and seed=643865.\n",
      "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=32 and seed=654571.\n",
      "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=128 and seed=513226.\n",
      "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=64 and seed=94177.\n",
      "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=64 and seed=201469.\n",
      "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=64 and seed=975622.\n",
      "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=64 and seed=526478.\n",
      "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=128 and seed=839748.\n",
      "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=128 and seed=128113.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n",
      "Trained 17 models.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4a17afd5ec08439a8956146853fd9679",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with fold 8.\n",
      "Starting fold 9.\n",
      "Training model using all of 768-dimension embeddings.\n",
      "Training model '32_1' (#1 at 32 dimensions) with seed 89250\n",
      "Training model '32_2' (#2 at 32 dimensions) with seed 773956\n",
      "Training model '32_3' (#3 at 32 dimensions) with seed 654571\n",
      "Training model '32_4' (#4 at 32 dimensions) with seed 438878\n",
      "Training model '64_1' (#1 at 64 dimensions) with seed 201469\n",
      "Training model '64_2' (#2 at 64 dimensions) with seed 94177\n",
      "Training model '64_3' (#3 at 64 dimensions) with seed 526478\n",
      "Training model '64_4' (#4 at 64 dimensions) with seed 975622\n",
      "Training model '128_1' (#1 at 128 dimensions) with seed 513226\n",
      "Training model '128_2' (#2 at 128 dimensions) with seed 128113\n",
      "Training model '128_3' (#3 at 128 dimensions) with seed 839748\n",
      "Training model '128_4' (#4 at 128 dimensions) with seed 450385\n",
      "Training model '256_1' (#1 at 256 dimensions) with seed 781567\n",
      "Training model '256_2' (#2 at 256 dimensions) with seed 643865\n",
      "Training model '256_3' (#3 at 256 dimensions) with seed 402414\n",
      "Training model '256_4' (#4 at 256 dimensions) with seed 822761\n",
      "\u001b[2m\u001b[36m(pid=72363)\u001b[0m Training model with n_components=32 and seed=773956.\n",
      "\u001b[2m\u001b[36m(pid=72366)\u001b[0m Training model with n_components=64 and seed=975622.\n",
      "\u001b[2m\u001b[36m(pid=72365)\u001b[0m Training model with n_components=32 and seed=438878.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=32 and seed=89250.\n",
      "\u001b[2m\u001b[36m(pid=72375)\u001b[0m Training model with n_components=32 and seed=654571.\n",
      "\u001b[2m\u001b[36m(pid=72364)\u001b[0m Training model with n_components=64 and seed=526478.\n",
      "\u001b[2m\u001b[36m(pid=72370)\u001b[0m Training model with n_components=128 and seed=513226.\n",
      "\u001b[2m\u001b[36m(pid=72369)\u001b[0m Training model with n_components=64 and seed=94177.\n",
      "\u001b[2m\u001b[36m(pid=72367)\u001b[0m Training model with n_components=64 and seed=201469.\n",
      "\u001b[2m\u001b[36m(pid=72374)\u001b[0m Training model with n_components=128 and seed=128113.\n",
      "\u001b[2m\u001b[36m(pid=72376)\u001b[0m Training model with n_components=128 and seed=839748.\n",
      "\u001b[2m\u001b[36m(pid=72373)\u001b[0m Training model with n_components=128 and seed=450385.\n",
      "\u001b[2m\u001b[36m(pid=72372)\u001b[0m Training model with n_components=256 and seed=781567.\n",
      "\u001b[2m\u001b[36m(pid=72371)\u001b[0m Training model with n_components=256 and seed=643865.\n",
      "\u001b[2m\u001b[36m(pid=72362)\u001b[0m Training model with n_components=256 and seed=402414.\n",
      "\u001b[2m\u001b[36m(pid=72368)\u001b[0m Training model with n_components=256 and seed=822761.\n",
      "Trained 17 models.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "038d5be9065f4ac6837021045d963b06",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Starting...', layout=Layout(width='100%'), max=17, style=ProgressStyle(descr…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done with fold 9.\n"
     ]
    }
   ],
   "source": [
    "def handle_fold(fold_ix: int) -> Dict[str, Any]:\n",
    "    \"\"\"\n",
    "    The per-fold processing of the previous section's cells, collapsed into \n",
    "    a single function.\n",
    "    \n",
    "    :param fold_ix: 0-based index of fold\n",
    "    \n",
    "    :returns: a dictionary that maps data structure name to data structure\n",
    "    \"\"\"\n",
    "    # To avoid accidentally picking up leftover data from a previous cell,\n",
    "    # variables local to this function are named with a leading underscore\n",
    "    _train_inputs_df = corpus_df.merge(train_keys[fold_ix])\n",
    "    _test_inputs_df = corpus_df.merge(test_keys[fold_ix])\n",
    "    _models = maybe_train_models(_train_inputs_df, fold_ix)\n",
    "    _evals = eval_models(_models, _test_inputs_df)\n",
    "    _summary_df = make_summary_df(_evals)\n",
    "    _gold_elts = cleaning.preprocess.combine_raw_spans_docs_to_match(corpus_raw,_evals[list(evals.keys())[0]])\n",
    "    _full_results = cleaning.flag_suspicious_labels(_evals,'ent_type','ent_type',\n",
    "                                                    label_name='ent_type',\n",
    "                                                    gold_feats=_gold_elts,\n",
    "                                                    align_over_cols=['fold','doc_num','span'],\n",
    "                                                    keep_cols=[],split_doc=False)\n",
    "    _results = _full_results[[\"fold\", \"doc_num\", \"span\", \n",
    "                              \"ent_type\", \"in_gold\", \"count\"]]\n",
    "    return {\n",
    "        \"models\": _models,\n",
    "        \"summary_df\": _summary_df,\n",
    "        \"full_results\": _full_results,\n",
    "        \"results\": _results\n",
    "    }\n",
    "\n",
    "# Start with the (already computed) results for fold 0\n",
    "results_by_fold = [\n",
    "    {\n",
    "        \"models\": models,\n",
    "        \"summary_df\": summary_df,\n",
    "        \"full_results\": full_results,\n",
    "        \"results\": results\n",
    "    }\n",
    "]\n",
    "\n",
    "for fold in range(1, _KFOLD_NUM_FOLDS):\n",
    "    print(f\"Starting fold {fold}.\")\n",
    "    results_by_fold.append(handle_fold(fold))\n",
    "    print(f\"Done with fold {fold}.\")\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_num</th>\n",
       "      <th>span</th>\n",
       "      <th>ent_type</th>\n",
       "      <th>in_gold</th>\n",
       "      <th>count</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>4927</th>\n",
       "      <td>train</td>\n",
       "      <td>907</td>\n",
       "      <td>[590, 598): 'Gorleben'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4925</th>\n",
       "      <td>train</td>\n",
       "      <td>907</td>\n",
       "      <td>[63, 67): 'BONN'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4924</th>\n",
       "      <td>train</td>\n",
       "      <td>907</td>\n",
       "      <td>[11, 17): 'German'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4923</th>\n",
       "      <td>train</td>\n",
       "      <td>896</td>\n",
       "      <td>[523, 528): 'China'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4922</th>\n",
       "      <td>train</td>\n",
       "      <td>896</td>\n",
       "      <td>[512, 518): 'Mexico'</td>\n",
       "      <td>LOC</td>\n",
       "      <td>True</td>\n",
       "      <td>17</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>271</th>\n",
       "      <td>dev</td>\n",
       "      <td>93</td>\n",
       "      <td>[469, 481): 'JAKARTA POST'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>183</th>\n",
       "      <td>dev</td>\n",
       "      <td>76</td>\n",
       "      <td>[1285, 1312): 'Chicago Purchasing Managers'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>126</th>\n",
       "      <td>dev</td>\n",
       "      <td>49</td>\n",
       "      <td>[1920, 1925): 'Tajik'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>dev</td>\n",
       "      <td>15</td>\n",
       "      <td>[109, 133): 'National Football League'</td>\n",
       "      <td>ORG</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>dev</td>\n",
       "      <td>15</td>\n",
       "      <td>[15, 40): 'AMERICAN FOOTBALL-RANDALL'</td>\n",
       "      <td>MISC</td>\n",
       "      <td>True</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>44802 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       fold  doc_num                                         span ent_type  \\\n",
       "4927  train      907                       [590, 598): 'Gorleben'      LOC   \n",
       "4925  train      907                             [63, 67): 'BONN'      LOC   \n",
       "4924  train      907                           [11, 17): 'German'     MISC   \n",
       "4923  train      896                          [523, 528): 'China'      LOC   \n",
       "4922  train      896                         [512, 518): 'Mexico'      LOC   \n",
       "...     ...      ...                                          ...      ...   \n",
       "271     dev       93                   [469, 481): 'JAKARTA POST'      ORG   \n",
       "183     dev       76  [1285, 1312): 'Chicago Purchasing Managers'      ORG   \n",
       "126     dev       49                        [1920, 1925): 'Tajik'     MISC   \n",
       "25      dev       15       [109, 133): 'National Football League'      ORG   \n",
       "17      dev       15        [15, 40): 'AMERICAN FOOTBALL-RANDALL'     MISC   \n",
       "\n",
       "      in_gold  count  \n",
       "4927     True     17  \n",
       "4925     True     17  \n",
       "4924     True     17  \n",
       "4923     True     17  \n",
       "4922     True     17  \n",
       "...       ...    ...  \n",
       "271      True      0  \n",
       "183      True      0  \n",
       "126      True      0  \n",
       "25       True      0  \n",
       "17       True      0  \n",
       "\n",
       "[44802 rows x 6 columns]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Combine all the results into a single dataframe for the entire corpus\n",
    "all_results = pd.concat([r[\"results\"] for r in results_by_fold])\n",
    "all_results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate CSV files for manual labeling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>count</th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_offset</th>\n",
       "      <th>corpus_span</th>\n",
       "      <th>corpus_ent_type</th>\n",
       "      <th>error_type</th>\n",
       "      <th>correct_span</th>\n",
       "      <th>correct_ent_type</th>\n",
       "      <th>notes</th>\n",
       "      <th>time_started</th>\n",
       "      <th>time_stopped</th>\n",
       "      <th>time_elapsed</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>0</td>\n",
       "      <td>dev</td>\n",
       "      <td>2</td>\n",
       "      <td>[760, 765): 'Leeds'</td>\n",
       "      <td>ORG</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>0</td>\n",
       "      <td>dev</td>\n",
       "      <td>2</td>\n",
       "      <td>[614, 634): 'Duke of Norfolk's XI'</td>\n",
       "      <td>ORG</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0</td>\n",
       "      <td>dev</td>\n",
       "      <td>2</td>\n",
       "      <td>[189, 218): 'Test and County Cricket Board'</td>\n",
       "      <td>ORG</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>dev</td>\n",
       "      <td>2</td>\n",
       "      <td>[87, 92): 'Ashes'</td>\n",
       "      <td>MISC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>dev</td>\n",
       "      <td>2</td>\n",
       "      <td>[25, 30): 'ASHES'</td>\n",
       "      <td>MISC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1738</th>\n",
       "      <td>17</td>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "      <td>[230, 238): 'Charlton'</td>\n",
       "      <td>PER</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1737</th>\n",
       "      <td>17</td>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "      <td>[177, 187): 'Englishman'</td>\n",
       "      <td>MISC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1736</th>\n",
       "      <td>17</td>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "      <td>[135, 142): 'Ireland'</td>\n",
       "      <td>LOC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1735</th>\n",
       "      <td>17</td>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "      <td>[87, 100): 'Jack Charlton'</td>\n",
       "      <td>PER</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1734</th>\n",
       "      <td>17</td>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "      <td>[69, 75): 'DUBLIN'</td>\n",
       "      <td>LOC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>11590 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      count  fold  doc_offset                                  corpus_span  \\\n",
       "30        0   dev           2                          [760, 765): 'Leeds'   \n",
       "21        0   dev           2           [614, 634): 'Duke of Norfolk's XI'   \n",
       "5         0   dev           2  [189, 218): 'Test and County Cricket Board'   \n",
       "3         0   dev           2                            [87, 92): 'Ashes'   \n",
       "0         0   dev           2                            [25, 30): 'ASHES'   \n",
       "...     ...   ...         ...                                          ...   \n",
       "1738     17  test         230                       [230, 238): 'Charlton'   \n",
       "1737     17  test         230                     [177, 187): 'Englishman'   \n",
       "1736     17  test         230                        [135, 142): 'Ireland'   \n",
       "1735     17  test         230                   [87, 100): 'Jack Charlton'   \n",
       "1734     17  test         230                           [69, 75): 'DUBLIN'   \n",
       "\n",
       "     corpus_ent_type error_type correct_span correct_ent_type notes  \\\n",
       "30               ORG                                                  \n",
       "21               ORG                                                  \n",
       "5                ORG                                                  \n",
       "3               MISC                                                  \n",
       "0               MISC                                                  \n",
       "...              ...        ...          ...              ...   ...   \n",
       "1738             PER                                                  \n",
       "1737            MISC                                                  \n",
       "1736             LOC                                                  \n",
       "1735             PER                                                  \n",
       "1734             LOC                                                  \n",
       "\n",
       "     time_started time_stopped time_elapsed  \n",
       "30                                           \n",
       "21                                           \n",
       "5                                            \n",
       "3                                            \n",
       "0                                            \n",
       "...           ...          ...          ...  \n",
       "1738                                         \n",
       "1737                                         \n",
       "1736                                         \n",
       "1735                                         \n",
       "1734                                         \n",
       "\n",
       "[11590 rows x 12 columns]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Reformat for output\n",
    "dev_and_test_results = all_results[all_results[\"fold\"].isin([\"dev\", \"test\"])]\n",
    "in_gold_to_write, not_in_gold_to_write = cleaning.analysis.csv_prep(dev_and_test_results, \"count\")\n",
    "in_gold_to_write"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>count</th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_offset</th>\n",
       "      <th>model_span</th>\n",
       "      <th>model_ent_type</th>\n",
       "      <th>error_type</th>\n",
       "      <th>corpus_span</th>\n",
       "      <th>corpus_ent_type</th>\n",
       "      <th>correct_span</th>\n",
       "      <th>correct_ent_type</th>\n",
       "      <th>notes</th>\n",
       "      <th>time_started</th>\n",
       "      <th>time_stopped</th>\n",
       "      <th>time_elapsed</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>17</td>\n",
       "      <td>dev</td>\n",
       "      <td>2</td>\n",
       "      <td>[760, 765): 'Leeds'</td>\n",
       "      <td>LOC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>17</td>\n",
       "      <td>dev</td>\n",
       "      <td>6</td>\n",
       "      <td>[567, 572): 'Rotor'</td>\n",
       "      <td>PER</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>17</td>\n",
       "      <td>dev</td>\n",
       "      <td>6</td>\n",
       "      <td>[399, 404): 'Rotor'</td>\n",
       "      <td>PER</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>17</td>\n",
       "      <td>dev</td>\n",
       "      <td>6</td>\n",
       "      <td>[262, 267): 'Rotor'</td>\n",
       "      <td>PER</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>142</th>\n",
       "      <td>17</td>\n",
       "      <td>dev</td>\n",
       "      <td>11</td>\n",
       "      <td>[1961, 1975): 'Czech Republic'</td>\n",
       "      <td>LOC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1708</th>\n",
       "      <td>1</td>\n",
       "      <td>test</td>\n",
       "      <td>228</td>\n",
       "      <td>[771, 784): 'De Graafschap'</td>\n",
       "      <td>ORG</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1690</th>\n",
       "      <td>1</td>\n",
       "      <td>test</td>\n",
       "      <td>228</td>\n",
       "      <td>[269, 287): 'Brazilian defender'</td>\n",
       "      <td>MISC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1679</th>\n",
       "      <td>1</td>\n",
       "      <td>test</td>\n",
       "      <td>228</td>\n",
       "      <td>[40, 43): 'SIX'</td>\n",
       "      <td>ORG</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1724</th>\n",
       "      <td>1</td>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "      <td>[19, 29): 'ENGLISHMAN'</td>\n",
       "      <td>LOC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1727</th>\n",
       "      <td>1</td>\n",
       "      <td>test</td>\n",
       "      <td>230</td>\n",
       "      <td>[19, 38): 'ENGLISHMAN CHARLTON'</td>\n",
       "      <td>PER</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>4366 rows × 14 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      count  fold  doc_offset                        model_span  \\\n",
       "29       17   dev           2               [760, 765): 'Leeds'   \n",
       "25       17   dev           6               [567, 572): 'Rotor'   \n",
       "20       17   dev           6               [399, 404): 'Rotor'   \n",
       "16       17   dev           6               [262, 267): 'Rotor'   \n",
       "142      17   dev          11    [1961, 1975): 'Czech Republic'   \n",
       "...     ...   ...         ...                               ...   \n",
       "1708      1  test         228       [771, 784): 'De Graafschap'   \n",
       "1690      1  test         228  [269, 287): 'Brazilian defender'   \n",
       "1679      1  test         228                   [40, 43): 'SIX'   \n",
       "1724      1  test         230            [19, 29): 'ENGLISHMAN'   \n",
       "1727      1  test         230   [19, 38): 'ENGLISHMAN CHARLTON'   \n",
       "\n",
       "     model_ent_type error_type corpus_span corpus_ent_type correct_span  \\\n",
       "29              LOC                                                       \n",
       "25              PER                                                       \n",
       "20              PER                                                       \n",
       "16              PER                                                       \n",
       "142             LOC                                                       \n",
       "...             ...        ...         ...             ...          ...   \n",
       "1708            ORG                                                       \n",
       "1690           MISC                                                       \n",
       "1679            ORG                                                       \n",
       "1724            LOC                                                       \n",
       "1727            PER                                                       \n",
       "\n",
       "     correct_ent_type notes time_started time_stopped time_elapsed  \n",
       "29                                                                  \n",
       "25                                                                  \n",
       "20                                                                  \n",
       "16                                                                  \n",
       "142                                                                 \n",
       "...               ...   ...          ...          ...          ...  \n",
       "1708                                                                \n",
       "1690                                                                \n",
       "1679                                                                \n",
       "1724                                                                \n",
       "1727                                                                \n",
       "\n",
       "[4366 rows x 14 columns]"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "not_in_gold_to_write"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_gold_to_write.to_csv(\"outputs/CoNLL_4_in_gold.csv\", index=False)\n",
    "not_in_gold_to_write.to_csv(\"outputs/CoNLL_4_not_in_gold.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>count</th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_offset</th>\n",
       "      <th>corpus_span</th>\n",
       "      <th>corpus_ent_type</th>\n",
       "      <th>error_type</th>\n",
       "      <th>correct_span</th>\n",
       "      <th>correct_ent_type</th>\n",
       "      <th>notes</th>\n",
       "      <th>time_started</th>\n",
       "      <th>time_stopped</th>\n",
       "      <th>time_elapsed</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1486</th>\n",
       "      <td>0</td>\n",
       "      <td>train</td>\n",
       "      <td>6</td>\n",
       "      <td>[121, 137): 'Toronto Dominion'</td>\n",
       "      <td>PER</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1358</th>\n",
       "      <td>0</td>\n",
       "      <td>train</td>\n",
       "      <td>24</td>\n",
       "      <td>[384, 388): 'FLNC'</td>\n",
       "      <td>ORG</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1355</th>\n",
       "      <td>0</td>\n",
       "      <td>train</td>\n",
       "      <td>24</td>\n",
       "      <td>[161, 169): 'Africans'</td>\n",
       "      <td>MISC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1965</th>\n",
       "      <td>0</td>\n",
       "      <td>train</td>\n",
       "      <td>25</td>\n",
       "      <td>[141, 151): 'mid-Norway'</td>\n",
       "      <td>MISC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1383</th>\n",
       "      <td>0</td>\n",
       "      <td>train</td>\n",
       "      <td>28</td>\n",
       "      <td>[1133, 1135): 'EU'</td>\n",
       "      <td>ORG</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4132</th>\n",
       "      <td>17</td>\n",
       "      <td>train</td>\n",
       "      <td>945</td>\n",
       "      <td>[130, 137): 'Preston'</td>\n",
       "      <td>ORG</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4131</th>\n",
       "      <td>17</td>\n",
       "      <td>train</td>\n",
       "      <td>945</td>\n",
       "      <td>[119, 127): 'Plymouth'</td>\n",
       "      <td>ORG</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4130</th>\n",
       "      <td>17</td>\n",
       "      <td>train</td>\n",
       "      <td>945</td>\n",
       "      <td>[72, 79): 'English'</td>\n",
       "      <td>MISC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4129</th>\n",
       "      <td>17</td>\n",
       "      <td>train</td>\n",
       "      <td>945</td>\n",
       "      <td>[43, 49): 'LONDON'</td>\n",
       "      <td>LOC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4128</th>\n",
       "      <td>17</td>\n",
       "      <td>train</td>\n",
       "      <td>945</td>\n",
       "      <td>[19, 26): 'ENGLISH'</td>\n",
       "      <td>MISC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>23499 rows × 12 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      count   fold  doc_offset                     corpus_span  \\\n",
       "1486      0  train           6  [121, 137): 'Toronto Dominion'   \n",
       "1358      0  train          24              [384, 388): 'FLNC'   \n",
       "1355      0  train          24          [161, 169): 'Africans'   \n",
       "1965      0  train          25        [141, 151): 'mid-Norway'   \n",
       "1383      0  train          28              [1133, 1135): 'EU'   \n",
       "...     ...    ...         ...                             ...   \n",
       "4132     17  train         945           [130, 137): 'Preston'   \n",
       "4131     17  train         945          [119, 127): 'Plymouth'   \n",
       "4130     17  train         945             [72, 79): 'English'   \n",
       "4129     17  train         945              [43, 49): 'LONDON'   \n",
       "4128     17  train         945             [19, 26): 'ENGLISH'   \n",
       "\n",
       "     corpus_ent_type error_type correct_span correct_ent_type notes  \\\n",
       "1486             PER                                                  \n",
       "1358             ORG                                                  \n",
       "1355            MISC                                                  \n",
       "1965            MISC                                                  \n",
       "1383             ORG                                                  \n",
       "...              ...        ...          ...              ...   ...   \n",
       "4132             ORG                                                  \n",
       "4131             ORG                                                  \n",
       "4130            MISC                                                  \n",
       "4129             LOC                                                  \n",
       "4128            MISC                                                  \n",
       "\n",
       "     time_started time_stopped time_elapsed  \n",
       "1486                                         \n",
       "1358                                         \n",
       "1355                                         \n",
       "1965                                         \n",
       "1383                                         \n",
       "...           ...          ...          ...  \n",
       "4132                                         \n",
       "4131                                         \n",
       "4130                                         \n",
       "4129                                         \n",
       "4128                                         \n",
       "\n",
       "[23499 rows x 12 columns]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Repeat for the contents of the original training set\n",
    "train_results = all_results[all_results[\"fold\"] == \"train\"]\n",
    "in_gold_to_write, not_in_gold_to_write = cleaning.analysis.csv_prep(train_results, \"count\")\n",
    "in_gold_to_write"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>count</th>\n",
       "      <th>fold</th>\n",
       "      <th>doc_offset</th>\n",
       "      <th>model_span</th>\n",
       "      <th>model_ent_type</th>\n",
       "      <th>error_type</th>\n",
       "      <th>corpus_span</th>\n",
       "      <th>corpus_ent_type</th>\n",
       "      <th>correct_span</th>\n",
       "      <th>correct_ent_type</th>\n",
       "      <th>notes</th>\n",
       "      <th>time_started</th>\n",
       "      <th>time_stopped</th>\n",
       "      <th>time_elapsed</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1738</th>\n",
       "      <td>17</td>\n",
       "      <td>train</td>\n",
       "      <td>3</td>\n",
       "      <td>[0, 10): '-DOCSTART-'</td>\n",
       "      <td>LOC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1485</th>\n",
       "      <td>17</td>\n",
       "      <td>train</td>\n",
       "      <td>6</td>\n",
       "      <td>[121, 137): 'Toronto Dominion'</td>\n",
       "      <td>LOC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1964</th>\n",
       "      <td>17</td>\n",
       "      <td>train</td>\n",
       "      <td>25</td>\n",
       "      <td>[141, 151): 'mid-Norway'</td>\n",
       "      <td>LOC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2022</th>\n",
       "      <td>17</td>\n",
       "      <td>train</td>\n",
       "      <td>29</td>\n",
       "      <td>[762, 774): 'Mark O'Meara'</td>\n",
       "      <td>PER</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1996</th>\n",
       "      <td>17</td>\n",
       "      <td>train</td>\n",
       "      <td>29</td>\n",
       "      <td>[454, 468): 'Phil Mickelson'</td>\n",
       "      <td>PER</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4416</th>\n",
       "      <td>1</td>\n",
       "      <td>train</td>\n",
       "      <td>943</td>\n",
       "      <td>[25, 46): 'SAN MARINO GRAND PRIX'</td>\n",
       "      <td>PER</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4461</th>\n",
       "      <td>1</td>\n",
       "      <td>train</td>\n",
       "      <td>944</td>\n",
       "      <td>[25, 32): 'MASTERS'</td>\n",
       "      <td>MISC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4462</th>\n",
       "      <td>1</td>\n",
       "      <td>train</td>\n",
       "      <td>944</td>\n",
       "      <td>[25, 32): 'MASTERS'</td>\n",
       "      <td>PER</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4463</th>\n",
       "      <td>1</td>\n",
       "      <td>train</td>\n",
       "      <td>944</td>\n",
       "      <td>[17, 32): 'BRITISH MASTERS'</td>\n",
       "      <td>LOC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4458</th>\n",
       "      <td>1</td>\n",
       "      <td>train</td>\n",
       "      <td>944</td>\n",
       "      <td>[11, 15): 'GOLF'</td>\n",
       "      <td>LOC</td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "      <td></td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5347 rows × 14 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      count   fold  doc_offset                         model_span  \\\n",
       "1738     17  train           3              [0, 10): '-DOCSTART-'   \n",
       "1485     17  train           6     [121, 137): 'Toronto Dominion'   \n",
       "1964     17  train          25           [141, 151): 'mid-Norway'   \n",
       "2022     17  train          29         [762, 774): 'Mark O'Meara'   \n",
       "1996     17  train          29       [454, 468): 'Phil Mickelson'   \n",
       "...     ...    ...         ...                                ...   \n",
       "4416      1  train         943  [25, 46): 'SAN MARINO GRAND PRIX'   \n",
       "4461      1  train         944                [25, 32): 'MASTERS'   \n",
       "4462      1  train         944                [25, 32): 'MASTERS'   \n",
       "4463      1  train         944        [17, 32): 'BRITISH MASTERS'   \n",
       "4458      1  train         944                   [11, 15): 'GOLF'   \n",
       "\n",
       "     model_ent_type error_type corpus_span corpus_ent_type correct_span  \\\n",
       "1738            LOC                                                       \n",
       "1485            LOC                                                       \n",
       "1964            LOC                                                       \n",
       "2022            PER                                                       \n",
       "1996            PER                                                       \n",
       "...             ...        ...         ...             ...          ...   \n",
       "4416            PER                                                       \n",
       "4461           MISC                                                       \n",
       "4462            PER                                                       \n",
       "4463            LOC                                                       \n",
       "4458            LOC                                                       \n",
       "\n",
       "     correct_ent_type notes time_started time_stopped time_elapsed  \n",
       "1738                                                                \n",
       "1485                                                                \n",
       "1964                                                                \n",
       "2022                                                                \n",
       "1996                                                                \n",
       "...               ...   ...          ...          ...          ...  \n",
       "4416                                                                \n",
       "4461                                                                \n",
       "4462                                                                \n",
       "4463                                                                \n",
       "4458                                                                \n",
       "\n",
       "[5347 rows x 14 columns]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "not_in_gold_to_write"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_gold_to_write.to_csv(\"outputs/CoNLL_4_train_in_gold.csv\", index=False)\n",
    "not_in_gold_to_write.to_csv(\"outputs/CoNLL_4_train_not_in_gold.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}