{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "ILsAojF_nXzT"
},
"source": [
"# Link to the lab\n",
"\n",
"https://tinyurl.com/inlplab5"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KVkKP3mNWP4c"
},
"source": [
"# Setup\n",
"\n",
"We'll use fasttext wiki embeddings in our embedding layer"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"collapsed": true,
"id": "shI-n-rp8nU2",
"jupyter": {
"outputs_hidden": true
},
"outputId": "077dc499-179c-488b-fa2d-d79c881c5d56",
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: fasttext in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (0.9.2)\n",
"Requirement already satisfied: pybind11>=2.2 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from fasttext) (2.10.0)\n",
"Requirement already satisfied: numpy in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from fasttext) (1.23.0)\n",
"Requirement already satisfied: setuptools>=0.7.0 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from fasttext) (61.2.0)\n",
"Requirement already satisfied: pytorch-crf in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (0.7.2)\n",
"Requirement already satisfied: datasets in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (2.4.0)\n",
"Requirement already satisfied: responses<0.19 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (0.18.0)\n",
"Requirement already satisfied: dill<0.3.6 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (0.3.5.1)\n",
"Requirement already satisfied: numpy>=1.17 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (1.23.0)\n",
"Requirement already satisfied: pyarrow>=6.0.0 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (9.0.0)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (4.64.0)\n",
"Requirement already satisfied: aiohttp in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (3.8.1)\n",
"Requirement already satisfied: multiprocess in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (0.70.13)\n",
"Requirement already satisfied: requests>=2.19.0 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (2.28.1)\n",
"Requirement already satisfied: pandas in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (1.4.3)\n",
"Requirement already satisfied: xxhash in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (3.0.0)\n",
"Requirement already satisfied: fsspec[http]>=2021.11.1 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (2022.7.1)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.1.0 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (0.9.1)\n",
"Requirement already satisfied: packaging in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from datasets) (21.3)\n",
"Requirement already satisfied: pyyaml>=5.1 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (6.0)\n",
"Requirement already satisfied: filelock in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (3.8.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from huggingface-hub<1.0.0,>=0.1.0->datasets) (4.3.0)\n",
"Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from packaging->datasets) (3.0.9)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (1.26.9)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (2022.6.15)\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (2.1.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from requests>=2.19.0->datasets) (3.3)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from aiohttp->datasets) (4.0.2)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from aiohttp->datasets) (1.2.0)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from aiohttp->datasets) (6.0.2)\n",
"Requirement already satisfied: attrs>=17.3.0 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from aiohttp->datasets) (21.4.0)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from aiohttp->datasets) (1.8.1)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: pytz>=2020.1 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from pandas->datasets) (2022.1)\n",
"Requirement already satisfied: python-dateutil>=2.8.1 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from pandas->datasets) (2.8.2)\n",
"Requirement already satisfied: six>=1.5 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n",
"Collecting sklearn\n",
" Using cached sklearn-0.0-py2.py3-none-any.whl\n",
"Collecting scikit-learn\n",
" Using cached scikit_learn-1.1.2-cp310-cp310-macosx_10_9_x86_64.whl (8.7 MB)\n",
"Collecting threadpoolctl>=2.0.0\n",
" Using cached threadpoolctl-3.1.0-py3-none-any.whl (14 kB)\n",
"Requirement already satisfied: scipy>=1.3.2 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from scikit-learn->sklearn) (1.9.0)\n",
"Requirement already satisfied: numpy>=1.17.3 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from scikit-learn->sklearn) (1.23.0)\n",
"Requirement already satisfied: joblib>=1.0.0 in /Users/knf792/miniconda3/envs/nlp-course/lib/python3.10/site-packages (from scikit-learn->sklearn) (1.1.0)\n",
"Installing collected packages: threadpoolctl, scikit-learn, sklearn\n",
"Successfully installed scikit-learn-1.1.2 sklearn-0.0 threadpoolctl-3.1.0\n"
]
}
],
"source": [
"!pip install fasttext\n",
"!pip install datasets\n",
"!pip install sklearn"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "qTxj2GUD86Mt"
},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"id": "K23XIfU19JC6"
},
"outputs": [],
"source": [
"import io\n",
"from math import log\n",
"from numpy import array\n",
"from numpy import argmax\n",
"import torch\n",
"import random\n",
"from math import log\n",
"from numpy import array\n",
"from numpy import argmax\n",
"import numpy as np\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torch import nn\n",
"from torch.optim import Adam\n",
"from torch.optim.lr_scheduler import ExponentialLR, CyclicLR\n",
"from typing import List, Tuple, AnyStr\n",
"from tqdm.notebook import tqdm\n",
"from sklearn.metrics import precision_recall_fscore_support\n",
"import matplotlib.pyplot as plt\n",
"from copy import deepcopy\n",
"from datasets import load_dataset\n",
"from sklearn.metrics import confusion_matrix\n",
"import torch.nn.functional as F\n",
"import heapq"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "1WG_TMG29Jkh"
},
"outputs": [],
"source": [
"def enforce_reproducibility(seed=42):\n",
" # Sets seed manually for both CPU and CUDA\n",
" torch.manual_seed(seed)\n",
" torch.cuda.manual_seed_all(seed)\n",
" # For atomic operations there is currently \n",
" # no simple way to enforce determinism, as\n",
" # the order of parallel operations is not known.\n",
" # CUDNN\n",
" torch.backends.cudnn.deterministic = True\n",
" torch.backends.cudnn.benchmark = False\n",
" # System based\n",
" random.seed(seed)\n",
" np.random.seed(seed)\n",
"\n",
"enforce_reproducibility()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y0-F6_Wb9Ams"
},
"source": [
"# Sequence Classification - recap\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kk7Nm4aD1Le_"
},
"source": [
"\n",
"Sequence classification is the task of \n",
"- predicting a class (e.g., POS tag) for each separate token in a textual input\n",
"- label tokens as beginning (B), inside (I), or outside (O) \n",
"- predicting which tokens from the input belong to a span, e.g.:\n",
" - which tokens from a document answer a given question (extractive QA)\n",
"\n",
" - which tokens in a news article contain propagandistic techniques\n",
"\n",
" - the spans can be of different types, e.g. type of a Named Entity (NE) -- Person, Location, Organisation\n",
" - ([More datasets for structured prediction](https://huggingface.co/datasets?languages=languages:en&task_categories=task_categories:structure-prediction&sort=downloads))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OHp8Z6Pc89h7"
},
"source": [
"## Named entity recognition"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "viwPhyqMaQhi"
},
"source": [
"\n",
"\n",
"- identify the **entities** that appear in a document and their types\n",
"- e.g., extract from the following sentence all names of the people, locations, and organizations:\n",
"\n",
"\n",
"
\n",
"\n",
" \n",
" Sundar \n",
" Pichai \n",
" is \n",
" the \n",
" CEO \n",
" of \n",
" Alphabet \n",
" , \n",
" located \n",
" in \n",
" Mountain \n",
" View \n",
" , \n",
" CA \n",
" \n",
" \n",
" PER \n",
" PER \n",
" O \n",
" O \n",
" O \n",
" O \n",
" ORG \n",
" O \n",
" O \n",
" O \n",
" LOC \n",
" LOC \n",
" LOC \n",
" LOC \n",
" \n",
" \n",
"
\n",
"\n",
"- we have labelled all of the tokens associate with their classes as the given type (PER: Person, ORG: Organization, LOC: Location, O: Outside). **Question: What are some issues that could arise as a result of this tagging?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bG7fTfhRdulS"
},
"source": [
"In practice, we will also want to denote which tokens are the beginning of an entity, and which tokens are inside the full entity span, giving the following tagging:\n",
"\n",
"\n",
"\n",
"\n",
" \n",
" Sundar \n",
" Pichai \n",
" is \n",
" the \n",
" CEO \n",
" of \n",
" Alphabet \n",
" , \n",
" located \n",
" in \n",
" Mountain \n",
" View \n",
" , \n",
" CA \n",
" \n",
" \n",
" B-PER \n",
" I-PER \n",
" O \n",
" O \n",
" O \n",
" O \n",
" B-ORG \n",
" O \n",
" O \n",
" O \n",
" B-LOC \n",
" I-LOC \n",
" I-LOC \n",
" I-LOC \n",
" \n",
" \n",
"
\n",
"\n",
"**Question: What are some other tagging schemes that you think could be good?**\n",
"\n",
"Modeling the dependencies between the predictions can be useful: for example knowing that the previous tag was `B-PER` influences whether or not the current tag will be `I-PER` or `O` or `I-LOC`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uOi3HWggedf9"
},
"source": [
"## Download and prepare the data\n",
"\n",
"We'll use a small set of Wikipedia data labelled with people, locations, organizations, and \"miscellaneous\" entities."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 569,
"referenced_widgets": [
"378fa5d2ca2d4005a0824d512c74bab9",
"054c24038d3e46559503463d23ddc389",
"d835d9ae913b4b8aa79195bbb65c67a2",
"100fee07db9a428ca681de6261a83220",
"878b7dbe92304f56be3f6cd519318522",
"d7fdc539d31c456498d6db7558c984fb",
"ab7ed4ca77bf4370820cdd932267885d",
"85b3a26f64e74fb299bea1ff292ec8c3",
"c537cce1751b48d482cfbcfc6611e64e",
"2b5d33ab96a746edbefa19cbbd3f28d8",
"7db97d142d2844fdb9b785cc7d9648f7",
"c7c6f6a5c86543f898e9598387d61437",
"8f4c1790ec104d208484a5e0b300d7bd",
"4373ceede55c499c9203d3fde6b31082",
"d12355d50fdc46f691212830c5510648",
"66fdff54c9fb4054b20461f69befb50a",
"e781457369704f12b9808104bcd8821f",
"f90f4b96b9c34ee79bf9db54d9376086",
"272da6e0838841a4a7107392d6e29f41",
"33640913b31b41a2a3e705cdee4e3324",
"f692ec41d40240508e620cc561355166",
"6ea12b0213b34770b91671d3ff7c90c9",
"8014b553c7214e85a760bdfdb56d20d5",
"bd909e355e85410683c23061f9e07518",
"ee98e3dec8e74ccda89c58ba82ad4f87",
"f01e0f0622a8403698d736e175cabd0f",
"a99fb612b1564315903fc6cac33259fb",
"520d60e185c7447f93e23c9accc27258",
"6b70e5760c6243248cb3c6d81723416c",
"58ee651ead9a408aae427ff74958ea4d",
"757d72ff8d77442687874f18efd8a31d",
"f3b08eaf8c754d5683a6442c29084d36",
"1c5b754a90354f739cb3e5660e1eeadf",
"be289fccdfa64228839b0b60774b29d9",
"aedc1a90c62f4385af8827f7203aaeba",
"eca645bca6f54b0580ad1d6fc3c02780",
"8f1fcc4069d34ac39fa705730cc62a3a",
"99f713f428d24c89a309b447bdc2cdf1",
"704acba671124ede9f4bcedfaa2217ed",
"9b55e7f888934a8cbb94e15946fd6653",
"d4052b6974ac4068a470650ce3863f94",
"4aa399acdf824ddc8f98d49ea633821e",
"d373c5ded26447bfb1a3ad0e86bb3720",
"76ce4bf8d41a4b51af901000791dbb76"
]
},
"id": "BoEPDmb6QTw5",
"outputId": "b4e2e71c-63f8-46f2-ae37-f4a3957850a1"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "90bc7b241ead44cb8b21bd98f95149dc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading builder script: 0%| | 0.00/2.58k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1267c3d89f2746778cad1cbc7d5548b8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading metadata: 0%| | 0.00/1.61k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading and preparing dataset conll2003/conll2003 (download: 959.94 KiB, generated: 9.78 MiB, post-processed: Unknown size, total: 10.72 MiB) to /Users/knf792/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6bbb7b119ba842e384b238c600a20fa8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading data: 0%| | 0.00/983k [00:00, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train split: 0%| | 0/14041 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating validation split: 0%| | 0/3250 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating test split: 0%| | 0/3453 [00:00, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset conll2003 downloaded and prepared to /Users/knf792/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/9a4d16a94f8674ba3466315300359b0acd891b68b6c8743ddf60b9c702adce98. Subsequent calls will reuse this data.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "79ea6d948af040d6b4245ecc090af3f3",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/3 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
" num_rows: 14041\n",
" })\n",
" validation: Dataset({\n",
" features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
" num_rows: 3250\n",
" })\n",
" test: Dataset({\n",
" features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
" num_rows: 3453\n",
" })\n",
"})"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"datasets = load_dataset(\"conll2003\")\n",
"datasets"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qU6qEpOTQknb",
"outputId": "294acb1c-fab4-40e4-fbc9-5e6067f9e2e9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset({\n",
" features: ['id', 'tokens', 'pos_tags', 'chunk_tags', 'ner_tags'],\n",
" num_rows: 14041\n",
"})\n",
"{'id': '0', 'tokens': ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.'], 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7], 'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0], 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}\n",
"Sequence(feature=ClassLabel(num_classes=9, names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'], id=None), length=-1, id=None)\n",
"['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']\n"
]
}
],
"source": [
"print(datasets['train'])\n",
"print(datasets['train'][0])\n",
"print(datasets[\"train\"].features[f\"ner_tags\"])\n",
"print(datasets[\"train\"].features[f\"ner_tags\"].feature.names)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pIZdSxg6fOHb"
},
"source": [
"We'll create the word embedding space:\n",
"- with FastText pretrained embeddings\n",
"- using all of the *vocabulary from the train and dev splits*, plus the most frequent tokens from the trained word embeddings. This will reduce the embeddings size!"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ABBIslsS9_1U",
"outputId": "d393b804-b108-4336-f916-b48bbc54e14b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-10-06 09:18:40-- https://dl.fbaipublicfiles.com/fasttext/vectors-english/wiki-news-300d-1M.vec.zip\n",
"Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 172.67.9.4, 104.22.74.142, 104.22.75.142\n",
"Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|172.67.9.4|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 681808098 (650M) [application/zip]\n",
"Saving to: ‘wiki-news-300d-1M.vec.zip’\n",
"\n",
"wiki-news-300d-1M.v 100%[===================>] 650,22M 11,9MB/s in 49s \n",
"\n",
"2022-10-06 09:19:30 (13,3 MB/s) - ‘wiki-news-300d-1M.vec.zip’ saved [681808098/681808098]\n",
"\n",
"Archive: wiki-news-300d-1M.vec.zip\n",
" inflating: wiki-news-300d-1M.vec \n"
]
}
],
"source": [
"!wget https://dl.fbaipublicfiles.com/fasttext/vectors-english/wiki-news-300d-1M.vec.zip\n",
"!unzip wiki-news-300d-1M.vec.zip"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "AU1ldp1VArxU"
},
"outputs": [],
"source": [
"# Reduce down to our vocabulary and word embeddings\n",
"def load_vectors(fname, vocabulary):\n",
" fin = io.open(fname, 'r', encoding='utf-8', newline='\\n', errors='ignore')\n",
" n, d = map(int, fin.readline().split())\n",
" tag_names = datasets[\"train\"].features[f\"ner_tags\"].feature.names\n",
" final_vocab = tag_names + ['[PAD]', '[UNK]', '[BOS]', '[EOS]']\n",
" final_vectors = [np.random.normal(size=(300,)) for _ in range(len(final_vocab))]\n",
" for j,line in enumerate(fin):\n",
" tokens = line.rstrip().split(' ')\n",
" if tokens[0] in vocabulary or len(final_vocab) < 30000:\n",
" final_vocab.append(tokens[0])\n",
" final_vectors.append(np.array(list(map(float, tokens[1:]))))\n",
" return final_vocab, np.vstack(final_vectors)\n",
"\n",
"class FasttextTokenizer:\n",
" def __init__(self, vocabulary):\n",
" self.vocab = {}\n",
" for j,l in enumerate(vocabulary):\n",
" self.vocab[l.strip()] = j\n",
"\n",
" def encode(self, text):\n",
" # Text is assumed to be tokenized\n",
" return [self.vocab[t] if t in self.vocab else self.vocab['[UNK]'] for t in text]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4OgHYnqV-CzF",
"outputId": "cead7528-df83-4a52-8261-ba9f26ff78c3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"size of vocabulary: 40630\n"
]
}
],
"source": [
"vocabulary = (set([t for s in datasets['train'] for t in s['tokens']]) | set([t for s in datasets['validation'] for t in s['tokens']]))\n",
"vocabulary, pretrained_embeddings = load_vectors('wiki-news-300d-1M.vec', vocabulary)\n",
"print('size of vocabulary: ', len(vocabulary))\n",
"tokenizer = FasttextTokenizer(vocabulary)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "03GlNGdjffvJ"
},
"source": [
"The main difference in the dataset reading and collation functions is that we now return a sequence of labels instead of a single label as in text classification."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"id": "DDNdg8kNCYxa"
},
"outputs": [],
"source": [
"def collate_batch_bilstm(input_data: Tuple) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n",
" input_ids = [tokenizer.encode(i['tokens']) for i in input_data]\n",
" seq_lens = [len(i) for i in input_ids]\n",
" labels = [i['ner_tags'] for i in input_data]\n",
"\n",
" max_length = max([len(i) for i in input_ids])\n",
"\n",
" input_ids = [(i + [0] * (max_length - len(i))) for i in input_ids]\n",
" labels = [(i + [0] * (max_length - len(i))) for i in labels] # 0 is the id of the O tag\n",
"\n",
" assert (all(len(i) == max_length for i in input_ids))\n",
" assert (all(len(i) == max_length for i in labels))\n",
" return torch.tensor(input_ids), torch.tensor(seq_lens), torch.tensor(labels)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "EAJsgXF_IZUQ",
"outputId": "b25fbeae-e7f8-4794-89a0-72b2f9bb8063"
},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[36231, 48, 10, 33561, 30770, 8120, 31121, 21803, 10, 36750,\n",
" 15]]),\n",
" tensor([11]),\n",
" tensor([[0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0]]))"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dev_dl = DataLoader(datasets['validation'], batch_size=1, shuffle=False, collate_fn=collate_batch_bilstm, num_workers=0)\n",
"next(iter(dev_dl))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AIHamPlSIgS5",
"outputId": "549bf33b-8cc7-4720-c255-5a877992f3d3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'id': '0', 'tokens': ['CRICKET', '-', 'LEICESTERSHIRE', 'TAKE', 'OVER', 'AT', 'TOP', 'AFTER', 'INNINGS', 'VICTORY', '.'], 'pos_tags': [22, 8, 22, 22, 15, 22, 22, 22, 22, 21, 7], 'chunk_tags': [11, 0, 11, 12, 13, 11, 12, 12, 12, 12, 0], 'ner_tags': [0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0]}\n",
"(tensor([[36231, 48, 10, 33561, 30770, 8120, 31121, 21803, 10, 36750,\n",
" 15]]), tensor([11]), tensor([[0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0]]))\n"
]
}
],
"source": [
"print(datasets['validation'][0])\n",
"print(collate_batch_bilstm([datasets['validation'][0]]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oo6sp4It9Txz"
},
"source": [
"# Creating the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cQIYJ0Q_gILv"
},
"source": [
"## LSTM model for sequence classification\n",
"\n",
"You'll notice that the BiLSTM model is mostly the same from the text classification and language modeling labs. "
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "nVsVJgToVrdz"
},
"outputs": [],
"source": [
"# Define the model\n",
"class BiLSTM(nn.Module):\n",
" \"\"\"\n",
" Basic BiLSTM-CRF network\n",
" \"\"\"\n",
" def __init__(\n",
" self,\n",
" pretrained_embeddings: torch.tensor,\n",
" lstm_dim: int,\n",
" dropout_prob: float = 0.1,\n",
" n_classes: int = 2\n",
" ):\n",
" \"\"\"\n",
" Initializer for basic BiLSTM network\n",
" :param pretrained_embeddings: A tensor containing the pretrained BPE embeddings\n",
" :param lstm_dim: The dimensionality of the BiLSTM network\n",
" :param dropout_prob: Dropout probability\n",
" :param n_classes: The number of output classes\n",
" \"\"\"\n",
"\n",
" # First thing is to call the superclass initializer\n",
" super(BiLSTM, self).__init__()\n",
"\n",
" # We'll define the network in a ModuleDict, which makes organizing the model a bit nicer\n",
" # The components are an embedding layer, a 2 layer BiLSTM, and a feed-forward output layer\n",
" self.model = nn.ModuleDict({\n",
" 'embeddings': nn.Embedding.from_pretrained(pretrained_embeddings, padding_idx=pretrained_embeddings.shape[0] - 1),\n",
" 'bilstm': nn.LSTM(\n",
" pretrained_embeddings.shape[1], # input size\n",
" lstm_dim, # hidden size\n",
" 2, # number of layers\n",
" batch_first=True,\n",
" dropout=dropout_prob,\n",
" bidirectional=True),\n",
" 'ff': nn.Linear(2*lstm_dim, n_classes),\n",
" })\n",
" self.n_classes = n_classes\n",
" self.loss = nn.CrossEntropyLoss()\n",
" # Initialize the weights of the model\n",
" self._init_weights()\n",
"\n",
" def _init_weights(self):\n",
" all_params = list(self.model['bilstm'].named_parameters()) + \\\n",
" list(self.model['ff'].named_parameters())\n",
" for n,p in all_params:\n",
" if 'weight' in n:\n",
" nn.init.xavier_normal_(p)\n",
" elif 'bias' in n:\n",
" nn.init.zeros_(p)\n",
"\n",
" def forward(self, inputs, input_lens, hidden_states = None, labels = None):\n",
" \"\"\"\n",
" Defines how tensors flow through the model\n",
" :param inputs: (b x sl) The IDs into the vocabulary of the input samples\n",
" :param input_lens: (b) The length of each input sequence\n",
" :param labels: (b) The label of each sample\n",
" :return: (loss, logits) if `labels` is not None, otherwise just (logits,)\n",
" \"\"\"\n",
"\n",
" # Get embeddings (b x sl x edim)\n",
" embeds = self.model['embeddings'](inputs)\n",
"\n",
" # Pack padded: This is necessary for padded batches input to an RNN - https://stackoverflow.com/questions/51030782/why-do-we-pack-the-sequences-in-pytorch\n",
" lstm_in = nn.utils.rnn.pack_padded_sequence(\n",
" embeds,\n",
" input_lens.cpu(),\n",
" batch_first=True,\n",
" enforce_sorted=False\n",
" )\n",
"\n",
" # Pass the packed sequence through the BiLSTM\n",
" if hidden_states:\n",
" lstm_out, hidden = self.model['bilstm'](lstm_in, hidden_states)\n",
" else:\n",
" lstm_out, hidden = self.model['bilstm'](lstm_in)\n",
"\n",
" # Unpack the packed sequence --> (b x sl x 2*lstm_dim)\n",
" lstm_out, lengths = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)\n",
"\n",
" # Get logits (b x seq_len x n_classes)\n",
" logits = self.model['ff'](lstm_out)\n",
" outputs = (logits, lengths)\n",
" if labels is not None:\n",
" loss = self.loss(logits.reshape(-1, self.n_classes), labels.reshape(-1))\n",
" outputs = outputs + (loss,)\n",
"\n",
" return outputs"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"id": "oH_92rb8VvEd"
},
"outputs": [],
"source": [
"def train(\n",
" model: nn.Module, \n",
" train_dl: DataLoader, \n",
" valid_dl: DataLoader, \n",
" optimizer: torch.optim.Optimizer, \n",
" n_epochs: int, \n",
" device: torch.device,\n",
" scheduler=None,\n",
"):\n",
" \"\"\"\n",
" The main training loop which will optimize a given model on a given dataset\n",
" :param model: The model being optimized\n",
" :param train_dl: The training dataset\n",
" :param valid_dl: A validation dataset\n",
" :param optimizer: The optimizer used to update the model parameters\n",
" :param n_epochs: Number of epochs to train for\n",
" :param device: The device to train on\n",
" :return: (model, losses) The best model and the losses per iteration\n",
" \"\"\"\n",
"\n",
" # Keep track of the loss and best accuracy\n",
" losses = []\n",
" learning_rates = []\n",
" best_f1 = 0.0\n",
"\n",
" # Iterate through epochs\n",
" for ep in range(n_epochs):\n",
"\n",
" loss_epoch = []\n",
"\n",
" #Iterate through each batch in the dataloader\n",
" for batch in tqdm(train_dl):\n",
" # VERY IMPORTANT: Make sure the model is in training mode, which turns on \n",
" # things like dropout and layer normalization\n",
" model.train()\n",
"\n",
" # VERY IMPORTANT: zero out all of the gradients on each iteration -- PyTorch\n",
" # keeps track of these dynamically in its computation graph so you need to explicitly\n",
" # zero them out\n",
" optimizer.zero_grad()\n",
"\n",
" # Place each tensor on the GPU\n",
" batch = tuple(t.to(device) for t in batch)\n",
" input_ids = batch[0]\n",
" seq_lens = batch[1]\n",
" labels = batch[2]\n",
"\n",
" # Pass the inputs through the model, get the current loss and logits\n",
" logits, lengths, loss = model(input_ids, seq_lens, labels=labels)\n",
" losses.append(loss.item())\n",
" loss_epoch.append(loss.item())\n",
"\n",
" # Calculate all of the gradients and weight updates for the model\n",
" loss.backward()\n",
"\n",
" # Optional: clip gradients\n",
" #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
"\n",
" # Finally, update the weights of the model\n",
" optimizer.step()\n",
" if scheduler != None:\n",
" scheduler.step()\n",
" learning_rates.append(scheduler.get_last_lr()[0])\n",
"\n",
" # Perform inline evaluation at the end of the epoch\n",
" f1 = evaluate(model, valid_dl)\n",
" print(f'Validation F1: {f1}, train loss: {sum(loss_epoch) / len(loss_epoch)}')\n",
"\n",
" # Keep track of the best model based on the accuracy\n",
" if f1 > best_f1:\n",
" torch.save(model.state_dict(), 'best_model')\n",
" best_f1 = f1\n",
"\n",
" return losses, learning_rates"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"id": "LQkyUeyhV1D3"
},
"outputs": [],
"source": [
"def evaluate(model: nn.Module, valid_dl: DataLoader):\n",
" \"\"\"\n",
" Evaluates the model on the given dataset\n",
" :param model: The model under evaluation\n",
" :param valid_dl: A `DataLoader` reading validation data\n",
" :return: The accuracy of the model on the dataset\n",
" \"\"\"\n",
" # VERY IMPORTANT: Put your model in \"eval\" mode -- this disables things like \n",
" # layer normalization and dropout\n",
" model.eval()\n",
" labels_all = []\n",
" preds_all = []\n",
"\n",
" # ALSO IMPORTANT: Don't accumulate gradients during this process\n",
" with torch.no_grad():\n",
" for batch in tqdm(valid_dl, desc='Evaluation'):\n",
" batch = tuple(t.to(device) for t in batch)\n",
" input_ids = batch[0]\n",
" seq_lens = batch[1]\n",
" labels = batch[2]\n",
" hidden_states = None\n",
"\n",
" logits, _, _ = model(input_ids, seq_lens, hidden_states=hidden_states, labels=labels)\n",
" preds_all.extend(torch.argmax(logits, dim=-1).reshape(-1).detach().cpu().numpy())\n",
" labels_all.extend(labels.reshape(-1).detach().cpu().numpy())\n",
"\n",
" P, R, F1, _ = precision_recall_fscore_support(labels_all, preds_all, average='macro')\n",
" print(confusion_matrix(labels_all, preds_all))\n",
" return F1"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"id": "ycIjTfhBZGNJ"
},
"outputs": [],
"source": [
"lstm_dim = 128\n",
"dropout_prob = 0.1\n",
"batch_size = 8\n",
"lr = 1e-2\n",
"n_epochs = 10\n",
"n_workers = 0 # set to a larger number if you run your code in colab\n",
"\n",
"device = torch.device(\"cpu\")\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
"\n",
"# Create the model\n",
"model = BiLSTM(\n",
" pretrained_embeddings=torch.FloatTensor(pretrained_embeddings), \n",
" lstm_dim=lstm_dim, \n",
" dropout_prob=dropout_prob, \n",
" n_classes=len(datasets[\"train\"].features[f\"ner_tags\"].feature.names)\n",
" ).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000,
"referenced_widgets": [
"afc43fa359df4352ba07ba5e5d00054d",
"95164b73146c4484873e302a514db3cb",
"00fdcbd96cb94d3eb4ebce2dd65453df",
"69dac05df03a4505bf57ddee27c052c7",
"5230116d8f9d40fbb22be74eb6bddba8",
"04e7a1a51fc843a28998e9c233eb6b53",
"ed85ba97e4db4aa4870c97fb5dae9320",
"032c119a998a43efa68e584df3832d19",
"7afee097ce0d4179a71382e1a770347d",
"bba4178aa4d44e9d91190d821d05b7ee",
"2614645b079b49bea5d51d1f8d99a38f",
"f01071efe73b4151a3b75e73dd249c48",
"b314e1ca91b1434e951253e9bebf6b0b",
"0258bc4f1039499f8d27f95ea6881d4b",
"470e65e908714a76a9a1aaeb9b889b62",
"ca708c2a421b4dd9a744970e3d9ae610",
"70a96efaef804fd8ad5ce058b130a5d8",
"0dedc3af2f79435a8f866cd47b0258ff",
"47d6e470ee324f50aa0b2ff7539a5308",
"1d004bcd0131429296f108282d445751",
"5e341b1c7e5346caad69ba22a54ec4dc",
"dc5d256cae234801877fb9ce3ca792cf",
"8ce8e901bde44dd6bf7de4ca40f35a5b",
"e1c62c4d464b451ebf33b77b2e8dbba6",
"931f034c69de4c379d851d93eb1a9084",
"ddd1301e94394e62bacf916ad4880292",
"bfcbe3d471ad412ca4161ee4fbb4125d",
"485299f0fd8e4f5ba78bd3d109366717",
"a6d38f7eb3cf43b58c5335f3b5557ff6",
"969beef0e2184558a1c741abc45e021b"
]
},
"id": "xOdf3IBNV4hx",
"outputId": "5b93d989-a102-4195-d7ec-88af05f514a9"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5f809429d2864a5583df712280a659ae",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1756 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "40b91a1be31a49e1908d48f429ed5b43",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[345419 47 10 49 52 14 2 41 13]\n",
" [ 23 1789 5 1 0 20 0 3 1]\n",
" [ 25 16 1260 0 4 0 0 1 1]\n",
" [ 37 82 0 1135 27 44 0 16 0]\n",
" [ 32 5 11 8 678 2 10 1 4]\n",
" [ 25 12 0 34 8 1741 3 14 0]\n",
" [ 6 0 0 0 37 2 211 0 1]\n",
" [ 65 17 0 34 3 22 1 775 5]\n",
" [ 47 1 5 0 51 1 5 16 220]]\n",
"Validation F1: 0.8935076155280685, train loss: 0.14170506390146825\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9905795cb97e465db4761a99e8e9c8fa",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1756 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9c73b8e5df714dde91128829b89940bc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[345436 62 2 45 22 11 2 33 34]\n",
" [ 31 1781 5 4 0 14 0 6 1]\n",
" [ 25 22 1249 2 1 0 2 1 5]\n",
" [ 23 56 1 1191 21 22 0 27 0]\n",
" [ 49 5 15 7 613 11 13 3 35]\n",
" [ 25 13 0 50 0 1727 1 21 0]\n",
" [ 11 0 3 1 12 7 211 1 11]\n",
" [ 43 17 0 35 1 6 0 813 7]\n",
" [ 27 2 5 0 16 2 0 18 276]]\n",
"Validation F1: 0.9018831260482847, train loss: 0.02657376601092108\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "db0a329c85724a348955cac7f3f17ea9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1756 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e451fda0ea3e4d54a344803b78fb0f82",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[345456 29 14 48 20 20 0 37 23]\n",
" [ 27 1752 18 20 1 20 0 4 0]\n",
" [ 17 9 1273 1 0 0 1 1 5]\n",
" [ 38 23 1 1216 15 23 0 25 0]\n",
" [ 36 1 10 8 659 2 10 0 25]\n",
" [ 12 8 0 62 3 1727 4 20 1]\n",
" [ 8 0 0 1 23 3 214 2 6]\n",
" [ 41 12 0 47 3 8 0 793 18]\n",
" [ 27 1 6 1 19 2 0 7 283]]\n",
"Validation F1: 0.9104821085939905, train loss: 0.017478332217572713\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "168d2049bcfe4ccea1e2112caa756dae",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1756 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c55f08461a6c4e77aacecf939756973b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[345508 31 3 37 10 10 1 23 24]\n",
" [ 23 1794 7 7 0 7 0 3 1]\n",
" [ 21 11 1272 0 0 0 0 1 2]\n",
" [ 50 50 0 1190 12 19 0 19 1]\n",
" [ 50 2 25 12 628 2 10 2 20]\n",
" [ 29 20 0 51 4 1695 1 35 2]\n",
" [ 11 0 6 0 14 3 212 1 10]\n",
" [ 75 12 0 28 1 6 0 795 5]\n",
" [ 32 1 5 0 12 2 2 15 277]]\n",
"Validation F1: 0.9107918478008801, train loss: 0.012492643622691443\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a813eafb1e7440f180c043c92b3abcf6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1756 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cafc279d8b75492a84aa58268cf67bc7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[345519 20 2 40 25 10 0 21 10]\n",
" [ 29 1776 6 13 2 13 0 2 1]\n",
" [ 25 11 1268 0 1 0 0 1 1]\n",
" [ 37 19 1 1222 12 30 0 20 0]\n",
" [ 49 0 9 11 661 6 5 1 9]\n",
" [ 14 7 0 36 3 1761 1 14 1]\n",
" [ 11 0 0 0 19 2 221 0 4]\n",
" [ 62 10 0 33 3 8 0 801 5]\n",
" [ 30 1 5 3 24 2 1 12 268]]\n",
"Validation F1: 0.9245278623743088, train loss: 0.008772710909356139\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ce55418c435649b0b711e3eab634cfee",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1756 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4c19c8b62e7b4889bea4d7f06fea8aa4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[345475 33 2 25 41 20 1 36 14]\n",
" [ 25 1782 3 4 2 21 0 4 1]\n",
" [ 21 16 1260 1 3 0 2 2 2]\n",
" [ 40 38 0 1170 30 34 1 28 0]\n",
" [ 34 0 10 3 674 9 7 0 14]\n",
" [ 16 7 0 28 3 1757 3 21 2]\n",
" [ 10 0 3 0 14 3 225 0 2]\n",
" [ 62 13 0 26 3 14 0 792 12]\n",
" [ 30 1 6 2 16 1 5 20 265]]\n",
"Validation F1: 0.9140443357471715, train loss: 0.00545991696425629\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "100b60679e3e4ec296f59634506868c8",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1756 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2aeff848b84b49fd9e48539bab9079ee",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[345466 14 2 56 29 14 2 40 24]\n",
" [ 25 1767 2 23 1 18 1 4 1]\n",
" [ 21 10 1261 1 6 0 4 1 3]\n",
" [ 29 21 0 1224 16 24 1 26 0]\n",
" [ 34 1 7 9 677 5 6 0 12]\n",
" [ 7 5 0 46 2 1758 2 16 1]\n",
" [ 9 0 0 0 13 3 227 0 5]\n",
" [ 47 8 0 39 3 8 0 811 6]\n",
" [ 28 2 5 2 19 2 3 11 274]]\n",
"Validation F1: 0.9206964790962123, train loss: 0.0035528619892963883\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "77a2c7abe0de4abf9ac555ae74887f9a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1756 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3b221478ebb549c6aa349c3ac900d14f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[345500 35 3 28 10 19 2 29 21]\n",
" [ 19 1794 2 7 1 14 0 4 1]\n",
" [ 22 12 1262 1 5 0 2 1 2]\n",
" [ 39 43 0 1201 10 28 1 18 1]\n",
" [ 44 1 12 7 650 6 12 1 18]\n",
" [ 10 7 0 41 2 1762 2 12 1]\n",
" [ 9 0 0 0 7 2 235 0 4]\n",
" [ 55 16 0 27 3 13 0 801 7]\n",
" [ 29 1 4 1 9 2 8 13 279]]\n",
"Validation F1: 0.9228745320291222, train loss: 0.0022250179169699172\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "74b69543d1f4401381a3d4ae19aa46d7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1756 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a7aa4d79afe240c1bbf648e6e6b4421f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[345500 28 2 41 23 10 2 23 18]\n",
" [ 21 1784 2 15 1 13 0 5 1]\n",
" [ 21 10 1262 2 6 0 1 1 4]\n",
" [ 28 36 1 1219 13 26 0 18 0]\n",
" [ 39 1 12 6 667 6 9 1 10]\n",
" [ 5 7 0 46 3 1761 2 13 0]\n",
" [ 11 0 0 0 10 2 232 0 2]\n",
" [ 53 11 0 31 3 10 0 808 6]\n",
" [ 30 0 4 1 21 1 3 14 272]]\n",
"Validation F1: 0.9249685641215538, train loss: 0.0012310007892623906\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e2c9529840fd4509bb55be62b7e30bbf",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1756 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a368684d8f064cb2aa64e61d8d87c937",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[345508 24 2 37 19 13 2 24 18]\n",
" [ 24 1780 2 14 1 15 0 5 1]\n",
" [ 20 9 1264 2 5 0 2 1 4]\n",
" [ 31 35 0 1220 12 26 1 16 0]\n",
" [ 41 0 12 8 665 7 7 1 10]\n",
" [ 7 4 0 43 3 1764 2 14 0]\n",
" [ 10 0 0 0 10 2 232 0 3]\n",
" [ 53 11 0 31 3 10 0 808 6]\n",
" [ 29 0 4 1 18 1 4 18 271]]\n",
"Validation F1: 0.925124536751276, train loss: 0.0008371544579220439\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_dl = DataLoader(datasets['train'], batch_size=batch_size, shuffle=True, collate_fn=collate_batch_bilstm, num_workers=n_workers)\n",
"valid_dl = DataLoader(datasets['validation'], batch_size=len(datasets['validation']), collate_fn=collate_batch_bilstm, num_workers=n_workers)\n",
"\n",
"# Create the optimizer\n",
"optimizer = Adam(model.parameters(), lr=lr)\n",
"scheduler = CyclicLR(optimizer, base_lr=0., max_lr=lr, step_size_up=1, step_size_down=len(train_dl)*n_epochs, cycle_momentum=False)\n",
"\n",
"# Train\n",
"losses, learning_rates = train(model, train_dl, valid_dl, optimizer, n_epochs, device, scheduler)\n",
"model.load_state_dict(torch.load('best_model'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Learning rate schedules"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Motivation: \n",
"- speed up training\n",
"- to train a better model\n",
"\n",
"With Pytorch:\n",
"- choose a learning rate schedulers form `torch.optim.lr_schedule`\n",
"- add a line in your training loop which calls the `step()` function of your scheduler\n",
"- this will automatically change your learning rate! \n",
"- **Note**: be aware of when to call `step()`; some schedulers change the learning rate after every epoch, and some change after every training step (batch). The one we will use here changes the learning rate after every training step. We'll define the scheduler in the cell that calls the `train()` function. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set up hyperparameters and create the model. Note the high learning rate -- this is partially due to the learning rate scheduler we will use."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Above we have used the `CyclicLR` scheduler. The cyclic learning rate schedule in general looks like this:\n",
"\n",
" [Source](https://arxiv.org/pdf/1506.01186.pdf)\n",
"\n",
"We are using it here to linearly decay the learning rate from a starting max learning rate (here 1e-2) down to 0 over the entire course of training (essentially one cycle that starts at the max and ends at 0). \n",
"\n",
"\" Allowing the learning rate to rise and fall is beneficial overall\n",
"even though it might temporarily harm the network’s performance\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAe8ElEQVR4nO3deZwU9Z3/8deHQzDxFjYSPFAX9aHrESSK2ejPJB6IBzGJ12a9kl03UX+rcXNgvIjJqokbk3giiTcqeKDgiiKeiOEahuEWGK5hLhhmYJhhGOb67h9d3dPdU33MTPdMV/t+Ph7zsLuqpvrTNfLu6k99q8qcc4iISPD16e0CREQkMxToIiJ5QoEuIpInFOgiInlCgS4ikif69dYLDxo0yA0bNqy3Xl5EJJAWLVq0zTk32G9erwX6sGHDKCgo6K2XFxEJJDPblGieWi4iInlCgS4ikicU6CIieUKBLiKSJxToIiJ5QoEuIpInFOgiInkicIG+urKO/5m5mppdTb1diohITglcoK+rqufRj4rZsrOxt0sREckpgQv0PmYA6L4cIiKxAhfoIiLiT4EuIpInAhvoDvVcRESiBS7QvRa6iIjECVygi4iIPwW6iEieCGyga9iiiEiswAW6WugiIv4CF+giIuJPgS4ikidSBrqZHWZmH5nZSjNbYWa3+CxjZvawmRWb2VIzG5GdckVEJJF+aSzTAvyXc67QzPYFFpnZLOfcyqhlLgCGez+nA094/80400B0ERFfKffQnXMVzrlC73EdsAoYGrfYWOB5FzIPOMDMhmS8WhERSahTPXQzGwZ8DZgfN2sosDnqeSkdQx8zu8HMCsysoKqqqnOViohIUmkHupntA7wO3Oqc29mVF3POTXTOjXTOjRw8eHBXVhG1rm79uohI3kkr0M2sP6Ewf9E5N9VnkTLgsKjnh3rTMk4ddBERf+mMcjHgKWCVc+6hBItNB67xRruMAmqdcxUZrFNERFJIZ5TLPwNXA8vMrMib9mvgcADn3ARgBjAGKAYagOszX6qIiCSTMtCdc3NI0elwzjngpkwVlQ5dD11EJFbgzhTVMHQREX+BC3QREfGnQBcRyROBDXSNQxcRiRXYQBcRkViBC3QdFBUR8Re4QBcREX8KdBGRPBHYQNcxURGRWIELdNPluUREfAUu0EVExJ8CXUQkTwQ20J3OLBIRiRG8QFcLXUTEV/ACXUREfCnQRUTyRGADXR10EZFYgQt0tdBFRPwFLtBFRMSfAl1EJE8ENtA1DF1EJFbgAt10QXQREV+BC3QREfGnQBcRyRMBDnQ10UVEogUu0NVBFxHxF7hAFxERfwp0EZE8EdhA1zh0EZFYgQt0DUMXEfEXuEAXERF/CnQRkTwR2EBXC11EJFbgAt00El1ExFfgAl1ERPwp0EVE8kRgA13j0EVEYqUMdDN72sy2mtnyBPPPNrNaMyvyfu7OfJnRr5fNtYuIBFe/NJZ5FngUeD7JMp865y7KSEUiItIlKffQnXOzgZoeqEVERLohUz30M8xsiZm9Y2YnJFrIzG4wswIzK6iqqurWCzo10UVEYmQi0AuBI5xzJwOPAG8mWtA5N9E5N9I5N3Lw4MFdejG10EVE/HU70J1zO51z9d7jGUB/MxvU7cpERKRTuh3oZnaIWWjsiZmd5q2zurvrFRGRzkk5ysXMXgbOBgaZWSlwD9AfwDk3AfgB8FMzawF2A1e6Hmhwq4MuIhIrZaA7565KMf9RQsMae4aa6CIivgJ7pqiIiMRSoIuI5InABrqGoYuIxApcoOt66CIi/gIX6CIi4k+BLiKSJxToIiJ5IrCB7nRqkYhIjMAFum5wISLiL3CBLiIi/hToIiJ5IriBrha6iEiMwAW6WugiIv4CF+giIuJPgS4ikicCG+hqoYuIxApcoHt3u2Nb/Z5erkREJLcELtBrdzcDcMvkol6uREQktwQu0Fvb2nq7BBGRnBS4QBcREX8BDHSNRBcR8RPAQNf4FhERPwEMdBER8RPAQFfLRUTET+ACvX5PS2+XICKSkwIX6IUl23u7BBGRnBS4QBcREX+BC3R10EVE/AUu0PvopqIiIr4CF+giIuJPgS4ikicCF+jquIiI+AtcoKuHLiLiL3CBLiIi/hToIiJ5QoEuIpInUga6mT1tZlvNbHmC+WZmD5tZsZktNbMRmS+z3Yefb83m6kVEAiudPfRngdFJ5l8ADPd+bgCe6H5ZidXsasrm6kVEAitloDvnZgM1SRYZCzzvQuYBB5jZkEwVGE9jXERE/GWihz4U2Bz1vNSblh1KdBERXz16UNTMbjCzAjMrqKqq6smXFhHJe5kI9DLgsKjnh3rTOnDOTXTOjXTOjRw8eHAGXlpERMIyEejTgWu80S6jgFrnXEUG1utLHRcREX/9Ui1gZi8DZwODzKwUuAfoD+CcmwDMAMYAxUADcH22igXYb+/+7GzUbehEROKlDHTn3FUp5jvgpoxVlML3vjaUhz8s7qmXExEJjMCdKdqnj5ouIiJ+AhfoIiLiL3CBrsvnioj4C1ygK85FRPwFLtBFRMRf4AJdHRcREX8BDHQluoiIn8AF+r4DUw6dFxH5QgpcoA/eZ0BvlyAikpMCF+giIuIvcIGuFrqIiL/ABbqIiPgLYKBrF11ExE8AA11ERPwELtDVQxcR8Re4QBcREX8KdBGRPBG4QNflc0VE/AUu0BXnIiL+Ahfo0VrbXG+XICKSMwId6NX1e3q7BBGRnBG4QFcLXUTEnwJdRCRPBC7QRUTEnwJdRCRPBC7QTQMXRUR8BS7QYyjbRUQigh3oGoYuIhIR7EAXEZGI4AW62iwiIr6CF+giIuJLgS4ikieCHehqv4iIRAQu0KMz/LT//oAdDU29VouISC4JXKDH27BtV2+XICKSEwIX6Karc4mI+ApeoMc/V8CLiABpBrqZjTaz1WZWbGbjfOZfZ2ZVZlbk/fxb5kv1d+OkRdTubu6plxMRyVkpA93M+gKPARcAxwNXmdnxPotOcc6d4v38LcN1JlRe28hL80t66uVERHJWOnvopwHFzrn1zrkmYDIwNrtliYhIZ6UT6EOBzVHPS71p8b5vZkvN7DUzO8xvRWZ2g5kVmFlBVVVVF8rVHYtERBLJ1EHRt4BhzrmTgFnAc34LOecmOudGOudGDh48OEMvnV2zVm5h+y6NdReR3JdOoJcB0Xvch3rTIpxz1c65Pd7TvwGnZqa8rpu/vpqfTSnCua5fY3f7rib+/fkC/v35ggxWJiKSHekE+kJguJkdaWZ7AVcC06MXMLMhUU8vAVZlrsSuufqpBbyxuIym1rYur6PZ+91NNQ2ZKktEJGv6pVrAOddiZjcDM4G+wNPOuRVmdi9Q4JybDvynmV0CtAA1wHXZKli3oBMR8Zcy0AGcczOAGXHT7o56fDtwe2ZLyx3d6NqIiPSYwJ0p6sd18150y0praWvzWYe+DIhIgAQu0DM9bLFgYw0XPzqHJ2evz+yKRUR6WOACvbNeml/Cwo01CeeX7dgNwKqKnUnWop6LiOS+vAj0P7y7OuG837y1kssmzE04P1l/XAdgRSRI8iLQ/aTbVw8v59fK6W5vXkSkJwUu0DO9zxzeQ0++Xu2pi0juC1ygZ1ok0JMebdWeuojkvrwM9G31e1IvFMcvztVDF5EgSevEolwy/Cv7+k53zmFmLNq0ne8/8fe016d9bxHJF4HbQx+87wDf6a8UhK7wuzLp8MOOXBpN9J2NLcxbX92p9SbS1uaYVlRGq9+JTCIi3RC4QE/krjdXJJ1/0viZkdCPFo5Vv/ZKuK3e1NLGlRPnUdvgf6u7I29/m9teKUqrztcKS7llchHPfLYhreVFRNKVN4Ge6qqKOxtbGD+9PfT3tLQyc0UlE70zRH2HLcbtRCd6DedgamGZ77x41fWha6tX1XW+zy8ikkzgeuippHMYc8GGGi5/MvZko546/LmjIRTo1bpphohkWN7soQMp+9J9vN1wv3643x56V64b09zaRkuSbwvha8a8tqiUleWd6/eLiCSTV4H+j3fMYG6Sg5fJ8jmdIYrpBPzxd7/LNx74MPWCQHFVfVrLiYikI68C3Tl4e2lFwvl1e1pYXVlHY3Nrl1+jaPMO/vz+moTzm1sdWzPYH3/kg7VMK2rvz3+6toqmlq7fhSmZ5WW1TJq3KSvrFpHsy7seeirn/3m27/Q+Ph9t8TvkTS1tfPexzwC49ZxjMlyZvz/OCn14jD1lKIUl27n6qQX82zeP5M6Ljs/Ya/zh3c95f9UW1mwJfWP411FHZGzdItJz8moPvXtS91OiL8N7eZIrOIa1tLYxad4mWtscK8preWl+SbcqrPFGyGzYtivy31PufY/S7d275+njH6+LhLmIBNcXbg89ke27mphWVMaJQ/fnqMH7+C5zy+T2seYLfK6xvmtPS8zzZz7byH/PWMVri0op2ryj0zXNXlOVdP7kBSXsaGjmrSUV/PTsozu9fhHJLwp0z7srKnl3RSUAj/9wBGcOH8Tf5nTu5J8/vhfbW9+xO7RHnSjMXZKLsVfU7uaapxf4ztvuDX2MnBTVjTGXd765rOu/nAXXPbOAbxx9MDecpQ8okc5SoPu48cXCLv3e7riDralGztwyuYjR/3QIA/r17TCvoanjgdtwgBeWhD4gwh8I3RlDP2le99pAmfbx6io+Xl2lQBfpAvXQMyh+TzmdPeeGPa0Ub61nccn2Tr9e+6V/O/2rGbejocn3RturK+s47q53KPdu9Sci2aNA74ZfvLqEUfd9EHleENdXTzdnz3noEy59PPYKkTt3+183xk+ybwIrymt9LzMwtbCUT9cm79Gna1v9Hk65d5bvcM4X5m2ksbmN91dtSbqOBRtq+Lzyi3ui1eaahsjBbpGuUsulG15dVBrzPHqkSHNrG9u6cXp/fD8eYOLsdTHPU/X4a3Y1ceHDc9h/7/4suee8mHm3vbKky7XFC19/fuaKLdx23rEx89K7IxQdLsXwRXPmHz4CYOMDF/ZyJdlVv6eFfn2Mgf07thml+7SHniW3T13W6WGKK8t3ctNLhUycvY6WttiTh4q31rFwY3tbJrq94ddyKdhYw4jfzgKgNm5vP509wU3Vuzp9ApbfPVjbD9yGipw0b1O39kQXl2xPejBZcs/WukYWbQr9v/tP98zk2//zce8WlMcU6FnyzrLEZ6wmMubhT3l7aQX3zficeetj2zfb6mP39udv6Dhs8hevLuGv3rVilpfVJnydVAd9W1rb+H8PfsxNaRwcnr++moU+tYRF9/mdc9z55nLGPjon5Xr9zFxRyaWP/50pC9svg7y5poFHP1wbE/Iry3eydktdl15DMu+ih+fE3HSmvLaxF6vJb4FsuRxx8JfYVN29k2myLfk9StstL08cvNEammLHuF/113kxr3XC3e+yyxsZc/4Jh1CwqfMHWcPC0fhJknHwy8tqOeLgL3HFxHkJl4ldW3u418WN109XeM9+fdQe/nXPLGBd1S4uHXEoe/fvy4Ff6s+Yhz8FOte+2L6rifLa3bxaUMrN3/5HBu3jfyMV6bxMXgpDkgtkoE/68emRnmOuqk8ztK5+yn+sebwfPVuQcJ5BJMwBznqw47apa2xm34H9WbuljtUpDj5OLyoHoC1Ba6OtzXHRI3P4+rADE65j7ZY6zv3TbI71bhloWGR9faI+7NZuqWPq4o7Xkm9qaWOvfrFfIP1+Pzy8s3zHbi6bMJdfnB/bww8r3lpHwcbtXHjSEPYd2D9mXl1jMxc9MocybyTOtvo9PPovIxK+N5FcFciWy35790+90BfIA+98nnKZu95czsZtuzj3T7NJdfe7/3o1dMA0erl3l1cyZWGJNz00I7qnD7E3BHmtMHTAeLXX+jBrX18fC4VmwcYarpg4jyc+jj3YC3DMne8wY1lFzHDH8HGDPlFffsKvWel9jf8gajRNRW3od99bUck5D81m3NRl3PzSYgCufXoBJ42fyZuLyzhx/HuRMIfUl2EWyVWBDPSB/QNZdtakulsTwJtF5ZzdjYNRP5m0iF+9Hjqr9OTfvJdwuVUVO1ldWceTn6yPmW60fxCYGRc9PIcfTJhLTZKRQDe+WMglXr+9qaWNusbQt56+UYkeXucK79ry0Vl8xv2hyxg/NKt9xNDS0h2s2VLHJ2uq2NnY4ttWiv5genH+prRvL5gJbW2Oos07OrTYwvOmLCxJ+2qbM5ZVUNeYePjr9CXlPKtbIWbEQ7PWMGzc277nYvQkJaN02i6fs1gB1m6t54K/fOp7RctxU5dx9oMfA6FwrtyZ3oGx8MHgUfd/ELk5SHTLJfzPZ8Inob38+DbR/e+sinm+vaGZkqjjL763Hozq+9/xxnKmFpbFjPhpbXM889kG9rR07TLM89dXc8vkxbS0tjFlYQn3TFsembekdAfffewzbp3c8UNk2pIyfvX6Mh7/uDhm+u6mVlbF3Ry9eGs9N75YyM9fTTw89T9fXsz4t1Z26T1IrMc+Cv1NWnt5BFYgA12j1nrOKwWbk15jvjPSDXE/0Xvy4UDfXNPQ4aSp+HbJk5+s7zDtraXlkcd9fRK9tc2xvqo+5ho8x931buTx64Wl/OatlTz2UcdWEcAbi0s7nGQW7YqJ85hWVM5v3lrJr15fxnNz269Bv9v7sHxv5RbueGMZjc2t7PT2ssPfUMLj/sNunbKYC/7yaWRvfGdjM1u8bV26vWfP0J04ex3Dxr2d9K5d+SgTl+HIhEAeFJWe88vXlvZ2Cfz/lxfHPO9j8MTH6/j9ux2PHfh94w0HYVh0m6WPT6C/v2or76/a6ltLSXVDZJs89el6zhw+iK8POyhmmZ9NCe0VR4+y+WRNFTsamhh7ytDItBd8biYSXf6L80t40TuXYeMDF0ZGTsW/x/AQ1pZWx56WVk4a394Sy8bOz+uLSrlvxioW3HEOffsYK8t3MmVhCeMvOSHS3mpqbaNf30DuL3ZJruxjfnG2uGTEz6b0XD857K0l5THP/zhrjW+YAx1aD9Dxm8GOhva+8pSCzfGLJ/Xz19pbGLuaWrlswtyEfer566sj9Vz79AJumVzEsHFvJ11/sgAOHzp4aX5JZD2tbS7yfpaW1XLC3TNjfmdlxU6mFZVR29DMtKIyRt33AS2tbWkdSI+3rX4PCzbUcOeby6ne1RRpOf3rU/N5bu6mmBuf1zW28OnaKrYm+Vb23N83Rk446q4nP1nHcXe9k5F1dUWudA0CuYfer09vf7H54nrDZ4jhF8Wna6t8d8X+5a/z+eXoY5mxrJLzjv9KZHrqMfod+Z1tG+bXHloQdVLXtQkut3xLXD++YNP2yDGHeMPGvc33Rgzl4pO+CsC3jvuHyLzLJsxlw7Zd7B132n64KuegsTnUahn959lsb2hmyP4DI8vVNsR+8N0zfQWQ+HyBppY2+vUx+iT4916/p4XirfWcctgB3J/kA6pmVxP7DezXI98Y4v96La1t1O5u5uAeOq8hkHvo/fr24YSv7tfbZcgXzNVPLfC9scmyslqufmoBLy8o4fpnF3b7Nfy8s6yCSfNjWzSXPzk3cm38znh+7sak86cWlnH9sws7vJfwiV2N3p75Kws38+7yisie+ZaovfHtXnhXRJ0VevK9/qOjlpfVxtwcprXN8d6KSo658x1ueKGAsh27cc5x1cR5zFxRGXMZge8+9lnMiKD4y0KUVDcw4rezGP/Wisi0Rz5Ym/SbUmVtI9X1e1hZ3rWLxUXXc/vUZZz6u/ezdh/geIHcQwd48AcnR84IFMl3P/W5DMOCDTUxe+jpmrGsMub5b/93JXdddHzSG4RvrmkfGRTOzPFvrWTfge0RclfUaJ3OuOiR0NDUST8+nW8OH8Qzn23gd2+HRieFjmd8yJrfXcDc9dXMXV8NwLSb/jny+4tL2g9et7Y5mlpbuXvaCm479xj+Y9Ki0LrnlXDyoQcw9pShkfv0Oud8z+gedX/7FVSfvm4kjc1tVNY2ctVph7O4ZDulO3Zz4YlD+PKA0HsPj3CB0M1sLpswlx+efnjk+AeELtYXf6JcNlg6Fzoys9HAX4C+wN+ccw/EzR8APA+cClQDVzjnNiZb58iRI11BQeKzH1Mp37GbbzzwYZd/X0TanXHUwZGw7E3r7xvDUb+e0WH6IfsNTGuU1MFf3iuml5/K/nv3p+jucznq1zM49fADufjkr0ZaQfFGHH5A5OYyAPddeiKLNm3n9cL2q64O2mdAh1FIAHv168PXhx3I0AP2ZuO2Bl75yRlp1xjPzBY550b6zksV6GbWF1gDnAuUAguBq5xzK6OWuRE4yTn3EzO7ErjUOXdFsvV2N9CBlAeYRERy0ayfncVw77IYnZUs0NP5DnAaUOycW++cawImA2PjlhkLPOc9fg34jqV7dapu2PjAhay7bwy/Gn0cf7riZL4adQBGRCRXnfunjiffZUI6PfShQPTYrlLg9ETLOOdazKwWOBjYFr2Qmd0A3ABw+OGHd7HkWH37WOSO92NOHEJdYwv7792fvuZ/dLyxuZXFJTso37Gbg/bZi211e9hU3cCN3zqaOWu3sd/e/Snx+oW/fG0p3znuH6hrbGHYoC/xwaqtHb7O7d2/b4d7iWbat44dzEerM3N3IRHpffdcfHxW1tujB0WdcxOBiRBquWR6/QP69WXAPsnvhDKwf1/OOPpg33nnnXAIAKOOCs2/fORhmS1QRCSL0mm5lAHRyXaoN813GTPrB+xP6OCoiIj0kHQCfSEw3MyONLO9gCuB6XHLTAeu9R7/APjQ6T5hIiI9KmXLxeuJ3wzMJDRs8Wnn3AozuxcocM5NB54CXjCzYqCGUOiLiEgPSquH7pybAcyIm3Z31ONG4LLMliYiIp0RyFP/RUSkIwW6iEieUKCLiOQJBbqISJ5I6+JcWXlhsyog8eXdkhtE3FmoOS5I9arW7AlSvao1e7pb7xHOucF+M3ot0LvDzAoSXZwmFwWpXtWaPUGqV7VmTzbrVctFRCRPKNBFRPJEUAN9Ym8X0ElBqle1Zk+Q6lWt2ZO1egPZQxcRkY6CuocuIiJxFOgiInkicIFuZqPNbLWZFZvZuF6q4TAz+8jMVprZCjO7xZs+3szKzKzI+xkT9Tu3ezWvNrPze/L9mNlGM1vm1VTgTTvIzGaZ2Vrvvwd6083MHvbqWWpmI6LWc623/FozuzbR63Wz1mOjtl+Rme00s1tzZdua2dNmttXMlkdNy9i2NLNTvb9Vsfe7Xb6VY4JaHzSzz7163jCzA7zpw8xsd9T2nZCqpkTvO8P1ZuzvbqFLgM/3pk+x0OXAM1nrlKg6N5pZkTe957atcy4wP4Qu37sOOArYC1gCHN8LdQwBRniP9yV0E+3jgfHAz32WP96rdQBwpPce+vbU+wE2AoPipv0BGOc9Hgf83ns8BngHMGAUMN+bfhCw3vvvgd7jA3vg710JHJEr2xY4CxgBLM/GtgQWeMua97sXZLjW84B+3uPfR9U6LHq5uPX41pTofWe43oz93YFXgCu9xxOAn2ay1rj5fwTu7ultG7Q99HRuWJ11zrkK51yh97gOWEXovqqJjAUmO+f2OOc2AMWE3ktvvp/oG3s/B3w3avrzLmQecICZDQHOB2Y552qcc9uBWcDoLNf4HWCdcy7ZGcU9um2dc7MJXfM/voZub0tv3n7OuXku9C/5+ah1ZaRW59x7zrkW7+k8QncgSyhFTYned8bqTaJTf3dvz/fbhG5i3+16k9XqvdblwMvJ1pGNbRu0QPe7YXWyIM06MxsGfA2Y70262fs6+3TU16REdffU+3HAe2a2yEI36gb4inOuwntcCXwlR2qNdiWx/yhycdtC5rblUO9x/PRs+RGhvcKwI81ssZl9YmZnetOS1ZTofWdaJv7uBwM7oj7MsrltzwS2OOfWRk3rkW0btEDPKWa2D/A6cKtzbifwBHA0cApQQehrVy74pnNuBHABcJOZnRU909s7yKnxq15/8xLgVW9Srm7bGLm4Lf2Y2R1AC/CiN6kCONw59zXgNuAlM9sv3fVl8X0H4u8e5ypid0R6bNsGLdDTuWF1jzCz/oTC/EXn3FQA59wW51yrc64N+Cuhr3+QuO4eeT/OuTLvv1uBN7y6tnhf+cJf/bbmQq1RLgAKnXNbvNpzctt6MrUty4htgWSlZjO7DrgI+KEXFniti2rv8SJCfehjUtSU6H1nTAb/7tWEWl794qZnlLf+7wFTot5Dj23boAV6OjeszjqvR/YUsMo591DU9CFRi10KhI+ATweuNLMBZnYkMJzQwZCsvx8z+7KZ7Rt+TOig2HJib+x9LTAtqtZrLGQUUOt99ZsJnGdmB3pfe8/zpmVLzF5OLm7bKBnZlt68nWY2yvt/7JqodWWEmY0Gfglc4pxriJo+2Mz6eo+PIrQd16eoKdH7zmS9Gfm7ex9cHxG6iX3W6gXOAT53zkVaKT26bdM9qpsrP4RGDqwh9Cl3Ry/V8E1CX4GWAkXezxjgBWCZN306MCTqd+7wal5N1MiFbL8fQkf7l3g/K8KvQain+AGwFngfOMibbsBjXj3LgJFR6/oRoYNPxcD1Wdy+Xya0R7V/1LSc2LaEPmQqgGZCPc8fZ3JbAiMJhdY64FG8s7kzWGsxoR5z+P/bCd6y3/f+/ygCCoGLU9WU6H1nuN6M/d29fwsLvG3wKjAgk7V6058FfhK3bI9tW536LyKSJ4LWchERkQQU6CIieUKBLiKSJxToIiJ5QoEuIpInFOgiInlCgS4ikif+D1RPiPE6XsNjAAAAAElFTkSuQmCC",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(losses)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD4CAYAAADlwTGnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd3hUdfr+8feTSg8tItJCVYOAYChSklU6Kiiigl0RLCAl6+7KusWvu6vruhuagIKoWAERNTaaugkgLTTpEooUKaFIr/L5/THH/WWzAQJMMpPM/bquXJz5nDLPOQlz55wzecacc4iISOgJC3QBIiISGAoAEZEQpQAQEQlRCgARkRClABARCVERgS7gQlSsWNHFxcUFugwRkUJj8eLFe5xzsbnNK1QBEBcXR0ZGRqDLEBEpNMzsh7PN0yUgEZEQpQAQEQlRCgARkRClABARCVEKABGREJWnADCzTma2zswyzezpXOZHm9kkb/4CM4vzxiuY2TdmdtjMXs6xznVmtsJbZ4SZmT92SERE8ua8AWBm4cAooDMQD/Qys/gci/UG9jvn6gBDgRe98ePAH4Gnctn0GKAPUNf76nQxOyAiIhcnL2cAzYBM59xG59xJYCLQLccy3YAJ3vQUoK2ZmXPuiHNuDr4g+A8zqwyUcc7Nd75+1G8Bt17KjpzLiK/Ws3zrT/m1eRGRQikvAVAF2Jrt8TZvLNdlnHOngQNAhfNsc9t5tgmAmfU1swwzy8jKyspDuf/tp6MneW/BFm4bPZfnv1jDsZM/X/A2RESKoqC/CeycG+ucS3DOJcTG5vrXzOdUtkQUM5ITuatpdcamb6Tz8HTmbdibD5WKiBQueQmA7UC1bI+remO5LmNmEUAMcK5X2e3eds61Tb8pUyySF7o34L0+zXFAr3Hz+f1HKzh4/FR+PaWISNDLSwAsAuqaWU0ziwJ6Aqk5lkkFHvCmewBfu3N81qRzbgdw0MxaeO/+uR/45IKrv0Ata1dk2sBE+rSpycSFW+iQks5Xa3bl99OKiASl8waAd02/PzAdWANMds6tMrPnzKyrt9h4oIKZZQLJwH/eKmpmm4EU4EEz25btHURPAK8BmcAG4Ev/7NK5FY8K55mb4pn6RCtiikfSe0IGA95fyt7DJwri6UVEgoYVpg+FT0hIcP7sBnry9BlG/zuTUd9kUrpYJH++JZ6uja5Af5IgIkWFmS12ziXkNi/obwLnp6iIMAa1q8dnT7ahWvkSDJy4jEcmZLDjwLFAlyYiku9COgB+ceXlpZn6eEv+cNPVzN2whw4p6by3YAtnzhSesyMRkQulAPCEhxmPtKnF9EGJXFMlht9/tIK7X5vP5j1HAl2aiEi+UADkUKNCSd7r05y/d2/Aqu0H6TgsnbHpGzj985lAlyYi4lcKgFyYGT2bVWdmchJt6lbk+S/WcvuYb1m782CgSxMR8RsFwDlcHlOMcfcnMLJXY7btP8bNI+aQMvN7TpxWOwkRKfwUAOdhZtzS6ApmJidxc8PKjPhqPbeMnMPSLfsDXZqIyCVRAORR+ZJRDOvZmNcfTODQ8dN0H/Mtf/lsNUdPng50aSIiF0UBcIFuvKoSMwYnck/z6oyfs4mOw9KZm7kn0GWJiFwwBcBFKF0skr/e2oCJfVsQbsY9ry3g6Q+/48AxNZcTkcJDAXAJWtSqwLRBiTyaVIvJGVtpn5LGjFU7A12WiEieKAAuUbHIcIZ0vpqP+7WifMko+r69mP7vLWGPmsuJSJBTAPhJw6plSe3fml+3r8eMVbtol5LGR0u3UZia7YlIaFEA+FFURBhPtq3L5wNaU7NiSQZPWs7Dby7ix5/UXE5Ego8CIB/UrVSaKY+15E83xzN/4z7ap6Tx9vwf1FxORIKKAiCfhIcZD7euyYzBiTSuXo4/frySnmPnszHrcKBLExEBFAD5rlr5Erzduxn/uL0ha3YepPPw2bySpuZyIhJ4CoACYGbc2bQas5KTSKoXy9+/XMuto+ey+kc1lxORwFEAFKBKZYrx6n3XMfqeJuw8cJyuL8/hXzPWqbmciASEAqCAmRldGlRm5uAkul57BSO/zuSmEXNY/MO+QJcmIiFGARAg5UpGkXLntbz5UFOOnfyZHq/M49nUVRw5oeZyIlIwFAAB9qsrL2P64ETua1GDN7/dTMdh6cxenxXoskQkBCgAgkCp6Aie63YNkx+9nqjwMO4bv5DffLCcA0fVXE5E8o8CIIg0q1meLwa24Ylf1Wbq0u20G5rGtJVqLici+UMBEGSKRYbz205X8Um/VsSWiuaxdxbzxLuL2X3oeKBLE5EiRgEQpK6pEsMn/Vvxm45XMmvNbtqnpDNlsZrLiYj/KACCWGR4GP1uqMMXA9pQ57JSPPXBch54YxHb9h8NdGkiUgQoAAqBOpeV4oNHr+f/utYnY/M+OgxNZ8K3m9VcTkQuiQKgkAgLMx5oGceMwYkkxJXnz6mruPPVeWxQczkRuUgKgEKmarkSTHioKf+8oxHrdx+m8/DZjPomk1NqLiciF0gBUAiZGT2uq8rM5ETaXX0ZL01fR7eX57Jy+4FAlyYihUieAsDMOpnZOjPLNLOnc5kfbWaTvPkLzCwu27wh3vg6M+uYbXywma0ys5Vm9r6ZFfPHDoWSy0oXY/Q91/HKvU3YfegE3UbN5cVpazl+Ss3lROT8zhsAZhYOjAI6A/FALzOLz7FYb2C/c64OMBR40Vs3HugJ1Ac6AaPNLNzMqgADgATn3DVAuLecXIRO11Tmq+Qkujeuwph/b6DL8Nks2qzmciJybnk5A2gGZDrnNjrnTgITgW45lukGTPCmpwBtzcy88YnOuRPOuU1Aprc9gAiguJlFACWAHy9tV0JbTIlIXrqjEW893IwTp89wxyvz+NMnKzms5nIichZ5CYAqwNZsj7d5Y7ku45w7DRwAKpxtXefcduCfwBZgB3DAOTcjtyc3s75mlmFmGVlZapJ2Pon1YpkxOJEHW8bx9vwf6Dg0nbTvddxE5H8F5CawmZXDd3ZQE7gCKGlm9+a2rHNurHMuwTmXEBsbW5BlFloloyN4tmt9pjx2PcUiw3jg9YUkT17GT0dPBro0EQkieQmA7UC1bI+remO5LuNd0okB9p5j3XbAJudclnPuFDAVaHkxOyBnd12N8nw+oA39b6hD6rIfaZeSxhcrdqidhIgAeQuARUBdM6tpZlH4btam5lgmFXjAm+4BfO18rzKpQE/vXUI1gbrAQnyXflqYWQnvXkFbYM2l747kVCwynKc6Xskn/VtxeUwxnnh3CY+9s5jdB9VcTiTUnTcAvGv6/YHp+F6kJzvnVpnZc2bW1VtsPFDBzDKBZOBpb91VwGRgNTAN6Oec+9k5twDfzeIlwAqvjrF+3TP5L/WviOHjJ1rxu05X8c26LNqlpDE5Y6vOBkRCmBWmF4CEhASXkZER6DIKvY1Zh3n6wxUs3LyP1nUq8kL3BlQrXyLQZYlIPjCzxc65hNzm6S+BQ1Ct2FJM7NuCv9x6DUu37KfD0HTemLuJn9VcTiSkKABCVFiYcV+LGsxITqJ5rfL836erueOVb8ncfSjQpYlIAVEAhLgqZYvzxoNNGXpXIzbuOUKX4XMY+dV6NZcTCQEKAMHMuK1xVWYlJ9G+fiX+NfN7bhk5hxXb1FxOpChTAMh/VCwVzai7m/Dqfdex78hJuo2awwtfrlFzOZEiSgEg/6Nj/cuZmZzEnQnVeDVtI52Hz2bBxr2BLktE/EwBILmKKR7J329vyLuPNOf0mTPcNXY+f/h4BYeOnwp0aSLiJwoAOadWdSoyfVAivVvX5N0FW+g4NJ1v1u4OdFki4gcKADmvElER/PHmeD58vCUloyN46M1FDJ60jH1H1FxOpDBTAEieNalejs8GtGZA27p8uvxH2qek8enyH9VOQqSQUgDIBYmOCCe5fT0+fbI1VcoV58n3l9LnrcXsUnM5kUJHASAX5erKZZj6eEt+3+UqZq/3NZebuHCLzgZEChEFgFy0iPAw+ibWZvqgROIrl+HpqSu457UFbNl7NNCliUgeKADkksVVLMn7fVrw/G0N+G7bAToMS+O12RvVXE4kyCkAxC/Cwoy7m1dnZnIiLWtX5K+fr6H7mG9Zt1PN5USClQJA/KpyTHHGP5DA8J7XsnXfUW4eOZths77n5Gk1lxMJNgoA8Tszo9u1VZg5OJEuDSozbNZ6bhk5h+Vbfwp0aSKSjQJA8k2FUtEM79mY1+5P4MCxU9w2ei5/+3w1x06quZxIMFAASL5rF1+JGcmJ9GxWnXGzN9FpeDrzNqi5nEigKQCkQJQpFsnztzXgvT7NAeg1bj5Dpq7goJrLiQSMAkAKVMvaFZk2MJG+ibWYtGgL7VPSmLV6V6DLEglJCgApcMWjwvl9l6uZ+kQryhaP4pG3Mhjw/lL2Hj4R6NJEQooCQALm2mpl+fTJ1gxuV48vV+6gXUoanyzbrnYSIgVEASABFRURxsB2dfl8QBtqVCjJwInLeGRCBjsOHAt0aSJFngJAgkK9SqX58PGW/OGmq5m7YQ/tU9J5d8EPnFE7CZF8owCQoBEeZjzSphYzBiXRsGoMz3y0krtfm8/mPUcCXZpIkaQAkKBTvUIJ3n2kOX/v3oBV2w/ScVg6Y9M3cPpntZMQ8ScFgAQlM6Nns+rMTE6iTd1Ynv9iLd3HfMuaHQcDXZpIkaEAkKB2eUwxxt1/HS/f3Zjt+49xy8g5pMz8nhOn1U5C5FIpACTomRk3N7yCWclJ3NLoCkZ8tZ6bR8xhyZb9gS5NpFDLUwCYWSczW2dmmWb2dC7zo81skjd/gZnFZZs3xBtfZ2Yds42XNbMpZrbWzNaY2fX+2CEpusqVjGLoXdfyxoNNOXziNLeP+Za/fLaaoydPB7o0kULpvAFgZuHAKKAzEA/0MrP4HIv1BvY75+oAQ4EXvXXjgZ5AfaATMNrbHsBwYJpz7iqgEbDm0ndHQsENV13GjMGJ3NO8OuPnbKLjsHTmZu4JdFkihU5ezgCaAZnOuY3OuZPARKBbjmW6ARO86SlAWzMzb3yic+6Ec24TkAk0M7MYIBEYD+CcO+mcU7N4ybPSxSL5660NmNS3BRFhYdzz2gJ+N+U7DhxTczmRvMpLAFQBtmZ7vM0by3UZ59xp4ABQ4Rzr1gSygDfMbKmZvWZmJXN7cjPra2YZZpaRlZWVh3IllDSvVYEvB7bhsaTaTFmyjfYpacxYtTPQZYkUCoG6CRwBNAHGOOcaA0eA/7m3AOCcG+ucS3DOJcTGxhZkjVJIFIsM5+nOV/HxE62oUCqavm8vpt97S8g6pOZyIueSlwDYDlTL9riqN5brMmYWAcQAe8+x7jZgm3NugTc+BV8giFy0BlVjSO3fiqc61GPmql20H5rGR0u3qbmcyFnkJQAWAXXNrKaZReG7qZuaY5lU4AFvugfwtfP9r0sFenrvEqoJ1AUWOud2AlvN7EpvnbbA6kvcFxEiw8Pof2NdvhjYmloVSzJ40nIeenMR239SczmRnM4bAN41/f7AdHzv1JnsnFtlZs+ZWVdvsfFABTPLBJLxLuc451YBk/G9uE8D+jnnfvkLnieBd83sO+Ba4Hn/7ZaEujqXleaDx1ry51viWbBxHx1S0nh73mY1lxPJxgrT6XFCQoLLyMgIdBlSyGzdd5Tff7SC2ev30CyuPH+/vQG1YksFuiyRAmFmi51zCbnN018CS5FXrXwJ3nq4GS/1aMjanQfpNHw2Y/6t5nIiCgAJCWbGHQnVmJWcxA1XxvLitLXcOnouq39UczkJXQoACSmXlSnGq/clMOaeJuw8cIKuL8/hn9PXcfyUmstJ6FEASEjq3KAys5IT6XZtFV7+JpObRsxm8Q/7Al2WSIFSAEjIKlsiin/d2YgJDzfj+Kkz9HhlHs+mruLICTWXk9CgAJCQl1QvlumDE7m/RQ0mzNtMh6HppH+vtiNS9CkARIBS0RH8X7drmPzo9URHhnH/6wt56oPlHDiq5nJSdCkARLJpGleeLwa04Ylf1eajpdtpNzSNaSt3BLoskXyhABDJoVhkOL/tdBWf9GtFbKloHntnCY+/s5jdh44HujQRv1IAiJzFNVVi+KR/K37T8Uq+Wrub9inpTFms5nJSdCgARM4hMjyMfjfU4YsBbah7WSme+mA597++kK37jga6NJFLpgAQyYM6l5Vi8qPX81y3+iz5YT8dh6Xz5txNai4nhZoCQCSPwsKM+6+PY/rgRBLiyvPsp6u589V5ZO4+HOjSRC6KAkDkAlUtV4IJDzXlX3c0Yv3uw3QZPptR32RySs3lpJBRAIhcBDPj9uuqMis5iXbxl/HS9HV0e3kuK7cfCHRpInmmABC5BLGloxl9z3W8cm8Tsg6foNuoubw4ba2ay0mhoAAQ8YNO11Rm1uAkbm9ShTH/3kCX4bNZtFnN5SS4KQBE/CSmRCT/6NGId3o35+TPZ7jjlXn86ZOVHFZzOQlSCgARP2tdtyLTByXyUKs43p7/Ax2HpvPvdbsDXZbI/1AAiOSDktER/PmW+kx5rCXFo8J58I1FJE9exv4jJwNdmsh/KABE8tF1Ncrx+YDWPHljHVKX/Uj7oWl8/t0OtZOQoKAAEMln0RHh/LrDlaT2b03lmOL0e28Jj769mN0H1VxOAksBIFJA4q8ow0dPtGRI56tI+z6LtilpTF60VWcDEjAKAJECFBEexqNJtflyYBuurlyG3374HfeNV3M5CQwFgEgA1IotxcQ+LfjrrdewbOtPdBiazutzNvGzmstJAVIAiARIWJhxb4sazBicSPNa5Xnus9Xc8cq3rN91KNClSYhQAIgE2BVli/PGg00Zdte1bNpzhJtGzGHkV+s5eVrN5SR/KQBEgoCZcWvjKsxMTqLjNZfzr5nf0/XlOXy37adAlyZFmAJAJIhULBXNyF6NGXd/AvuPnuTWUXN54Ys1ai4n+UIBIBKE2sdXYsbgJO5qWo1X0zfSaVg68zfuDXRZUsQoAESCVEzxSF7o3pD3HmnOGQc9x87nmY9WcOj4qUCXJkVEngLAzDqZ2TozyzSzp3OZH21mk7z5C8wsLtu8Id74OjPrmGO9cDNbamafXeqOiBRVLetUZNqgNjzSuibvL9xCh6HpfLNWzeXk0p03AMwsHBgFdAbigV5mFp9jsd7AfudcHWAo8KK3bjzQE6gPdAJGe9v7xUBgzaXuhEhRVyIqgj/cHM+Hj7ekVHQED725iEETl7JPzeXkEuTlDKAZkOmc2+icOwlMBLrlWKYbMMGbngK0NTPzxic650445zYBmd72MLOqwE3Aa5e+GyKhoXH1cnw2oDUD29bl8xU7aJeSRuryH9VOQi5KXgKgCrA12+Nt3liuyzjnTgMHgArnWXcY8FvgnG92NrO+ZpZhZhlZWVl5KFekaIuOCGdw+3p8+mRrqpUrzoD3l9LnrcXsPKDmcnJhAnIT2MxuBnY75xafb1nn3FjnXIJzLiE2NrYAqhMpHK66vAxTn2jFM12uZk5mFu1T0nh/4RadDUie5SUAtgPVsj2u6o3luoyZRQAxwN5zrNsK6Gpmm/FdUrrRzN65iPpFQlp4mNEnsRbTBiZSv0oZhkxdwd3jFvDD3iOBLk0KgbwEwCKgrpnVNLMofDd1U3Mskwo84E33AL52vl9DUoGe3ruEagJ1gYXOuSHOuarOuThve1875+71w/6IhKS4iiV575EWPH9bA1ZuP0DHYem8NnujmsvJOZ03ALxr+v2B6fjesTPZObfKzJ4zs67eYuOBCmaWCSQDT3vrrgImA6uBaUA/55z+pFEkH4SFGXc3r86M5ERa1a7IXz9fQ/cx37Jup5rLSe6sMF0vTEhIcBkZGYEuQyToOef49LsdPJu6ikPHT9Hvhjo88as6REXobz9DjZktds4l5DZPPw0iRZCZ0bXRFcxKTqJLg8oMm7WeW0bOYdlWNZeT/08BIFKElS8ZxfCejRn/QAIHjp2i++i5/O3z1Rw7qSuxogAQCQltr67EjOREejarzrjZm+g4LJ1vN+wJdFkSYAoAkRBRplgkz9/WgPf7tMAM7h63gCFTV3BQzeVClgJAJMRcX7sC0wYm8mhiLSYt2kL7lDRmrd4V6LIkABQAIiGoeFQ4Q7pczcf9WlGuRBSPvJXBk+8vZe/hE4EuTQqQAkAkhDWsWpbU/q1Jbl+PaSt9zeU+WbZd7SRChAJAJMRFRYQxoG1dPh/QhhoVSjJw4jJ6T8jgx5+OBbo0yWcKABEBoF6l0nz4eEv+eHM88zbspcPQdN5d8ANn1E6iyFIAiMh/hIcZvVvXZPqgRBpVi+GZj1bSa9x8Nu1Rc7miSAEgIv+jeoUSvNO7OS/e3oDVOw7SaVg6r6Zt4PTP5/z4DilkFAAikisz466m1ZmVnERivVhe+HIt3cd8y5odBwNdmviJAkBEzqlSmWKMve86Rt3dhB9/OsYtI+eQMmMdJ06rnURhpwAQkfMyM25qWJmZg5Po2ugKRnydyc0j5rBky/5AlyaXQAEgInlWrmQUKXddyxsPNeXIidPcPuZbnvt0NUdPng50aXIRFAAicsFuuPIypg9O5N7mNXh9rq+53Jz1ai5X2CgAROSilC4WyV9uvYbJj15PRFgY945fwG+nLOfAMTWXKywUACJySZrVLM+XA9vw+K9q8+GS7bRPSWP6qp2BLkvyQAEgIpesWGQ4v+t0FR8/0YoKpaJ59O3F9Ht3CVmH1FwumCkARMRvGlSNIbV/K37T8Upmrt5F+6FpTF2yTc3lgpQCQET8KjI8jH431OGLga2pVbEkyZOX8+Abi9iu5nJBRwEgIvmizmWl+eCxljx7SzyLNu+jQ0oab83brOZyQUQBICL5JjzMeLCVr7lckxrl+NMnq7hr7Dw2ZB0OdGmCAkBECkC18iV46+FmvNSjIet2HqLz8NmM/nemmssFmAJARAqEmXFHQjVm/TqJG6+8jH9MW8eto+ey6scDgS4tZCkARKRAXVa6GK/cdx1j7mnCzgMn6PryXF6avpbjp9RcrqApAEQkIDo3qMys5ERua1yFUd9s4KYRs8nYvC/QZYUUBYCIBEzZElH8845GvPVwM46fOsMdr87j2dRVHDmh5nIFQQEgIgGXWC+WGYMTeeD6OCbM20yHoemkf58V6LKKPAWAiASFktERPNu1Ph88ej3RkWHc//pCnvpgOT8dPRno0oqsPAWAmXUys3VmlmlmT+cyP9rMJnnzF5hZXLZ5Q7zxdWbW0RurZmbfmNlqM1tlZgP9tUMiUrglxJXniwFt6HdDbT5aup12Kel8uWJHoMsqks4bAGYWDowCOgPxQC8zi8+xWG9gv3OuDjAUeNFbNx7oCdQHOgGjve2dBn7tnIsHWgD9ctmmiISoYpHh/KbjVaT2b0WlMtE8/u4SHn9nMbsPHQ90aUVKXs4AmgGZzrmNzrmTwESgW45lugETvOkpQFszM298onPuhHNuE5AJNHPO7XDOLQFwzh0C1gBVLn13RKQoqX9FDB/3a8XvOl3FV2t30z4lnQ8ytqq5nJ/kJQCqAFuzPd7G/75Y/2cZ59xp4ABQIS/repeLGgMLcntyM+trZhlmlpGVpZtCIqEmMjyMx39Vmy8HtqFepVL8Zsp33P/6QrbuOxro0gq9gN4ENrNSwIfAIOfcwdyWcc6Ndc4lOOcSYmNjC7ZAEQkatWNLManv9fylW32W/LCfjsPSeXPuJjWXuwR5CYDtQLVsj6t6Y7kuY2YRQAyw91zrmlkkvhf/d51zUy+meBEJLWFhxn3XxzF9cCJN48rz7KeruePVeWTuPhTo0gqlvATAIqCumdU0syh8N3VTcyyTCjzgTfcAvna+i3SpQE/vXUI1gbrAQu/+wHhgjXMuxR87IiKho2q5Erz5UFNS7mzEhqzDdBk+h1HfZHJKzeUuyHkDwLum3x+Yju9m7WTn3Coze87MunqLjQcqmFkmkAw87a27CpgMrAamAf2ccz8DrYD7gBvNbJn31cXP+yYiRZiZ0b1JVWYOTqJ9/Uq8NH0d3V6ey8rtai6XV1aY7qYnJCS4jIyMQJchIkFo+qqd/OHjlew7cpK+ibUY2LYuxSLDA11WwJnZYudcQm7z9JfAIlIkdKx/ObMGJ9GjSVXG/HsDXYbPZuEmNZc7FwWAiBQZMSUiebFHQ97p3ZyTP5/hzlfn8cePV3JYzeVypQAQkSKndd2KzBicyMOtavLOgh/okJLGN+t2B7qsoKMAEJEiqURUBH+6JZ4pj7WkRHQED72xiORJy9h/RM3lfqEAEJEi7boa5fh8QGsG3FiH1OU/0n5oGp9/t0PtJFAAiEgIiI4IJ7nDlXz6ZGsqxxSn33tLePTtxew6GNrN5RQAIhIyrq5cho+eaMmQzleR9n0W7VLSmLRoS8ieDSgARCSkRISH8WhSbaYNSuTqymX43YcruHf8ArbsDb3mcgoAEQlJNSuWZGKfFvz11mtYvvUAHYelM37OJn4OoeZyCgARCVlhYca9LWowY3AiLWqV5y+frabHK9+yfldoNJdTAIhIyLuibHFef7Apw3tey+Y9R7hpxBxGfLWek6eLdnM5BYCICL7mct2urcKs5CQ6XnM5KTO/p+vLc1i+9adAl5ZvFAAiItlUKBXNyF6NGXd/AvuPnuS20XN54Ys1HDv5c6BL8zsFgIhILtrHV2JmchJ3Na3Gq+kb6Tw8nfkb9wa6LL9SAIiInEWZYpG80L0h7z3SnDMOeo6dzzMfreDQ8VOBLs0vFAAiIufRsk5Fpg9KpE+bmry/cAsdhqbz9dpdgS7rkikARETyoHhUOM/cFM/UJ1pRplgkD7+ZwcCJS9l7+ESgS7toCgARkQtwbbWyfPpkawa1q8sXK3bQfmg6qct/LJTtJBQAIiIXKCoijEHt6vHZk22oVr4EA95fSp+3Mth5oHA1l1MAiIhcpCsvL83Ux1vyh5uuZk7mHtqnpPH+wsLTXE4BICJyCcLDjEfa1GL6oESuqRLDkKkruHvcAn7YeyTQpZ2XAkBExA9qVCjJe32a80L3Bqzc7msuNy59Y1A3l1MAiIj4iZnRq1l1ZiYn0bpORf72xRq6j57Lup3B2VxOASAi4meXxxRj3P0JjOzVmG37j1M1W90AAAhBSURBVHHzyNkMnfl90DWXUwCIiOQDM+OWRlcwMzmJmxpUZvhX67l55GyWBVFzOQWAiEg+Kl8yimE9G/P6gwkcOn6a7qPn8tfPVgdFczkFgIhIAbjxqkrMGJxIr2bVeW3OJjoOS+fbDXsCWpMCQESkgJQuFsnfbmvAxL4tCDO4e9wChkz9jgPHAtNcTgEgIlLAWtSqwLRBiTyaVItJi7bSYWgaM1cXfHM5BYCISAAUiwxnSOer+bhfK8qViKLPWxn0f28JewqwuZwCQEQkgBpWLUtq/9b8un09ZqzaRfuUND5eur1A2knkKQDMrJOZrTOzTDN7Opf50WY2yZu/wMziss0b4o2vM7OOed2miEioiIoI48m2dfl8QGviKpZk0KRl9J6QwY8/HcvX5z1vAJhZODAK6AzEA73MLD7HYr2B/c65OsBQ4EVv3XigJ1Af6ASMNrPwPG5TRCSk1K1UmimPteRPN8czb8NeOgxN5535P3Amn9pJ5OUMoBmQ6Zzb6Jw7CUwEuuVYphswwZueArQ1M/PGJzrnTjjnNgGZ3vbysk0RkZATHmY83LomMwYncm21svzh45X0HDefoydP+/258hIAVYCt2R5v88ZyXcY5dxo4AFQ4x7p52SYAZtbXzDLMLCMrKysP5YqIFH7Vypfg7d7N+MftDalZoSQloiL8/hxBfxPYOTfWOZfgnEuIjY0NdDkiIgXGzLizaTVe7NEwX7aflwDYDlTL9riqN5brMmYWAcQAe8+xbl62KSIi+SgvAbAIqGtmNc0sCt9N3dQcy6QCD3jTPYCvne89TKlAT+9dQjWBusDCPG5TRETy0XkvKjnnTptZf2A6EA687pxbZWbPARnOuVRgPPC2mWUC+/C9oOMtNxlYDZwG+jnnfgbIbZv+3z0RETkbKyyfXQmQkJDgMjIyAl2GiEihYWaLnXMJuc0L+pvAIiKSPxQAIiIhSgEgIhKiFAAiIiGqUN0ENrMs4IeLXL0iENiP38k71Zp/ClO9qjX/FKZ6L7XWGs65XP+KtlAFwKUws4yz3QkPNqo1/xSmelVr/ilM9eZnrboEJCISohQAIiIhKpQCYGygC7gAqjX/FKZ6VWv+KUz15lutIXMPQERE/lsonQGIiEg2CgARkRBV5AMgGD583syqmdk3ZrbazFaZ2UBv/Fkz225my7yvLtnWGeLVvM7MOhb0/pjZZjNb4dWV4Y2VN7OZZrbe+7ecN25mNsKr6Tsza5JtOw94y683swfO9nyXUOeV2Y7fMjM7aGaDgunYmtnrZrbbzFZmG/PbsTSz67zvVaa3rvm51pfMbK1Xz0dmVtYbjzOzY9mO8Svnq+ls++3HWv32fTdfu/oF3vgk87Wu92etk7LVudnMlnnjBXdcnXNF9gtfq+kNQC0gClgOxAegjspAE2+6NPA9EA88CzyVy/LxXq3RQE1vH8ILcn+AzUDFHGP/AJ72pp8GXvSmuwBfAga0ABZ44+WBjd6/5bzpcvn8/d4J1AimYwskAk2AlflxLPF9xkYLb50vgc5+rrUDEOFNv5it1rjsy+XYTq41nW2//Vir377vwGSgpzf9CvC4P2vNMf9fwJ8K+rgW9TOAoPjweefcDufcEm/6ELCGs3wGsqcbMNE5d8I5twnIxLcvgd6fbsAEb3oCcGu28becz3ygrJlVBjoCM51z+5xz+4GZQKd8rK8tsME5d66/Fi/wY+ucS8f3ORk567jkY+nNK+Ocm+98//vfyrYtv9TqnJvhfJ/1DTAf3yf4ndV5ajrbfvul1nO4oO+795v1jcCU/K7Ve647gffPtY38OK5FPQDy/OHzBcXM4oDGwAJvqL93av16ttO2s9VdkPvjgBlmttjM+npjlZxzO7zpnUClIKoXfB9ElP0/UbAeW/DfsaziTecczy8P4/vN8xc1zWypmaWZWRtv7Fw1nW2//ckf3/cKwE/Zgi8/j2sbYJdzbn22sQI5rkU9AIKKmZUCPgQGOecOAmOA2sC1wA58p4HBorVzrgnQGehnZonZZ3q/gQTNe4i967NdgQ+8oWA+tv8l2I7l2ZjZM/g+2e9db2gHUN051xhIBt4zszJ53V4+7Xeh+b5n04v//sWlwI5rUQ+AoPnweTOLxPfi/65zbiqAc26Xc+5n59wZYBy+01E4e90Ftj/Oue3ev7uBj7zadnmnob+cju4OlnrxBdUS59wur+6gPbYefx3L7fz3JZl8qdvMHgRuBu7xXmDwLqfs9aYX47uWXu88NZ1tv/3Cj9/3vfguv0XkGPcrb/vdgUnZ9qHAjmtRD4Cg+PB57xrfeGCNcy4l23jlbIvdBvzyDoFUoKeZRZtZTaAuvps/BbI/ZlbSzEr/Mo3vJuBK77l+effJA8An2eq933xaAAe809HpQAczK+edinfwxvLDf/0WFazHNhu/HEtv3kEza+H9nN2fbVt+YWadgN8CXZ1zR7ONx5pZuDddC9+x3Hiems623/6q1S/fdy/kvgF65FetnnbAWufcfy7tFOhxzetd7ML6he9dFd/jS9FnAlRDa3ynZN8By7yvLsDbwApvPBWonG2dZ7ya15HtXR0FsT/43hGx3Pta9cvz4Lsu+hWwHpgFlPfGDRjl1bQCSMi2rYfx3XDLBB7Kp3pL4vuNLSbbWNAcW3zBtAM4he+6bW9/HksgAd8L3QbgZby/8PdjrZn4rpP/8rP7irfs7d7PxzJgCXDL+Wo62377sVa/fd+9/wcLvf3/AIj2Z63e+JvAYzmWLbDjqlYQIiIhqqhfAhIRkbNQAIiIhCgFgIhIiFIAiIiEKAWAiEiIUgCIiIQoBYCISIj6f3x0Oxh6UrCwAAAAAElFTkSuQmCC",
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(learning_rates)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Evaluate"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 328,
"referenced_widgets": [
"6cdc774b00ab477d956c98d936c2a422"
]
},
"id": "RWxwYR7KV-RA",
"outputId": "15ce76cc-1373-4dde-88ed-bca45b3bd1b9"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7af82e38e7e94395b38173cdb0084ba9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/1 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[419531 57 20 157 74 86 12 76 47]\n",
" [ 113 1363 22 73 9 33 2 1 1]\n",
" [ 87 9 1042 2 13 1 2 0 0]\n",
" [ 116 27 2 1364 27 89 0 33 3]\n",
" [ 55 1 9 11 703 5 33 4 14]\n",
" [ 58 18 0 51 3 1502 4 31 1]\n",
" [ 27 0 6 0 14 7 202 0 1]\n",
" [ 71 9 1 37 6 16 0 557 5]\n",
" [ 32 0 3 1 14 3 8 6 149]]\n"
]
},
{
"data": {
"text/plain": [
"0.8411122480145246"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_dl = DataLoader(datasets['test'], batch_size=len(datasets['test']), collate_fn=collate_batch_bilstm, num_workers=n_workers)\n",
"\n",
"evaluate(model, test_dl)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YATCbeeTDNQG"
},
"source": [
"# Beam Search"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sylZvVBm0lo_"
},
"source": [
" [(source)](https://towardsdatascience.com/foundations-of-nlp-explained-visually-beam-search-how-it-works-1586b9849a24)\n"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"id": "oKUr7j_y1J4p"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 6.931471805599453]\n",
"[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 7.154615356913663]\n",
"[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 7.154615356913663]\n"
]
}
],
"source": [
"# source https://machinelearningmastery.com/beam-search-decoder-natural-language-processing/\n",
"\n",
"def beam_search_decoder(data, k):\n",
" sequences = [[list(), 0.0]]\n",
" # walk over each step in sequence\n",
" for row in data:\n",
" all_candidates = list()\n",
" # expand each current candidate\n",
" for i in range(len(sequences)):\n",
" seq, score = sequences[i]\n",
" for j in range(len(row)):\n",
" candidate = [seq + [j], score - log(row[j])]\n",
" all_candidates.append(candidate)\n",
" # order all candidates by score\n",
" ordered = sorted(all_candidates, key=lambda tup:tup[1])\n",
" # select k best\n",
" sequences = ordered[:k]\n",
" return sequences\n",
" \n",
"# define a sequence of 10 words over a vocab of 5 words\n",
"data = [[0.1, 0.2, 0.3, 0.4, 0.5],\n",
" [0.5, 0.4, 0.3, 0.2, 0.1],\n",
" [0.1, 0.2, 0.3, 0.4, 0.5],\n",
" [0.5, 0.4, 0.3, 0.2, 0.1],\n",
" [0.1, 0.2, 0.3, 0.4, 0.5],\n",
" [0.5, 0.4, 0.3, 0.2, 0.1],\n",
" [0.1, 0.2, 0.3, 0.4, 0.5],\n",
" [0.5, 0.4, 0.3, 0.2, 0.1],\n",
" [0.1, 0.2, 0.3, 0.4, 0.5],\n",
" [0.5, 0.4, 0.3, 0.2, 0.1]]\n",
"data = array(data)\n",
"# decode sequence\n",
"result = beam_search_decoder(data, 3)\n",
"# print result\n",
"for seq in result:\n",
" print(seq)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lOTwIiaG1KPU"
},
"source": [
"- **Question: Can you find what is the problem with the above?**\n",
"\n",
"\n",
"- In the above, when generating text, the probability distribution for the next step does not depend on the previous step's choice.\n",
"- Beam search is usually employed with encoder-decoder architectures:\n",
"\n",
"- At each step, the decoder receives as input the prediction of the previous step and the hidden state of the previous step.\n",
"- During training : at each step provide either the prediction at the previous step with highest probability or the gold label for the next step (teacher forcing).\n",
"- During testing: build a beam of top-k generated sequences and re-run the decoder with each of them.\n",
"\n",
"Resources:\n",
"- Implementing an encoder-decoder model [example 1](https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html), [example 2](https://bastings.github.io/annotated_encoder_decoder/)\n",
"- Implementing beam search [example](https://github.com/budzianowski/PyTorch-Beam-Search-Decoding/blob/master/decode_beam.py)\n"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {
"id": "aUlZdPe7v5KJ"
},
"outputs": [],
"source": [
"class EncoderRNN(nn.Module):\n",
" \"\"\"\n",
" RNN Encoder model.\n",
" \"\"\"\n",
" def __init__(self, \n",
" pretrained_embeddings: torch.tensor, \n",
" lstm_dim: int,\n",
" dropout_prob: float = 0.1):\n",
" \"\"\"\n",
" Initializer for EncoderRNN network\n",
" :param pretrained_embeddings: A tensor containing the pretrained embeddings\n",
" :param lstm_dim: The dimensionality of the LSTM network\n",
" :param dropout_prob: Dropout probability\n",
" \"\"\"\n",
" # First thing is to call the superclass initializer\n",
" super(EncoderRNN, self).__init__()\n",
"\n",
" # We'll define the network in a ModuleDict, which makes organizing the model a bit nicer\n",
" # The components are an embedding layer, and an LSTM layer.\n",
" self.model = nn.ModuleDict({\n",
" 'embeddings': nn.Embedding.from_pretrained(pretrained_embeddings, padding_idx=pretrained_embeddings.shape[0] - 1),\n",
" 'lstm': nn.LSTM(pretrained_embeddings.shape[1], lstm_dim, 2, batch_first=True, bidirectional=True),\n",
" })\n",
" # Initialize the weights of the model\n",
" self._init_weights()\n",
"\n",
" def _init_weights(self):\n",
" all_params = list(self.model['lstm'].named_parameters())\n",
" for n, p in all_params:\n",
" if 'weight' in n:\n",
" nn.init.xavier_normal_(p)\n",
" elif 'bias' in n:\n",
" nn.init.zeros_(p)\n",
"\n",
" def forward(self, inputs, input_lens):\n",
" \"\"\"\n",
" Defines how tensors flow through the model\n",
" :param inputs: (b x sl) The IDs into the vocabulary of the input samples\n",
" :param input_lens: (b) The length of each input sequence\n",
" :return: (lstm output state, lstm hidden state) \n",
" \"\"\"\n",
" embeds = self.model['embeddings'](inputs)\n",
" lstm_in = nn.utils.rnn.pack_padded_sequence(\n",
" embeds,\n",
" input_lens.cpu(),\n",
" batch_first=True,\n",
" enforce_sorted=False\n",
" )\n",
" lstm_out, hidden_states = self.model['lstm'](lstm_in)\n",
" lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)\n",
" return lstm_out, hidden_states\n",
"\n",
"\n",
"class DecoderRNN(nn.Module):\n",
" \"\"\"\n",
" RNN Decoder model.\n",
" \"\"\"\n",
" def __init__(self, pretrained_embeddings: torch.tensor, \n",
" lstm_dim: int,\n",
" dropout_prob: float = 0.1,\n",
" n_classes: int = 2):\n",
" \"\"\"\n",
" Initializer for DecoderRNN network\n",
" :param pretrained_embeddings: A tensor containing the pretrained embeddings\n",
" :param lstm_dim: The dimensionality of the LSTM network\n",
" :param dropout_prob: Dropout probability\n",
" :param n_classes: Number of prediction classes\n",
" \"\"\"\n",
" # First thing is to call the superclass initializer\n",
" super(DecoderRNN, self).__init__()\n",
" # We'll define the network in a ModuleDict, which makes organizing the model a bit nicer\n",
" # The components are an embedding layer, a LSTM layer, and a feed-forward output layer\n",
" self.model = nn.ModuleDict({\n",
" 'embeddings': nn.Embedding.from_pretrained(pretrained_embeddings, padding_idx=pretrained_embeddings.shape[0] - 1),\n",
" 'lstm': nn.LSTM(pretrained_embeddings.shape[1], lstm_dim, 2, bidirectional=True, batch_first=True),\n",
" 'nn': nn.Linear(lstm_dim*2, n_classes),\n",
" })\n",
" # Initialize the weights of the model\n",
" self._init_weights() \n",
"\n",
" def forward(self, inputs, hidden, input_lens):\n",
" \"\"\"\n",
" Defines how tensors flow through the model\n",
" :param inputs: (b x sl) The IDs into the vocabulary of the input samples\n",
" :param hidden: (b) The hidden state of the previous step\n",
" :param input_lens: (b) The length of each input sequence\n",
" :return: (output predictions, lstm hidden states) the hidden states will be used as input at the next step\n",
" \"\"\"\n",
" embeds = self.model['embeddings'](inputs)\n",
"\n",
" lstm_in = nn.utils.rnn.pack_padded_sequence(\n",
" embeds,\n",
" input_lens.cpu(),\n",
" batch_first=True,\n",
" enforce_sorted=False\n",
" )\n",
" lstm_out, hidden_states = self.model['lstm'](lstm_in, hidden)\n",
" lstm_out, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)\n",
" output = self.model['nn'](lstm_out)\n",
" return output, hidden_states\n",
"\n",
" def _init_weights(self):\n",
" all_params = list(self.model['lstm'].named_parameters()) + list(self.model['nn'].named_parameters())\n",
" for n, p in all_params:\n",
" if 'weight' in n:\n",
" nn.init.xavier_normal_(p)\n",
" elif 'bias' in n:\n",
" nn.init.zeros_(p)\n",
"\n",
"# Define the model\n",
"class Seq2Seq(nn.Module):\n",
" \"\"\"\n",
" Basic Seq2Seq network\n",
" \"\"\"\n",
" def __init__(\n",
" self,\n",
" pretrained_embeddings: torch.tensor,\n",
" lstm_dim: int,\n",
" dropout_prob: float = 0.1,\n",
" n_classes: int = 2\n",
" ):\n",
" \"\"\"\n",
" Initializer for basic Seq2Seq network\n",
" :param pretrained_embeddings: A tensor containing the pretrained embeddings\n",
" :param lstm_dim: The dimensionality of the LSTM network\n",
" :param dropout_prob: Dropout probability\n",
" :param n_classes: The number of output classes\n",
" \"\"\"\n",
"\n",
" # First thing is to call the superclass initializer\n",
" super(Seq2Seq, self).__init__()\n",
"\n",
" # We'll define the network in a ModuleDict, which consists of an encoder and a decoder\n",
" self.model = nn.ModuleDict({\n",
" 'encoder': EncoderRNN(pretrained_embeddings, lstm_dim, dropout_prob),\n",
" 'decoder': DecoderRNN(pretrained_embeddings, lstm_dim, dropout_prob, n_classes),\n",
" })\n",
" self.loss = nn.CrossEntropyLoss(weight=torch.FloatTensor([0.5]+[1]*(len(datasets[\"train\"].features[f\"ner_tags\"].feature.names)-1)).to(device))\n",
"\n",
"\n",
" def forward(self, inputs, input_lens, labels=None):\n",
" \"\"\"\n",
" Defines how tensors flow through the model. \n",
" For the Seq2Seq model this includes 1) encoding the whole input text, \n",
" and running *target_length* decoding steps to predict the tag of each token.\n",
"\n",
" :param inputs: (b x sl) The IDs into the vocabulary of the input samples\n",
" :param input_lens: (b) The length of each input sequence\n",
" :param labels: (b) The label of each sample\n",
" :return: (loss, logits) if `labels` is not None, otherwise just (logits,)\n",
" \"\"\"\n",
"\n",
" # Get embeddings (b x sl x embedding dim)\n",
" encoder_output, encoder_hidden = self.model['encoder'](inputs, input_lens)\n",
" decoder_hidden = encoder_hidden\n",
" decoder_input = torch.tensor([tokenizer.encode(['[BOS]'])]*inputs.shape[0], device=device)\n",
" target_length = labels.size(1)\n",
"\n",
" loss = None\n",
" for di in range(target_length):\n",
" decoder_output, decoder_hidden = self.model['decoder'](\n",
" decoder_input, decoder_hidden, torch.tensor([1]*inputs.shape[0]))\n",
"\n",
" if loss == None: \n",
" loss = self.loss(decoder_output.squeeze(1), labels[:, di])\n",
" else:\n",
" loss += self.loss(decoder_output.squeeze(1), labels[:, di])\n",
" # Teacher forcing: Feed the target as the next input\n",
" decoder_input = labels[:, di].unsqueeze(-1) \n",
"\n",
" return loss / target_length"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {
"id": "2OAd-kev7BcY"
},
"outputs": [],
"source": [
"def train(\n",
" model: nn.Module, \n",
" train_dl: DataLoader, \n",
" valid_dl: DataLoader, \n",
" optimizer: torch.optim.Optimizer, \n",
" n_epochs: int, \n",
" device: torch.device,\n",
"):\n",
" \"\"\"\n",
" The main training loop which will optimize a given model on a given dataset\n",
" :param model: The model being optimized\n",
" :param train_dl: The training dataset\n",
" :param valid_dl: A validation dataset\n",
" :param optimizer: The optimizer used to update the model parameters\n",
" :param n_epochs: Number of epochs to train for\n",
" :param device: The device to train on\n",
" :return: (model, losses) The best model and the losses per iteration\n",
" \"\"\"\n",
"\n",
" # Keep track of the loss and best accuracy\n",
" losses = []\n",
" best_f1 = 0.0\n",
"\n",
" # Iterate through epochs\n",
" for ep in range(n_epochs):\n",
"\n",
" loss_epoch = []\n",
"\n",
" #Iterate through each batch in the dataloader\n",
" for batch in tqdm(train_dl):\n",
" # VERY IMPORTANT: Make sure the model is in training mode, which turns on \n",
" # things like dropout and layer normalization\n",
" model.train()\n",
"\n",
" # VERY IMPORTANT: zero out all of the gradients on each iteration -- PyTorch\n",
" # keeps track of these dynamically in its computation graph so you need to explicitly\n",
" # zero them out\n",
" optimizer.zero_grad()\n",
"\n",
" # Place each tensor on the GPU\n",
" batch = tuple(t.to(device) for t in batch)\n",
" input_ids = batch[0]\n",
" labels = batch[2]\n",
" input_lens = batch[1]\n",
"\n",
" # Pass the inputs through the model, get the current loss and logits\n",
" loss = model(input_ids, labels=labels, input_lens=input_lens)\n",
" losses.append(loss.item())\n",
" loss_epoch.append(loss.item())\n",
"\n",
" # Calculate all of the gradients and weight updates for the model\n",
" loss.backward()\n",
"\n",
" # Optional: clip gradients\n",
" #torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
"\n",
" # Finally, update the weights of the model\n",
" optimizer.step()\n",
"\n",
" # Perform inline evaluation at the end of the epoch\n",
" f1 = evaluate(model, valid_dl)\n",
" print(f'Validation F1: {f1}, train loss: {sum(loss_epoch) / len(loss_epoch)}')\n",
"\n",
" # Keep track of the best model based on the accuracy\n",
" if f1 > best_f1:\n",
" torch.save(model.state_dict(), 'best_model')\n",
" best_f1 = f1\n",
"\n",
" return losses"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {
"id": "OykuF5dURSd5"
},
"outputs": [],
"source": [
"softmax = nn.Softmax(dim=-1)\n",
"\n",
"def decode(model, inputs, input_lens, labels=None, beam_size=2):\n",
" \"\"\"\n",
" Decoding/predicting the labels for an input text by running beam search.\n",
"\n",
" :param inputs: (b x sl) The IDs into the vocabulary of the input samples\n",
" :param input_lens: (b) The length of each input sequence\n",
" :param labels: (b) The label of each sample\n",
" :param beam_size: the size of the beam \n",
" :return: predicted sequence of labels\n",
" \"\"\"\n",
"\n",
" assert inputs.shape[0] == 1\n",
" # first, encode the input text\n",
" encoder_output, encoder_hidden = model.model['encoder'](inputs, input_lens)\n",
" decoder_hidden = encoder_hidden\n",
"\n",
" # the decoder starts generating after the Begining of Sentence (BOS) token\n",
" decoder_input = torch.tensor([tokenizer.encode(['[BOS]',]),], device=device)\n",
" target_length = labels.shape[1]\n",
" \n",
" # we will use heapq to keep top best sequences so far sorted in heap_queue \n",
" # these will be sorted by the first item in the tuple\n",
" heap_queue = []\n",
" heap_queue.append((torch.tensor(0), tokenizer.encode(['[BOS]']), decoder_input, decoder_hidden))\n",
"\n",
" # Beam Decoding\n",
" for _ in range(target_length):\n",
" # print(\"next len\")\n",
" new_items = []\n",
" # for each item on the beam\n",
" for j in range(len(heap_queue)): \n",
" # 1. remove from heap\n",
" score, tokens, decoder_input, decoder_hidden = heapq.heappop(heap_queue)\n",
" # 2. decode one more step\n",
" decoder_output, decoder_hidden = model.model['decoder'](\n",
" decoder_input, decoder_hidden, torch.tensor([1]))\n",
" decoder_output = softmax(decoder_output)\n",
" # 3. get top-k predictions\n",
" best_idx = torch.argsort(decoder_output[0], descending=True)[0]\n",
" # print(decoder_output)\n",
" # print(best_idx)\n",
" for i in range(beam_size):\n",
" decoder_input = torch.tensor([[best_idx[i]]], device=device)\n",
" \n",
" new_items.append((score + decoder_output[0,0, best_idx[i]],\n",
" tokens + [best_idx[i].item()], \n",
" decoder_input, \n",
" decoder_hidden))\n",
" # add new sequences to the heap\n",
" for item in new_items:\n",
" # print(item)\n",
" heapq.heappush(heap_queue, item)\n",
" # remove sequences with lowest score (items are sorted in descending order)\n",
" while len(heap_queue) > beam_size:\n",
" heapq.heappop(heap_queue)\n",
" \n",
" final_sequence = heapq.nlargest(1, heap_queue)[0]\n",
" assert labels.shape[1] == len(final_sequence[1][1:])\n",
" return final_sequence"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {
"id": "9tV1Sdlze9eO"
},
"outputs": [],
"source": [
"def evaluate(model: nn.Module, valid_dl: DataLoader, beam_size:int = 1):\n",
" \"\"\"\n",
" Evaluates the model on the given dataset\n",
" :param model: The model under evaluation\n",
" :param valid_dl: A `DataLoader` reading validation data\n",
" :return: The accuracy of the model on the dataset\n",
" \"\"\"\n",
" # VERY IMPORTANT: Put your model in \"eval\" mode -- this disables things like \n",
" # layer normalization and dropout\n",
" model.eval()\n",
" labels_all = []\n",
" logits_all = []\n",
" tags_all = []\n",
"\n",
" # ALSO IMPORTANT: Don't accumulate gradients during this process\n",
" with torch.no_grad():\n",
" for batch in tqdm(valid_dl, desc='Evaluation'):\n",
" batch = tuple(t.to(device) for t in batch)\n",
" input_ids = batch[0]\n",
" input_lens = batch[1]\n",
" labels = batch[2]\n",
"\n",
" best_seq = decode(model, input_ids, input_lens, labels=labels, beam_size=beam_size)\n",
" mask = (input_ids != 0)\n",
" labels_all.extend([l for seq,samp in zip(list(labels.detach().cpu().numpy()), input_ids) for l,i in zip(seq,samp) if i != 0])\n",
" tags_all += best_seq[1][1:]\n",
" # print(best_seq[1][1:], labels)\n",
" P, R, F1, _ = precision_recall_fscore_support(labels_all, tags_all, average='macro')\n",
" print(confusion_matrix(labels_all, tags_all))\n",
" return F1"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {
"id": "4KjDoQkl8Omy"
},
"outputs": [],
"source": [
"lstm_dim = 300\n",
"dropout_prob = 0.1\n",
"batch_size = 64\n",
"lr = 1e-3\n",
"n_epochs = 20\n",
"n_workers = 0\n",
"\n",
"device = torch.device(\"cpu\")\n",
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
"\n",
"# Create the model\n",
"model = Seq2Seq(\n",
" pretrained_embeddings=torch.FloatTensor(pretrained_embeddings), \n",
" lstm_dim=lstm_dim, \n",
" dropout_prob=dropout_prob, \n",
" n_classes=len(datasets[\"train\"].features[f\"ner_tags\"].feature.names)\n",
" ).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000,
"referenced_widgets": [
"e69a0f48b8104bd29adfab7a1409d1d4",
"ce96995f8d93445a93d4cb357ec3c72c",
"a86d9662ed444599a707f383ef9293ad",
"7ec12acc502c4bafa871a289054204b2",
"927b8def881c449abcfb01750074aaef",
"79da7ceb7e17457aaea7dd19ec96c0a2",
"8d6f82061cee47aca199947061c51514",
"66098eb69aaf4486bb2eddf5de1be7f4",
"f2944271cf214122a7cbea8952d9b6ba",
"b09a25113e134475b06d53d60538dbd0",
"24d8f5b9572049f9aa34c0a91142473b"
]
},
"id": "iHvNb7nm6kLI",
"outputId": "0436841f-d3a3-44aa-bc76-b9eda50f2fac"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "13dd97cb90cf43cfb2de15168cf9e933",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f0a9c8b452314326a030e7e8f5a32c11",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[40937 526 433 155 194 192 68 47 207]\n",
" [ 515 1052 218 15 13 14 3 2 10]\n",
" [ 379 146 758 3 9 3 3 1 5]\n",
" [ 467 49 19 686 66 32 5 10 7]\n",
" [ 285 12 25 65 323 10 13 5 13]\n",
" [ 627 35 10 91 23 951 46 35 19]\n",
" [ 92 4 5 12 26 22 79 1 16]\n",
" [ 535 16 8 24 9 34 6 246 44]\n",
" [ 154 5 5 22 24 18 12 14 92]]\n",
"Validation F1: 0.5148684803742811, train loss: 0.1254397713663903\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d432903540744d39b4e7d2ef0ef9628e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "71c753c0072648c492a0c07f09d6e260",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[40938 372 251 212 345 316 34 181 110]\n",
" [ 353 1221 148 35 26 28 2 20 9]\n",
" [ 329 94 822 14 28 5 0 3 12]\n",
" [ 275 39 10 820 114 43 2 32 6]\n",
" [ 190 2 17 58 435 8 7 6 28]\n",
" [ 404 21 3 66 24 1264 22 26 7]\n",
" [ 84 3 5 3 32 23 97 0 10]\n",
" [ 342 15 3 35 21 60 2 414 30]\n",
" [ 141 4 3 9 30 12 15 13 119]]\n",
"Validation F1: 0.5992961020420663, train loss: 0.08672628991983154\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fe22d796d2f54ecf91fb205bce63eab7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "38cd9267f3b74ba3af01ff34e08fc615",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41156 413 324 116 301 233 76 99 41]\n",
" [ 341 1254 170 13 19 29 4 7 5]\n",
" [ 279 119 866 2 16 12 5 3 5]\n",
" [ 249 59 8 851 109 38 5 18 4]\n",
" [ 137 9 24 30 511 6 13 5 16]\n",
" [ 345 27 9 59 25 1314 26 29 3]\n",
" [ 51 3 8 8 22 27 126 1 11]\n",
" [ 350 22 4 39 10 27 0 459 11]\n",
" [ 132 3 5 11 30 8 12 28 117]]\n",
"Validation F1: 0.6389387782477018, train loss: 0.06806593586436727\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "953b998954c4469b9d5b4b76a56f1e94",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "33f55519bc774106b430edd393955fd4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[40798 449 320 148 528 183 57 120 156]\n",
" [ 225 1307 201 18 25 34 6 7 19]\n",
" [ 237 98 914 5 22 10 9 0 12]\n",
" [ 144 59 21 890 147 52 6 17 5]\n",
" [ 98 4 23 22 578 1 9 0 16]\n",
" [ 255 33 9 62 36 1372 25 21 24]\n",
" [ 36 1 5 4 38 14 139 2 18]\n",
" [ 266 33 8 66 27 30 1 450 41]\n",
" [ 83 5 10 10 68 1 6 12 151]]\n",
"Validation F1: 0.6429197452938862, train loss: 0.057959155670621175\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "19b33a42ee594495bc961ca822eeff6a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e4e5c8e95d8c44b7acac7b1ca343ab8b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41519 298 217 66 69 183 18 135 254]\n",
" [ 255 1427 105 1 5 23 0 7 19]\n",
" [ 183 120 979 1 3 6 1 0 14]\n",
" [ 234 79 7 853 32 60 1 52 23]\n",
" [ 147 7 24 49 424 16 13 9 62]\n",
" [ 229 29 10 35 1 1476 7 33 17]\n",
" [ 41 1 6 1 18 24 153 0 13]\n",
" [ 265 15 1 16 2 17 1 568 37]\n",
" [ 76 5 8 9 13 8 5 32 190]]\n",
"Validation F1: 0.7031936004879823, train loss: 0.04639621409164234\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d82c7d40de8b49ac9ebb039c9a7726cb",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a5f69901f1544a2faed4986c4370d85e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41100 261 184 172 518 160 125 102 137]\n",
" [ 181 1429 119 22 27 34 5 4 21]\n",
" [ 168 75 997 7 27 11 5 0 17]\n",
" [ 84 29 12 985 160 43 4 19 5]\n",
" [ 45 3 12 9 633 2 23 3 21]\n",
" [ 146 14 2 57 25 1539 30 14 10]\n",
" [ 19 0 3 2 27 8 192 0 6]\n",
" [ 159 22 2 54 28 35 11 574 37]\n",
" [ 52 2 10 7 48 2 14 9 202]]\n",
"Validation F1: 0.7135904551805614, train loss: 0.03937693921510469\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "aac4212c66d349d2b25835f22a960d48",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fea1c973632640c19a76570b2650dd98",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41380 226 245 125 370 126 34 118 135]\n",
" [ 188 1430 132 13 37 13 4 9 16]\n",
" [ 163 58 1035 1 34 2 2 0 12]\n",
" [ 103 60 10 978 129 28 3 22 8]\n",
" [ 45 4 31 24 621 2 9 2 13]\n",
" [ 162 29 5 94 23 1445 17 47 15]\n",
" [ 18 1 10 2 35 14 160 0 17]\n",
" [ 207 22 5 43 19 13 0 580 33]\n",
" [ 55 4 15 8 48 4 5 12 195]]\n",
"Validation F1: 0.7209686415011278, train loss: 0.03529796007258648\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f167021136c94d229be4d81596dd9042",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e0e3a72de210442e91e5e964982be7cd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41846 201 116 128 205 100 19 87 57]\n",
" [ 165 1507 98 21 16 22 0 6 7]\n",
" [ 154 66 1045 5 25 3 3 1 5]\n",
" [ 92 22 10 1051 95 48 1 21 1]\n",
" [ 72 3 3 21 619 3 17 3 10]\n",
" [ 139 12 1 75 11 1557 15 20 7]\n",
" [ 31 0 3 2 21 14 175 2 9]\n",
" [ 176 16 0 48 11 30 2 610 29]\n",
" [ 70 2 6 5 29 3 8 13 210]]\n",
"Validation F1: 0.7763060675321695, train loss: 0.030404941890050063\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "710033482286413e8e7e6ea044048564",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d2a06a7b53f1489dba3f404895e980b0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41926 220 107 83 116 119 8 128 52]\n",
" [ 169 1527 93 7 4 26 1 9 6]\n",
" [ 201 69 1021 1 5 5 1 0 4]\n",
" [ 97 45 3 998 73 66 1 51 7]\n",
" [ 80 3 17 22 546 6 14 5 58]\n",
" [ 154 15 1 41 4 1566 11 39 6]\n",
" [ 34 0 4 0 14 14 174 2 15]\n",
" [ 156 15 0 30 3 24 1 673 20]\n",
" [ 77 1 7 5 18 3 5 22 208]]\n",
"Validation F1: 0.7765224797107156, train loss: 0.024601146790452978\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "adea3499b9284a8687334a8241b9b860",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "979fb8ccfe4f42bfaba154726b45e3b2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41690 257 165 140 109 143 55 114 86]\n",
" [ 152 1519 104 21 2 24 7 6 7]\n",
" [ 144 80 1061 3 4 4 7 0 4]\n",
" [ 104 34 9 1072 61 39 1 14 7]\n",
" [ 90 6 12 26 578 5 15 3 16]\n",
" [ 126 14 6 72 5 1560 19 21 14]\n",
" [ 21 0 3 1 22 11 186 0 13]\n",
" [ 152 19 3 51 9 27 1 635 25]\n",
" [ 67 1 8 4 32 3 14 16 201]]\n",
"Validation F1: 0.7671100558971141, train loss: 0.021562039246782662\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e7506367fde44beb90d76dea3299c2d4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f1d39f20a6fe4269b0b6f4363fb96144",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41929 121 84 86 314 85 32 76 32]\n",
" [ 149 1549 57 17 31 25 4 5 5]\n",
" [ 159 40 1061 6 34 1 4 0 2]\n",
" [ 66 25 4 1075 96 52 3 18 2]\n",
" [ 41 1 7 21 635 3 24 1 18]\n",
" [ 81 14 2 64 19 1622 16 17 2]\n",
" [ 20 1 4 3 37 8 180 0 4]\n",
" [ 132 16 2 47 18 22 1 665 19]\n",
" [ 62 3 10 2 23 3 16 12 215]]\n",
"Validation F1: 0.7923309583490342, train loss: 0.017838069995526562\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6df76dbe20244a48a208e804a9005d19",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf3801f400224652865fe7cdf11c8807",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41880 187 148 114 141 112 9 102 66]\n",
" [ 175 1561 61 11 5 16 1 8 4]\n",
" [ 113 92 1079 4 6 4 4 0 5]\n",
" [ 91 31 5 1093 66 25 2 20 8]\n",
" [ 59 3 7 23 618 2 11 4 24]\n",
" [ 111 23 3 84 5 1562 8 30 11]\n",
" [ 15 0 5 5 32 9 180 1 10]\n",
" [ 131 17 2 47 11 16 2 672 24]\n",
" [ 51 1 9 10 25 1 3 17 229]]\n",
"Validation F1: 0.7990632659575675, train loss: 0.01514953335757706\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf654d8b73dd418c987e7854857f56fd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c447d327234449a49315a3ccfbd7a860",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41927 225 153 97 89 94 21 88 65]\n",
" [ 157 1546 86 17 3 23 1 5 4]\n",
" [ 140 74 1072 7 6 4 0 1 3]\n",
" [ 86 43 4 1091 37 44 1 29 6]\n",
" [ 72 1 6 20 578 5 23 2 44]\n",
" [ 109 18 3 60 4 1606 7 25 5]\n",
" [ 18 0 3 4 21 7 192 1 11]\n",
" [ 146 17 2 39 3 18 1 676 20]\n",
" [ 52 2 4 7 19 2 4 26 230]]\n",
"Validation F1: 0.8022651715326824, train loss: 0.015754620486404747\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "51e5c42780204659ab0eaf9a5d6a47be",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "57f3e11cb93e46e3bc741592040add79",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41978 195 158 90 87 129 13 67 42]\n",
" [ 149 1556 72 23 5 21 1 12 3]\n",
" [ 123 73 1088 4 12 0 2 1 4]\n",
" [ 107 30 5 1098 35 44 1 15 6]\n",
" [ 70 2 6 27 604 4 17 2 19]\n",
" [ 110 17 4 68 4 1599 11 18 6]\n",
" [ 22 0 1 1 22 14 185 1 11]\n",
" [ 157 24 3 46 5 23 3 643 18]\n",
" [ 58 2 8 12 27 6 9 23 201]]\n",
"Validation F1: 0.8001482986272885, train loss: 0.013831334230913357\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ed37a92f4c1747dc80ae059c944eca99",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e6011160d4f5481a9c9c3f5f61891331",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[42161 163 113 72 47 80 12 71 40]\n",
" [ 164 1550 81 8 2 29 0 6 2]\n",
" [ 139 74 1078 2 4 3 4 0 3]\n",
" [ 103 29 1 1084 35 59 1 26 3]\n",
" [ 70 2 7 26 590 6 25 3 22]\n",
" [ 108 14 3 45 1 1636 9 20 1]\n",
" [ 23 0 2 2 9 15 197 1 8]\n",
" [ 162 15 0 27 2 21 1 682 12]\n",
" [ 56 2 8 3 13 3 7 22 232]]\n",
"Validation F1: 0.8245979034012647, train loss: 0.011220586808948692\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ac67918f890e484aa71fa55efcbf7f98",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "926c67f91166414492168ee5debca915",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41968 188 124 98 86 96 27 106 66]\n",
" [ 167 1528 84 18 8 29 3 3 2]\n",
" [ 173 64 1048 0 4 3 8 5 2]\n",
" [ 81 34 6 1051 66 76 1 24 2]\n",
" [ 55 3 14 24 599 4 31 3 18]\n",
" [ 89 7 7 58 4 1632 11 26 3]\n",
" [ 18 1 2 2 19 6 199 1 9]\n",
" [ 139 10 2 32 5 16 3 698 17]\n",
" [ 59 1 5 3 17 2 8 20 231]]\n",
"Validation F1: 0.8035197642799284, train loss: 0.013202182106165724\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "824d22a71c55450aacfbcc060dade879",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cf71805e1100468e98c41ceaddced65e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[42038 168 125 68 86 94 43 84 53]\n",
" [ 163 1554 77 18 3 18 2 4 3]\n",
" [ 149 67 1082 1 2 3 1 0 2]\n",
" [ 84 39 8 1064 51 63 4 26 2]\n",
" [ 55 3 9 22 606 8 36 0 12]\n",
" [ 91 15 3 43 2 1643 19 18 3]\n",
" [ 11 0 6 0 16 9 209 0 6]\n",
" [ 140 20 0 30 4 13 1 693 21]\n",
" [ 49 4 8 4 25 2 5 12 237]]\n",
"Validation F1: 0.8160772657119046, train loss: 0.009503976078386503\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "337938da8c154bae9d300c26c4ae7217",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "199d539df6d24ced989c2a5b7092cdf6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[42097 164 114 80 66 105 6 91 36]\n",
" [ 148 1556 67 23 2 27 0 13 6]\n",
" [ 152 63 1063 1 10 8 1 0 9]\n",
" [ 87 28 4 1084 44 60 2 29 3]\n",
" [ 68 3 5 26 582 6 27 4 30]\n",
" [ 83 18 2 40 3 1657 10 23 1]\n",
" [ 21 0 2 0 15 6 204 0 9]\n",
" [ 132 16 0 24 2 22 3 709 14]\n",
" [ 61 0 5 4 11 2 8 13 242]]\n",
"Validation F1: 0.8246442255568004, train loss: 0.008092457134096714\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ebbd0b56207d46a7973c4b617b3f929f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "40834b84df014611a0cbd4a2dc0139ab",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[42137 129 89 100 91 70 12 73 58]\n",
" [ 139 1578 62 17 9 24 0 10 3]\n",
" [ 127 51 1103 5 12 4 0 1 4]\n",
" [ 78 26 3 1105 53 41 0 31 4]\n",
" [ 53 1 8 20 627 4 16 3 19]\n",
" [ 80 19 2 51 2 1653 9 17 4]\n",
" [ 18 0 2 0 23 11 195 0 8]\n",
" [ 120 15 0 33 6 13 0 711 24]\n",
" [ 50 2 8 6 21 3 5 17 234]]\n",
"Validation F1: 0.8295417224894234, train loss: 0.013141678918724541\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "553bb14b1ee34f87818e208540d2e28f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/220 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "077f7816085b4e92bcd1a50f876059c5",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3250 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[41750 220 126 184 113 160 28 100 78]\n",
" [ 156 1528 93 27 11 16 3 8 0]\n",
" [ 185 44 1052 2 15 4 2 1 2]\n",
" [ 113 31 3 1047 86 39 4 14 4]\n",
" [ 72 1 8 27 601 5 19 1 17]\n",
" [ 150 26 3 58 2 1567 16 13 2]\n",
" [ 23 0 9 1 20 7 196 0 1]\n",
" [ 129 16 0 42 8 25 2 670 30]\n",
" [ 52 2 8 4 28 4 6 13 229]]\n",
"Validation F1: 0.7906017625430706, train loss: 0.0177222935853272\n"
]
},
{
"data": {
"text/plain": [
""
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_dl = DataLoader(datasets['train'], batch_size=batch_size, shuffle=True, collate_fn=collate_batch_bilstm, num_workers=n_workers)\n",
"valid_dl = DataLoader(datasets['validation'], batch_size=1, collate_fn=collate_batch_bilstm, num_workers=n_workers)\n",
"\n",
"# Create the optimizer\n",
"optimizer = Adam(model.parameters(), lr=lr)\n",
"\n",
"# Train\n",
"losses = train(model, train_dl, valid_dl, optimizer, n_epochs, device)\n",
"model.load_state_dict(torch.load('best_model'))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R80wA-upBHDD"
},
"source": [
"**Question: Do you have ideas how to improve the model?**\n",
"How about adding attention mechanism for the decoder to attend to the separate hidden states of the separate token steps in the encoder? (see the resources)\n"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 311
},
"id": "o1yYPFe5wKC5",
"outputId": "95db2ec2-2fdf-444b-9dad-42c2d7615094"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
" cpuset_checked))\n",
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:17: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
"Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a8d16dfe21b74ea98e7853171b8c66a1",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3453 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[37092 237 111 313 218 97 32 133 90]\n",
" [ 188 1254 51 73 19 20 0 9 3]\n",
" [ 177 55 868 3 38 4 4 3 4]\n",
" [ 133 58 4 1264 77 69 2 53 1]\n",
" [ 93 4 12 33 639 7 30 2 15]\n",
" [ 140 42 5 107 24 1304 11 32 3]\n",
" [ 26 0 10 4 33 9 165 1 9]\n",
" [ 94 23 6 57 10 16 1 479 16]\n",
" [ 45 3 8 1 23 1 2 11 122]]\n"
]
},
{
"data": {
"text/plain": [
"0.727843916681434"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_dl = DataLoader(datasets['test'], batch_size=1, collate_fn=collate_batch_bilstm, num_workers=n_workers)\n",
"evaluate(model, test_dl, beam_size=1)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 311,
"referenced_widgets": [
"a3e56bc011584c638aba5ebb200cc409",
"2c62f6825ac549668b5adb970201f9ee",
"311130cce54541aeb64ed9ee04338e85",
"e194b1b75f4c40b88e7ecf6b7ce308f4",
"6ffb7f64d4924c73b468a160a730d785",
"6d84ccb32486488db4eb66ca7401b3c2",
"6aa949d213804e3ba5bfdfb1b8c89794",
"02dd4f78a3e74dbe8c6a3012fe1225d5",
"7926b0ae95a643ad90c0341340cf2b43",
"eff6ea1ca2d1415683784d3ce95a9f07",
"a8d985e224b64b6baa282f4d00a00a74"
]
},
"id": "5jBpWDyExq1q",
"outputId": "068ef3dc-d1db-4fec-8135-726334a8dca4"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:17: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
"Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a3e56bc011584c638aba5ebb200cc409",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Evaluation: 0%| | 0/3453 [00:00, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n",
" cpuset_checked))\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[37128 232 121 311 209 93 32 128 69]\n",
" [ 185 1270 49 68 14 18 0 10 3]\n",
" [ 163 60 887 4 32 2 1 3 4]\n",
" [ 135 64 5 1275 66 70 2 43 1]\n",
" [ 86 6 15 30 645 8 30 2 13]\n",
" [ 126 41 7 105 26 1314 16 31 2]\n",
" [ 28 0 10 4 36 7 167 0 5]\n",
" [ 92 26 6 61 10 16 2 476 13]\n",
" [ 48 2 8 1 22 1 2 12 120]]\n"
]
},
{
"data": {
"text/plain": [
"0.7355817928230471"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"evaluate(model, test_dl, beam_size=2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mDcsAYEKFJB3"
},
"source": [
"# Transformers for sequence classification\n",
"\n",
"- have to adjust the vocabulary where a word is split into multiple word piesces\n",
"- [Tutorial on NER](https://github.com/huggingface/notebooks/blob/master/examples/token_classification.ipynb)\n",
"- Some generative transformers now perform the same as token classification transformers [e.g. T5 can extract the span of a tweet that contains a sentiment](https://github.com/enzoampil/t5-intro/blob/master/t5_qa_training_pytorch_span_extraction.ipynb)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "lab_5",
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}