{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "bOChJSNXtC9g" }, "source": [ "# Recurrent Neural Networks" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "OLIxEDq6VhvZ" }, "source": [ "\n", "\n", "When working with sequential data (time-series, sentences, etc.) the order of the inputs is crucial for the task at hand. Recurrent neural networks (RNNs) process sequential data by accounting for the current input and also what has been learned from previous inputs. In this notebook, we'll learn how to create and train RNNs on sequential data.\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "VoMq0eFRvugb" }, "source": [ "# Overview" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qWro5T5qTJJL" }, "source": [ "* **Objective:** Process sequential data by accounting for the currend input and also what has been learned from previous inputs.\n", "* **Advantages:** \n", " * Account for order and previous inputs in a meaningful way.\n", " * Conditioned generation for generating sequences.\n", "* **Disadvantages:** \n", " * Each time step's prediction depends on the previous prediction so it's difficult to parallelize RNN operations. \n", " * Processing long sequences can yield memory and computation issues.\n", " * Interpretability is difficult but there are few [techniques](https://arxiv.org/abs/1506.02078) that use the activations from RNNs to see what parts of the inputs are processed. \n", "* **Miscellaneous:** \n", " * Architectural tweaks to make RNNs faster and interpretable is an ongoing area of research." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rsHeBbehrKzl" }, "source": [ "\n", "\n", "RNN forward pass for a single time step $X_t$:\n", "\n", "$h_t = tanh(W_{hh}h_{t-1} + W_{xh}X_t+b_h)$\n", "\n", "$y_t = W_{hy}h_t + b_y $\n", "\n", "$ P(y) = softmax(y_t) = \\frac{e^y}{\\sum e^y} $\n", "\n", "*where*:\n", "* $X_t$ = input at time step t | $\\in \\mathbb{R}^{NXE}$ ($N$ is the batch size, $E$ is the embedding dim)\n", "* $W_{hh}$ = hidden units weights| $\\in \\mathbb{R}^{HXH}$ ($H$ is the hidden dim)\n", "* $h_{t-1}$ = previous timestep's hidden state $\\in \\mathbb{R}^{NXH}$\n", "* $W_{xh}$ = input weights| $\\in \\mathbb{R}^{EXH}$\n", "* $b_h$ = hidden units bias $\\in \\mathbb{R}^{HX1}$\n", "* $W_{hy}$ = output weights| $\\in \\mathbb{R}^{HXC}$ ($C$ is the number of classes)\n", "* $b_y$ = output bias $\\in \\mathbb{R}^{CX1}$\n", "\n", "You repeat this for every time step's input ($X_{t+1}, X_{t+2}, ..., X_{N})$ to the get the predicted outputs at each time step.\n", "\n", "**Note**: At the first time step, the previous hidden state $h_{t-1}$ can either be a zero vector (unconditioned) or initialize (conditioned). If we are conditioning the RNN, the first hidden state $h_0$ can belong to a specific condition or we can concat the specific condition to the randomly initialized hidden vectors at each time step. More on this in the subsequent notebooks on RNNs." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "dIXlGMExJD6w" }, "source": [ "Let's see what the forward pass looks like with an RNN for a synthetic task such as processing reviews (a sequence of words) to predict the sentiment at the end of processing the review." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "RcWE5cw0_cKA", "outputId": "a44156b9-b43f-409c-f0ce-4a4bd871d6a0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.0.0)\n" ] } ], "source": [ "# Load PyTorch library\n", "!pip3 install torch" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "o6eEK1wM_dXG" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "Qi9hIEV6COLF" }, "outputs": [], "source": [ "batch_size = 5\n", "seq_size = 10 # max length per input (masking will be used for sequences that aren't this max length)\n", "x_lengths = [8, 5, 4, 10, 5] # lengths of each input sequence\n", "embedding_dim = 100\n", "rnn_hidden_dim = 256\n", "output_dim = 4" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "bLEzfxjhB94C", "outputId": "f2feefbf-8635-4b23-ef53-b5713cf2cdb2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([5, 10, 100])\n" ] } ], "source": [ "# Initialize synthetic inputs\n", "x_in = torch.randn(batch_size, seq_size, embedding_dim)\n", "x_lengths = torch.tensor(x_lengths)\n", "print (x_in.size())" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "dr6oLqtXB98N", "outputId": "9817e88d-6e73-414a-dfa6-2386f40db0d9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([5, 256])\n" ] } ], "source": [ "# Initialize hidden state\n", "hidden_t = torch.zeros((batch_size, rnn_hidden_dim))\n", "print (hidden_t.size())" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "ryZMOLLgB9-v", "outputId": "14ec0a2a-bf37-4e03-b69b-099180f8f149" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RNNCell(100, 256)\n" ] } ], "source": [ "# Initialize RNN cell\n", "rnn_cell = nn.RNNCell(embedding_dim, rnn_hidden_dim)\n", "print (rnn_cell)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "rlbZ7ujxExXb", "outputId": "6c83ba2b-94c5-4f76-c8fb-ef0c1ccdeb37" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([5, 10, 256])\n" ] } ], "source": [ "# Forward pass through RNN\n", "x_in = x_in.permute(1, 0, 2) # RNN needs batch_size to be at dim 1\n", "\n", "# Loop through the inputs time steps\n", "hiddens = []\n", "for t in range(seq_size):\n", " hidden_t = rnn_cell(x_in[t], hidden_t)\n", " hiddens.append(hidden_t)\n", "hiddens = torch.stack(hiddens)\n", "hiddens = hiddens.permute(1, 0, 2) # bring batch_size back to dim 0\n", "print (hiddens.size())" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "3TTL-jmg-MHa", "outputId": "3fae323f-c37d-4dac-c8a8-7fea7a45c95c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "out: torch.Size([5, 10, 256])\n", "h_n: torch.Size([1, 5, 256])\n" ] } ], "source": [ "# We also could've used a more abstracted layer\n", "x_in = torch.randn(batch_size, seq_size, embedding_dim)\n", "rnn = nn.RNN(embedding_dim, rnn_hidden_dim, batch_first=True)\n", "out, h_n = rnn(x_in) #h_n is the last hidden state\n", "print (\"out: \", out.size())\n", "print (\"h_n: \", h_n.size())" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "iAsyRNnbHwcT" }, "outputs": [], "source": [ "def gather_last_relevant_hidden(hiddens, x_lengths):\n", " x_lengths = x_lengths.long().detach().cpu().numpy() - 1\n", " out = []\n", " for batch_index, column_index in enumerate(x_lengths):\n", " out.append(hiddens[batch_index, column_index])\n", " return torch.stack(out)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "PVhp1KLqHqpA", "outputId": "d04be3ef-c2d6-48b9-f0f5-a93f619ec594" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([5, 256])\n" ] } ], "source": [ "# Gather the last relevant hidden state\n", "z = gather_last_relevant_hidden(hiddens, x_lengths)\n", "print (z.size())" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 119 }, "colab_type": "code", "id": "yGk_iZ5cITZl", "outputId": "84749ff2-1e45-4599-a38d-8c83cee116a9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([5, 4])\n", "tensor([[0.3030, 0.2351, 0.2168, 0.2452],\n", " [0.2614, 0.1912, 0.2617, 0.2858],\n", " [0.2428, 0.2600, 0.2254, 0.2717],\n", " [0.2379, 0.2226, 0.1901, 0.3494],\n", " [0.2629, 0.2854, 0.2146, 0.2371]], grad_fn=)\n" ] } ], "source": [ "# Forward pass through FC layer\n", "fc1 = nn.Linear(rnn_hidden_dim, output_dim)\n", "y_pred = fc1(z)\n", "y_pred = F.softmax(y_pred, dim=1)\n", "print (y_pred.size())\n", "print (y_pred)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "hPBQpki_n6yY" }, "source": [ "# Sequential data" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "kP1awuluoCSr" }, "source": [ "There are a variety of different sequential tasks that RNNs can help with.\n", "\n", "1. **One to one**: there is one input and produces one output. \n", " * Ex. Given a word predict it's class (verb, noun, etc.).\n", "2. **One to many**: one input generates many outputs.\n", " * Ex. Given a sentiment (positive, negative, etc.) generate a review.\n", "3. **Many to one**: Many inputs are sequentially processed to generate one output.\n", " * Ex. Process the words in a review to predict the sentiment.\n", "4. **Many to many**: Many inputs are sequentially processed to generate many outputs.\n", " * Ex. Given a sentence in French, processes the entire sentence and then generate the English translation.\n", " * Ex. Given a sequence of time-series data, predict the probability of an event (risk of disease) at each time step.\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "tnxUIEMdukYY" }, "source": [ "# Issues with vanilla RNNs" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "uMx2s93VLUTt" }, "source": [ "There are several issues with the vanilla RNN that we've seen so far. \n", "\n", "1. When we have an input sequence that has many time steps, it becomes difficult for the model to retain information seen earlier as we process more and more of the downstream timesteps. The goals of the model is to retain the useful components in the previously seen time steps but this becomes cumbersome when we have so many time steps to process. \n", "\n", "2. During backpropagation, the gradient from the loss has to travel all the way back towards the first time step. If our gradient is larger than 1 (${1.01}^{1000} = 20959$) or less than 1 (${0.99}^{1000} = 4.31e-5$) and we have lot's of time steps, this can quickly spiral out of control.\n", "\n", "To address both these issues, the concept of gating was introduced to RNNs. Gating allows RNNs to control the information flow between each time step to optimize on the task. Selectively allowing information to pass through allows the model to process inputs with many time steps. The most common RNN gated varients are the long short term memory ([LSTM](https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM)) units and gated recurrent units ([GRUs](https://pytorch.org/docs/stable/nn.html#torch.nn.GRU)). You can read more about how these units work [here](http://colah.github.io/posts/2015-08-Understanding-LSTMs/).\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "tirko0kwp-9J" }, "outputs": [], "source": [ "# GRU in PyTorch\n", "gru = nn.GRU(input_size=embedding_dim, hidden_size=rnn_hidden_dim, \n", " batch_first=True)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "UZjUhh4VBWxM", "outputId": "9fe275fe-c8d9-42f0-e5d0-0295268ed83d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([5, 10, 100])\n" ] } ], "source": [ "# Initialize synthetic input\n", "x_in = torch.randn(batch_size, seq_size, embedding_dim)\n", "print (x_in.size())" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "xJ_SE7AvBfa4", "outputId": "b9411aaa-fab1-4104-aee7-8f9a423332ab" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "out: torch.Size([5, 10, 256])\n", "h_n: torch.Size([1, 5, 256])\n" ] } ], "source": [ "# Forward pass\n", "out, h_n = gru(x_in)\n", "print (\"out:\", out.size())\n", "print (\"h_n:\", h_n.size())" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ij_GA2Rr9BbA" }, "source": [ "**Note**: Choosing whether to use GRU or LSTM really depends on the data and empirical performance. GRUs offer comparable performance with reduce number of parameters while LSTMs are more efficient and may make the difference in performance for your particular task." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "9agJw4gwK1LC" }, "source": [ "# Bidirectional RNNs" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Xck0n-KpmXkV" }, "source": [ "There have been many advancements with RNNs ([attention](https://www.oreilly.com/ideas/interpretability-via-attentional-and-memory-based-interfaces-using-tensorflow), Quasi RNNs, etc.) that we will cover in later lessons but one of the basic and widely used ones are bidirectional RNNs (Bi-RNNs). The motivation behind bidirectional RNNs is to process an input sequence by both directions. Accounting for context from both sides can aid in performance when the entire input sequence is known at time of inference. A common application of Bi-RNNs is in translation where it's advantageous to look at an entire sentence from both sides when translating to another language (ie. Japanese → English).\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "gSk_5XrvApCd" }, "outputs": [], "source": [ "# BiGRU in PyTorch\n", "bi_gru = nn.GRU(input_size=embedding_dim, hidden_size=rnn_hidden_dim, \n", " batch_first=True, bidirectional=True)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "Fx7-GTptBCtZ", "outputId": "f0242cc5-534a-460b-ebe0-4e8c504fab22" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "out: torch.Size([5, 10, 512])\n", "h_n: torch.Size([2, 5, 256])\n" ] } ], "source": [ "# Forward pass\n", "out, h_n = bi_gru(x_in)\n", "print (\"out:\", out.size()) # collection of all hidden states from the RNN for each time step\n", "print (\"h_n:\", h_n.size()) # last hidden state from the RNN" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "k5lvJirLBjI6" }, "source": [ "Notice that the output for each sample at each timestamp has size 512 (double the hidden dim). This is because this includes both the forward and backward directions from the BiRNN. " ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "mJSknbofK2S9" }, "source": [ "# Document classification with RNNs" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "JgYdEZmHlmft" }, "source": [ "Let's apply RNNs to the document classification task from the [emebddings notebook](https://colab.research.google.com/drive/1yDa5ZTqKVoLl-qRgH-N9xs3pdrDJ0Fb4) where we want to predict an article's category given its title." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "eIvXqvPQEiDC" }, "source": [ "## Set up" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "muTcvMynlmAu" }, "outputs": [], "source": [ "import os\n", "from argparse import Namespace\n", "import collections\n", "import copy\n", "import json\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import re\n", "import torch" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "00ESjecep-_y" }, "outputs": [], "source": [ "# Set Numpy and PyTorch seeds\n", "def set_seeds(seed, cuda):\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " if cuda:\n", " torch.cuda.manual_seed_all(seed)\n", " \n", "# Creating directories\n", "def create_dirs(dirpath):\n", " if not os.path.exists(dirpath):\n", " os.makedirs(dirpath)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 34 }, "colab_type": "code", "id": "m67THDvxEl1e", "outputId": "7118c77b-cbf9-4d7e-ff7a-b9dc1fb63cbb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using CUDA: True\n" ] } ], "source": [ "# Arguments\n", "args = Namespace(\n", " seed=1234,\n", " cuda=True,\n", " shuffle=True,\n", " data_file=\"news.csv\",\n", " split_data_file=\"split_news.csv\",\n", " vectorizer_file=\"vectorizer.json\",\n", " model_state_file=\"model.pth\",\n", " save_dir=\"news\",\n", " train_size=0.7,\n", " val_size=0.15,\n", " test_size=0.15,\n", " pretrained_embeddings=None,\n", " cutoff=25, # token must appear at least times to be in SequenceVocabulary\n", " num_epochs=5,\n", " early_stopping_criteria=5,\n", " learning_rate=1e-3,\n", " batch_size=64,\n", " embedding_dim=100,\n", " rnn_hidden_dim=128,\n", " hidden_dim=100,\n", " num_layers=1,\n", " bidirectional=False,\n", " dropout_p=0.1,\n", ")\n", "\n", "# Set seeds\n", "set_seeds(seed=args.seed, cuda=args.cuda)\n", "\n", "# Create save dir\n", "create_dirs(args.save_dir)\n", "\n", "# Expand filepaths\n", "args.vectorizer_file = os.path.join(args.save_dir, args.vectorizer_file)\n", "args.model_state_file = os.path.join(args.save_dir, args.model_state_file)\n", "\n", "# Check CUDA\n", "if not torch.cuda.is_available():\n", " args.cuda = False\n", "args.device = torch.device(\"cuda\" if args.cuda else \"cpu\")\n", "print(\"Using CUDA: {}\".format(args.cuda))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "s7T-_kGvExVW" }, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "XVyK25xOEwjN" }, "outputs": [], "source": [ "import re\n", "import urllib" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "M_gclwECEwll" }, "outputs": [], "source": [ "# Upload data from GitHub to notebook's local drive\n", "url = \"https://raw.githubusercontent.com/LisonEvf/practicalAI-cn/master/data/news.csv\"\n", "response = urllib.request.urlopen(url)\n", "html = response.read()\n", "with open(args.data_file, 'wb') as fp:\n", " fp.write(html)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 204 }, "colab_type": "code", "id": "V244zOIPEwoP", "outputId": "ab8b5cab-4e25-436e-9cb3-0db6f524eb9a" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
categorytitle
0BusinessWall St. Bears Claw Back Into the Black (Reuters)
1BusinessCarlyle Looks Toward Commercial Aerospace (Reu...
2BusinessOil and Economy Cloud Stocks' Outlook (Reuters)
3BusinessIraq Halts Oil Exports from Main Southern Pipe...
4BusinessOil prices soar to all-time record, posing new...
\n", "
" ], "text/plain": [ " category title\n", "0 Business Wall St. Bears Claw Back Into the Black (Reuters)\n", "1 Business Carlyle Looks Toward Commercial Aerospace (Reu...\n", "2 Business Oil and Economy Cloud Stocks' Outlook (Reuters)\n", "3 Business Iraq Halts Oil Exports from Main Southern Pipe...\n", "4 Business Oil prices soar to all-time record, posing new..." ] }, "execution_count": 25, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "# Raw data\n", "df = pd.read_csv(args.data_file, header=0)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "id": "ICl2MNK4EwrL", "outputId": "d2073597-71e5-40b1-a845-90bf4913ea7a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Business: 30000\n", "Sci/Tech: 30000\n", "Sports: 30000\n", "World: 30000\n" ] } ], "source": [ "# Split by category\n", "by_category = collections.defaultdict(list)\n", "for _, row in df.iterrows():\n", " by_category[row.category].append(row.to_dict())\n", "for category in by_category:\n", " print (\"{0}: {1}\".format(category, len(by_category[category])))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "76PwKQHLEww5" }, "outputs": [], "source": [ "# Create split data\n", "final_list = []\n", "for _, item_list in sorted(by_category.items()):\n", " if args.shuffle:\n", " np.random.shuffle(item_list)\n", " n = len(item_list)\n", " n_train = int(args.train_size*n)\n", " n_val = int(args.val_size*n)\n", " n_test = int(args.test_size*n)\n", "\n", " # Give data point a split attribute\n", " for item in item_list[:n_train]:\n", " item['split'] = 'train'\n", " for item in item_list[n_train:n_train+n_val]:\n", " item['split'] = 'val'\n", " for item in item_list[n_train+n_val:]:\n", " item['split'] = 'test' \n", "\n", " # Add to final list\n", " final_list.extend(item_list)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "id": "CQeS0KHOEwzm", "outputId": "93c9aadb-25c4-4029-f002-8a43f3956045" }, "outputs": [ { "data": { "text/plain": [ "train 84000\n", "val 18000\n", "test 18000\n", "Name: split, dtype: int64" ] }, "execution_count": 28, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "# df with split datasets\n", "split_df = pd.DataFrame(final_list)\n", "split_df[\"split\"].value_counts()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "pPJDyVusEw3-" }, "outputs": [], "source": [ "# Preprocessing\n", "def preprocess_text(text):\n", " text = ' '.join(word.lower() for word in text.split(\" \"))\n", " text = re.sub(r\"([.,!?])\", r\" \\1 \", text)\n", " text = re.sub(r\"[^a-zA-Z.,!?]+\", r\" \", text)\n", " text = text.strip()\n", " return text\n", " \n", "split_df.title = split_df.title.apply(preprocess_text)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 204 }, "colab_type": "code", "id": "IAetKendEw6b", "outputId": "d5946f7e-840e-4a0b-e492-d3da68cefd44" }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
categorysplittitle
0Businesstraingeneral electric posts higher rd quarter profit
1Businesstrainlilly to eliminate up to us jobs
2Businesstrains amp p lowers america west outlook to negative
3Businesstraindoes rand walk the talk on labor policy ?
4Businesstrainhousekeeper advocates for changes
\n", "
" ], "text/plain": [ " category split title\n", "0 Business train general electric posts higher rd quarter profit\n", "1 Business train lilly to eliminate up to us jobs\n", "2 Business train s amp p lowers america west outlook to negative\n", "3 Business train does rand walk the talk on labor policy ?\n", "4 Business train housekeeper advocates for changes" ] }, "execution_count": 30, "metadata": { "tags": [] }, "output_type": "execute_result" } ], "source": [ "# Save to CSV\n", "split_df.to_csv(args.split_data_file, index=False)\n", "split_df.head()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "NHzGXAI3E7lF" }, "source": [ "## Vocabulary" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "ZIRUjX0MEw88" }, "outputs": [], "source": [ "class Vocabulary(object):\n", " def __init__(self, token_to_idx=None):\n", "\n", " # Token to index\n", " if token_to_idx is None:\n", " token_to_idx = {}\n", " self.token_to_idx = token_to_idx\n", "\n", " # Index to token\n", " self.idx_to_token = {idx: token \\\n", " for token, idx in self.token_to_idx.items()}\n", "\n", " def to_serializable(self):\n", " return {'token_to_idx': self.token_to_idx}\n", "\n", " @classmethod\n", " def from_serializable(cls, contents):\n", " return cls(**contents)\n", "\n", " def add_token(self, token):\n", " if token in self.token_to_idx:\n", " index = self.token_to_idx[token]\n", " else:\n", " index = len(self.token_to_idx)\n", " self.token_to_idx[token] = index\n", " self.idx_to_token[index] = token\n", " return index\n", "\n", " def add_tokens(self, tokens):\n", " return [self.add_token[token] for token in tokens]\n", "\n", " def lookup_token(self, token):\n", " return self.token_to_idx[token]\n", "\n", " def lookup_index(self, index):\n", " if index not in self.idx_to_token:\n", " raise KeyError(\"the index (%d) is not in the Vocabulary\" % index)\n", " return self.idx_to_token[index]\n", "\n", " def __str__(self):\n", " return \"\" % len(self)\n", "\n", " def __len__(self):\n", " return len(self.token_to_idx)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "id": "1LtYf3lpExBb", "outputId": "617297a7-3fdb-4789-bbca-dea82d06c8ce" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "4\n", "0\n", "Business\n" ] } ], "source": [ "# Vocabulary instance\n", "category_vocab = Vocabulary()\n", "for index, row in df.iterrows():\n", " category_vocab.add_token(row.category)\n", "print (category_vocab) # __str__\n", "print (len(category_vocab)) # __len__\n", "index = category_vocab.lookup_token(\"Business\")\n", "print (index)\n", "print (category_vocab.lookup_index(index))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Z0zkF6CsE_yH" }, "source": [ "## Sequence vocabulary" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "QtntaISyE_1c" }, "source": [ "Next, we're going to create our Vocabulary classes for the article's title, which is a sequence of tokens." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "ovI8QRefEw_p" }, "outputs": [], "source": [ "from collections import Counter\n", "import string" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "4W3ZouuTEw1_" }, "outputs": [], "source": [ "class SequenceVocabulary(Vocabulary):\n", " def __init__(self, token_to_idx=None, unk_token=\"\",\n", " mask_token=\"\", begin_seq_token=\"\",\n", " end_seq_token=\"\"):\n", "\n", " super(SequenceVocabulary, self).__init__(token_to_idx)\n", "\n", " self.mask_token = mask_token\n", " self.unk_token = unk_token\n", " self.begin_seq_token = begin_seq_token\n", " self.end_seq_token = end_seq_token\n", "\n", " self.mask_index = self.add_token(self.mask_token)\n", " self.unk_index = self.add_token(self.unk_token)\n", " self.begin_seq_index = self.add_token(self.begin_seq_token)\n", " self.end_seq_index = self.add_token(self.end_seq_token)\n", " \n", " # Index to token\n", " self.idx_to_token = {idx: token \\\n", " for token, idx in self.token_to_idx.items()}\n", "\n", " def to_serializable(self):\n", " contents = super(SequenceVocabulary, self).to_serializable()\n", " contents.update({'unk_token': self.unk_token,\n", " 'mask_token': self.mask_token,\n", " 'begin_seq_token': self.begin_seq_token,\n", " 'end_seq_token': self.end_seq_token})\n", " return contents\n", "\n", " def lookup_token(self, token):\n", " return self.token_to_idx.get(token, self.unk_index)\n", " \n", " def lookup_index(self, index):\n", " if index not in self.idx_to_token:\n", " raise KeyError(\"the index (%d) is not in the SequenceVocabulary\" % index)\n", " return self.idx_to_token[index]\n", " \n", " def __str__(self):\n", " return \"\" % len(self.token_to_idx)\n", "\n", " def __len__(self):\n", " return len(self.token_to_idx)\n" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 85 }, "colab_type": "code", "id": "g5UHjpi3El37", "outputId": "cb20aa34-2bd5-4178-b219-d845fdc4968e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "4400\n", "4\n", "general\n" ] } ], "source": [ "# Get word counts\n", "word_counts = Counter()\n", "for title in split_df.title:\n", " for token in title.split(\" \"):\n", " if token not in string.punctuation:\n", " word_counts[token] += 1\n", "\n", "# Create SequenceVocabulary instance\n", "title_vocab = SequenceVocabulary()\n", "for word, word_count in word_counts.items():\n", " if word_count >= args.cutoff:\n", " title_vocab.add_token(word)\n", "print (title_vocab) # __str__\n", "print (len(title_vocab)) # __len__\n", "index = title_vocab.lookup_token(\"general\")\n", "print (index)\n", "print (title_vocab.lookup_index(index))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "4Dag6H0SFHAG" }, "source": [ "## Vectorizer" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "VQIfxcUuKwzz" }, "source": [ "Something new that we introduce in this Vectorizer is calculating the length of our input sequence. We will use this later on to extract the last relevant hidden state for each input sequence." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "tsNtEnhBEl6s" }, "outputs": [], "source": [ "class NewsVectorizer(object):\n", " def __init__(self, title_vocab, category_vocab):\n", " self.title_vocab = title_vocab\n", " self.category_vocab = category_vocab\n", "\n", " def vectorize(self, title):\n", " indices = [self.title_vocab.lookup_token(token) for token in title.split(\" \")]\n", " indices = [self.title_vocab.begin_seq_index] + indices + \\\n", " [self.title_vocab.end_seq_index]\n", " \n", " # Create vector\n", " title_length = len(indices)\n", " vector = np.zeros(title_length, dtype=np.int64)\n", " vector[:len(indices)] = indices\n", "\n", " return vector, title_length\n", " \n", " def unvectorize(self, vector):\n", " tokens = [self.title_vocab.lookup_index(index) for index in vector]\n", " title = \" \".join(token for token in tokens)\n", " return title\n", "\n", " @classmethod\n", " def from_dataframe(cls, df, cutoff):\n", " \n", " # Create class vocab\n", " category_vocab = Vocabulary() \n", " for category in sorted(set(df.category)):\n", " category_vocab.add_token(category)\n", "\n", " # Get word counts\n", " word_counts = Counter()\n", " for title in df.title:\n", " for token in title.split(\" \"):\n", " word_counts[token] += 1\n", " \n", " # Create title vocab\n", " title_vocab = SequenceVocabulary()\n", " for word, word_count in word_counts.items():\n", " if word_count >= cutoff:\n", " title_vocab.add_token(word)\n", " \n", " return cls(title_vocab, category_vocab)\n", "\n", " @classmethod\n", " def from_serializable(cls, contents):\n", " title_vocab = SequenceVocabulary.from_serializable(contents['title_vocab'])\n", " category_vocab = Vocabulary.from_serializable(contents['category_vocab'])\n", " return cls(title_vocab=title_vocab, category_vocab=category_vocab)\n", " \n", " def to_serializable(self):\n", " return {'title_vocab': self.title_vocab.to_serializable(),\n", " 'category_vocab': self.category_vocab.to_serializable()}" ] }, { "cell_type": "code", "execution_count": 50, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 119 }, "colab_type": "code", "id": "JtRRXU53El9Y", "outputId": "ba63f1e4-d50e-458c-cb38-da4cc69e5dfa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "(10,)\n", "title_length: 10\n", "[ 2 1 4151 1231 25 1 2392 4076 38 3]\n", " federer wins the tennis tournament . \n" ] } ], "source": [ "# Vectorizer instance\n", "vectorizer = NewsVectorizer.from_dataframe(split_df, cutoff=args.cutoff)\n", "print (vectorizer.title_vocab)\n", "print (vectorizer.category_vocab)\n", "vectorized_title, title_length = vectorizer.vectorize(preprocess_text(\n", " \"Roger Federer wins the Wimbledon tennis tournament.\"))\n", "print (np.shape(vectorized_title))\n", "print (\"title_length:\", title_length)\n", "print (vectorized_title)\n", "print (vectorizer.unvectorize(vectorized_title))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "uk_QvpVfFM0S" }, "source": [ "## Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "oU7oDdelFMR9" }, "outputs": [], "source": [ "from torch.utils.data import Dataset, DataLoader" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "pB7FHmiSFMXA" }, "outputs": [], "source": [ "class NewsDataset(Dataset):\n", " def __init__(self, df, vectorizer):\n", " self.df = df\n", " self.vectorizer = vectorizer\n", "\n", " # Data splits\n", " self.train_df = self.df[self.df.split=='train']\n", " self.train_size = len(self.train_df)\n", " self.val_df = self.df[self.df.split=='val']\n", " self.val_size = len(self.val_df)\n", " self.test_df = self.df[self.df.split=='test']\n", " self.test_size = len(self.test_df)\n", " self.lookup_dict = {'train': (self.train_df, self.train_size), \n", " 'val': (self.val_df, self.val_size),\n", " 'test': (self.test_df, self.test_size)}\n", " self.set_split('train')\n", "\n", " # Class weights (for imbalances)\n", " class_counts = df.category.value_counts().to_dict()\n", " def sort_key(item):\n", " return self.vectorizer.category_vocab.lookup_token(item[0])\n", " sorted_counts = sorted(class_counts.items(), key=sort_key)\n", " frequencies = [count for _, count in sorted_counts]\n", " self.class_weights = 1.0 / torch.tensor(frequencies, dtype=torch.float32)\n", "\n", " @classmethod\n", " def load_dataset_and_make_vectorizer(cls, split_data_file, cutoff):\n", " df = pd.read_csv(split_data_file, header=0)\n", " train_df = df[df.split=='train']\n", " return cls(df, NewsVectorizer.from_dataframe(train_df, cutoff))\n", "\n", " @classmethod\n", " def load_dataset_and_load_vectorizer(cls, split_data_file, vectorizer_filepath):\n", " df = pd.read_csv(split_data_file, header=0)\n", " vectorizer = cls.load_vectorizer_only(vectorizer_filepath)\n", " return cls(df, vectorizer)\n", "\n", " def load_vectorizer_only(vectorizer_filepath):\n", " with open(vectorizer_filepath) as fp:\n", " return NewsVectorizer.from_serializable(json.load(fp))\n", "\n", " def save_vectorizer(self, vectorizer_filepath):\n", " with open(vectorizer_filepath, \"w\") as fp:\n", " json.dump(self.vectorizer.to_serializable(), fp)\n", "\n", " def set_split(self, split=\"train\"):\n", " self.target_split = split\n", " self.target_df, self.target_size = self.lookup_dict[split]\n", "\n", " def __str__(self):\n", " return \" software firm to cut jobs \n", "tensor([3.3333e-05, 3.3333e-05, 3.3333e-05, 3.3333e-05])\n" ] } ], "source": [ "# Dataset instance\n", "dataset = NewsDataset.load_dataset_and_make_vectorizer(args.split_data_file,\n", " args.cutoff)\n", "print (dataset) # __str__\n", "input_ = dataset[5] # __getitem__\n", "print (input_['title'], input_['title_length'], input_['category'])\n", "print (dataset.vectorizer.unvectorize(input_['title']))\n", "print (dataset.class_weights)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "_IUIqtbvFUAG" }, "source": [ "## Model" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "xJV5WlDiFVVz" }, "source": [ "input → embedding → RNN → FC " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "rZCzdZZ9FMhm" }, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "wbWO4lZcIdqZ" }, "outputs": [], "source": [ "def gather_last_relevant_hidden(hiddens, x_lengths):\n", " x_lengths = x_lengths.long().detach().cpu().numpy() - 1\n", " out = []\n", " for batch_index, column_index in enumerate(x_lengths):\n", " out.append(hiddens[batch_index, column_index])\n", " return torch.stack(out)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "9TT66Y-UFMcZ" }, "outputs": [], "source": [ "class NewsModel(nn.Module):\n", " def __init__(self, embedding_dim, num_embeddings, rnn_hidden_dim, \n", " hidden_dim, output_dim, num_layers, bidirectional, dropout_p, \n", " pretrained_embeddings=None, freeze_embeddings=False, \n", " padding_idx=0):\n", " super(NewsModel, self).__init__()\n", " \n", " if pretrained_embeddings is None:\n", " self.embeddings = nn.Embedding(embedding_dim=embedding_dim,\n", " num_embeddings=num_embeddings,\n", " padding_idx=padding_idx)\n", " else:\n", " pretrained_embeddings = torch.from_numpy(pretrained_embeddings).float()\n", " self.embeddings = nn.Embedding(embedding_dim=embedding_dim,\n", " num_embeddings=num_embeddings,\n", " padding_idx=padding_idx,\n", " _weight=pretrained_embeddings)\n", " \n", " # Conv weights\n", " self.gru = nn.GRU(input_size=embedding_dim, hidden_size=rnn_hidden_dim, \n", " num_layers=num_layers, batch_first=True, \n", " bidirectional=bidirectional)\n", " \n", " # FC weights\n", " self.dropout = nn.Dropout(dropout_p)\n", " self.fc1 = nn.Linear(rnn_hidden_dim, hidden_dim)\n", " self.fc2 = nn.Linear(hidden_dim, output_dim)\n", " \n", " if freeze_embeddings:\n", " self.embeddings.weight.requires_grad = False\n", "\n", " def forward(self, x_in, x_lengths, apply_softmax=False):\n", " \n", " # Embed\n", " x_in = self.embeddings(x_in)\n", " \n", " # Feed into RNN\n", " out, h_n = self.gru(x_in)\n", " \n", " # Gather the last relevant hidden state\n", " out = gather_last_relevant_hidden(out, x_lengths)\n", "\n", " # FC layers\n", " z = self.dropout(out)\n", " z = self.fc1(z)\n", " z = self.dropout(z)\n", " y_pred = self.fc2(z)\n", "\n", " if apply_softmax:\n", " y_pred = F.softmax(y_pred, dim=1)\n", " return y_pred" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "jHPYCPd7Fl3M" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "D3seBMA7FlcC" }, "outputs": [], "source": [ "import torch.optim as optim" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "HnRKWLekFlnM" }, "outputs": [], "source": [ "class Trainer(object):\n", " def __init__(self, dataset, model, model_state_file, save_dir, device, shuffle, \n", " num_epochs, batch_size, learning_rate, early_stopping_criteria):\n", " self.dataset = dataset\n", " self.class_weights = dataset.class_weights.to(device)\n", " self.model = model.to(device)\n", " self.save_dir = save_dir\n", " self.device = device\n", " self.shuffle = shuffle\n", " self.num_epochs = num_epochs\n", " self.batch_size = batch_size\n", " self.loss_func = nn.CrossEntropyLoss(self.class_weights)\n", " self.optimizer = optim.Adam(self.model.parameters(), lr=learning_rate)\n", " self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(\n", " optimizer=self.optimizer, mode='min', factor=0.5, patience=1)\n", " self.train_state = {\n", " 'stop_early': False, \n", " 'early_stopping_step': 0,\n", " 'early_stopping_best_val': 1e8,\n", " 'early_stopping_criteria': early_stopping_criteria,\n", " 'learning_rate': learning_rate,\n", " 'epoch_index': 0,\n", " 'train_loss': [],\n", " 'train_acc': [],\n", " 'val_loss': [],\n", " 'val_acc': [],\n", " 'test_loss': -1,\n", " 'test_acc': -1,\n", " 'model_filename': model_state_file}\n", " \n", " def update_train_state(self):\n", "\n", " # Verbose\n", " print (\"[EPOCH]: {0:02d} | [LR]: {1} | [TRAIN LOSS]: {2:.2f} | [TRAIN ACC]: {3:.1f}% | [VAL LOSS]: {4:.2f} | [VAL ACC]: {5:.1f}%\".format(\n", " self.train_state['epoch_index'], self.train_state['learning_rate'], \n", " self.train_state['train_loss'][-1], self.train_state['train_acc'][-1], \n", " self.train_state['val_loss'][-1], self.train_state['val_acc'][-1]))\n", "\n", " # Save one model at least\n", " if self.train_state['epoch_index'] == 0:\n", " torch.save(self.model.state_dict(), self.train_state['model_filename'])\n", " self.train_state['stop_early'] = False\n", "\n", " # Save model if performance improved\n", " elif self.train_state['epoch_index'] >= 1:\n", " loss_tm1, loss_t = self.train_state['val_loss'][-2:]\n", "\n", " # If loss worsened\n", " if loss_t >= self.train_state['early_stopping_best_val']:\n", " # Update step\n", " self.train_state['early_stopping_step'] += 1\n", "\n", " # Loss decreased\n", " else:\n", " # Save the best model\n", " if loss_t < self.train_state['early_stopping_best_val']:\n", " torch.save(self.model.state_dict(), self.train_state['model_filename'])\n", "\n", " # Reset early stopping step\n", " self.train_state['early_stopping_step'] = 0\n", "\n", " # Stop early ?\n", " self.train_state['stop_early'] = self.train_state['early_stopping_step'] \\\n", " >= self.train_state['early_stopping_criteria']\n", " return self.train_state\n", " \n", " def compute_accuracy(self, y_pred, y_target):\n", " _, y_pred_indices = y_pred.max(dim=1)\n", " n_correct = torch.eq(y_pred_indices, y_target).sum().item()\n", " return n_correct / len(y_pred_indices) * 100\n", " \n", " def pad_seq(self, seq, length):\n", " vector = np.zeros(length, dtype=np.int64)\n", " vector[:len(seq)] = seq\n", " vector[len(seq):] = self.dataset.vectorizer.title_vocab.mask_index\n", " return vector\n", " \n", " def collate_fn(self, batch):\n", " \n", " # Make a deep copy\n", " batch_copy = copy.deepcopy(batch)\n", " processed_batch = {\"title\": [], \"title_length\": [], \"category\": []}\n", " \n", " # Get max sequence length\n", " get_length = lambda sample: len(sample[\"title\"])\n", " max_seq_length = max(map(get_length, batch))\n", " \n", " # Pad\n", " for i, sample in enumerate(batch_copy):\n", " padded_seq = self.pad_seq(sample[\"title\"], max_seq_length)\n", " processed_batch[\"title\"].append(padded_seq)\n", " processed_batch[\"title_length\"].append(sample[\"title_length\"])\n", " processed_batch[\"category\"].append(sample[\"category\"])\n", " \n", " # Convert to appropriate tensor types\n", " processed_batch[\"title\"] = torch.LongTensor(\n", " processed_batch[\"title\"])\n", " processed_batch[\"title_length\"] = torch.LongTensor(\n", " processed_batch[\"title_length\"])\n", " processed_batch[\"category\"] = torch.LongTensor(\n", " processed_batch[\"category\"])\n", " \n", " return processed_batch \n", " \n", " def run_train_loop(self):\n", " for epoch_index in range(self.num_epochs):\n", " self.train_state['epoch_index'] = epoch_index\n", " \n", " # Iterate over train dataset\n", "\n", " # initialize batch generator, set loss and acc to 0, set train mode on\n", " self.dataset.set_split('train')\n", " batch_generator = self.dataset.generate_batches(\n", " batch_size=self.batch_size, collate_fn=self.collate_fn, \n", " shuffle=self.shuffle, device=self.device)\n", " running_loss = 0.0\n", " running_acc = 0.0\n", " self.model.train()\n", "\n", " for batch_index, batch_dict in enumerate(batch_generator):\n", " # zero the gradients\n", " self.optimizer.zero_grad()\n", "\n", " # compute the output\n", " y_pred = self.model(batch_dict['title'], batch_dict['title_length'])\n", "\n", " # compute the loss\n", " loss = self.loss_func(y_pred, batch_dict['category'])\n", " loss_t = loss.item()\n", " running_loss += (loss_t - running_loss) / (batch_index + 1)\n", "\n", " # compute gradients using loss\n", " loss.backward()\n", "\n", " # use optimizer to take a gradient step\n", " self.optimizer.step()\n", " \n", " # compute the accuracy\n", " acc_t = self.compute_accuracy(y_pred, batch_dict['category'])\n", " running_acc += (acc_t - running_acc) / (batch_index + 1)\n", "\n", " self.train_state['train_loss'].append(running_loss)\n", " self.train_state['train_acc'].append(running_acc)\n", "\n", " # Iterate over val dataset\n", "\n", " # # initialize batch generator, set loss and acc to 0; set eval mode on\n", " self.dataset.set_split('val')\n", " batch_generator = self.dataset.generate_batches(\n", " batch_size=self.batch_size, collate_fn=self.collate_fn, \n", " shuffle=self.shuffle, device=self.device)\n", " running_loss = 0.\n", " running_acc = 0.\n", " self.model.eval()\n", "\n", " for batch_index, batch_dict in enumerate(batch_generator):\n", "\n", " # compute the output\n", " y_pred = self.model(batch_dict['title'], batch_dict['title_length'])\n", "\n", " # compute the loss\n", " loss = self.loss_func(y_pred, batch_dict['category'])\n", " loss_t = loss.to(\"cpu\").item()\n", " running_loss += (loss_t - running_loss) / (batch_index + 1)\n", "\n", " # compute the accuracy\n", " acc_t = self.compute_accuracy(y_pred, batch_dict['category'])\n", " running_acc += (acc_t - running_acc) / (batch_index + 1)\n", "\n", " self.train_state['val_loss'].append(running_loss)\n", " self.train_state['val_acc'].append(running_acc)\n", "\n", " self.train_state = self.update_train_state()\n", " self.scheduler.step(self.train_state['val_loss'][-1])\n", " if self.train_state['stop_early']:\n", " break\n", " \n", " def run_test_loop(self):\n", " # initialize batch generator, set loss and acc to 0; set eval mode on\n", " self.dataset.set_split('test')\n", " batch_generator = self.dataset.generate_batches(\n", " batch_size=self.batch_size, collate_fn=self.collate_fn, \n", " shuffle=self.shuffle, device=self.device)\n", " running_loss = 0.0\n", " running_acc = 0.0\n", " self.model.eval()\n", "\n", " for batch_index, batch_dict in enumerate(batch_generator):\n", " # compute the output\n", " y_pred = self.model(batch_dict['title'], batch_dict['title_length'])\n", "\n", " # compute the loss\n", " loss = self.loss_func(y_pred, batch_dict['category'])\n", " loss_t = loss.item()\n", " running_loss += (loss_t - running_loss) / (batch_index + 1)\n", "\n", " # compute the accuracy\n", " acc_t = self.compute_accuracy(y_pred, batch_dict['category'])\n", " running_acc += (acc_t - running_acc) / (batch_index + 1)\n", "\n", " self.train_state['test_loss'] = running_loss\n", " self.train_state['test_acc'] = running_acc\n", " \n", " def plot_performance(self):\n", " # Figure size\n", " plt.figure(figsize=(15,5))\n", "\n", " # Plot Loss\n", " plt.subplot(1, 2, 1)\n", " plt.title(\"Loss\")\n", " plt.plot(trainer.train_state[\"train_loss\"], label=\"train\")\n", " plt.plot(trainer.train_state[\"val_loss\"], label=\"val\")\n", " plt.legend(loc='upper right')\n", "\n", " # Plot Accuracy\n", " plt.subplot(1, 2, 2)\n", " plt.title(\"Accuracy\")\n", " plt.plot(trainer.train_state[\"train_acc\"], label=\"train\")\n", " plt.plot(trainer.train_state[\"val_acc\"], label=\"val\")\n", " plt.legend(loc='lower right')\n", "\n", " # Save figure\n", " plt.savefig(os.path.join(self.save_dir, \"performance.png\"))\n", "\n", " # Show plots\n", " plt.show()\n", " \n", " def save_train_state(self):\n", " with open(os.path.join(self.save_dir, \"train_state.json\"), \"w\") as fp:\n", " json.dump(self.train_state, fp)" ] }, { "cell_type": "code", "execution_count": 88, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 136 }, "colab_type": "code", "id": "ICkiOaGtFlk-", "outputId": "57f7f143-7899-407a-acbd-17f767eb56c3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Initialization\n", "dataset = NewsDataset.load_dataset_and_make_vectorizer(args.split_data_file,\n", " args.cutoff)\n", "dataset.save_vectorizer(args.vectorizer_file)\n", "vectorizer = dataset.vectorizer\n", "model = NewsModel(embedding_dim=args.embedding_dim, \n", " num_embeddings=len(vectorizer.title_vocab), \n", " rnn_hidden_dim=args.rnn_hidden_dim,\n", " hidden_dim=args.hidden_dim,\n", " output_dim=len(vectorizer.category_vocab),\n", " num_layers=args.num_layers,\n", " bidirectional=args.bidirectional,\n", " dropout_p=args.dropout_p, \n", " pretrained_embeddings=None, \n", " padding_idx=vectorizer.title_vocab.mask_index)\n", "print (model.named_modules)" ] }, { "cell_type": "code", "execution_count": 89, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 102 }, "colab_type": "code", "id": "tuaRZ4DiFlh1", "outputId": "fba7ac04-7e1d-4372-b358-7340a013960d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[EPOCH]: 00 | [LR]: 0.001 | [TRAIN LOSS]: 0.75 | [TRAIN ACC]: 70.7% | [VAL LOSS]: 0.54 | [VAL ACC]: 80.5%\n", "[EPOCH]: 01 | [LR]: 0.001 | [TRAIN LOSS]: 0.48 | [TRAIN ACC]: 82.7% | [VAL LOSS]: 0.49 | [VAL ACC]: 82.3%\n", "[EPOCH]: 02 | [LR]: 0.001 | [TRAIN LOSS]: 0.41 | [TRAIN ACC]: 85.0% | [VAL LOSS]: 0.47 | [VAL ACC]: 83.1%\n", "[EPOCH]: 03 | [LR]: 0.001 | [TRAIN LOSS]: 0.37 | [TRAIN ACC]: 86.6% | [VAL LOSS]: 0.47 | [VAL ACC]: 83.3%\n", "[EPOCH]: 04 | [LR]: 0.001 | [TRAIN LOSS]: 0.33 | [TRAIN ACC]: 88.2% | [VAL LOSS]: 0.49 | [VAL ACC]: 83.0%\n" ] } ], "source": [ "# Train\n", "trainer = Trainer(dataset=dataset, model=model, \n", " model_state_file=args.model_state_file, \n", " save_dir=args.save_dir, device=args.device,\n", " shuffle=args.shuffle, num_epochs=args.num_epochs, \n", " batch_size=args.batch_size, learning_rate=args.learning_rate, \n", " early_stopping_criteria=args.early_stopping_criteria)\n", "trainer.run_train_loop()" ] }, { "cell_type": "code", "execution_count": 73, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 335 }, "colab_type": "code", "id": "mzRJIz88Flfe", "outputId": "a7ac8786-01ea-4421-e70c-d79c22c7ed4a" }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA2gAAAE+CAYAAAD4XjP+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAIABJREFUeJzs3Xl41fWd//3nWbLvy8keQhISAiGJ7PuqCALuothinWrvuXVse/9+wzg4mU5RR2t7jXRq7dQuP39tdTruQUQUF0qQTbZAEgIBErZsZF8IgSznnPuPwIGwK0nOOcnrcV1e5Lud8zpHOMk77+/n8zHY7XY7IiIiIiIi4nRGZwcQERERERGRbirQREREREREXIQKNBERERERERehAk1ERERERMRFqEATERERERFxESrQREREREREXIQKNJFvafjw4Zw8edLZMURERPrFkiVLuOuuu5wdQ2TAU4EmIiIiItd06NAhAgICiImJYc+ePc6OIzKgqUAT6WXt7e389Kc/Zd68edxxxx38/Oc/x2q1AvDf//3f3HHHHcyfP58HHniAw4cPX3O/iIiIK1i1ahXz589n0aJFfPjhh479H374IfPmzWPevHk8/fTTdHR0XHX/9u3bmTt3ruPai7dfffVVfvKTn/DAAw/w5z//GZvNxnPPPce8efOYM2cOTz/9NJ2dnQA0NDTwxBNPcOutt3LnnXeyefNmcnNzWbRoUY/M9913H19++WVfvzUivc7s7AAiA81f/vIXTp48ydq1a+nq6mLp0qV8/PHH3Hrrrbzyyits2LABf39/Pv30U3Jzc4mOjr7i/pSUFGe/FBEREaxWK1988QVPPfUUJpOJlStX0tHRQU1NDb/4xS/48MMPiYiI4Ec/+hFvvPEG8+fPv+L+jIyMaz7Pxo0bWb16NaGhoXz22Wfs2rWLjz/+GJvNxr333ssnn3zC3XffzcqVK0lOTuZ3v/sd+/fv5/vf/z6bNm2itraW4uJi0tLSqKys5MSJE8yYMaOf3iWR3qMCTaSX5ebm8thjj2E2mzGbzdx5551s2bKFBQsWYDAYeP/991m0aBF33HEHAJ2dnVfcLyIi4go2b95MRkYG/v7+AEyYMIENGzbQ1NTE6NGjiYyMBGDlypWYTCY++OCDK+7fvXv3NZ8nKyuL0NBQAObNm8fs2bPx8PAAICMjg7KyMqC7kPvjH/8IwMiRI1m/fj2enp7MmzePtWvXkpaWxpdffsmtt96Kp6dn778hIn1MtziK9LKGhgaCgoIc20FBQdTX1+Ph4cGf//xn8vLymDdvHt/5znc4ePDgVfeLiIi4gpycHHJzcxk3bhzjxo3j888/Z9WqVTQ2NhIYGOg4z8vLC7PZfNX913Px986GhgaWL1/OvHnzmD9/PuvXr8dutwPQ1NREQECA49zzhePChQtZu3YtAF9++SULFiy4uRcu4iQq0ER6WXh4OE1NTY7tpqYmwsPDge7f9P36179m27ZtTJs2jRUrVlxzv4iIiDM1NzezY8cOtm/fzq5du9i1axc7d+6ksLAQo9FIY2Oj49zW1lbq6uoICQm54n6TyeQYkw3Q0tJy1ef9z//8T8xmM2vWrGHdunXMnDnTcSw4OLjH45eXl9PZ2cn48ePp6upiw4YNHD58mClTpvTW2yDSr1SgifSyWbNm8f7772O1Wmlra2P16tXMnDmTgwcP8uMf/5iOjg48PT0ZNWoUBoPhqvtFREScbe3atUyaNKnHrYJms5lp06bR0dFBXl4e5eXl2O12VqxYwfvvv8/MmTOvuN9isVBbW0t9fT1Wq5U1a9Zc9Xnr6+tJTU3F09OT4uJi9uzZQ1tbGwBz5sxh1apVAJSUlHDfffdhtVoxGo0sWLCAf//3f2fOnDmO2yNF3I3GoInchEceeQSTyeTYfuGFF3jkkUcoKytj4cKFGAwG5s+f7xhXFhcXx6JFi/Dw8MDPz4+f/vSnpKamXnG/iIiIs3344Yc8+uijl+2fO3cuv/3tb3n++ed59NFHMZlMZGRk8P3vfx8vL6+r7r///vu55557iImJ4e677+bAgQNXfN7HHnuM5cuXk5OTw7hx41i+fDn/+q//SmZmJk8//TTLly9nzpw5+Pn58fLLL+Pt7Q103+b4pz/9Sbc3ilsz2M/f0CsiIiIi4sbq6uq49957yc3N7fELVBF3olscRURERGRA+PWvf83DDz+s4kzcmgo0EREREXFrdXV13HrrrdTV1fHYY485O47ITdEtjiIiIiIiIi5CHTQREREREREXoQJNRERERETERfT7NPu1tadu+jFCQnxpbGzrhTT9w53yKmvfcKes4F55lbVv9FZWiyWgF9IMHoPte6Q7ZQX3yqusfUNZ+4475e2NrNf6/uiWHTSz2b1m5nGnvMraN9wpK7hXXmXtG+6UVXpyp/937pQV3CuvsvYNZe077pS3r7O6ZYEmIiIiIiIyEKlAExERERERcREq0ERERERERFyECjQREREREREXoQJNRERERETERahAExERERERcREq0ERERERERFyECjQRkUEkN3f9DZ334osvUllZ0cdpRERE5FIq0EREBomqqkq+/PKzGzr3X//1X4mJie3jRCIiInIps7MDfFPtHVbWbDrCLUkheHu6XXwREaf55S9/wYEDRUyfPp7bb7+DqqpKfvWr3/LSS89TW1vDmTNneOyxv2fq1Ok88sgj/PCH/8iGDes5fbqVEyeOU1FRzo9/vIzJk6c6+6WIiIj0qy6rjePVpzhS0cK4UdGE+PRdHeJ2Fc7Bskb+8GEht46N47tzU50dR0TEbTz88CPk5LxLYmIyJ04c47e//T80NjYwYcIk7rhjERUV5fzbvz3D1KnTe1xXU1PNyy//mq+/3srq1R+oQBMRkQGv7WwXRyqbOVTeTEl5E0cqW+josgFwvLaVHywY0WfP7XYF2sihoUSH+5G7p4LbxsURGeLr7EgiIt/Yu38rYWdxTa8+5vi0CB6cM+yGzh0xIh2AgIBADhwo4qOPcjAYjLS0NF92bmbmLQBERETQ2trae4FFRERcREPLWQ6XN3O4vInD5c2U17RiP3fMAMRa/EmJDyIlLohbJw7l9KmzfZbF7Qo0s8nIowtG8vM3dvJBbin/cG+GsyOJiLgdDw8PAL74Yh0tLS3813/9H1paWvjBDx657FyTyeT42m63X3ZcRETEndjsdirrTl8oyMqaqW+5UHB5mI2kxAeTEhdESlwww2ID8fX2cBz39fZQgXapKZnRJMUEsutgLaUVzSTHBjk7kojIN/LgnGE33O3qLUajEavV2mNfU1MT0dExGI1GNm78G52dnf2aSUREpK91dlk5WnXK0R0rKW+mrb3Lcdzfx4PRKeGkxHUXZQlRAZhNzptL0S0LNIPBwIOzh/Hzv+bxzoYS/uW7YzAYDM6OJSLi0hISEjl4sJjo6BiCg4MBmDVrDs8884/s37+PhQvvIiIigj/96Y9OTup6Tp8+zfLly2lubqazs5OnnnqKP/zhD47jNTU13HvvvTzxxBOOfa+++ipr1qwhMjISgLvuuovFixf3e3YRkcGm9UwnJRUXblc8VtVCl/XCHSARwT7dBdm5LllUqK9L1RJuWaABpMYHMzolnD2H69hzuI4xqRZnRxIRcWkhISHk5KztsS86Ooa//OVtx/btt98BgMUSQG3tKZKSLnT5kpKG8Zvf/IHBaNWqVSQmJrJs2TKqq6t59NFHWbduneP4D37wA+6+++7Lrvve977H0qVL+zOqiMigYrfbqW/uOX6sou6047jBAEMiA0iJCyI1LphhcUEE+3s5MfH1uW2BBvDArGTyS+p5L7eUzOQwp7YiRURk4AoJCeHgwYMAtLS0EBIS4ji2detWhg4dSnR0tLPiiYgMGjabnfLa1h4FWeOpdsdxTw8jIxJCusePxQeTFB2Ij5d7lTzulfYS0WF+zLglhtw9FWzKr2T2mDhnRxIRkQFo4cKF5OTkMHfuXFpaWvj973/vOPbGG2+QnZ19xevWrVvH+vXr8fT05Cc/+Qnx8fH9FVlEZEBo77RytLLlwviximbOdlwYTx3o58nY4RbH+LH4CH+3b9q4dYEGcPe0RLbtO8nqzUeZlB7ldhWyiIi4vtWrVxMTE8Prr79OcXEx2dnZ5OTkUF1dTVtbG0OGDLnsmpkzZzJp0iTGjx/P2rVreeGFF3oUdlcSEuKL2Wy65jk3wmIJuOnH6C/ulBXcK6+y9g1l7TsWSwBNp9o5cKye/Ucb2H+0ntLyZqy2C+PHYi3+jEwMZWRiGCOTQokO83PK+LG+fG/dvpoJ8vPkjolD+HDzUT7bcYJ7pic5O5KIiAwweXl5TJs2DYC0tDRqamqwWq1s3LiRSZMmXfGazMxMx9dz5szh5Zdfvu7zNDa23XTW8+MH3YE7ZQX3yqusfUNZe5fdbqem8QyHypsor2ujoKSO6oYLn4Mmo4GEqIAL093HBRHo63nxA1BX1//rc/bGe3utAs/tCzSA2yfEs2FPBet2nGDmLbGEBLj2wD8REXEvCQkJ5OfnM2/ePCoqKvDz88NkMlFYWMjs2bOveM0LL7zA/PnzGTduHDt27CAlJaWfU4uIuJYuq42ymlYOlzU5xpC1tF1Y3sXHy8SopFBS4oJJjQtiaHQgXh43f1eBuxkQBZq3p5l7pifyl3UHWb35KH93R5qzI4mIyADy0EMPkZ2dzdKlS+nq6uLZZ58FoLa2lrCwMMd5tbW1vPrqqzz//PMsXryYFStWYDabMRgMvPDCC05KLyLiHGfauzhy0fix0spmOjptjuMhAV5MGBFBSlwwEzNj8DUZMBpdZ7p7ZxkQBRrAtMxoPt9ZxqaCSuaOjyc23M/ZkURE3NIDD9zJJ5+svf6Jg4ifnx+vvPLKZft/97vf9di2WCw8//zzAAwfPpy33377smtERAaqxlPt3euPneuQnag5hf3C8DFiLX6OyTxS4oIIC/R2jB9zh1sy+8uAKdBMRiOLZw3j1x8U8P6GEv6/xVnOjiQiIiIiMiDZ7Haq6tsoKb9wu2Jt01nHcbPJwLDYIEdBlhwbhL+PhxMTu48BU6ABZA0LY3h8MPml9RQfbyQtIeT6F4mIDBKPPfZdfvazlURFRXHyZBX/8i/LsFgiOHPmDGfPnuV//++nGTlylLNjioiIC+rssnH85CnH7YqHy5s4fbbLcdzP20xWchgp8d0F2dCoADx6YVbawWhAFWgGg4HFs4fxwhu7eHdDCT95dBxGJ0y7KSLiimbMmM2WLV9x//0PsmnTRmbMmE1ycgozZsxi9+6d/PWvf+HFF//D2TFFRMQFtJ3t7L5dsbz7lsUjVafosl4YPxYe5E1mcpijQxYd7qefu3vJgCrQAJJiApkwIoIdB2rYeaCGiSMjnR1JROQyOSUfs6emsFcfc3REBvcNW3TV4zNmzOY3v/kV99//IJs3b+SHP/zfvP32m7z11pt0dnbi7e3dq3lERMR91Def7dEdq6g9zfnhYwYDxFv8u4ux+CCGxQYRGqjvGX1lwBVoAPfNTGb3wVo+2FjKmFQLHmb3Xk1cRKQ3JCUlU19fS3X1SU6dOsWmTbmEh0fwb//27xQX7+c3v/mVsyOKiEg/sNnsVNSd7lGQNbS0O457mo0MHxLsKMiSY4Lw8RqQZYNLGpDvdESwD3PGxPHFrjI27Kng9vHxzo4kItLDfcMWXbPb1VcmT57GH/7wW6ZPn0lTUyPJyd1rc23cuIGurq7rXC0iIu6oo9PK0aqWc8VYMyUVzZxpv/CZH+DrwZhUi2NB6CGR/phNanA4y4As0ADunDqUzYVVrNlylGkZUfh6a9YYEZGZM2fzxBOP8ec/v8XZs2d44YUVbNjwJfff/yBffvk5a9d+5OyIIiJyk1rPdHJkXxW79p/kcHkTx6pOYbVdmO8+MsSHsecLsvhgIkN8HNPdi/MN2ALN38eDRZMTeC+3lLXbjrN49jBnRxIRcboRI9LZuHG7Y/uvf33f8fW0aTMBWLjwLvz8/Ghr03o0IiLuoKHlLIfKmrr/K2+msu6045jRYCAhyt8xmcewuGCC/DydmFauZ8AWaAC3jo1jfV45X+wqZ86YOMKCNJhRRERERNyX/dz6Y4fKmzhc1sShsmbqWy6sP+bpYWREQgijh0cQG+pDUkwQXp6a7t6dDOgCzdPDxL3Tk3h97QFyvjrC/3PnSGdHEhERERG5YVabjRPVrY4O2eHyZlrPdDqO+/t4MDolnJS4YFLjL4wfs1gCqK3VnRDuaEAXaACTR0Xx+c4yvi46ye3j40mICnB2JBERERGRK+rotHKkssXRISupbKG9w+o4HhroxaSkSFLjgkmJDyY6zFfrjw0wA75AMxoMPDh7GCvf2cv7uSUsWzLa2ZFERERERIDuBaEPlzefGz92+YQe0WG+pMYHnyvIgggP8nFiWukPN1Sg/exnPyM/Px+DwUB2djaZmZkAVFdX80//9E+O88rKyli2bBl33nln36T9ltITQ0lPDKXoaAP7jtQzKinM2ZFEREREZBBqPNXO4fJzE3qUNVNR2+pYEPriCT1S47sn9Qjw1YQeg811C7QdO3Zw/Phx3nnnHUpLS8nOzuadd94BIDIykjfffBOArq4uHnnkEebMmdO3ib+lxbOS2X+0gXc3lDJyaChGo1rBIiIiItJ37HY7NY1nOFh2bkKP8iZqmy5M6OFx0YLQqfHBJMcG4u054G9wk+u47t+Abdu2cdtttwGQnJxMc3Mzra2t+Pv79zhv1apVzJs3Dz8/v75JepOGRAYwZVQUW/adZFvRSaZmRDs7koiIiIgMIDabnbKa1gszLJY303K6w3Hc18tMVnJYd3csPpihUQFaEFouc90Cra6ujvT0dMd2aGgotbW1lxVo7733Hv/3//7f6z5hSIgvZvPNT/VpsXzzyT4evyeTncU1rN58lDumJ+Pl0X9Tjn6bvM6irH3DnbKCe+VV1r7hTllFRJyhs8vG0aoWx/ix0opmzrRfmNAj2N+TCSMiHB2yWIufJvSQ6/rGPVS73X7Zvj179pCUlHRZ0XYljY1t3/QpL3Mz04beNi6eT74+zluf7mfh5KE3neVGuNM0p8raN9wpK7hXXmXtG72VVUWeiAwkbWc7KTxS3z3dfVkTR6pO0WW1OY5HhvoybniQo0NmCfLGoIJMvqHrFmgRERHU1dU5tmtqarBYLD3Oyc3NZfLkyb2frg8smJTAV/mVfPL1cWZkxWjgpYiIiIhcUfPpjnOLQXd3yMprWjk/waLBAPER/qReNKFHkL+XcwPLgHDdAm3q1Km8+uqrLFmyhKKiIiIiIi7rlBUWFrJgwYI+C9mbfL3N3DllKG+tP8yaLcf4ztxUZ0cSERERESez2+3UNp+9qCBrprrhwp1fZpOREYlhDI30JzU+mGGxQfh4aUIP6X3X/Vs1ZswY0tPTWbJkCQaDgRUrVpCTk0NAQABz584FoLa2lrAw95m6fvaYWL7cXcaGPRXcOi6OyBBfZ0cSERERkX5ks9uprD3NIceU9000tV6Y0MPb08SopFBHhywxOoCY6GC3uVVd3NcNlf0Xr3UGkJaW1mN7zZo1vZeoH5hNRu6fmczvVheRs/EIT94zytmRRERERKQPdVltHDt5ytEhK6lo5vTZLsfxQD9Pxg23kHJuUej4CH8tyyROMWj7suPTIvhsRxk7i2u4vbKZ5JggZ0cSERERkV5ytqOL0oruGRYPlzdxpLKFjq4LE3pYgr25ZVg4KfHBDI8PJiLERxN6iEsYtAWawWDgwdnJ/OJ/9vDe30pY/t0x+kcpIiIi4qZOtXVwuLzZUZAdP9mK7dzs4wYg1uJPavy5GRbjggkJ0IQe4poGbYEGMHxICLcMC2dvSR17S+oYnWK5/kUiIiIi4nT1zWd7jB+rqr8woYfJaCAxJoDUuO7p7lPigvDz9nBiWpEbN6gLNIAHZiVTUFrP+7mlZCaHYTJqNXcRERERV2K326msb+seP1bevQZZfUu747iXh4n0oSGO8WOJMYF4eZicmFjk2xv0BVpMuB8zsqLJ3VvJpvwqZo2OdXYkERERkUHNarNxorqVgye6b1c8XN5M65lOx3F/Hw9Gp4STGt89w+KQSH/9kl0GjEFfoAHcPS2RbUXVfLj5KJPSI/H21NsiIiIi0l/aO60cqWxxdMhKK1po77Q6jocFepORFOrokEWH+WruABmwVIkAQf5ezJsQz0dbjrFu+wnumZ7k7EgiIuJCTp8+zfLly2lubqazs5OnnnqKP/zhD7S1teHr272W5vLlyxk16sKyLZ2dnTzzzDNUVlZiMpl46aWXiI+Pd9ZLEHE5NU1n2FpYxcHyZkrKmrDa7I5jMeF+pMYFOQqysCBvJyYV6V8q0M6ZP3EIuXsr+WxHGbNGxxLsr5l9RESk26pVq0hMTGTZsmVUV1fz6KOPYrFYeOmll0hNTb3iNR9//DGBgYGsXLmSzZs3s3LlSn71q1/1c3IR19LeaWX3wRo2F1RRfKIJAKPRQEJkQPcMi3HBDIsLIsDX08lJRZxHBdo53p5m7pmWyBufHeSjzUf53vy0618kIiKDQkhICAcPHgSgpaWFkJCQ616zbds27rnnHgCmTJlCdnZ2n2YUcVV2u50jlS1sKqhix4FqznZ037o4PD6YaZnR3D4lkdOnzjo5pYjrUIF2kelZ0Xyxq4yv8qu4bVw8MeF+zo4kIiIuYOHCheTk5DB37lxaWlr4/e9/z8qVK/n1r39NY2MjycnJZGdn4+194Tasuro6QkNDATAajRgMBjo6OvD0vHpnICTEF7P55mees1gCbvox+os7ZQX3yuvsrI0tZ9mwu4wvd56grLoVgPAgb+6ekcyt44cQfdHPWb5uNAW+s9/Xb8KdsoJ75e3LrCrQLmIyGnlgVjKvflDI+7ml/PiBTGdHEhERF7B69WpiYmJ4/fXXKS4uJjs7myeffJLhw4czZMgQVqxYwV//+lcef/zxqz6G3W6/6rHzGhvbrnvO9VgsAdTWnrrpx+kP7pQV3Cuvs7J2WW0UlNazuaCKgtJ6bHY7ZpOBCSMimJYZzciEUIxGA9htjnx6X/uGO2UF98rbG1mvVeCpQLvELcPCSY0LYm9JHQdPNDJ8yPVvYxERkYEtLy+PadOmAZCWlkZNTQ1z5szBZOruds2ZM4dPPvmkxzURERHU1taSlpZGZ2cndrv9mt0zEXdWUdvK5sIqtu07SUtb93T4CZEBTMuMZuLISPx93KdDJuJsWjDiEgaDgcVzhgHw7obSG/qNp4iIDGwJCQnk5+cDUFFRga+vL48//jgtLS0AbN++nZSUlB7XTJ06lXXr1gGwYcMGJk6c2L+hRfpY29kucvdU8O9/2cW/vb6Dz3aUYbPDbWPjePb741nx/fHcOjZOxZnIN6QO2hUkxwQxPi2CncU17CyuYcKISGdHEhERJ3rooYfIzs5m6dKldHV18dxzz9HY2Mjf/d3f4ePjQ2RkJD/60Y8AePLJJ3nttddYsGABW7du5eGHH8bT05Of//znTn4VIjfPZrdz8Hgjmwqr2H2wls4uGwYDZCSFMT0zmqxh4XiY9ft/kZuhAu0q7p+ZRN6hWj7YWMqYVAtmkz5sREQGKz8/P1555ZXL9i9YsOCyfa+99hqAY+0zkYGgrvkMWwpPsqWwirrm7hkXI0J8mJ4ZzZRR0YQEaHkikd6iAu0qIkJ8mT0mli93lbMhr4K547W4qIiIiAweHZ1W8g7VsrmwigPHGrEDXh4mpmVEMy0zmpS4IAwGg7Njigw4KtCu4c4pQ9lSWMWarceYmhHlVlPAioiIiHxTdrudYydPsbmgiq/3V3OmvQuAlLggpmVGMz4tAm9P/fgo0pf0L+waAnw9WTApgQ82HuGTr0/wwKxkZ0cSERER6XUtbR18ve8kmwqrqKg9DUCQvydzxiQwNSOaqFBfJycUGTxUoF3H3HHx/C2vgi92lTFnTCyhgd7Xv0hERETExVltNgqPNLC5oIr8kjqsNjsmo4Gxwy1Mz4wmPTEUk1Fj8EX6mwq06/D0MHHfjCReX3uAVV8d4fFFI50dSURERORbq6o/zeaCKrbuO0nz6Q4A4iz+TM+MZlJ6JAG+Wq9PxJlUoN2AyelRfLajjK37TjJ3fDxDIq++8reIiIiIqznT3sXO4ho2F1RRUtEMgK+XmTljYpmeGcOQSH9N+CHiIlSg3QCj0cCDc5L55Tv5vJdbyrKHbnF2JBEREZFrstvtHCprYnNBFTsP1tDRacMApCeGMj0zmtEp4XiYTc6OKSKXUIF2g0YlhpE+NISiow3sO1rPqMQwZ0cSERERuUxDy1n+ll/J59uOU9N0BoDwIG+mZUYzdVQ0YUEaTy/iylSgfQOLZw9j/5928t6GUkYODcWoWwFERETEBXR22dhzuJbNBVUUHW3ADniajUxOj2J6ZjSpQ4L1c4uIm1CB9g0MiQxg8qgotu47ybZ9J5maEe3sSCIiIjKIHXesWXaS02e71yxLjgnkjqmJpMUG4eutH/VE3I3+1X5D905PYseBGlZtOsL4tAg8PXTvtoiIiPSf1jOdfF10ks0FVZyoaQUg0M+T+ROHMC0jmphwPyyWAGprTzk5qYh8GyrQvqGwIG/mjovj0+0nWL+7nDsmJTg7koiIiAxwNpudomMNbCqoYu/hWrqs3WuWjU4JZ3pmDKOSQjGbtGaZyECgAu1bWDg5ga/yK/l423GmZ8Xg7+Ph7EgiIiIyAFU3tjnWLGs81Q5ATLgf0zKimTwqiiA/rVkmMtCoQPsWfL09uHNqIm+vP8yaLcd4+LYUZ0cSERGRAeJsRxe7imvZXFDJofLuNct8vEzMuiWGaZkxJEYHaM0ykQFMBdq3NHt0LF/uKuNveeXcOjaWiBBfZ0cSERERN2W32ympaGZTQRU7i2to77ACMCIhhGmZ0YxJteClce8ig4IKtG/Jw2zkgVnJ/G51ETlfHeGJu0c5O5KIiIi4mcZT7WzdV8XmwpNUN7QBEBboxbzx8UzNiMYS7OPkhCLS31Sg3YRxaREk7jjBjgM13D6+haSYQGdHEhERERfXZbWRX1LHpoIqCo/UY7eD2WRk0shIpmZGMyIhRGuWiQxxdaM7AAAgAElEQVRiKtBugtFg4MHZw/jF/+zh3Q0lLP/OaN0TLiIiIldUXtPKpoIqthWdpPVMJwCJ0QFMy4hmwshI/Lw16ZiIqEC7acOHhHDLsHD2ltSRX1LPLSnhzo4kIiIiLuL02U62769mU0EVx092r0vm7+PB7ePjmZYRTVyEv5MTioirUYHWC+6flUx+aR3v5ZaQkRyKyah1SERERAYrm93OgWONbCqoJO9QHV1WGwYDZCWHMS0zhqxhYVqzTESuSgVaL4gN92N6Zgxf5VeyqaCKWbfEOjuSiIiI9LOapjNsKahi674q6lu61yyLCvVlemb3mmXB/l5OTigi7kAFWi+5Z3oiX+8/yepNR5k0MhJvT721IiIiA117p5XdB2vYXFBF8YkmALw8TczIimZaRgzJsYEany4i34iqiF4S7O/F/AlD+GjLMT7fUcZd0xKdHUlERET6gN1u50hVC5sLqthxoJoz7d1rlg2PD2ZaZjTjhkfg5ak1ywY6m92G1Waly27FarNitVvpslmx2rvO/Wm78PW5445zbF3d15279vLHsNJl6+p+jPPnnjvPdtH5XfYurLaLnuei8+wGOzabDQADhnN/nnfRV5f8AsHAJduGS6+9sGW4cFKPbcdjGC6+4sLRK13rYTJ23w7M5b/QuFrGqz1/z+fpmeXa78Wl116yfW7H7anTSffvuyW2VKD1onkThpC7p4JPt59g5i0xBOlWBhERkQGj8dRZ1m0/waaCSqrqu9csCwnw4taxcUzNiCYyxNfJCd2X3W7vLnguKlCuWsxcWsCc//r8/vPFziXb3cXMpUXQpUVS1xWKoIsKLKx0Wbuf12a3OfttA8BkMGEymjAbTD2+9vAwY7PasWPvPtHxh91xreMru73nNj237faLrzr/KPaLN7Hb7Ve8tsdzXiHD+W2jwYDNZr/smssy2S/ZvsLzXe/1XH78Cpnsl2c8L7UpUQWau/DxMnP39CTe/Owgq7cc43vzhjs7koiIiNyklrYO3lh3kL0lddhsdswmA+PTIpieGc3IoaEYja53C6PdbqfL1tWjELFd1uG5uAg6V8Cc79hcWsBcVCxZ7TZHsXPlYubcObauqxRBF29fONcVGDBcKHaM5woegwlPkwcmow/eHh7YrZefY3b8acZkMGIymh3HzQYTxovONxvMmIxGTAaz4zpTj+MmTOcep/vxTFc4r/vxjQbjVW+htVgCqK091c/v4LfnTnn7OqsKtF42PTOaL3aW8dXeSuaOiyM6zM/ZkUREROQm5Gw8Qt6hGobG+DNhZDijU8Pw8jJgtVmpO1v3zbsxPTpBVyhmrtD5sdl6FkXXux3OVbo7RoOxZ4Fx7msPk1d3oWIw4u3pic1q6D7nfKFzjWLmSsXKlYoiU4+iyYTJaDz3eFcqrrof32i49uya7lREiPtSgdbLzCYjD8xK5jc5hbyfW8qP7s90diQREblJp0+fZvny5TQ3N9PZ2clTTz2FxWLh+eefx2g0EhgYyMqVK/Hx8XFck5OTwyuvvMKQIUMAmDJlCk8++aSzXoJ8C22dZ9hyYg/bz27GZ3w91QY7a1pgzS7nZbpSN8ZsMOHl4dWjgPH29MTWxbmi51wBc77IuawIupGOzeXFjOkK+7uLoAuPd72CB1T0iFxKBVofGJ0STkpcEHsO13GorAmLJcDZkURE5CasWrWKxMREli1bRnV1NY8++ijh4eE888wzZGZm8otf/IKcnBy++93v9rhuwYIFLF++3Emp5ds423WWgrr95NXkc6D+EF12K8YgCDFbGBJuwdZl71mgnC9OrtD5MV6zY3Pl2+EuP37RbXLXuJ3tUip6RNyXCrQ+YDAYeHD2MF58czfvbihhyug4Z0cSEZGbEBISwsGDBwFoaWkhJCSE3/3ud/j7+wMQGhpKU1OTMyPKTeiwdrCvvpjd1Xspqi+m89x4qCifKMoPBxLcNZTn/m4OUZFBKnpEpM+pQOsjybFBjBtuYdfBWrYUVDI8JtDZkURE5FtauHAhOTk5zJ07l5aWFn7/+987irO2tjZWr17NK6+8ctl1O3bs4PHHH6erq4vly5czcuTI/o4uV9Fp7WR/w0F2V+dTWH+ADmsHAFG+EYyJzGJsRBafbWrgaGUl9ywaicl4/Vv1RER6gwq0PnT/rGT2HK7jjbUHeO6x8ZhN+nAXEXFHq1evJiYmhtdff53i4mKys7PJycmhra2NJ598kscee4zk5OQe12RlZREaGsqsWbPYs2cPy5cvZ82aNdd8npAQX8zmm18/y51ure/PrF3WLgqqi9latoudFfmc6TwLQKS/halDxjIlfhzxQTEYDAZqGtrYXFBErMWPhTOSMZ37Hq73tm8oa99wp6zgXnn7MqsKtD4UGeLLrNGxrN9dTu6eCm4bF+/sSCIi8i3k5eUxbdo0ANLS0qipqaGjo4N/+Id/YNGiRdx3332XXZOcnOwo2kaPHk1DQwNWqxWT6eoFWGNj201ndaexR/2R1WqzcrjpCLur89lbW0hb1xkAQryCmTpkImMjsogPiO0e29UJdXWtALyxrhirzc4dE4fQ0HC63/L2FmXtG8rad9wpb29kvVaBpwKtj905dSjbik7y0ZZjTBkVja+33nIREXeTkJBAfn4+8+bNo6KiAj8/P15//XUmTJjA4sWLr3jNH//4R6Kjo1m0aBGHDh0iNDT0msWZ9B6b3UZp01F21xSwp6aA1s7uAivIM4DZcdMYE5lFYuCQq064Udd8hs0FVUSG+DBxZGR/RhcRUYHW1wJ9PXlgTgpvfHKAT7cf5/6Zyde/SEREXMpDDz1EdnY2S5cupauri2effZann36auLg4tm3bBsDEiRP54Q9/yJNPPslrr73GnXfeydNPP83bb79NV1cXL774opNfxcBmt9s52nKCvOp88moKaO5oAcDfw4/psZMZG5FJcnDiDU37/sm241htdu6cOlRjz0Sk391Qgfazn/2M/Px8DAYD2dnZZGZeWNurqqqKf/zHf6Szs5ORI0fy/PPP91lYd3Xn9CTWbDrC5zvLmD06ltBAb2dHEhGRb8DPz++ySUA2b958xXNfe+01AKKionjzzTf7PNtgZrfbKTtVwe6afHZX59PY3j2Tpq/ZhynRExgbmUVKcBIm4413Luuaz7BJ3TMRcaLrFmg7duzg+PHjvPPOO5SWlpKdnc0777zjOP7zn/+cxx57jLlz5/Lcc89RWVlJTExMn4Z2N96eZu6ZnsifPilm1aYjPL5Qs3iJiIh8G3a7ncrTJ8mrzmd3TT61Z+oB8DZ5MzFqLGMiMkkLTcFs/HY3Cal7JiLOdt1Pr23btnHbbbcB3QOem5ubaW1txd/fH5vNxu7du/nlL38JwIoVK/o2rRubOiqaL3aWsbXwJLePH0J8hL+zI4mIiLiNk6dr2F2TT151PifbagDwNHowNiKLsZFZjAwdjofJ46ae43z3LELdMxFxousWaHV1daSnpzu2Q0NDqa2txd/fn4aGBvz8/HjppZcoKipi3LhxLFu2rE8Duyuj0cDi2cP4z3fzeS+3hH988BZnRxIREXFpdWfq2X2uU1bRWgWAh9HMLZZRjI28hVFhaXiaPHvt+RzdsynqnomI83zj/r/dbu/xdXV1Nd/73veIjY3l7//+78nNzWXWrFlXvX4wrvEC3Xlnh/vztz0V5B+uo6LxDLekRjg71hW503urrH3HnfIqa99wp6wycDSebXKMKTtxqhwAk8FERvgIxkRkkRk+Em9z74/lrm8+6+ieTUpX90xEnOe6BVpERAR1dXWO7ZqaGiwWCwAhISHExMQwZMgQACZPnszhw4evWaANtjVeoGfee6Ymkn+4jj+uKuSn3x+P8SpT/DqLO723ytp33CmvsvaN3sqqIk9uRHN7C3k1BeTV5HOk+TgARoOREaGpjI3IIsuSjq+Hb59mWPu1umci4hquW6BNnTqVV199lSVLllBUVERERAT+/t3jp8xmM/Hx8Rw7doyhQ4dSVFTEwoUL+zy0O0uICmByeiTbiqrZXlTN5FFRzo4kIiLS7051tLKnZA8bS7dT0nQUO3YMGEgNTmZsZBa3WDLw9/Trlyz1zWfZlF+p7pmIuITrFmhjxowhPT2dJUuWYDAYWLFiBTk5OQQEBDB37lyys7N55plnsNvtpKamMmfOnP7I7dbunZHEzuJacr4qZVyaBY9euOVTRETE1bV1trG3tojd1Xs51FSKzW4DIDloKGMisxhtySTIq/+7ruqeiYgruaExaP/0T//UYzstLc3xdUJCAm+99VbvphrgwoN8uG1cHOu2n+DL3eXcMTHB2ZFERET6xJmusxTUFpFXk8+BhsNY7VYAEgLjmZk0gVTf4YR4Bzstn6N7FqzumYi4hm+3SIjctIWTE9iUX8nHW48zPTMGf5+bmxpYRETEVbRbO9hXt5/dNQUU1RfTZesCIM4/hrERWYyJzCTcJ8wlxmV+8rXWPRMR16ICzUn8vD1YNGUo7/ythI+3HmPJrSnOjiQiIvKtdVo7KWo4SF51PoV1++mwdQIQ5RfJuIgsxkRmEelrcXLKnhpazvKVumci4mJUoDnRnDFxrN9dzvrd5dw6Ng5LsI+zI4mIiNywLlsXBxoOsbu6gMK6Is5a2wGI8AlnTGQWYyOyiPF33cmw1m5T90xEXI8KNCfyMBu5b2YSf/hoPzlfHeH/vSv9+heJiIg4kdVm5VBjKbtr8smv3Udb1xkAQr1DmB47mbGRWcT5x2BwsWVkLqXumYi4KhVoTjZhRCSf7Shj+/5qbh8fT2J0oLMjiYiI9GCz2yhpOsrumnz21hTS2nkagGCvICZFj2NMRBZDA+Ndvii72Pnu2SLN3CgiLkYFmpMZDQYenD2M/3hrD+/+rYR//s5ot/oGJyIiA5PNbuNYywl2Veezt6aA5o7uyTwCPPyZETuFsZFZJAUlYDS4X3HT0HKWTQXd3bPJo9Q9ExHXogLNBYxICCEzOYyC0noKSuvJGhbu7EgiIjII2e12TpwqZ3d1Pnk1BTS2NwHgZ/ZlaswExkRkkRKchMno3ut3rv36OF1Wdc9ExDWpQHMRi2clU3iknvdySxmVFKpvGCIi0i/sdjsVrVXsrsknrzqfurMNAPiYvZkYNZaxkbeQFjLM7Yuy8xpaLqx7pu6ZiLgiFWguItbiz/TMaL7Kr2JL4UlmZMU4O5KIiAxgJ09Xs7s6n901+VS31QLgafJkXOQtjI3IYkTYcDyMA+/HBHXPRMTVDbxPXjd297Qkvi6qZtWmI0wcEYmX58D4baWIiLiGmrY68mry2V2dT+XpkwB4GM2MtmQwJjKLUWFpeJo8nZyy75zvnlmCvdU9ExGXpQLNhYQEeHH7hCF8vPUYn+08wV1TE50dSURE3Fz9mUbyavLJq8nnxKkKAMwGExnhIxkbkUVG+Ai8zd5OTtk/PlH3TETcgAo0F3PHxCFs3FvBp9tPMPOWWIL8Bu5vMkVEpG80tTezp6aQ3dV7OdpyAgCjwcjIsOGMjcgiMzwdXw8fJ6fsX+fXPbMEezM53XUXzxYRUYHmYny8zNw9LZH//vwQH205yiO3D3d2JBERcRMHGg7xm8KNFNeWYMeOAQPDQ4YxNiKLrIhR+Hv4OTui01zcPTOb1D0TEdelAs0FzciK4Ytd5WzcU8ltY+OIDhu831BFROTG/e3EJoobSkgKGsrYyCxGR2QQ6Bng7FhOp+6ZiLgTFWguyGwy8sDMZP5rVSEfbDzCD+/LcHYkERFxA4+P+i6BIV50nDI4O4pLcXTPJqt7JiKuT59SLmpMajjDYoPIO1TL4fImZ8cRERE34G32Jsg70NkxXErjqXa+yq8kPMibyaPUPRMR16cCzUUZDAYenD0MgHc3lGC3252cSERExP18sq27e3anxp6JiJvQJ5ULGxYXxNjhFkorWth9sNbZcURERNxK46l2NuZXqHsmIm5FY9Bc3P0zk9l7uI73N5ZyS0q4fvsnIuIEp0+fZvny5TQ3N9PZ2clTTz2FxWLh2WefBWD48OE899xzPa7p7OzkmWeeobKyEpPJxEsvvUR8fLwT0g9e6p6JiDvSp5WLiwr1ZeYtMdQ0nmHj3kpnxxERGZRWrVpFYmIib775Jq+88govvvgiL774ItnZ2bz99tu0traycePGHtd8/PHHBAYG8tZbb/HEE0+wcuVKJ6UfnNQ9ExF3pQLNDdw1NRFvTxOrNx/lTHuXs+OIiAw6ISEhNDV1T9jU0tJCcHAwFRUVZGZmAjB79my2bdvW45pt27Yxd+5cAKZMmUJeXl7/hh7ktO6ZiLgrfWK5gUA/T+6YlEDrmU4+3X7c2XFERAadhQsXUllZydy5c1m6dCn//M//TGDghdkSw8LCqK3tOVa4rq6O0NBQAIxGIwaDgY6Ojn7NPVg1nmpn497umRunqHsmIm5GY9DcxO3j49mQV87nO8qYPTqOkAAvZ0cSERk0Vq9eTUxMDK+//jrFxcU89dRTBARcWAD6RmbavZFzQkJ8MZtNN5UVwGJxn8Wp+yJrzuajdFltPDwvjeiooF597MH+3vYVZe0b7pQV3CtvX2ZVgeYmvDxM3Ds9iT99WsyqTUd4bMEIZ0cSERk08vLymDZtGgBpaWm0t7fT1XXhlvPq6moiIiJ6XBMREUFtbS1paWl0dnZit9vx9PS85vM0NrbddFaLJYDa2lM3/Tj9oS+yNp5qZ92244QHeZORENyrjz/Y39u+oqx9w52ygnvl7Y2s1yrwdIujG5maEU1suB9bCqsor2l1dhwRkUEjISGB/Px8ACoqKvDz8yM5OZldu3YB8PnnnzN9+vQe10ydOpV169YBsGHDBiZOnNi/oQep7rFnNo09ExG3pU8uN2I0Glg8Oxm7Hd7LLXV2HBGRQeOhhx6ioqKCpUuXsmzZMp599lmys7P55S9/yZIlSxgyZAhTpkwB4MknnwRgwYIF2Gw2Hn74Yf7617+ybNkyZ76EQUFjz0RkINAtjm4mIymMEQkhFB6pZ/+xBkYODXV2JBGRAc/Pz49XXnnlsv3/8z//c9m+1157DcCx9pn0n0/VPRORAUCfXm7GYOjuogG8t6EU2w0MOhcRERnoGk+1k6vumYgMACrQ3NDQqEAmjYzkePUptu+vdnYcERERp1P3TEQGCn2Cuan7ZiRhNhnI2XiEzi6rs+OIiIg4jbpnIjKQqEBzU+HBPtw6No76lrOs313h7DgiIiJO8+n27u7ZwskJ6p6JiNtzu0+xDmsnn5d8xZHm49jsNmfHcapFU4bi523m463HaD3T6ew4IiIi/a6ptXvmxrBAb6ZmRDs7jojITXO7WRwPNx3h/+S/BYC/hx8jw4YzKmwEI8NS8TH7ODld//Lz9mDh5KG8u6GEtduO8dCcFGdHEhER6VeffH2czi4bi6aoeyYiA4PbFWgjQ1N5Zvo/sKl0F/vqitlxMo8dJ/MwGowMC0okPTyNjLARRPhaMBgMzo7b524dG8v63eWs313OnDFxWIIHV5EqIiKDl7pnIjIQuV2BZjAYGBOTQbzHUOx2O+WtleyrO0Bh/QEONx3hUFMpq0rWYvEJY1TYCEaFj2BYcCJmo9u91BviYTZx/8wk/rBmP6u+OsLf35Xu7EgiIiL9Qt0zERmI3LpqMRgMxAfEEh8Qyx2Jt9HScYqi+oPsqzvAgYaDbCjfzIbyzXibvEgLTWVUWBrp4WkEegY4O3qvmjAyks92lPH1/mpunxDP0KhAZ0cSERHpUxe6Z17qnonIgOLWBdqlAj0DmBw9jsnR4+iydVHSdJR99QfYV3eAvbWF7K0tBCAhMJ5RYWmMCh9BvH+s298KaTQYeHB2Mv/x9l7e/VsJTz882u1fk4iIyLV8+vUJOrtsLNS6ZyIywAyoAu1iZqOZtNAU0kJTeCDlLqrbatlX112slTQf5XhLGWuPfkGQZyCjwtNIDxtBWmgKXiZPZ0f/VkYMDSUzOYyC0noKj9STmRzu7EgiIiJ9oqm1ndy9FYQFejFN3TMRGWAGbIF2qUhfC5FDLNw6ZAZnus5woOEw++oOUFRfzJbKHWyp3IHZaCY1OJlR4SMYFZZGmE+os2N/Iw/MSqbwSD3vbShlVGIYRqO6aCIiMvCoeyYiA9mgKdAu5mP2YUxEJmMiMrHZbRxrKaPo3EQj+xsOsr/hIO8C0X6RjolGEgOHYDKanB39muIs/kzNiGZzQRWbC6uYkRXj7EgiIiK9St0zERnoBmWBdjGjwUhSUAJJQQncmTyfxrNN7KsvZl/dAQ42HuaLE7l8cSIXX7MPI8OGkxE2ghFhw/Hz8HV29Cu6d3oSO/ZX8+GmI0wcEYmXp2sXlSIiIt/Euu3numeT1T0TkYFp0BdolwrxDmZ67CSmx06iw9rBocZSR8G2q3ovu6r3YsBAUtBQMsK7u2tRvhEuMylHSIAXt0+I5+Otx/l8Vxl3Thnq7EgiIiK9orm1nQ17KggN9GJaprpnIjIwqUC7Bk+TZ/d4tPAR2FPvofL0SQrrDlBUf4AjzccobT7Kh6WfEOYdcm7c2ghSgpPwMHk4NfcdExPYuLeST78+zsysGAL93HPiExERkYt9eq57tkjdMxEZwFSg3SCDwUCsfzSx/tHMHzqHUx2t7K8/yL76AxxoOMTG8q1sLN+Kp9Gje8218DTSw9II9grq96w+XmbumprIX784xEdbjrL09uH9nkFERKQ3qXsmIoOFCrRvKcDTn4nRY5kYPRarzUpp87HuafzriymoK6KgrgiA+IBYJg7JIsknmfiAWIyG/vmN38xbYvhyVxkb91Zy27h4okJdc8yciIjIjVD3TEQGCxVovcBkNJEakkxqSDL3pSyipq2OonPj1g43HeH9ogqgu6hLD0sj49yaa95m7z7LZDYZuX9mMr/9cB8f5Jby1H0ZffZcIiIifam5tZ1cdc9EZJBQgdYHInzDifCdxuz4aZztOkultZwtR/Ioqivm66pdfF21C5PBREpwkmPsmsU3rNdzjB1uITk2kN2Haikpb2ZYXP/fbikiInKzPt1+gg7N3Cgig4QKtD7mbfZmYvRokryGYbPbKDtVQWHdAfbVH6C48TDFjYd5//BHRPpGMCosjVHhI0gOGtora64ZDAYenD2Ml/47j3c3lPAvS8e4zGyTIiIiN6JH90zrnonIIKACrR8ZDUYSAuNJCIxnUdLtNLU3n7sVspjihkOsL/uK9WVf4WP2ZmTocNLDuica8ff0+9bPmRIXzJhUC3mHask7VMvY4RG9+IpERET61sXdMw+zumciMvDdUIH2s5/9jPz8fAwGA9nZ2WRmZjqOzZkzh6ioKEym7o7Pyy+/TGRkZN+kHWCCvYKYGjORqTET6bR2crjpCPvqD7Cv7gC7a/LZXZOPAQOJQUNIDxtBRvgIYvyivnEX7IFZyew9XMf7uaVkDQvX7SEiIuIWmk93qHsmIoPOdQu0HTt2cPz4cd555x1KS0vJzs7mnXfe6XHOH//4R/z8vn2XR8DD5MHIsOGMDBvO4pS7qTpdTVF9MYV13WuuHWk+zpoj6wjxCj43bi2N1JBheN7AmmtRob7MHB3DhrwKvsqvZM6YuH54RSIiIjdn3fbj6p6JyKBz3QJt27Zt3HbbbQAkJyfT3NxMa2sr/v7+fR5usDIYDMT4RxHjH8XchFmc7mxzrLm2v/4gmyq2saliGx5GD4aHDGNUeBqjwkYQ4h181ce8a2oiW/edZPXmo0xOj8LHS3e3ioiI62o+3cGGvApCAtQ9E5HB5bo/pdfV1ZGenu7YDg0Npba2tkeBtmLFCioqKhg7dizLli3TRBS9zM/Dl/FRoxkfNRqrzcrRlhPn1ly78B+sItY/moywEaSHj2BoYHyPNdeC/DxZMHEIqzYd5dPtJ7hvRpLzXpCIiMh1nO+ePTQ5Qd0zERlUvnEbxW6399j+8Y9/zPTp0wkKCuKpp57is88+Y/78+Ve9PiTEF7P55mcotFgCbvox+lNv5o2KDGZySvc4wJrWOvKq9pFXWci+mkNUtFax7vjfCPDyZ3R0OmOiM7glaiS+nj58546RbMyv5POdZTxwWyphQT59nrWvKWvfcae8yto33ClrX3vvvff46KOPHNv5+flkZWU5tmtqarj33nt54oknHPteffVV1qxZ4xiXfdddd7F48eL+C+3GenTPMmOcHUdEpF9dt0CLiIigrq7OsV1TU4PFYnFs33PPPY6vZ8yYwaFDh65ZoDU2tn3brA4WSwC1tadu+nH6S1/mNeDF2OCxjA0ey9nUdg42lrCv7gBF9Qf46th2vjq2HaPByLCgREaFj2D2pDA++Lya1z8s5PsLRvRr1t6mrH3HnfIqa9/orawDpchbvHixo7jasWMHn376KStWrHAc/8EPfsDdd9992XXf+973WLp0ab/lHCjUPRORwey6BdrUqVN59dVXWbJkCUVFRURERDhubzx16hT/63/9L1577TU8PT3ZuXMn8+bN6/PQcmXeZi+yLOlkWdKx2+2UtVacuxWymENNpRxqKgXAb7QfX9eHk3K0k0kJIzEbNR5NRORG/dd//Rcvv/yyY3vr1q0MHTqU6GiNk+oNLeqeicggd92fzMeMGUN6ejpLlizBYDCwYsUKcnJyCAgIYO7cucyYMYOHHnoILy8vRo4cec3umfQfg8HAkIA4hgTEsSBxLs3tp9hfX8y++gMU1R3EHHWct44eZ9UJL9JCUxkVPoL0sOFYGBi/7RYR6QsFBQVER0f3uJPkjTfeIDs7+4rnr1u3jvXr1+Pp6clPfvIT4uPj+yuq21p3bt2zB9U9E5FBymC/dFBZH+utW2bc5TYhcL28HdZOXlr1BZUdRwmLa6a5qwkAAwYi/cMJMAcQ5BVIkFcgwV5BBHl2/xnsFUigVyAeLtJxc7X39VrcKSu4V15l7Ru6xfHKfvrTn7Jw4UImTpwIQHV1NU8//TRvvPHGZecWFAvBiSMAACAASURBVBTQ3t7O+PHjWbt2LR999BG///3vr/n4XV3WXhmn7a6aTrXz+ItfEODrwR+zb+P/b+/Oo6Oq833vv2vMWBkq80wGZAphBgUZhaiIfRSFo7e7z3JpT9J2r9XX7qVN6/Icr0Pbj/roUVtP63P69sP1qC3SynFiFkUiCYOEBASSABkIIfNA5qTuHwkFYUiCpFJV5PNasmTv2rXry0/Zv/rm+xssI7gtRGTk8oxv2jKsrCYL/zLnRv7X33yxtgfy2IpR5FcfIr/qOypaK6loqsLB5fP2QEvABclb7+99ggjuTeQCLQFazVPEA7V3tdPQ3kRjeyON7U00tjf1HHf0PedjsbBq4k8ItGqPy/Pt2rWLxx57zHm8fft2rr/++ktem5GR4fz9okWL+gyLvJyRNk/7wlj/vrWA9o4ubp2VSt0QtMVQ8+a29WSK1TW8KVbwrniHItb+foCpBG2ESo4JYtb4KHYdrODEiW4Wj5/P4sT5RETYOFVRR0N7I3Vt9dS3NVDX1kBdWz11bQ3UtzdQ31ZPVUs1ZU3ll72/2WAiyKcneQv2CSbEel5FzudcUmc1WYfxTy1y7XE4HLR0tpxLujrO0HBeotV4XjLW0NFEe1f7gPcMsPgT7B/TZ6sO6amWBQQEYLWee24dOHCAhQsXXvL6p556iltuuYXp06eTnZ3N6NGjhytUr9Rwpp2te0sJtfkwV3PPRGQEU4I2gi2fl8Kew6dZt72IaddFOsf6m4wmQn1D+t34GqC1s9WZvNW3NfQkc+29iVzv+eMNJXQ7Tlz2Hn5mv56Ezdo3eTtbiQv2CSLIatMXRRlRurq7aOo4cy7B6mi6dNLVe02Xo6vf+xkNRmyWQKL8wgm0BhJktWGzBvb8svT+22ojyBpIoCUAk9HkVT/JHC6VlZXY7faLzoWFhfU5fuWVV3jyySdZsWIFTzzxBGazGYPBwFNPPTXcIXuVz7M190xEBJSgjWgRIX4smhrPxpwStu4t5eaZiVf0fl+zL9FmX6IDIi97Tbejm8b2M9S31fepwvVJ6trqOXWm4rL3MBqMBFltzuGUwb2JXEJTJKZ2H2dS52vy1bBK8VjtXR2cPlPNiYZTfYcWnj/UsKOJpvYmznQ09zvMGMBqtGCz2ki0xfUmXT1J1tmEK8h6LvHyM/vqhxxDID09nbfeeqvPuTfeeKPPcUREBE8++SQAY8aM4d133x22+LyZqmciIucoQRvhls0exY7ccj7eeZwbM2KIGPgtV8RoMBLsYyPYx0Yi8Ze9rr2r3Vl5q2+rp679XPJ29nxZ00lONJace1NR33tYTdY+QynPzYs7N18u2MembQVkSPQMLWy9zLDCs9WtRhrae5Ku1q62Ae/pb/bDZrURExB1Lumy2LBZA3qSr/MSMR8ND5ZryOfZxbR3dLNigapnIiL6pjrCBfpZuG12Eu9vK+STnSf45T/bB36TC1hNViL9w4n0D7/sNQ6HgzMdzT3Vt/YGuixtlFSfPpfU9SZ0p1uqLnsPAJslsM9QyvPnxJ1N5AIs/qrGjUDdju4+QwsvGlbYu5DG2aSrcxBDCwMtAYT52Qmy2gi3hWDt9j03vLB3WKGtd2ihfnggI9H51bN5k7SXnIiIvg0Ii6fFs3VPKZv3lHD3kjF46s8uDQYDgdYAAq0BxBPbM0cm6OI5Mp3dndS3NVJ/wXy48/99urmS0qaTl/0ss9HcW3Hru0plT4XuXHXOarK48o8sQ6Cjq6M3sbpgWGHHueOm3mRsMEMLLUYzNquNOFtsb4Ur0Fndsl0w1NDf4tdnaKHmdYlcrG/1TMvqi4goQRMsZhPL56Xy5scH+X/f2cvd81JIivbevYvMRjNhfqGE+YVe9hqHw0FrV99FTurOq8SdTeaO1Z/o9wu7v9nvkkMpzyV1wdiGcZlyh8PhjPfsFocOHL3nzx6dfe28c47e63CAA3xaobG9qfe9zrtccE/6ucelPu/c1d3Oz+/5vPOv7vt55712mZiL2gyUVlZeclhhQ3sTrV2tA7abn9kPmzWAKP/I8+ZuBfYdVtg71NDH5KPqqsgQqW9qU/VMROQCStAEgFkTovg6r5z8omryi6oZmxhC5sxEMlLDMF6DX0YNBgN+Zj/8zH7EBERd9rqu7i4aO5ouqMKd9/v2Bmrb6jh55tRl72E0GAnyCcTRDc6E5YJkB2dy03umN0npvkSyc/b4/MRloKrPSGOgp9oa5hd63iqF5yVdlgDnSoaB1kCP2XxdZKT5xxcFqp6JiFxA30oEAKPBwMP/PJmSmhbe33SY/OO1fFdcR0yYP0tmJDB7QjRWy8jrPE1Gk3NuWhIJl72urau9zzy481eorGtroKW7mc6ubgz0JA89FRgDBgCDgfOOnNUZIwbna73vouef886d/7qB81/pc9+z97z0PcDQ82nO9/j4WGhv67z8PS75eWfv2/Na3z/fxXGe/55L36PPn/yCtjh3HB4cjLHd0mf5+ACLv1YtFPFwDc3tfPz1MUICraqeiYicRwmaOBkMBqaNjSIxzJ+S001szC7mm4MV/P+fH2bd9iIWTY1j4dR4ggO0etyFfExWIv0jiPS/9DqY3jb3yJvi9aZYReScDbuKaWvv4u75qaqeiYicRwmaXFJCZCAPLBvPXQtS2bKnlC/2lbH+6+N8+k0xN0yIInNGAnERge4OU0REvFBDcztb9pZiD/JV9UxE5AJK0KRfIYE+3DU/lWU3jOLrvHI25pTwVW45X+WWk55i5+aZiYxPCtWiCSIiMmgbzq7cuGy0qmciIhdQgiaD4mM1sWhqPAsmx7G/oIoN2cXkFdWQV1RDfEQgN89MYOa4KG0wKiIi/WpobmfrnjJCAq1kzkqivq7Z3SGJiHgUJWhyRYxGA1Oui2DKdREcK29gQ3Yxu7+r5P/75BBrtxdy09R4FkyJI9BP+4OJiMjFNmQX09bRxV3zU0bk4lMiIgNRgibfW3JMEL/4p3SqFrSwZU8p2789ybovi/g46zg3ToxhyYwEokL93R2miIh4iLPVs+BAK/Mnx7o7HBERj6QETa5aeLAf/7xoND+Yk8xX+0+yaXcJW/eWsW1vGZNHh3PzzERGxwdrnpqIyAh3fvVMc89ERC5NCZoMGT8fM5kzE7lpejx7DleyIbuYfUer2He0iuQYG5kzEpk+NgKTUfPURERGmkZVz0REBkUJmgw5k9HIzHFRzBgbydHSejbmlLDvSCX/sT6ftV/4sHh6AnMzYvH31f9+IiIjxYbsElXPREQGQd+QxWUMBgPXJYRwXUIIFbXNbMopYceBct7bWsBHO44xb1Isi6fHEx7s5+5QRUTEhRqb29myp1TVMxGRQVCCJsMiKtSfH2WO4Y65KWz/tozNe0rZmFPC5t2lTB8bQeaMRFJig9wdpoiIuMDZ6tlyVc9ERAakBE2GVaCfhdtuGMXNMxPZdbCCDdklZB86Tfah04yOD+bmmYlMTgvHaNSCIiIi14I+1bNJqp6JiAxECZq4hdlkZM7EGGanR3PoRC0bsks4UFTN0dIDRIb6sWR6AjdOjMHHqp+0ioh4s40556pn2vdMRGRgStDErQwGA+NH2Rk/yk5Z1Rk25RSzM6+Ctzcd4cOvilgwJY5FU+MJtfm4O1QREblCjc3tbFb1TETkiihBE48RFx7AfbeO4855qWzbW8rWvWV8knWCz3cVM2t8FJkzEkiMsrk7TBERGaSNOSW0tXexfK6qZyIig6UETTxOcICVO+amsPT6JLLyT7Exp4SdeafYmXeK8aNCyZyRyMQUu7vDFBGRfjirZwFauVFE5EooQROPZbWYmD85jrmTYjlQWM3GnBIOHq/l4PFaYsMDWL5wNBOTgrUimIiIB1L1TETk+1GCJh7PaDAwKS2cSWnhnDjVyMacErIPVfDq+99i87ewaGo8C6fGEeRvdXeoIiICNLV0qHomIvI9KUETr5IUbeOnt4/n7gWp7DxYwWc7j/PRjmN8+s0JZqdHkzkjgZiwAHeHKSIyom3ILlb1TETke1KCJl4p1ObDfcsmcNOUWHbklrNpdwnbvz3J9m9PkpEaxs0zEhibFIrBoP3URESGk6pnIiJXRwmaeDVfq5nF0xNYNDWefUcr2ZBdQm5hNbmF1SRGBnLzzERmjIvEbDK6O1QR8WLvv/8+69evdx7n5eWRnp5Oc3Mz/v7+ADzyyCOkp6c7r+no6ODRRx/l5MmTmEwmnn32WRISEoY99uF2tnp2p6pnIiLfixI0uSYYjQamjYlk2phICsvq2ZBTwp7Dp3nz44Os3V7ITdPimT85lgBfi7tDFREvtGLFClasWAFAdnY2n332GQUFBTz77LNcd911l3zPxx9/TFBQEC+88AI7duzghRde4KWXXhrOsIfd+dWzBaqeiYh8LyoryDUnNS6YVXek89zPb2DJ9ASa2zpZ+0Uhv31tJ29vOsLpuhZ3hygiXuy1115j1apVA16XlZXFkiVLAJg9ezZ79+51dWhutzGnp3p26/VJqp6JiHxPqqDJNSs8xI97F4/mn24cxZf7e+apbdlTyta9pUwdHcHNMxNJiw92d5gi4kVyc3OJiYkhIiICgH//93+ntraW1NRUVq9eja+vr/Paqqoq7PaePRuNRiMGg4H29nas1mtzxdmmlg4271b1TETkailBk2uev6+FW2Ylsnh6PLsPn2ZDdgl7jlSy50glqbFBZM5MZOp14ZiMKiiLSP/Wrl3LnXfeCcC//Mu/MGbMGBITE3niiSd4++23eeCBBy77XofDMeD9Q0P9MQ/B3o4REbarvseV+vyzQ7S2d/HDW8YRFxsy6Pe5I9ar4U3xKlbXUKyu403xujJWJWgyYphNRq4fH82scVEcKaljQ3YJ+wuqeP3DPMKDfVk8PYG5GTH4+eivhYhc2q5du3jssccAnMMXARYtWsSnn37a59rIyEgqKysZO3YsHR0dOByOAatntbXNVx1jRISNysrGq77PlWhq6WD9l4UEBViZPjps0J/vjlivhjfFq1hdQ7G6jjfFOxSx9pfgqWQgI47BYGBMYii/vjuDp392PQunxNFwpp13txzlt3/+mr9vLaCmodXdYYqIh6moqCAgIACr1YrD4eC+++6joaEB6EncRo8e3ef6OXPm8PnnnwOwbds2Zs2aNewxD5eNOcW0tnexdFYiPpp7JiJyVVQqkBEt2u7Pj28ewx1zk/liXxlb9pbxeXYxm3aXMGNsJJkzExgVHeTuMEXEA1RWVjrnlBkMBlauXMl9992Hn58fUVFR/OpXvwLgwQcf5PXXX2fp0qXs3LmTe++9F6vVyh//+Ed3hu8yZ+eeBQVYmT8lzt3hiIh4PSVoIoDN38rtc5K5ZVYS3xw8xcacEr45WME3BysYkxBC5swEJqWFY9TG1yIjVnp6Om+99ZbzeOnSpSxduvSi615//XUA595n17qNOSW0tndxx43Jqp6JiAwBJWgi57GYjczNiOXGiTHkH69hQ3YJ+cdqOFxSR5Tdn8wZCcxOj9aXEBERzlbPSlQ9ExEZQkrQRC7BYDCQnhxGenIYpZVNbMwu4ZuDp1iz4TD/+LKIBVPiuGlqHMGBPu4OVUTEbc5Wz/5J1TMRkSGjBE1kAPERgdx/2zjump/Clr1lfLGvjI93HufzXSe4fnw0mTMTiI8IdHeYIiLDylk987ewQNUzEZEhowRNZJCCA31YPi+F225IYmdezzy1HQfK2XGgnAnJdm6emcCEUXYMmqcmIiPAJlXPRERcQgmayBXysZhYOCWO+ZNjyS2oZkN2MfnHasg/VkNcRACZMxK4fnw0FrN2sRCRa1NTSweb96h6JiLiCkrQRL4no8HA5NHhTB4dzvFTDWzMLiH70Gn++ul3fLC9iJumxrFwajyBfhZ3hyoiMqQ25ZTQ0tbFDxapeiYiMtSUoIkMgVHRQfzsBxO4e0Eqm/eUsv3bk/zjq2N8knWC2RNjuOfmsVjdHaSIyBBQ9UxExLWUoIkMIXuQLysXpnH77FF8lVvOppwSvtjXs7BIYmQgGWnhTE4LZ1SMTXuqiYhXOls9u32hqmciIq6gBE3EBfx8zGTOSOCmaXHsO1JF1sEKDhRWUXy6iY93HifI30JGajiT0sIYP8qOn4/+KoqI5zvTeq56tlDVMxERl9C3QhEXMhmNTB8bya1zUykureXg8Vr2F1SRW1jlXAHSZDQwNjGEjLRwJqWFExni5+6wRUQuqU/1zKrqmYiIKwwqQXvmmWfYv38/BoOB1atXk5GRcdE1L7zwAt9++y1r1qwZ8iBFrgV+PmamjYlg2pgIuh0Ojpc3sr+giv2FVeQfryX/eC3vbD5KbHgAk1LDmJQWTmpcECajVoMUEfc709rBpt2qnomIuNqACVp2djYnTpzgvffeo7CwkNWrV/Pee+/1uaagoICcnBwsFq1WJzIYRoOBlNggUmKDuHNeCjUNreQWVbP/aBUHT9Ty2a5iPttVTICvmYkpPclaeoqdAF/9HRMR91D1TERkeAyYoGVlZbF48WIAUlNTqa+vp6mpicDAQOc1f/zjH/nNb37Dq6++6rpIRa5h9iBfFkyOY8HkONo6uvjuRC37C6vZX1DFNwcr+OZgBUaDgdHxwUxK65m7Fm3316bYIjIszlbPbKqeiYi43IAJWlVVFRMmTHAe2+12KisrnQnaunXrmDlzJnFxemCLDAUfi6k3CQvHkXkdJaebeodCVnOkpI7DJXX8fVsBkSF+zmTtuoQQzCYNhRQR1zhbPVup6pmIiMtd8SIhDofD+fu6ujrWrVvHX//6VyoqKgb1/tBQf8zmq3+4R0TYrvoew8mb4lWsrvF9Y42MDGJaeiwAtY2t7Dl0mpxDp9h3+DSbdpewaXcJfj5mpo6JZMb4KKaNjSLE5uO2eN1BsbqGN8UqrtPc2sGm3aWqnomIDJMBE7TIyEiqqqqcx6dPnyYiIgKAb775hpqaGn74wx/S3t5OcXExzzzzDKtXr77s/Wprm6866IgIG5WVjVd9n+HiTfEqVtcYylgnJYcyKTmUjswxHCmpY39BFd8WVPF17km+zj2JAUiJDXJW4eIjAq54KORIbVtXG4mxKsnzfhtzSmhp62TlwjRVz0REhsGACdqcOXN45ZVXuOeee8jPzycyMtI5vPGWW27hlltuAaC0tJTf//73/SZnIjJ0LGYjE5LtTEi2c+/i0ZRXN7O/sIr9BdUcLa2j8GQD674swh7kw6TePdfGJoZi1cayIjJIqp6JiAy/ARO0qVOnMmHCBO655x4MBgNPPPEE69atw2azsWTJkuGIUUQGYDAYiA0PIDY8gFtnJdHU0kFeUTX7C6s5UFjNtn1lbNtXhtViZHySnUlpYWSkhhM6BEMhReTadbZ6tmJhqqpnIiLDZFBz0H7729/2OR47duxF18THx2sPNBEPEehn4foJ0Vw/IZqu7m4KSuudq0J+2/sLDpMUZWNSWs8y/knRNoxaFVJEep2tngX6WVg0Jd7d4YiIjBhXvEiIiHgXk9HImMRQxiSGsnJhGhW1zeQWVLO/sIrDxXWcqGhk/dfHCQ6wktG7Qfa8ID93hy0ibrZpd6mqZyIibqAETWSEiQr1Z8kMf5bMSKClrZP8YzXsL6wit7Car3LL+Sq3nDc+ymdsYkjPQiOpYYSHKGETGUmaWzvYmFOi6pmIiBsoQRMZwfx8zEwfG8n0sZF0OxwcO9nA/sIq8o/XkneshrxjNby9CeIiApwLjaTGBmM0aiikyLVM1TMREfdRgiYiABgNBlLjgkmNC+bnd9k4XFjpnLd26EQtn35zgk+/OUGgn4WJKXYmpYWTnmzH39fi7tBFZAipeiYi4l5K0ETkkuxBviycEsfCKXG0dXRx6Hht7zL+VWTlV5CVX4HJaGB0fLBzz7Vou7+7wxaRq7T5bPVsgapnIiLuoARNRAbkYzExeXQ4k0eH43A4KK5ocu659l1xHd8V1/He1gKiQv2cydro+GDMJqO7QxeRK3B+9WzhVO17JiLiDkrQROSKGAwGkqJtJEXb+MGcZOqb2sgt7NlzLf9YDRtzStiYU4Kfj4n05DAmpYUxMSUMm7/V3aGLyAA27y6lubd65mvVVwQREXfQ01dErkpwoA9zJ8Uyd1IsHZ3dHC6pZX9Bz9y1nO9Ok/PdaQwGSI0LZlLvMv5x4QEYtOeaiEdR9UxExDMoQRORIWMxG0lPDiM9OYz/sXg0J6vOOBcaKSirp6C0ng+2FxEW5OvcIHtsYggWs+a5iGd7//33Wb9+vfM4Ly+Pd955hyeffBKj0UhQUBAvvPACfn7ntqRYt24dL7/8MomJiQDMnj2bBx98cNhjH6yz1bO7VT0TEXErPYFFxCUMBgNxEYHERQSy9Pokmlo6OFDUk6wdKKph694ytu4tw8diYvyoUCalhZORGkZIoI+7Qxe5yIoVK1ixYgUA2dnZfPbZZzz11FM8+uijZGRk8Nxzz7Fu3Tp++MMf9nnf0qVLeeSRR9wR8hVpbu08t3KjqmciIm6lBE1EhkWgn4UbJkRzw4RoOru6KSitdy40su9oFfuOVgEwKtrWu9BIGElRNg2FFI/z2muv8fzzz+Pn50dgYCAAdruduro6N0f2/W3eU6LqmYiIh9BTWESGndlkZGxSKGOTQvnnRaOpqGl2DoU8UlLH8VONfLTjGCGBVjJ6N8gen2TXkt/idrm5ucTExBAREeE819zczEcffcTLL7980fXZ2dk88MADdHZ28sgjjzB+/Ph+7x8a6o95CIb8RkTYBn3tmZYONu0uxeZvZWXmWPx8hverwZXE6gm8KV7F6hqK1XW8KV5XxqoETUTcLsruT6bdn8wZCTS3dpJ/vIb9BVXkFlbz5f6TfLn/JGaTkXFJoT1z11LDCQv2dXfYMgKtXbuWO++803nc3NzMgw8+yP33309qamqfaydNmoTdbmfBggXs27ePRx55hP/+7//u9/61tc1XHWNEhI3KysZBX7/+62Ocaeng7gWpNDW00HTVEQzelcbqbt4Ur2J1DcXqOt4U71DE2l+CpwRNRDyKv6+ZGWMjmTE2ku5uB0XlDewvqOqdu1bNgaJq/g9HiI8IZFJaGPOnJWL3N2M0aiikuN6uXbt47LHHAOjs7GTVqlUsW7aM5cuXX3RtamqqM2mbMmUKNTU1dHV1YTJ5TiW4ubWTjdmaeyYi4kmUoImIxzIaDaTFBZMWF8xd81Opqm/p2XOtoJpDJ2opzWrik6wTBPiaSU8JY2KKnfSUMIK055q4QEVFBQEBAVitPf9/vfnmm8ycOdO5eMiF3nzzTWJiYli2bBlHjhzBbrd7VHIGsKV37tld81M090xExEPoaSwiXiM82I9FU+NZNDWetvYuDh6v4Wh5I7vyytl1sIJdByswAKNigshIDSMjNYykaBtGLTQiQ6CyshK73e48fvvtt4mPjycrKwuAWbNm8dBDD/Hggw/y+uuvc/vtt/O73/2Od999l87OTp5++ml3hX5JfVdujHd3OCIi0ksJmoh4JR+riSnXRZA5J4UV85IpqzrDgcJqcgurOVpaz7HyBj7acQybv4WJKWFMTAljQrKdQD+Lu0MXL5Wens5bb73lPN6xY8clr3v99dcBiI6OZs2aNcMS2/exZU8JZ1p7qmfDvTCIiIhcnp7IIuL1DAYD8RGBxEcEcuv1STS3dnLweA25RdUcKKxmZ94pduadwmCA1LhgMlJ6qmsJkYFaxl9GJFXPREQ8lxI0Ebnm+PuamT42kuljI+l2OCipaHIma4Vl9RSU1rPuyyKCA61MTAkjIyWM8aPs+PvqkSgjg6pnIiIX++KLLSxYcNOA1z399NMsW3YXsbGuWVxJT2URuaYZDQaSom0kRdu4ffYomlo6yD9WQ25hz4qQO3LL2ZFbjsloYHR8MBNTe4ZDxoUHqLom16SWtp7qWYCvWdUzEZFe5eUn2bx5w6AStD/84Q8u3RJACZqIjCiBfhZmjY9i1vgouh0Ojpc3kltYxYGiGg4X1/FdcR3vbyvEHuRDRkoYE1PDGJcUqhXu5JqxeU+pqmciIhd48cXnOHQon7lzZ5CZeSvl5Sd56aU/8+yzT1JZeZqWlhbuv/9nzJkzlx//+Mc89ND/ZNu2LZw500Rx8QnKykr59a8f5oYb5lx1LHoyi8iIZTQYSIkNIiU2iDvmptBwpp28Yz0LjeQfq+GLb0/yxbcnMZsMXJcQ4kzYou3+qq6JV2pp62RjdrGqZyLi0f6+tYCc704P6T1njI1k5aK0y75+770/Zt26v5OcnEpx8XH+/Oe3qK2tYebM67n11mWUlZXy+OOPMmfO3D7vO326guef/3e++WYnH330gRI0EZGhFBRgZXZ6DLPTY+jq7ubYyUZyi6rILazm4PFaDh6v5d2tBUSE+JKREs7EVDtjEkPxsXjW3lYil6PqmYjIwMaNmwCAzRbEoUP5rF+/DoPBSEND/UXXZmRMBiAyMpKmpqYh+Xw9nUVELsFkNJIWH0xafDDL56VS29hGXlHPvLX84zVs2VvKlr2lWMxGxiaGkpHaU12LDPFzd+gil6TqmYh4i5WL0vqtdrmaxdKzJc+mTZ/T0NDAa6+9RUNDAz/5yY8vutZkOvdDWofDMSSfrwRNRGQQQm0+zJ0Uy9xJsXR2dVNYVk9uYXXP6pC9v9gE0Xb/nmQtJYzrEkKwmI3uDl0EgC291bPl81Q9ExG5kNFopKurq8+5uro6YmJiMRqNbN++lY6OjmGJRU9oEZErZDYZGZMYypjEUFYsTKO6vtWZpB08XsvGnBI25pTgYzExLinUmbBFRNjcHbqMUC1tnWzorZ7dNE3VMxGRCyUlJXP48HfExMQSEhICwIIFi3j00f/JwYN53HbbD4iMjOSvf33T5bEoQRMRuUphwb4smBLHgilxdHR2c6S0jgOFPYuNfFtQxbcFVQAkRtsYnxRKRkoYafHBmE2qrsnwUPVMRKR/AP8HOwAADTRJREFUoaGhrFv3SZ9zMTGx/O1v7zqPMzNvBSAiwkZlZSMpKeeGYaakpPHqq38Zklj0lBYRGUIWs5EJo+xMGGXnnptGc7quhQO9e659d6KW4lONfL6rGD8fE+NH2clICSM9JYxQm4+7Q5drlKpnIiLeRQmaiIgLRYb4cdO0eG6aFk9QiD879pRwoKia3MIq9hyuZM/hSgASIwOZmBpGRmoYKbFBmIyqrsnQUPVMRMS76EktIjJMfCwmMnqTsP+xeDQVtS3kFlZzoLCKwyV1FJ9u4pOsEwT4mpmQbGdiSs/ctaAAq7tDFy+l6pmIiPdRgiYi4gYGg4Fouz/Rdn8yZyTQ2t7JdyfqelaFLKwi+9Bpsg+dxgCMirH1JGupYSRHB2E0apNsGZyte3uqZ3eqeiYi4jX0tBYR8QC+VjOTR4czeXQ4Dsd1nKw605usVXO0tJ5j5Y2s//o4gX4WJqbYmZgaRnpyGIF+FneHLh6qpa2Tz3f1VM8Wq3omIuI1lKCJiHgYg8FAXEQgcRGB3DoriZa2Tg4er3Huu5aVX0FWfgUGA6TGBvfMXUsJIyEqEKNB1TXpoeqZiIh30hNbRMTD+fmYmTYmkmljInE4HJScbupdaKSagrJ6Csrq+ceXRQQHWJmY0jPHbfwoO/6+esSPVKqeiYi4xt13386nn34y8IVXQb23iIgXMRgMJEbZSIyycdsNozjT2kH+sZ7qWl5RNTsOlLPjQDlGg4G0+OCeRUlSwoiLCMCg6tqI4ayezU1W9UxExMvoqS0i4sUCfC3MHBfFzHFRdDscnDjV2LNJdlE1R0vqOFJSx9ovCgm1+TiTtXGjQvG16vF/rWpu7WBDdknvyo0J7g5HRMQr3H//D3nmmReIjo7m1Klyfv/7h4mIiKSlpYXW1lZ+85vfMX58+rDEoh5aROQaYTQYSI4JIjkmiB/cmExDczv5RTUcKOrZKHv7tyfZ/u1JTEYD1yWEOJf8j7b7q7p2Dfnk62M0tXRw59xkDXMVEa+0ruBj9p0+MKT3nBI5keVpyy77+rx5C/n66y+5666VfPXVdubNW0hq6mjmzVvAnj05vP3233j66f9nSGO6HD25RUSuUUH+Vm5Ij+aG9Gi6ux0UlTf07rtWzaETtRw6Uct7WwsID/YlI7Vnz7WxSaH4WEzuDl2+p5a2Tv7xRaGqZyIiV2jevIW8+upL3HXXSnbs2M5DD/2Gd99dwzvvrKGjowNfX99hi0UJmojICGA0GkiLCyYtLpjl81Kob2rjQFENuUXV5B+rYeveMrbuLcNsMjI2KYSF0xOZlByqVSG9zNa9pTQ2t6t6JiJebXnasn6rXa6QkpJKdXUlFRWnaGxs5KuvviA8PJLHH/9ffPfdQV599aVhi0VPbxGRESg40IcbM2K4MSOGzq5uCsvqexK2wmryimrIK6rh2Z9fT1Sov7tDlSuQf6wGm79F1TMRke/hhhtu5C9/+TNz586nrq6W1NTRAGzfvo3Ozs5hi0MJmojICGc2GRmTGMqYxFDuXpBKTUMrmE3Y/bUJtrf5ybLxBIf4Y+rudncoIiJeZ/78hfziF/fzv//3O7S2tvDUU0+wbdtm7rprJZs3b+STT9YPSxxK0EREpA97kC8RETYqKxvdHYpcIXuQLxFhAfpvJyLyPYwbN4Ht23c5j99+e63z9zfeOB+A2277AQEBATQ3u+45a3TZnUVEREREROSKqIImIiIygPfff5/1688NbcnLy+Odd97hX//1XwEYM2YM//Zv/9bnPR0dHTz66KOcPHkSk8nEs88+S0KC5oaJiEj/VEETEREZwIoVK1izZg1r1qzhV7/6FXfccQdPP/00q1ev5t1336WpqYnt27f3ec/HH39MUFAQ77zzDr/4xS944YUX3BS9iIh4EyVoIiIiV+C1117jpz/9KWVlZWRkZACwcOFCsrKy+lyXlZXFkiVLAJg9ezZ79+4d9lhFRMT7aIijiIjIIOXm5hITE4PJZCIoKMh5PiwsjMrKyj7XVlVVYbfbATAajRgMBtrb27FarZe9f2ioP2bz1W8UHhFhu+p7DBdvihW8K17F6hqK1XW8KV5XxqoETUREZJDWrl3LnXfeedF5h8Mx4HsHc01tbfP3iut83rQCpzfFCt4Vr2J1DcXqOt4U71DE2l+CpyGOIiIig7Rr1y6mTJmC3W6nrq7Oeb6iooLIyMg+10ZGRjqrah0dHTgcjn6rZyIiIqAETUREZFAqKioICAjAarVisVhISUlh9+7dAGzcuJG5c+f2uX7OnDl8/vnnAGzbto1Zs2YNe8wiIuJ9lKCJiIgMQmVlpXNOGcDq1at58cUXueeee0hMTGT27NkAPPjggwAsXbqU7u5u7r33Xt5++20efvhht8QtIiLeRXPQREREBiE9PZ233nrLeZyWlsZ//dd/XXTd66+/DuDc+0xERORKGByDmbUsIiIiIiIiLqchjiIiIiIiIh5CCZqIiIiIiIiHUIImIiIiIiLiIZSgiYiIiIiIeAglaCIiIiIiIh5CCZqIiIiIiIiH8Ph90J555hn279+PwWBg9erVZGRkOF/buXMnL774IiaTiXnz5vHLX/7SjZH2H+uiRYuIjo7GZDIB8PzzzxMVFeWuUAE4cuQIq1at4r777uNHP/pRn9c8rW37i9XT2vZPf/oTe/bsobOzk5///OdkZmY6X/O0du0vVk9q15aWFh599FGqq6tpa2tj1apVLFy40Pm6J7XrQLF6Uruer7W1lWXLlrFq1SqWL1/uPO9JbSt9eVP/CN7VR6p/dB31kUNPfaRruaV/dHiwXbt2OX72s585HA6Ho6CgwLFy5co+r996662OkydPOrq6uhz33nuv4+jRo+4I0+FwDBzrwoULHU1NTe4I7ZLOnDnj+NGPfuR47LHHHGvWrLnodU9q24Fi9aS2zcrKcvzkJz9xOBwOR01NjWP+/Pl9Xvekdh0oVk9q108++cTxl7/8xeFwOBylpaWOzMzMPq97UrsOFKsntev5XnzxRcfy5csdH3zwQZ/zntS2co439Y8Oh3f1keofXUd9pGuoj3Qtd/SPHj3EMSsri8WLFwOQmppKfX09TU1NAJSUlBAcHExMTAxGo5H58+eTlZXlkbF6IqvVyptvvklkZORFr3la2/YXq6eZMWMGL7/8MgBBQUG0tLTQ1dUFeF679herp1m6dCk//elPASgvL+/z0zRPa9f+YvVUhYWFFBQUsGDBgj7nPa1t5Rxv6h/Bu/pI9Y+uoz7SNdRHuo67+kePHuJYVVXFhAkTnMd2u53KykoCAwOprKzEbrf3ea2kpMQdYQL9x3rWE088QVlZGdOmTePhhx/GYDC4I1QAzGYzZvOl//N7Wtv2F+tZntK2JpMJf39/ANauXcu8efOcZXpPa9f+Yj3LU9r1rHvuuYdTp07xxhtvOM95WruedalYz/K0dn3uued4/PHH+fDDD/uc99S2Fe/qH8G7+kj1j66jPtK11EcOPXf1jx6doF3I4XC4O4RBuzDWX//618ydO5fg4GB++ctfsmHDBm655RY3RXdt8cS23bx5M2vXruU///M/3RrHYFwuVk9s13fffZdDhw7xu9/9jvXr17u9M+zP5WL1tHb98MMPmTx5MgkJCW6LQa6eN/WPoD5yuHhqu6qPdA31kUPLnf2jRw9xjIyMpKqqynl8+vRpIiIiLvlaRUWFW0v8/cUKcMcddxAWFobZbGbevHkcOXLEHWEOiqe17UA8rW2/+uor3njjDd58801sNpvzvCe26+ViBc9q17y8PMrLywEYN24cXV1d1NTUAJ7Xrv3FCp7VrgBffPEFW7ZsYeXKlbz//vv8+c9/ZufOnYDnta2c4039I1w7faQntm1/PLFd1UcOPfWRruHO/tGjE7Q5c+awYcMGAPLz84mMjHQOh4iPj6epqYnS0lI6OzvZtm0bc+bM8chYGxsbeeCBB2hvbwcgJyeH0aNHuy3WgXha2/bH09q2sbGRP/3pT/zHf/wHISEhfV7ztHbtL1ZPa9fdu3c7f3pZVVVFc3MzoaGhgOe1a3+xelq7Arz00kt88MEH/P3vf2fFihWsWrWK2bNnA57XtnKON/WPcO30kZ7Ytpfjie2qPtI11Ee6hjv7R4PDw8dFPP/88+zevRuDwcATTzzBwYMHsdlsLFmyhJycHJ5//nkAMjMzeeCBBzw21r/97W98+OGH+Pj4MH78eB5//HG3lp7z8vJ47rnnKCsrw2w2ExUVxaJFi4iPj/e4th0oVk9q2/fee49XXnmF5ORk57lZs2YxZswYj2vXgWL1pHZtbW3lD3/4A+Xl5bS2tvLQQw9RV1fnkc+CgWL1pHa90CuvvEJcXByAR7at9OVN/SN4Tx+p/tF11Ee6hvpI1xvu/tHjEzQREREREZGRwqOHOIqIiIiIiIwkStBEREREREQ8hBI0ERERERERD6EETURERERExEMoQRMREREREfEQStBEREREREQ8hBI0ERERERERD6EETURERERExEP8X8DXTIaExxBpAAAAAElFTkSuQmCC", "text/plain": [ "" ] }, "metadata": { "tags": [] }, "output_type": "display_data" } ], "source": [ "# Plot performance\n", "trainer.plot_performance()" ] }, { "cell_type": "code", "execution_count": 74, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "4EmFhiX-FMaV", "outputId": "c689c3e6-972b-4499-81b6-8812a25076d1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test loss: 0.49\n", "Test Accuracy: 82.9%\n" ] } ], "source": [ "# Test performance\n", "trainer.run_test_loop()\n", "print(\"Test loss: {0:.2f}\".format(trainer.train_state['test_loss']))\n", "print(\"Test Accuracy: {0:.1f}%\".format(trainer.train_state['test_acc']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "zVU1zakYFMVF" }, "outputs": [], "source": [ "# Save all results\n", "trainer.save_train_state()" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qLoKfjSpFw7t" }, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "ANrPcS7Hp_CP" }, "outputs": [], "source": [ "class Inference(object):\n", " def __init__(self, model, vectorizer):\n", " self.model = model\n", " self.vectorizer = vectorizer\n", " \n", " def predict_category(self, title):\n", " # Vectorize\n", " vectorized_title, title_length = self.vectorizer.vectorize(title)\n", " vectorized_title = torch.tensor(vectorized_title).unsqueeze(0)\n", " title_length = torch.tensor([title_length]).long()\n", " \n", " # Forward pass\n", " self.model.eval()\n", " y_pred = self.model(x_in=vectorized_title, x_lengths=title_length, \n", " apply_softmax=True)\n", "\n", " # Top category\n", " y_prob, indices = y_pred.max(dim=1)\n", " index = indices.item()\n", "\n", " # Predicted category\n", " category = vectorizer.category_vocab.lookup_index(index)\n", " probability = y_prob.item()\n", " return {'category': category, 'probability': probability}\n", " \n", " def predict_top_k(self, title, k):\n", " # Vectorize\n", " vectorized_title, title_length = self.vectorizer.vectorize(title)\n", " vectorized_title = torch.tensor(vectorized_title).unsqueeze(0)\n", " title_length = torch.tensor([title_length]).long()\n", " \n", " # Forward pass\n", " self.model.eval()\n", " y_pred = self.model(x_in=vectorized_title, x_lengths=title_length, \n", " apply_softmax=True)\n", " \n", " # Top k categories\n", " y_prob, indices = torch.topk(y_pred, k=k)\n", " probabilities = y_prob.detach().numpy()[0]\n", " indices = indices.detach().numpy()[0]\n", "\n", " # Results\n", " results = []\n", " for probability, index in zip(probabilities, indices):\n", " category = self.vectorizer.category_vocab.lookup_index(index)\n", " results.append({'category': category, 'probability': probability})\n", "\n", " return results" ] }, { "cell_type": "code", "execution_count": 77, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 136 }, "colab_type": "code", "id": "W6wr68o2p_Eh", "outputId": "3e94c736-3ad3-4c70-b24c-591edbe069ad" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# Load the model\n", "dataset = NewsDataset.load_dataset_and_load_vectorizer(\n", " args.split_data_file, args.vectorizer_file)\n", "vectorizer = dataset.vectorizer\n", "model = NewsModel(embedding_dim=args.embedding_dim, \n", " num_embeddings=len(vectorizer.title_vocab), \n", " rnn_hidden_dim=args.rnn_hidden_dim,\n", " hidden_dim=args.hidden_dim,\n", " output_dim=len(vectorizer.category_vocab),\n", " num_layers=args.num_layers,\n", " bidirectional=args.bidirectional,\n", " dropout_p=args.dropout_p, \n", " pretrained_embeddings=None, \n", " padding_idx=vectorizer.title_vocab.mask_index)\n", "model.load_state_dict(torch.load(args.model_state_file))\n", "model = model.to(\"cpu\")\n", "print (model.named_modules)" ] }, { "cell_type": "code", "execution_count": 80, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 51 }, "colab_type": "code", "id": "JPKgHxsfN954", "outputId": "c9f21a76-8307-4737-c785-01f1004891b6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Enter a title to classify: President Obama signed the petition during the White House dinner.\n", "President Obama signed the petition during the White House dinner. → World (p=0.62)\n" ] } ], "source": [ "# Inference\n", "inference = Inference(model=model, vectorizer=vectorizer)\n", "title = input(\"Enter a title to classify: \")\n", "prediction = inference.predict_category(preprocess_text(title))\n", "print(\"{} → {} (p={:0.2f})\".format(title, prediction['category'], \n", " prediction['probability']))" ] }, { "cell_type": "code", "execution_count": 82, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 102 }, "colab_type": "code", "id": "JRdz4wzuQR4N", "outputId": "9a349bf0-16ba-402d-9133-11ce27e1ec59" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "President Obama signed the petition during the White House dinner.\n", "World (p=0.62)\n", "Sci/Tech (p=0.17)\n", "Business (p=0.15)\n", "Sports (p=0.06)\n" ] } ], "source": [ "# Top-k inference\n", "top_k = inference.predict_top_k(preprocess_text(title), k=len(vectorizer.category_vocab))\n", "print (\"{}\".format(title))\n", "for result in top_k:\n", " print (\"{} (p={:0.2f})\".format(result['category'], \n", " result['probability']))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "noPtpaHZ6NAW" }, "source": [ "# Layer normalization" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "t3bu5cEP6PSb" }, "source": [ "Recall from our [CNN notebook](https://colab.research.google.com/github/LisonEvf/practicalAI-cn/blob/master/notebooks/11_Convolutional_Neural_Networks.ipynb) that we used batch normalization to deal with internal covariant shift. Our activations will experience the same issues with RNNs but we will use a technique known as [layer normalization](https://arxiv.org/abs/1607.06450) (layernorm) to maintain zero mean unit variance on the activations. \n", "\n", "With layernorm it's a bit different from batchnorm. We compute the mean and var for every single sample (instead of each hidden dim) for each layer independently and then do the operations on the activations before they go through the nonlinearities. PyTorch's [LayerNorm](https://pytorch.org/docs/stable/nn.html#torch.nn.LayerNorm) class abstracts all of this for us when we feed in inputs to the layer." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "-G-mVmUL61Fk" }, "source": [ "$ LN = \\frac{a - \\mu_{L}}{\\sqrt{\\sigma^2_{L} + \\epsilon}} * \\gamma + \\beta $\n", "\n", "where:\n", "* $a$ = activation | $\\in \\mathbb{R}^{NXH}$ ($N$ is the number of samples, $H$ is the hidden dim)\n", "* $ \\mu_{L}$ = mean of input| $\\in \\mathbb{R}^{NX1}$\n", "* $\\sigma^2_{L}$ = variance of input | $\\in \\mathbb{R}^{NX1}$\n", "* $epsilon$ = noise\n", "* $\\gamma$ = scale parameter (learned parameter)\n", "* $\\beta$ = shift parameter (learned parameter)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "P0e9TnQ581-1" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "oAhAHcgZBMFe" }, "source": [ "The most useful location to apply layernorm will be inside the RNN on the activations before the non-linearities. However, this is a bit involved and though PyTorch has a [LayerNorm](https://pytorch.org/docs/stable/nn.html#torch.nn.LayerNorm) class, they do not have an RNN that has built in layernorm yet. You could implement the RNN yourself and manually add layernorm by following a similar setup like below.\n", "\n", "```python\n", "# Layernorm\n", "for t in range(seq_size):\n", " # Normalize over hidden dim\n", " layernorm = nn.LayerNorm(args.hidden_dim)\n", " # Activating the module\n", " a = layernorm(x)\n", "```" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "1YHneO3SStOp" }, "source": [ "# TODO" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "gGHaKTe1SuEk" }, "source": [ "- interpretability with task to see which words were most influential" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "13_Recurrent_Neural_Networks", "provenance": [], "toc_visible": true, "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }