{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Summary of my results:\n", "\n", "model | train_loss | valid_loss | seq2seq_acc | bleu\n", "-------------------|----------|----------|----------|----------\n", "seq2seq | 3.355085 | 4.272877 | 0.382089 | 0.291899\n", "\\+ teacher forcing | 3.154585 |\t4.022432 | 0.407792 | 0.310715\n", "\\+ attention | 1.452292 | 3.420485 | 0.498205 | 0.413232\n", "transformer | 1.913152 | 2.349686 | 0.781749 | 0.612880" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Translation with an RNN" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook is modified from [this one](https://github.com/fastai/fastai_docs/blob/master/dev_course/dl2/translation.ipynb) created by Sylvain Gugger.\n", "\n", "Today we will be tackling the task of translation. We will be translating from French to English, and to keep our task a manageable size, we will limit ourselves to translating questions.\n", "\n", "This task is an example of sequence to sequence (seq2seq). Seq2seq can be more challenging than classification, since the output is of variable length (and typically different from the length of the input." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "French/English parallel texts from http://www.statmt.org/wmt15/translation-task.html . It was created by Chris Callison-Burch, who crawled millions of web pages and then used *a set of simple heuristics to transform French URLs onto English URLs (i.e. replacing \"fr\" with \"en\" and about 40 other hand-written rules), and assume that these documents are translations of each other*." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Translation is much tougher in straight PyTorch: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from fastai.text import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download and preprocess our data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will start by reducing the original dataset to questions. You only need to execute this once, uncomment to run. The dataset can be downloaded [here](https://s3.amazonaws.com/fast-ai-nlp/giga-fren.tgz)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "path = Config().data_path()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# ! wget https://s3.amazonaws.com/fast-ai-nlp/giga-fren.tgz -P {path}" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# ! tar xf {path}/giga-fren.tgz -C {path} " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[PosixPath('/home/racheltho/.fastai/data/giga-fren/models'),\n", " PosixPath('/home/racheltho/.fastai/data/giga-fren/giga-fren.release2.fixed.fr'),\n", " PosixPath('/home/racheltho/.fastai/data/giga-fren/cc.en.300.bin'),\n", " PosixPath('/home/racheltho/.fastai/data/giga-fren/data_save.pkl'),\n", " PosixPath('/home/racheltho/.fastai/data/giga-fren/giga-fren.release2.fixed.en'),\n", " PosixPath('/home/racheltho/.fastai/data/giga-fren/cc.fr.300.bin'),\n", " PosixPath('/home/racheltho/.fastai/data/giga-fren/questions_easy.csv')]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = Config().data_path()/'giga-fren'\n", "path.ls()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# with open(path/'giga-fren.release2.fixed.fr') as f: fr = f.read().split('\\n')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# with open(path/'giga-fren.release2.fixed.en') as f: en = f.read().split('\\n')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use regex to pick out questions by finding the strings in the English dataset that start with \"Wh\" and end with a question mark. You only need to run these lines once:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# re_eq = re.compile('^(Wh[^?.!]+\\?)')\n", "# re_fq = re.compile('^([^?.!]+\\?)')\n", "# en_fname = path/'giga-fren.release2.fixed.en'\n", "# fr_fname = path/'giga-fren.release2.fixed.fr'" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# lines = ((re_eq.search(eq), re_fq.search(fq)) \n", "# for eq, fq in zip(open(en_fname, encoding='utf-8'), open(fr_fname, encoding='utf-8')))\n", "# qs = [(e.group(), f.group()) for e,f in lines if e and f]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# qs = [(q1,q2) for q1,q2 in qs]\n", "# df = pd.DataFrame({'fr': [q[1] for q in qs], 'en': [q[0] for q in qs]}, columns = ['en', 'fr'])\n", "# df.to_csv(path/'questions_easy.csv', index=False)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[PosixPath('/home/racheltho/.fastai/data/giga-fren/models'),\n", " PosixPath('/home/racheltho/.fastai/data/giga-fren/giga-fren.release2.fixed.fr'),\n", " PosixPath('/home/racheltho/.fastai/data/giga-fren/cc.en.300.bin'),\n", " PosixPath('/home/racheltho/.fastai/data/giga-fren/data_save.pkl'),\n", " PosixPath('/home/racheltho/.fastai/data/giga-fren/giga-fren.release2.fixed.en'),\n", " PosixPath('/home/racheltho/.fastai/data/giga-fren/cc.fr.300.bin'),\n", " PosixPath('/home/racheltho/.fastai/data/giga-fren/questions_easy.csv')]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path.ls()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load our data into a DataBunch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Our questions look like this now:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
enfr
0What is light ?Qu’est-ce que la lumière?
1Who are we?Où sommes-nous?
2Where did we come from?D'où venons-nous?
3What would we do without it?Que ferions-nous sans elle ?
4What is the absolute location (latitude and lo...Quelle sont les coordonnées (latitude et longi...
\n", "
" ], "text/plain": [ " en \\\n", "0 What is light ? \n", "1 Who are we? \n", "2 Where did we come from? \n", "3 What would we do without it? \n", "4 What is the absolute location (latitude and lo... \n", "\n", " fr \n", "0 Qu’est-ce que la lumière? \n", "1 Où sommes-nous? \n", "2 D'où venons-nous? \n", "3 Que ferions-nous sans elle ? \n", "4 Quelle sont les coordonnées (latitude et longi... " ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(path/'questions_easy.csv')\n", "df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To make it simple, we lowercase everything." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "df['en'] = df['en'].apply(lambda x:x.lower())\n", "df['fr'] = df['fr'].apply(lambda x:x.lower())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first thing is that we will need to collate inputs and targets in a batch: they have different lengths so we need to add padding to make the sequence length the same;" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def seq2seq_collate(samples, pad_idx=1, pad_first=True, backwards=False):\n", " \"Function that collect samples and adds padding. Flips token order if needed\"\n", " samples = to_data(samples)\n", " max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples])\n", " res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx\n", " res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx\n", " if backwards: pad_first = not pad_first\n", " for i,s in enumerate(samples):\n", " if pad_first: \n", " res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])\n", " else: \n", " res_x[i,:len(s[0]):],res_y[i,:len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])\n", " if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1)\n", " return res_x,res_y" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we create a special `DataBunch` that uses this collate function." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "doc(Dataset)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "doc(DataLoader)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [], "source": [ "doc(DataBunch)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqDataBunch(TextDataBunch):\n", " \"Create a `TextDataBunch` suitable for training an RNN classifier.\"\n", " @classmethod\n", " def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1,\n", " dl_tfms=None, pad_first=False, device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch:\n", " \"Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`\"\n", " datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n", " val_bs = ifnone(val_bs, bs)\n", " collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)\n", " train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2)\n", " train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)\n", " dataloaders = [train_dl]\n", " for ds in datasets[1:]:\n", " lengths = [len(t) for t in ds.x.items]\n", " sampler = SortSampler(ds.x, key=lengths.__getitem__)\n", " dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))\n", " return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "SortishSampler??" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And a subclass of `TextList` that will use this `DataBunch` class in the call `.databunch` and will use `TextList` to label (since our targets are other texts)." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqTextList(TextList):\n", " _bunch = Seq2SeqDataBunch\n", " _label_cls = TextList" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Thats all we need to use the data block API!" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "src = Seq2SeqTextList.from_df(df, path = path, cols='fr').split_by_rand_pct(seed=42).label_from_df(cols='en', label_cls=TextList)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "28.0" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.percentile([len(o) for o in src.train.x.items] + [len(o) for o in src.valid.x.items], 90)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "23.0" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.percentile([len(o) for o in src.train.y.items] + [len(o) for o in src.valid.y.items], 90)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We remove the items where one of the target is more than 30 tokens long." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "src = src.filter_by_func(lambda x,y: len(x) > 30 or len(y) > 30)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "48352" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(src.train) + len(src.valid)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "scrolled": true }, "outputs": [], "source": [ "data = src.databunch()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "data.save()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Seq2SeqDataBunch;\n", "\n", "Train: LabelList (38706 items)\n", "x: Seq2SeqTextList\n", "xxbos qu’est - ce que la lumière ?,xxbos où sommes - nous ?,xxbos d'où venons - nous ?,xxbos que ferions - nous sans elle ?,xxbos quel est le groupe autochtone principal sur l’île de vancouver ?\n", "y: TextList\n", "xxbos what is light ?,xxbos who are we ?,xxbos where did we come from ?,xxbos what would we do without it ?,xxbos what is the major aboriginal group on vancouver island ?\n", "Path: /home/racheltho/.fastai/data/giga-fren;\n", "\n", "Valid: LabelList (9646 items)\n", "x: Seq2SeqTextList\n", "xxbos quels pourraient être les effets sur l’instrument de xxunk et sur l’aide humanitaire qui ne sont pas co - xxunk ?,xxbos quand la source primaire a - t - elle été créée ?,xxbos pourquoi tant de soldats ont - ils fait xxunk de ne pas voir ce qui s'est passé le 4 et le 16 mars ?,xxbos quels sont les taux d'impôt sur le revenu au canada pour 2007 ?,xxbos pourquoi le programme devrait - il intéresser les employeurs et les fournisseurs de services ?\n", "y: TextList\n", "xxbos what would be the resulting effects on the pre - accession instrument and humanitarian aid that are not co - decided ?,xxbos when was the primary source created ?,xxbos why did so many soldiers look the other way in relation to the incidents of march 4th and march xxunk ?,xxbos what are the income tax rates in canada for 2007 ?,xxbos why is the program good for employers and service providers ?\n", "Path: /home/racheltho/.fastai/data/giga-fren;\n", "\n", "Test: None" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('/home/racheltho/.fastai/data/giga-fren')" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "data = load_data(path)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
texttarget
xxbos quelle position devrait - il défendre pour concilier les objectifs stratégiques des divers traités internationaux sur la propriété intellectuelle , l’environnement , et les droits sociaux et économiques ?xxbos what position should canada advocate with respect to xxunk the policy objectives of various international treaties on intellectual property , the environment , and social and economic rights ?
xxbos que faire s’il semble que pour sauver un stock local de poisson de fond , il xxunk réduire ou éliminer la prédation par les phoques dans le secteur ?xxbos what if it appears that in some xxunk , saving a local groundfish stock would require reducing or xxunk seal predation in that area ?
xxbos quels sont les impacts économiques produits par les xxunk millions de dollars dépensés par les résidents du yukon qui ont participé à des activités reliées à la nature ?xxbos what are the economic impacts that result from participation in nature - related activities by residents of the yukon ?
xxbos quelles pourraient être les raisons pour lesquelles un programme n ' a pas marché aussi bien que prévu , même si les employés ont effectué un travail excellent ?xxbos what would be some of the reasons why a program could be less than successful , even if staff were excellent ?
xxbos quand les pièces , les feuilles ou les fils métalliques contenant des substances de l’inrp figurant dans les parties 1a et 1b perdent - ils leur statut xxunk ?xxbos when do metal parts , sheets or xxunk containing npri part xxunk and xxunk substances lose their status as articles ?
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create our Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pretrained embeddings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You will need to download the word embeddings (crawl vectors) from the fastText docs. FastText has [pre-trained word vectors](https://fasttext.cc/docs/en/crawl-vectors.html) for 157 languages, trained on Common Crawl and Wikipedia. These models were trained using CBOW.\n", "\n", "If you need a refresher on word embeddings, you can check out my gentle intro in [this word embedding workshop](https://www.youtube.com/watch?v=25nC0n9ERq4&list=PLtmWHNX-gukLQlMvtRJ19s7-8MrnRV6h6&index=10&t=0s) with accompanying [github repo](https://github.com/fastai/word-embeddings-workshop). \n", "\n", "More reading on CBOW (Continuous Bag of Words vs. Skip-grams):\n", "\n", "- [fastText tutorial](https://fasttext.cc/docs/en/unsupervised-tutorial.html#advanced-readers-skipgram-versus-cbow)\n", "- [StackOverflow](https://stackoverflow.com/questions/38287772/cbow-v-s-skip-gram-why-invert-context-and-target-words)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To install fastText:\n", "```\n", "$ git clone https://github.com/facebookresearch/fastText.git\n", "$ cd fastText\n", "$ pip install .\n", "```" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "import fastText as ft" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The lines to download the word vectors only need to be run once:" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "# ! wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz -P {path}\n", "# ! wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.bin.gz -P {path}" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "# gunzip {path} / cc.en.300.bin.gz\n", "# gunzip {path} / cc.fr.300.bin.gz" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "scrolled": true }, "outputs": [], "source": [ "fr_vecs = ft.load_model(str((path/'cc.fr.300.bin')))\n", "en_vecs = ft.load_model(str((path/'cc.en.300.bin')))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create an embedding module with the pretrained vectors and random data for the missing parts." ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "def create_emb(vecs, itos, em_sz=300, mult=1.):\n", " emb = nn.Embedding(len(itos), em_sz, padding_idx=1)\n", " wgts = emb.weight.data\n", " vec_dic = {w:vecs.get_word_vector(w) for w in vecs.get_words()}\n", " miss = []\n", " for i,w in enumerate(itos):\n", " try: wgts[i] = tensor(vec_dic[w])\n", " except: miss.append(w)\n", " return emb" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "emb_enc = create_emb(fr_vecs, data.x.vocab.itos)\n", "emb_dec = create_emb(en_vecs, data.y.vocab.itos)" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([11336, 300]), torch.Size([8152, 300]))" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "emb_enc.weight.size(), emb_dec.weight.size()" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "model_path = Config().model_path()" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "torch.save(emb_enc, model_path/'fr_emb.pth')\n", "torch.save(emb_dec, model_path/'en_emb.pth')" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "emb_enc = torch.load(model_path/'fr_emb.pth')\n", "emb_dec = torch.load(model_path/'en_emb.pth')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Our Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Review Question: What are the two types of numbers in deep learning?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Encoders & Decoders" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model in itself consists in an encoder and a decoder\n", "\n", "![Seq2seq model](images/seq2seq.png)\n", "\n", "
Diagram from Smerity's Peeking into the neural network architecture used for Google's Neural Machine Translation
\n", "\n", "The encoder is a recurrent neural net and we feed it our input sentence, producing an output (that we discard for now) and a hidden state. A **hidden state** is the activations that come out of an RNN.\n", "\n", "That hidden state is then given to the decoder (an other RNN) which uses it in conjunction with the outputs it predicts to get produce the translation. We loop until the decoder produces a padding token (or at 30 iterations to make sure it's not an infinite loop at the beginning of training). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use a GRU for our encoder and a separate GRU for our decoder. Other options are to use LSTMs or QRNNs (see here). GRUs, LSTMs, and QRNNs all solve the problem of how RNNs can lack long-term memory.\n", "\n", "Links:\n", "- [Illustrated Guide to LSTM’s and GRU’s: A step by step explanation](https://towardsdatascience.com/illustrated-guide-to-lstms-and-gru-s-a-step-by-step-explanation-44e9eb85bf21)\n", "- [fast.ai implementation of seq2seq with QRNNs](https://github.com/fastai/fastai_docs/blob/master/dev_course/dl2/translation.ipynb)" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqRNN(nn.Module):\n", " def __init__(self, emb_enc, emb_dec, \n", " nh, out_sl, \n", " nl=2, bos_idx=0, pad_idx=1):\n", " super().__init__()\n", " self.nl,self.nh,self.out_sl = nl,nh,out_sl\n", " self.bos_idx,self.pad_idx = bos_idx,pad_idx\n", " self.em_sz_enc = emb_enc.embedding_dim\n", " self.em_sz_dec = emb_dec.embedding_dim\n", " self.voc_sz_dec = emb_dec.num_embeddings\n", " \n", " self.emb_enc = emb_enc\n", " self.emb_enc_drop = nn.Dropout(0.15)\n", " self.gru_enc = nn.GRU(self.em_sz_enc, nh, num_layers=nl,\n", " dropout=0.25, batch_first=True)\n", " self.out_enc = nn.Linear(nh, self.em_sz_dec, bias=False)\n", " \n", " self.emb_dec = emb_dec\n", " self.gru_dec = nn.GRU(self.em_sz_dec, self.em_sz_dec, num_layers=nl,\n", " dropout=0.1, batch_first=True)\n", " self.out_drop = nn.Dropout(0.35)\n", " self.out = nn.Linear(self.em_sz_dec, self.voc_sz_dec)\n", " self.out.weight.data = self.emb_dec.weight.data\n", " \n", " def encoder(self, bs, inp):\n", " h = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " _, h = self.gru_enc(emb, h)\n", " h = self.out_enc(h)\n", " return h\n", " \n", " def decoder(self, dec_inp, h):\n", " emb = self.emb_dec(dec_inp).unsqueeze(1)\n", " outp, h = self.gru_dec(emb, h)\n", " outp = self.out(self.out_drop(outp[:,0]))\n", " return h, outp\n", " \n", " def forward(self, inp):\n", " bs, sl = inp.size()\n", " h = self.encoder(bs, inp)\n", " dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n", " \n", " res = []\n", " for i in range(self.out_sl):\n", " h, outp = self.decoder(dec_inp, h)\n", " dec_inp = outp.max(1)[1]\n", " res.append(outp)\n", " if (dec_inp==self.pad_idx).all(): break\n", " return torch.stack(res, dim=1)\n", " \n", " def initHidden(self, bs): return one_param(self).new_zeros(self.nl, bs, self.nh)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "xb,yb = next(iter(data.valid_dl))" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([64, 30])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xb.shape" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "rnn = Seq2SeqRNN(emb_enc, emb_dec, 256, 30)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "Seq2SeqRNN(\n", " (emb_enc): Embedding(11336, 300, padding_idx=1)\n", " (emb_enc_drop): Dropout(p=0.15)\n", " (gru_enc): GRU(300, 256, num_layers=2, batch_first=True, dropout=0.25)\n", " (out_enc): Linear(in_features=256, out_features=300, bias=False)\n", " (emb_dec): Embedding(8152, 300, padding_idx=1)\n", " (gru_dec): GRU(300, 300, num_layers=2, batch_first=True, dropout=0.1)\n", " (out_drop): Dropout(p=0.35)\n", " (out): Linear(in_features=300, out_features=8152, bias=True)\n", ")" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rnn" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "30" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(xb[0])" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "h = rnn.encoder(64, xb.cpu())" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([2, 64, 300])" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" } ], "source": [ "h.size()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The loss pads output and target so that they are of the same size before using the usual flattened version of cross entropy. We do the same for accuracy." ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "def seq2seq_loss(out, targ, pad_idx=1):\n", " bs,targ_len = targ.size()\n", " _,out_len,vs = out.size()\n", " if targ_len>out_len: out = F.pad(out, (0,0,0,targ_len-out_len,0,0), value=pad_idx)\n", " if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)\n", " return CrossEntropyFlat()(out, targ)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train our model" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, rnn, loss_func=seq2seq_loss)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": 56, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl4XNV9xvHvb0b7blmSLa+yHdsY2xiwDBgIOwQIO4FAyxZoSNosJZSmaXmetA0la5O0CWmABkM2nAQILVvYC4TFNrLxFttgkG0sr5Ily9bImtGMTv+YkTw4lrXNPu/neebRzJ07c3/HI8+rc8+955pzDhERyV6eZBcgIiLJpSAQEclyCgIRkSynIBARyXIKAhGRLKcgEBHJcgoCEZEspyAQEclyCgIRkSyXk+wCBqOqqsrV1dUluwwRkbSyfPnyFudc9UDrpUUQ1NXV0dDQkOwyRETSipltGcx62jUkIpLlFAQiIllOQSAikuXiFgRmtsjMdpvZ2qhld5nZajNbaWbPm9m4eG1fREQGJ549goeA8w9Z9j3n3DHOuWOBp4Cvx3H7IiIyCHELAufca0DrIcv2RT0sBnRVHBGRJEv44aNmdjdwA9AOnJno7YuIyEclfLDYOXenc24i8Gvgi/2tZ2a3mlmDmTU0NzcnrkARkRTQ5gvwrT+sp7G5I+7bSuZRQw8DV/b3pHPufudcvXOuvrp6wBPjREQyyvItbdz3aiPN+/1x31ZCg8DMpkc9vATYkMjti4iki7e3tJLrNeZNrIj7tuI2RmBmi4EzgCozawL+GbjQzGYCPcAW4PPx2r6ISDpbvrmN2ePKKcj1xn1bcQsC59y1h1n8QLy2JyKSKbq6Q6xuaufGkycnZHs6s1hEJMWs3dZOINTD/MmVCdmegkBEJMU0bGkDYP7kUQnZnoJARCTFNGxuY0pVMdWl+QnZnoJARCSFOOdYvqU1Yb0BUBCIiKSUD5p9tHV2U68gEBHJTsu3hKdoq69LzEAxKAhERFLK25vbGFWUy7Tq4oRtU0EgIpJClm9pY/7kUZhZwrapIBARSREtHX42tfgSulsIFAQiIiljeeT8gUQOFIOCQEQkZTRsbiXP62HO+PKEbldBICKSIhq2tHHMhMRMNBdNQSAikgK6ukOs3dbO/LrE7hYCBYGISEpYtXUv3SFHfYImmoumIBARSQGJnmgumoJARCTJWjr8vLh+F9Oqi6kszkv49uN2YRoRETmyDTv3sej1TfzPyu0Egj3804VHJaUOBYGISIJ90NzB1/93LW+8v4eCXA9X10/gM6dMYVp1SVLqURCIiCTQ9r0HuO5nS+nqDvHV82fyFydMoqIo8buDoikIREQSpM0X4IZFy+joCvK7zy9kVm1ZsksCFAQiIgnRGQhy88/f5sPWTn5x8wkpEwKgo4ZEROKuO9TDFx9+h5Vb9/Kja47lpKmjk13SRygIRETi7J9+v4aXN+zmrkvncP6c2mSX82cUBCIicbRh5z4eWd7E506fynUnTU52OYelIBARiaNnVu/AY/BXp05Ndin9UhCIiMSJc46n1+zghCmVVJfmJ7ucfikIRETi5L1dHXzQ7OOTx4xLdilHpCAQEYmTp1dvx2Nw/uyxyS7liBQEIiJxkC67hUBBICISF327heam3uGih1IQiIjEwdNrdmAGn5iT2ruFQEEgIhIXz6zZwQl1ldSUFiS7lAEpCEREYuy9Xft5f3cHnzwm9XcLgYJARCTmnl4d3i10fhrsFgIFgYhIzKXTbiFQEIiIxNTGXfvZmEa7hUBBICISU71HC6XLbiFQEIiIxNQr7zZz3MSKtNktBAoCEZGY6fAHWbOtnZOnVSW7lCGJWxCY2SIz221ma6OWfc/MNpjZajN73Mwq4rV9EZFEW76ljVCP48SplckuZUji2SN4CDj/kGUvAHOcc8cA7wH/GMfti4gk1JLGPeR4jPmTRyW7lCGJWxA4514DWg9Z9rxzLhh5uASYEK/ti4gk2pLGPcybWEFRXk6ySxmSZI4R3Az8IYnbFxGJGZ8/yOqmdk6ckl67hSBJQWBmdwJB4NdHWOdWM2sws4bm5ubEFSciMgy94wMnTR2d7FKGLOFBYGY3AhcBf+mcc/2t55y73zlX75yrr66uTlyBIiLDkK7jAwAJ3ZFlZucD/wCc7pzrTOS2RUTiaUnjHo6ZUE5xfnqND0B8Dx9dDLwFzDSzJjO7BbgHKAVeMLOVZnZvvLYvIpIonYHI+EAa7haCOPYInHPXHmbxA/HanohIsizf0kYwTccHQGcWi4iM2JLGPXg9Rn0ajg+AgkBEZMSWNLYyd3x6jg+AgkBEZETC4wN703a3ECgIRERGZMWWvXSHHCel2fxC0RQEIiIj0Dc+UKcgEBHJSksa9zBnfDklaTo+AAoCEZFhOxAIsappb1rvFgIFgYjIsDVsaQ2PD0xJ34FiUBCIiAzbS+t3k5/jSbsL0RxKQSAiMgzOOV7asItTPlaVdtcfOJSCQERkGDbu7mBr6wHOnlWT7FJGTEEgIjIML6zbBcDZR41JciUjpyAQERmGl9bvYu74csaWFyS7lBFTEIiIDFFLh593tu7NiN1CoCAQERmy/9uwG+fgnFnpv1sIFAQiIkP20vrdjC0rYPa4smSXEhMKAhGRIfAHQ/xxYzNnzarBzJJdTkwoCEREhmBJYyu+QIhzM2S3ECgIRESG5MV1uyjM9bJwWnpPKxFNQSAiMkjOOV5av4tTp1dRkOtNdjkxoyAQERmk9Tv2s729i3My5LDRXgoCEZFBemn9LszgrAw4mziagkBEZJBe2rCbeRMqqC7NT3YpMaUgEBEZpA9bO5kzPjPOHYimIBARGSSfP0hxGl+Ssj8KAhGRQQiGevAHeyhO82sPHI6CQERkEDq7QwAU5WXOYaO9FAQiIoPQ6Q8HgXYNiYhkKV8gCKhHICKStfp6BBojEBHJTn09gnz1CEREslJnJAjUIxARyVK+vsFi9QhERLJSZ99gsXoEIiJZyafBYhGR7Obza7BYRCSr+QIh8nI85Hoz72sz81okIhIHnYEgxRl4MhkMMgjMbJqZ5Ufun2FmXzaziviWJiKSOnz+UEYOFMPgewSPASEz+xjwADAFePhILzCzRWa228zWRi27ysz+ZGY9ZlY/7KpFRBKsMxDMyENHYfBB0OOcCwKXA//hnPsKUDvAax4Czj9k2VrgCuC1oRQpIpJsvkDm9ggG26puM7sWuBG4OLIs90gvcM69ZmZ1hyxbD2BmQ6tSRCTJOv3qEXwGWAjc7ZzbZGZTgF/FrywRkdSS9T0C59w64MsAZjYKKHXOfTuehZnZrcCtAJMmTYrnpkREBqSjhsxeMbMyM6sEVgEPmtkP4lmYc+5+51y9c66+uro6npsSERmQzx+iKAMvSgOD3zVU7pzbR3ig90Hn3HzgnPiVJSKSWrK+RwDkmFktcDXw1GBeYGaLgbeAmWbWZGa3mNnlZtZEeLzhaTN7blhVi4gkUE+PozPbxwiAbwDPAW845942s6nAxiO9wDl3bT9PPT6E+kREkq73wvUlGbpraLCDxY8Aj0Q9bgSujFdRIiKppDODJ5yDwQ8WTzCzxyNnCu8ys8fMbEK8ixMRSQW+QOZOQQ2DHyN4EHgCGAeMB56MLBMRyXh9U1Bn+WBxtXPuQedcMHJ7CNAxnSKSFTp7ewQZOkYw2CBoMbPrzMwbuV0H7IlnYSIiqcIXUI8A4GbCh47uBHYAnyI87YSISMbr9KtHgHPuQ+fcJc65audcjXPuMsInl4mIZDz1CPp3e8yqEBFJYb2Hj2b7UUOHo7mkRSQr9B4+mtXnEfTDxawKEZEU5vMHyfEYeRl44XoY4MxiM9vP4b/wDSiMS0UiIikmPM+QN2MvqnXEIHDOlSaqEBGRVOXzBzN2niEY2a4hEZGs0BnI3GsRgIJARGRAvgy+FgEoCEREBtTpz9xrEYCCQERkQL5AkOIMPXQUFAQiIgPK5KuTgYJARGRAPr96BCIiWS3TewSZ2zLgkYatvNW4h5L8HIrzc8I/IyeFdId6CPU4gj0O5xylBbmUF4ZvZYW5FOR6sENm0cjxGl6PkePp/ekh12vk5njI83rI9XrwejLzhBORbOWcy/ijhjI6CLa2HWBpYyu+QBCfP0h3KP6zYng9Rn6OJ3Lzkp/roSgvh5J8b+RnDoV53r7nC3I95OWEQ8dFTuJ2h5TZezKj18Khk+Mx8nI85Hg8faHUewPoDvUQCPaEf4Yc+Tmej4ZhvpeC3IM15Od4yM/1UJDjxaMgE/mIru4enCOjzyPI3JYBt587g9vPndH32B8M4YvMK+71GLneg1+e+7uCtB/o7rv5u3sOeTdHqAeCPQd7EsGQI9gT/tINRH35+rt78AfDjw90h+gMhPD5g7R1Bmhq6+RAIIQ/2BO5hRISUIOV6zUKcrwU5HkpyguHV1Hkfkl+zkd6TWWFuYwuzqOqJJ/RJeGfZQU5GXsavmSnjr6ZR9UjyAjhv34P/2Hml3ipKslPcEVhoZ5wEBgH//rv/TJ17mAvIeQc3aEeukOOYORnyDlCkUDqcQ7nIC8nvJsqL8dDrseDPxiiwx/E5w//7AwE8Qd76OoO/dnPru7ex70BFuJAd5D9XUF2tHf1BWUgeGhQhuV6jYqiPEYV5TKqKI9RRXlUleYxtqyAMWUF1JYXMrY8n+qSAsoKFRqS+jr7rkWQuV+XmduyNHKkcYXeL0oz8GDkDmv2w1xqhllbf7q6Q7Qf6KbVF6Clw8+ejvDPlo4AezsDtPoC7O3s5v3mDpZs8rO3s/vP3iPP66GqJI+q0nxqSvMZV1HYdxtfUcD4iiJqSvO1u0qSypfhVycDBYEMU0FueJxhTFnBoNbv6g6xs72Lnfu62LWvi+b9fpo7/LTsD9Dc4aep7QDLNrWyryv4kdfleT2MqyhgwqgiJowqZPLoYupGFzF5dDGTRxdl9H9OSQ29PYJMPnxU/4skIQpyvdRVFVNXVXzE9fZ3dbOjvYttbQfYtvcATW0HaGrrZNveA7y4fhctHYGPrF9bXsDsceXMHV/OnPFlzBlfTk1pvnY5Scz0XZRGu4ZEEqO0IJfSglxmjDn8DOgd/iBb9vjYsqeTTS0+3tu1n7Xb2nlpw66+o61GFeUyvaaU6WNKmF5TwoyxpcwaW8ao4rwEtkQyRd9lKtUjEEkNJfk5zB5Xzuxx5R9Z3uEPsn7HPtZua+e9XR1s3LWfJ1dt/8iupnHlBRw9royja8s4ely4BzG+olC9Bzmi3h5Bpl6vGBQEkiFK8nNYUFfJgrrKvmXOOZr3+9mwcz/rd+xj3Y59rNu+j5c37KYnqvcwZ3w4WI6fVEF9XSWV6jlIlINHDalHIJJ2zIyasgJqygo4bUZ13/Ku7hAbdoZ3Kf1peztrtrXzwOuN3Bs5n2NqdTELJldSXzeKE6eMZmKleg3ZTEcNiWSgglwvx06s4NiJFX3LurpDrNnWTsPmNho2t/Lcup38tmErEB6QPmFKJSdOGc3J00YzeXSRgiGLdAaCeAzyczJ3ajYFgQjhcDi4a2kaPT2O95s7WNq4h6WbWnnzgz3878rtAEwYVcjHp1fz8elVnDxtNBVF2pWUyXz+EMV5mX3yo4JA5DA8HmPGmFJmjCnl+oV1OOdobPHxxvst/HFjC0+t2s7iZR/iMaivq+TcWWM4e1YNU6tLkl26xJjPH6Qog48YAgWByKCYGdOqS5hWXcINC+sIhnpY1bSXV95t5sX1u7n7mfXc/cx6plYXc+7RY7hgTi3zJpRn9F+R2SI882hmf1VmdutE4iTH62H+5ErmT67k786bSVNbJy9v2M0L63bxwB83cd+rjdSWF/CJ2WM5f85YFtRVaoryNNUZCKlHICIDmzCqiBsW1nHDwjraO7t5acMu/rB2J4uXfchDb26mqiSPc48eywVzxrJw2uhhzhklyeDzq0cgIkNUXpTLFcdP4IrjJ+DzB3nl3Wb+sHYHT6zcxuJlH1JemMs5s8ZwwZyxnDq9ioLczP5rM911BkJUlyZnZuJEURCIxFFxfg6fPKaWTx5TS1d3iD9ubOEPa3fwwrqdPLaiiZL8HM46qoYL5ozljJk1FGbwSUvpyhcIMjmvKNllxJWCQCRBCnK9nHv0GM49egyBYA9vftDCs2t38tyfdvLEqu0U5no5a1YNF82tVSikkM7I4aOZLG6tM7NFwEXAbufcnMiySuC3QB2wGbjaOdcWrxpEUlVejoczZtZwxswa/u2yOSzb1MrTa3bw7NqdPL16B0V5Xs6eNYZL5o3j9BnV5GXwyUypzhfQ4aMj8RBwD/CLqGVfA15yzn3bzL4WefwPcaxBJOXleD2c/LEqTv5YFf96yWyWbWrlqUgoPLlqO+WFuVw4dyyXzBvPiVMqdaGeBHLO0RlQj2DYnHOvmVndIYsvBc6I3P858AoKApE+h4bC6++38MTK7TyxcjuLl21lTFk+n5w7jovn1XLsxAqdpxBn/mD4GuXqEcTWGOfcDgDn3A4z6/cKimZ2K3ArwKRJkxJUnkjqyPV6OHNmDWfOrOFAIMSL63fx5Krt/GrJFha9sYkJowq5eN44Ljt2PDPHHv76DTIynVkwBTWk8GCxc+5+4H6A+vp6l+RyRJKqMM/LxfPGcfG8cezr6ua5tTt5cvUO7n+tkZ++8gFH15ZxxfHjuWTeOGoGeflQGZjPn/lTUEPig2CXmdVGegO1wO4Eb18k7ZUV5HJV/USuqp9IS4efp1Zt5/F3tvFvT6/nm8+s59Tp1Vx5/HjOO3qsjjwaIV/f9YpT9m/mmEh0654AbgS+Hfn5vwnevkhGqSrJ56ZTpnDTKVP4oLmD/3lnG79fsY2//c1KSvJz+OTcWq6cP4EFdaM0njAMvdciUI9gmMxsMeGB4SozawL+mXAA/M7MbgE+BK6K1/ZFss206hL+7ryZfOWcGSzd1MpjK5p4avV2ftuwlUmVRXxq/gSunD+B8RWFyS41bXSqRzAyzrlr+3nq7HhtU0TCU2gvnDaahdNG841LZ/Ps2p08uryJH7zwHj988T1OmVbFp+ZP4Pw5YzW9xQD6rk6mwWIRSVdFeTl98x5tbe3ksRVNPLq8idt+u5LyJ3K5/LjxXHvCJB111I+DPYLMDkwFgUiWmFhZxG3nzODLZ01nSeMeFr+9lYeXhmdHPX5SBX9x4mQuOqZWvYQovkDvGEFmf1VmdutE5M94PNZ30lqrL8DvVzTx8LIPueORVdz11Do+NX8Cf3HiJKbpamt0+tUjEJEMV1mcx199fCq3nDqFJY2t/HrpFn7x1mYeeH0TJ08bzQ0LJ3POrDHkZOn1E3yBEGZQkKMgEJEMZ3ZwgLl5v5/fNYR3G33+VysYV17AdQsnc82CSVQW5yW71ITq9AcpyvVm/PxO2RnzItKv6tJ8vnDmx3jtq2dy3/XzmTy6mO8++y4Lv/USX310FRt37U92iQnjC4QoyvBDR0E9AhHph9djfGL2WD4xeyzv7tzPz9/azO9XNPG7hibOPqqGz50+LeNPVOsMBCnO8JPJQD0CERmEmWNL+eblc3nza2dz2znTWfFhG1ff9xaX/9ebvLBuF85l5nRgPn8w448YAgWBiAxBZXEet50zgze/djZ3XTqbPT4/n/1FA5f95A1efa854wLB5w9l/BFDoCAQkWEozPNy/cI6Xv67M/jOlXNp6Qhw46JlXH3fWyxp3JPs8mKmM6AegYjIEeV6PXx6wSRevuN07rp0Nlv2dHLN/Uv4618tZ0f7gWSXN2K+QIiSLBgsVhCIyIjl54R7CK999UzuOG8GL2/Yzdnff5X/fq2R7lBPsssbtk5/MONnHgUFgYjEUEGuly+eNZ0Xbz+dk6aO5u5n1nPxj19n+ZbWZJc2LL5AKONnHgUFgYjEwcTKIh64sZ57r5tP+4FuPnXvW3zzmfV0dYeSXdqQhMcI1CMQERkWM+P8OWN54fbTufaESdz/WiOX3PM6a7e1J7u0QQkEe+gOOfUIRERGqiQ/h29ePpcHP7OAvZ3dXPaTN/jPFzem/NhB7xTU6hGIiMTImTNreP4rp3Hh3Fp++OJ7XP/AUvZ0+JNdVr96p6DO9IvSgIJARBKooiiPH117HD+4eh4rPtzLJfe8wbrt+5Jd1mH1TkFdpBPKRERi74rjJ/DI5xYS6nFc+dM3eXr1jqTW0x3q4Z8eX/OR8Qv1CERE4mzexAqe+NIpzKot5QsPr+Dfn3uXnp7kTFHx7s79PLz0Q25+6G12tncB4XmGQGMEIiJxVVNawOJbT+LT9RO55//e55afv017Z3fC69jU4gNgjy/Arb9soKs71BcEOmpIRCTO8nO8fPvKudx12Rxef7+FS37yOht2JnbcoDcI/uPTx7JmWztffXQ1Ph01JCKSOGbG9SdN5je3nsSBQIjLf/ImT6zanrDtb2rxMb6ikIvnjeOO82byxKrt3PtKI4DmGhIRSaT5kyt56kunMntcGV9e/A7ff/7dhGy3scXHlKpiAP7mjGlceuw43o1ciS0brlCmIBCRlFJTVsDDnw2PG/z45ff56SsfxHV7zjkamzuYWh0OAjPjO1cew7wJ5eR5PRTmZv6uocyPOhFJO3k5Hr51xVw6u0N859kNVBTlcu0Jk+KyrT2+APu7gn09AghPnveLW06ksbkDb4ZfuB4UBCKSojwe4/tXzWPfgW7ufHwNFYW5XDC3Nubb6R0ojg4CgPLCXI6bNCrm20tF2jUkIikrL8fDT687nuMmjeJvf7OS1ze2xHwbm5rDQTC1qiTm750uFAQiktKK8nJYdOMCplQVc+svG1i+pS2m7/9BSwd5Xg/jRxXG9H3TiYJARFJeeVEuv7zlBGpK87lx0bKYXuhmU7OPyaOLsmIsoD8KAhFJCzVlBfzm1oVUl+ZzwwPLaNgcmzDYFHXoaLZSEIhI2hhbXsBvbj2JMWUF3LBoGcs2jSwMQj2OLXs6mVKtIBARSRtjysJhMLa8gJseXMbSxj3Dfq9tbQcIhHqYqh6BiEh6qYmEQW15ATc/9Dbv7+4Y1vs0toRfN7U6e48YAgWBiKSpmtICfvVXJ1KY5+Vzv2ygIzJb6FD0dw5BtlEQiEjaqi0v5MfXHs/mPZ3c8btVODe06xlsavFRWpDD6OK8OFWYHhQEIpLWFk4bzT9ecBTP/mkn977aOKTXbmrxMbWqGLPsPXQUFAQikgFuOXUKFx1Ty/ee28AfNzYP+nWNzTp0FJIUBGb2t2a21sz+ZGa3JaMGEckcZsZ3P3UM02tK+dLid9ja2jnga7q6Q2zbeyDrB4ohCUFgZnOAzwInAPOAi8xseqLrEJHMUpSXw73XzyfU47hx0TKa9/uPuP7mPRoo7pWMHsEsYIlzrtM5FwReBS5PQh0ikmGmVBWz6KYF7Gjv4rqfLaXNF+h33d7J5hQEyQmCtcBpZjbazIqAC4GJSahDRDLQgrpKfnZjPZv2+Lh+0VLaD3Qfdr1GHTraJ+FB4JxbD3wHeAF4FlgF/NkBwGZ2q5k1mFlDc/PgB39ERE75WBX3XTefd3fu56YHlx32HIPGZh9jywoozoJLUQ4kKYPFzrkHnHPHO+dOA1qBjYdZ537nXL1zrr66ujrxRYpIWjvzqBp+dM1xrG5q55aH3qarO/SR5ze1dKg3EJGso4ZqIj8nAVcAi5NRh4hktgvm1vL9q+axdFMrdz6+9iMnnG1q8WX9ZHO9ktUneszMRgPdwBecc7G90oSISMRlx42nscXHj17ayJzxZXzmlCm0+QK0dXZn/WRzvZISBM65jydjuyKSnW47ezrrtu/j355ez8yxpeTneAGYqh4BoDOLRSQLeDzGDz89jylVxXzh1yv6zj6eksXXKY6mIBCRrFBakMv9188n2OP4jxc3kuMxJmTxdYqjKQhEJGtMrS7hx9cehxlMqiwi16uvQEjeYLGISFKcMbOGH1w9DyO7ZxyNpiAQkaxz+XETkl1CSlG/SEQkyykIRESynIJARCTLKQhERLKcgkBEJMspCEREspyCQEQkyykIRESynEXPz52qzKwZ2BK1qBxoP2S1wSyLftzf/SqgZYQlH66Woa7X33MDtXMwbY5FG49U41DWG2w70/mz7O/5bP2dHc5nq9/ZwYt+v8nOuYGv7OWcS7sbcP9wlkU/PsL9hnjUN9T1+ntuoHYOps2xaGOi25nOn+VgPrdMaGcsPsv+2qnf2di3MfqWrruGnhzmsicHcT8WBvt+R1qvv+cGaudg2xwLiWxnOn+W/T2frb+zw/1sYyFV2pkKbeyTFruGEsnMGpxz9cmuI56yoY2gdmaSbGgjJK+d6dojiKf7k11AAmRDG0HtzCTZ0EZIUjvVIxARyXLqEYiIZLmMDQIzW2Rmu81s7TBeO9/M1pjZ+2b2IzOzqOe+ZGbvmtmfzOy7sa166OLRTjP7FzPbZmYrI7cLY1/5kGuNy+cZef4OM3NmVhW7iocuTp/lXWa2OvI5Pm9m42Jf+ZBrjUc7v2dmGyJtfdzMKmJf+ZBrjUc7r4p89/SYWezGEmJxSFYq3oDTgOOBtcN47TJgIWDAH4ALIsvPBF4E8iOPazK0nf8C3JHstsW7nZHnJgLPET5PpSrT2giURa3zZeDeTPwsgfOAnMj97wDfydB2zgJmAq8A9bGqNWN7BM6514DW6GVmNs3MnjWz5Wb2RzM76tDXmVkt4f88b7nwv/wvgMsiT/818G3nnD+yjd3xbcXA4tTOlBPHdv4Q+CqQ9MGyeLTRObcvatViMredzzvngpFVlwBJvwRZnNq53jn3bqxrzdgg6Mf9wJecc/OBO4D/Osw644GmqMdNkWUAM4CPm9lSM3vVzBbEtdrhG2k7Ab4Y6WYvMrNR8St1REbUTjO7BNjmnFsV70JHYMSfpZndbWZbgb8Evh7HWkciFr+zvW4m/Fd0KoplO2Mma65ZbGYlwMnAI1G7iPMPt+phlvX+FZUDjAJOAhYAvzOzqZHUTgkxaudPgbsij+8Cvk/4P1fKGGk7zawIuJPwLoWUFKPPEufcncCdZvaPwBeBf45xqSMSq3ZG3utOIAj8OpY1xkIs2xlrWRMEhHs/e51zx0YvNDMvsDzy8AnCX4LR3coJwPbI/Sbg95FRoRXnAAAEV0lEQVQv/mVm1kN4bpDmeBY+RCNup3NuV9Tr/ht4Kp4FD9NI2zkNmAKsivynnACsMLMTnHM741z7YMXidzbaw8DTpFgQEKN2mtmNwEXA2an0x1mUWH+esZPsAZV43oA6ogZqgDeBqyL3DZjXz+veJvxXf+9AzYWR5Z8HvhG5PwPYSuRcjAxrZ23UOl8BfpPsNsajnYess5kkDxbH6bOcHrXOl4BHk93GOLXzfGAdUJ3stsWznVHPv0IMB4uT/g8Vxw9gMbAD6Cb8l/wthP8CfBZYFfml+Xo/r60H1gIfAPf0ftkDecCvIs+tAM7K0Hb+ElgDrCb8F0ptotqTyHYesk7SgyBOn+VjkeWrCc9BMz4TP0vgfcJ/mK2M3FLh6Kh4tPPyyHv5gV3Ac7GoVWcWi4hkuWw7akhERA6hIBARyXIKAhGRLKcgEBHJcgoCEZEspyCQtGRmHQne3s/M7OgYvVcoMhvoWjN7cqCZMs2swsz+JhbbFjkcHT4qacnMOpxzJTF8vxx3cNKyuIqu3cx+DrznnLv7COvXAU855+Ykoj7JPuoRSMYws2oze8zM3o7cToksP8HM3jSzdyI/Z0aW32Rmj5jZk8DzZnaGmb1iZo9G5rb/ddQ88K/0zv9uZh2RidxWmdkSMxsTWT4t8vhtM/vGIHstb3FwErwSM3vJzFZYeC76SyPrfBuYFulFfC+y7t9HtrPazP41hv+MkoUUBJJJ/hP4oXNuAXAl8LPI8g3Aac654wjPvvnNqNcsBG50zp0VeXwccBtwNDAVOOUw2ykGljjn5gGvAZ+N2v5/RrY/4NwwkTlmziZ89jZAF3C5c+54wte++H4kiL4GfOCcO9Y59/dmdh4wHTgBOBaYb2anDbQ9kf5k06RzkvnOAY6OmtmxzMxKgXLg52Y2nfAsjrlRr3nBORc9Z/wy51wTgJmtJDxXzOuHbCfAwYn4lgPnRu4v5OC1Dh4G/r2fOguj3ns58EJkuQHfjHyp9xDuKYw5zOvPi9zeiTwuIRwMr/WzPZEjUhBIJvEAC51zB6IXmtmPgf9zzl0e2d/+StTTvkPewx91P8Th/490u4ODa/2tcyQHnHPHmlk54UD5AvAjwtcLqAbmO+e6zWwzUHCY1xvwLefcfUPcrshhadeQZJLnCc+3D4CZ9U73Ww5si9y/KY7bX0J4lxTANQOt7JxrJ3z5yDvMLJdwnbsjIXAmMDmy6n6gNOqlzwE3R+a3x8zGm1lNjNogWUhBIOmqyMyaom63E/5SrY8MoK4jPG04wHeBb5nZG4A3jjXdBtxuZsuAWqB9oBc4594hPBPlNYQvplJvZg2EewcbIuvsAd6IHG76Pefc84R3Pb1lZmuAR/loUIgMiQ4fFYmRyFXPDjjnnJldA1zrnLt0oNeJJJvGCERiZz5wT+RIn72k2OU9RfqjHoGISJbTGIGISJZTEIiIZDkFgYhIllMQiIhkOQWBiEiWUxCIiGS5/wecQOM8/VnmIgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
05.8260656.01806000:47
15.0413475.65085000:44
24.6519174.83903400:47
34.0461784.60167800:53
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(4, 1e-2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's free up some RAM" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "del fr_vecs\n", "del en_vecs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As loss is not very interpretable, let's also look at the accuracy. Again, we will add padding so that the output and target are of the same length." ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "def seq2seq_acc(out, targ, pad_idx=1):\n", " bs,targ_len = targ.size()\n", " _,out_len,vs = out.size()\n", " if targ_len>out_len: out = F.pad(out, (0,0,0,targ_len-out_len,0,0), value=pad_idx)\n", " if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)\n", " out = out.argmax(2)\n", " return (out==targ).float().mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Bleu metric (see dedicated notebook)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In translation, the metric usually used is BLEU.\n", "\n", "A great post by Rachael Tatman: [Evaluating Text Output in NLP: BLEU at your own risk](https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "class NGram():\n", " def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n\n", " def __eq__(self, other):\n", " if len(self.ngram) != len(other.ngram): return False\n", " return np.all(np.array(self.ngram) == np.array(other.ngram))\n", " def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "def get_grams(x, n, max_n=5000):\n", " return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)]" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "def get_correct_ngrams(pred, targ, n, max_n=5000):\n", " pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n)\n", " pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams)\n", " return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams)" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "class CorpusBLEU(Callback):\n", " def __init__(self, vocab_sz):\n", " self.vocab_sz = vocab_sz\n", " self.name = 'bleu'\n", " \n", " def on_epoch_begin(self, **kwargs):\n", " self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4\n", " \n", " def on_batch_end(self, last_output, last_target, **kwargs):\n", " last_output = last_output.argmax(dim=-1)\n", " for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()):\n", " self.pred_len += len(pred)\n", " self.targ_len += len(targ)\n", " for i in range(4):\n", " c,t = get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)\n", " self.corrects[i] += c\n", " self.counts[i] += t\n", " \n", " def on_epoch_end(self, last_metrics, **kwargs):\n", " precs = [c/t for c,t in zip(self.corrects,self.counts)]\n", " len_penalty = exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1\n", " bleu = len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25)\n", " return add_metrics(last_metrics, bleu)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training with metrics" ] }, { "cell_type": "code", "execution_count": 64, "metadata": { "scrolled": true }, "outputs": [], "source": [ "learn = Learner(data, rnn, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))])" ] }, { "cell_type": "code", "execution_count": 65, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEKCAYAAAD9xUlFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xd8leX5x/HPlQWEkQAJM0AYQaasgAiKCoiKFtyiddZKq62rtfandrjaam0drQMRtdZZF4oTaQUZgpKAbGSPJEDCCiQh+/79cU5iCEkIcFaS7/v1yotznnGe6+Yk5zr3eO7bnHOIiIgAhAU7ABERCR1KCiIiUk5JQUREyikpiIhIOSUFEREpp6QgIiLllBRERKSckoKIiJRTUhARkXIRwQ7gWMXFxbnExMRghyEiUqekpqbuds7FH+24OpcUEhMTSUlJCXYYIiJ1ipltrc1xaj4SEZFySgoiIlJOSUFERMopKYiISDklBRERKefX0UdmtgU4CJQAxc655Er7fwz81vs0B7jZObfMnzGJiEj1AjEk9Szn3O5q9m0GznDO7TOz84CpwCkBiElERKoQ1OYj59zXzrl93qeLgIRgxiMiEqqe/O865q3P8vt1/J0UHPCFmaWa2eSjHHsj8FlVO8xsspmlmFlKVpb//1NEREKJc45/frmBRZv2+P1a/m4+GumcyzCzNsAsM1vrnJtb+SAzOwtPUjitqhdxzk3F07REcnKy82fAIiKhpqC4lJJSR9NG/m/x92tNwTmX4f03E5gODKt8jJmdDEwDJjrn/J8GRUTqmNyCYgCa1eWkYGZNzax52WNgHLCy0jGdgfeBa5xz6/wVi4hIXZZbUAJAdJT/k4I/r9AWmG5mZdd5wzn3uZn9HMA5NwX4A9AaeNZ73BHDVkVEGrqc8ppCuN+v5bek4JzbBAyoYvuUCo9/CvzUXzGIiNQHeYWepFDn+xREROTEldUUAtF8pKQgIhLiyvoU6nRHs4iI+EZuefOR//sUlBREREJc2ZDUpmo+EhGR8qSg5iMREcktLCEqPIyoCP9/ZCspiIiEuNyCYqID0J8ASgoiIiEvt6AkIP0JoKQgIhLycguKAzIcFZQURERCXm6hmo9ERMRLNQURESmnPgURESmXo9FHIiJSJq9QzUciIuKVW1ASkBlSQUlBRCSkFRaXUlhSGpAFdkBJQUQkpAVygR1QUhARCWk5AZwhFZQURERCWtkCO6opiIhIQBfYASUFEZGQFsi1FEBJQUQkpAVy1TVQUhARCWllfQq6eU1ERMr7FOrFNBdmtsXMVpjZd2aWUsV+M7N/mNkGM1tuZoP9GY+ISF0T6JpCIK5ylnNudzX7zgOSvD+nAM95/xURETx9CuFhRqMArM8MwW8+mgj823ksAmLNrH2QYxIRCRk5BcVER4VjZgG5nr+TggO+MLNUM5tcxf6OwPYKz9O820REhMDOkAr+bz4a6ZzLMLM2wCwzW+ucm1thf1Wpz1Xe4E0okwE6d+7sn0hFREJQbkFJwO5RAD/XFJxzGd5/M4HpwLBKh6QBnSo8TwAyqnidqc65ZOdccnx8vL/CFREJOTkFxTSNCszII/BjUjCzpmbWvOwxMA5YWemwGcC13lFIw4Fs59wOf8UkIlLX5BUWB7Sm4M8rtQWmeztHIoA3nHOfm9nPAZxzU4BPgfHABiAPuMGP8YiI1Dk5BSUktIwK2PX8lhScc5uAAVVsn1LhsQN+4a8YRETqutz60nwkIiInLtDNR0oKIiIhLKdASUFERIDiklLyi0oDNkMqKCmIiISsvKKyVdfUpyAi0uAFeoEdUFIQEQlZSgoiIlLuh2mz1XwkItLgldUUotXRLCIiuYWBXWAHlBREREKW+hRERKRcTllS0DQXIiKSV6iagoiIeOUUlGAGTSJVUxARafByC4qJjgwnLCww6zODkoKISMgK9AypoKQgIhKycgpKAjocFZQURERCVm5BMdEBvJsZlBREREKWZ9U11RRERATILSxW85GIiHjkFpQQraQgIiLgaT4K5AypoKQgIhKy1KcgIiIAlJY6cgvVfCQiIsChosAvsANKCiIiISkYC+xAAJKCmYWb2VIz+7iKfZ3NbLZ3/3IzG+/veERE6oKyabPr45DU24E11ez7HfC2c24QMAl4NgDxiIiEvDzvqmv1au4jM0sAzgemVXOIA1p4H8cAGf6MR0SkrgjGAjsA/k5BTwJ3A82r2X8/8IWZ3Qo0Bcb6OR4RkTohGAvsgB9rCmZ2AZDpnEut4bArgX855xKA8cCrZnZETGY22cxSzCwlKyvLTxGLiISOnIL613w0EphgZluAt4DRZvZapWNuBN4GcM4tBBoDcZVfyDk31TmX7JxLjo+P92PIIiKhoWz0UdP6MiTVOXePcy7BOZeIpxP5S+fc1ZUO2waMATCz3niSgqoCItLg/ZAU6k9NoUpm9qCZTfA+/TVwk5ktA94ErnfOuUDHJCISanLLmo8CfJ9CQK7mnJsDzPE+/kOF7avxNDOJiEgFuYXFNI4MIzyA6zOD7mgWEQlJnhlSA1tLACUFEZGQlFtQHPD+BFBSEBEJSTkFJQGf9wiUFEREQlJeYeAX2AElBRGRkKTmIxERKZcThFXXQElBRCQk5RWWBPxuZlBSEBEJSTlqPhIREQDnnKdPQc1HIiJSUFxKqQv8vEegpCAiEnJ+WIpTfQoiIg1enncyPN28JiIiPyzFqeYjERHJLQzOAjugpCAiEnKCtcAOKCmIiIScsgV2NHW2iIiU1xSio0K0+cjMuptZI+/jM83sNjOL9W9oIiINU1mfQijXFN4DSsysB/Ai0BV4w29RiYg0YD/UFEI3KZQ654qBi4AnnXN3Au39F5aISMOVU1BCVHgYURGBb+Gv7RWLzOxK4DrgY++2SP+EJCLSsGUfKqR548DXEqD2SeEG4FTgT865zWbWFXjNf2GJiDRcq3ccJKlts6Bcu1ZJwTm32jl3m3PuTTNrCTR3zj3i59hERBqcopJS1uw4QP+OMUG5fm1HH80xsxZm1gpYBrxsZo/7NzQRkYZn/a4cCotL6RfKSQGIcc4dAC4GXnbODQHG1uZEMws3s6Vm9nE1+y83s9VmtsrMNKJJRBq0lenZAEGrKdS2JyPCzNoDlwP3HeM1bgfWAC0q7zCzJOAeYKRzbp+ZtTnG1xYRqVdWpGfTrFEEia2bBuX6ta0pPAjMBDY65xabWTdg/dFOMrME4HxgWjWH3AQ845zbB+Ccy6xlPCIi9dKK9Gz6dmhBWJgF5fq17Wh+xzl3snPuZu/zTc65S2px6pPA3UBpNft7Aj3NbIGZLTKzc2sVtYhIPVQc5E5mqH1Hc4KZTTezTDPbZWbveWsBNZ1zAZDpnEut4bAIIAk4E7gSmFbV9BlmNtnMUswsJSsrqzYhi4jUOeszcygoLqV/QognBeBlYAbQAegIfOTdVpORwAQz2wK8BYw2s8r3NqQBHzrnipxzm4Hv8SSJwzjnpjrnkp1zyfHx8bUMWUSkblnh7WQO1sgjqH1SiHfOveycK/b+/Auo8dPZOXePcy7BOZcITAK+dM5dXemwD4CzAMwsDk9z0qZjKYCISH2x0tvJ3DVIncxQ+6Sw28yu9g4vDTezq4E9x3NBM3vQzCZ4n84E9pjZamA28Bvn3HG9rohIXbciPZs+QexkhtoPSf0J8DTwBOCAr/FMfVErzrk5wBzv4z9U2O6AX3l/REQarLJO5h+f0iWocdR29NE259wE51y8c66Nc+5CPDeyiYiID2zIyiG/qDSoI4/gxFZe07d7EREfWZEW/E5mOLGkELxGLxGRemZlejZNo8LpFhe8TmY4saTgfBaFiEgD57mTOSaoncxwlI5mMztI1R/+BjTxS0QiIg1McUkpq3cc4Kphwe1khqMkBedc80AFIiLSUG3MyvV0MiccMW9owAV+AVARETnMiiBPl12RkoKISJCtTM8mOiqcrnHBWYKzIiUFEZEgK5suOzzIncygpCAiElQlpY7VGQeCfn9CGSUFEZEgWp95kENFJSHRnwBKCiIiQbVgg2cO0GFdWwU5Eg8lBRGRIJq/PouucU1JaBkd7FAAJQURkaApLC7lm817Oa1HXLBDKaekICISJEu27SOvsITTkpQUREQavPnrdxMeZpzavXWwQymnpCAiEiTzNuxmQEIMLRpHBjuUckoKIiJBkJ1XxIq0/ZyWVONy9wGnpCAiEgRfb9xNqYPTQ6g/AZQURESCYt6G3TRrFMHATrHBDuUwSgoiIkEwb30Ww7u1IjI8tD6GQysaEZEGYOueXLbvPcTpIdafAEoKIiIBN2/9boCQuj+hjJKCiEiAzV+/mw4xjekW1zTYoRxBSUFEJIBKSh1fb9zNaUlxmAV//YTK/J4UzCzczJaa2cc1HHOpmTkzS/Z3PCIiwbQ8bT8H8otD7v6EMoGoKdwOrKlup5k1B24DvglALCIiQTXf258wMoSmtqjIr0nBzBKA84FpNRz2EPBXIN+fsYiIhIL5G3bTt0MLWjdrFOxQquTvmsKTwN1AaVU7zWwQ0Mk5V23Tkq/kF5Xwdsp2nHP+vpSISJVKSx0r07NJ7tIy2KFUy29JwcwuADKdc6nV7A8DngB+XYvXmmxmKWaWkpWVdVzxzPgug7vfXc6nK3Ye1/kiIidqy55ccgtL6NshNJberIo/awojgQlmtgV4CxhtZq9V2N8c6AfM8R4zHJhRVWezc26qcy7ZOZccH398nTOXDEmgd/sW/OmT1eQVFh/Xa4iInIiVGQcA6NOhRZAjqZ7fkoJz7h7nXIJzLhGYBHzpnLu6wv5s51yccy7Re8wiYIJzLsUf8YSHGff/qA8Z2flMmbPRH5cQEanRqoxsIsONnm2bBzuUagX8PgUze9DMJgT6ugCndGvNhAEdmDJ3E9v25AUjBBFpwFZnHKBn2+ZERYTuLWIBicw5N8c5d4H38R+cczOqOOZMf9USKrp3fG8iwoyHPlnt70uJiJRzztPJ3C+E+xOgAd7R3C6mMb8c3YNZq3fx1brj67QWETlWO7Lz2ZdXRN+OodufAA0wKQDceFpXusY15YEZqygsrnK0rIiIT63ydjL3DeFOZmigSaFRRDh/uKAPm3bn8vKCzcEOR0QagJXp2ZhB7/ZKCiHprF5tGNG9NW98uy3YoYhIA7Aq4wDd4poSHRUR7FBq1GCTAsCZJ8WzdU8eWQcLgh2KiNRzqzKyQ/qmtTINOikM6dIKgNSte4MciYjUZ3tzC9mRnU+/EO9khgaeFPp1bEFURBiLt+wLdigiUo+tysgGUE0h1DWKCGdgQiwpW5UURMR/VqbXjZFH0MCTAsCQxJasSs/mUGFJsEMRkXpqVUY2HWObEBsdFexQjqrBJ4WhiS0pLnV8t31/sEMRkXpqdcaBOlFLACUFBnf2zGteXWfzyvRsHv/ie0pLtQ6DiBy7nIJiNu3OpV/H0O9PAAjtAbMBEBsdRc+2zartbP7bF98z5/ssOrduyqVDEgIcnYjUdWt21J3+BFBNAfAMTV2ybd8RtYEd2YeYuy6LiDDjL5+uIftQUZAiFJG6alW6Z+RRXakpKCkAyV1acjC/mHWZBw/b/v6SdEod/OPKQezLK+SJWeuCFKGI1FUrMw4Q1yyKNs1Dc03mypQUgKGJnpvYKjYhOed4O2U7w7u1Ynz/9lw9vAv/XrilfLyxiEhtrMo4QJ8OMZhZsEOpFSUFoFOrJsQ3b0Tqlh86m7/ZvJete/K4PLkTAL8++yRaRkfxxw9XqdNZRGqloLiE9bsO0q+O9CeAkgIAZsbQxJaH3cT2dsp2mjWK4Lx+7QGIiY7kt+f1ImXrPt5fmh6sUEWkDlm3M4fiUlcn7mQuo6TgNaRLK9L2HWJndj4H84v4dMUOfjSgA02iwsuPuXRwAoM7x6rTWURqZbG39aGujDwCJYVyQxM99yukbN3Lx8t3kF9UyuXJhw9BDQszHpzYj315hTz88WqcUzOSeJSWOqYvTeOT5TtYnraffbmF+v2oR4pLSvlsxQ6KS2q/KFdhcSkvzt/MwE6xdGkd7cfofKvB36dQpnf7FjSJDCdlyz6Wpe0nqU0zBnaKPeK4fh1juOXMHjw9ewMnd4rlmuFdghCthJq3Fm/n3ukrDtvWrFEEJyfE8ODEvvRo0zxIkYkvfPhdBr9+Zxn3/6gP14/sWqtz3k1NI33/If50Ub8608kMqimUiwwPY2CnWD5evoOl2/ZzxdBO1b6Rd57dk9G92vDAjFV8s2lPgCOVUJN1sIBHPlvD8G6t+OS203j+miH87vzeXDokgTU7DjD+H/OZNm+TBijUYe+mpgHwzJyN5BcdfZ60wuJSnpm9gQGdYjmjZ7y/w/MpJYUKhia2ZHdOARFhxoWDOlZ7XHiY8eSkgXRuFc0try8hff+hAEYpoebPn67hUFEJD1/Yn74dYjinbzt+eno37p/Ql5l3jmJUUhwPf7KGq6YtYvvevGCHK8do+948Fm7awxk948k6WMBri7Ye9Zz3lnhqCXeMTapTtQRQUjjMEO/9CmN6tyGuWc03mrRoHMnUa5MpLC7lZ6+m1Orbg9Q/CzbsZvrSdG4+ozs92jQ7Yn+b5o154dpk/nrJyaxMP8B5T83jzW+3Ba2/IWXLXmYsy1B/RwXb9+bx+jdbq/0/me4dbfjwhf0Y2aM1U77aSF5hcbWvV1hcytNfemoJZ9axWgIoKRxmaGJLTu3Wmsmjutfq+B5tmvHkpIGsyjjA/723XH9oDUx+UQm/+2AlXVpHc8tZPao9zsy4fGgnPrv9dPp3jOGe91dw1QvfsHVPbsBidc7x3JyNXP78Qm57cynXvvStarh4/l9uf2sp901fyazVu6rc/25qGqd2a02nVtHcObYnu3MKeXVh9bWF9+twLQGUFA4THRXBm5OHM6RLy1qfM6Z3W341ticffJfBf9dk+jE68becgmKem7ORUX+dzR1vLWV/XmGNx0/5aiObd+fy0MR+NI4Mr/FYgE6tonn9p6fw54v6syI9m3OenMu0eZso8XNfw4H8In72aiqPfr6W8/q354EJfUnduo9znpgb1FpLKPho+Q6WbNtPk8hwHv187RGjixZv2ce2vXnlk2EmJ7ZiVM94np+7idyCI2sLhcWlPD17AwMSYupkLQECkBTMLNzMlprZx1Xs+5WZrTaz5Wb2PzOrk0N5bj6zOx1jmzB17sZghyLHITuviKf+u56Rj3zJo5+vJb55Iz5evoNxT8xl9vdVJ/rNu3N5dvZGfjSgA6OO4Y8/LMy46pTOzPrVKEZ09/Q1XPzc16xI88/0KWt3HmDi0wv4cm0mv7+gD09fOYjrRiQy845R5bWWa1/6lqXb9jW45JBfVMKjn62lT/sWPH75ADZm5fKOt0O5zHupaURHhXNuv3bl2+4cm8Te3EJeWbjliNd8f0kaafsOccfYnnWylgCBqSncDqypZt9SINk5dzLwLvDXAMTjcxHhYfz09K4s3rKPVC3tWad8vnInIx/9kif+u45hXVvx4S9G8t7NI/jgFyOJjY7khpcXc+/0FeQWFLN1Ty5vfbuNO95aymVTvqZRZBi/v6D3cV23fUwTXrwumacmDSRtbx4TnpnP3e8uI/Ngvs/KtnjLXi565mtyCop546bh3Hha1/IPqrJay0MTPbWGi579mhGPfMn9M1axaNMev9deQsGL8zeTvv8Qv7+gD+f2a8eQLi15Yta68v6CvMJiPlmxg/H929O00Q+j9wd1bsnoXm2YOncTB/M9N7Fu35vHtHmb+PusdZ5awkl1s5YAYP78dmBmCcArwJ+AXznnLqjh2EHA0865kTW9ZnJysktJSfFtoD6QW1DMiEe+ZHi3Vjx/TfJxv85HyzJoGR3FaUlxPoxOquKcY+zjXxFmxlOTBtGn0l2n+UUlPD5rHS/M20RkeBiFxZ6mhbhmjRjerRXXDO/CKd1an3AcB/KLePrLDby8YDONIsK5dXQPrh+ZSKOIozdJVWf73jwmPrOA2CaRvDV5OG1aNK722Oy8Iv63dhefrdzJ3HVZFBSX0rlVNB/8YiStmob+8pHHI/NAPmf+bQ6nJ8WV/72mbNnLpVMWcte4nvxydBLTl6Zx53+W8dbk4Qyv9D6vSMvmR0/P54ye8WQeLChfM6F3+xY8ekl/Tk448h6nYDOzVOfcUT+c/H3z2pPA3UBt7ty5Efisqh1mNhmYDNC5c2efBedLTRtFcM3wLjwzZwObsnLoFn/kSJSjKSop5d73VxAZEcbsu84kpkmkHyKVMsvSstmYlcsjF/c/IiEANI4M597xvRnbuy3Tl6bTp0MLTu3Wiu7xzXzaNNCicST3ju/NpKGd+NMna/jLZ2v568zviWsWRVyzRsQ3b0Tb5o25ZEgCw7q2OurrHcwv4sZXFlNS6njx+qE1JgTwzOt18eAELh6cQG5BMbNW7+LX7yzjqf+u44GJ/XxVzJDyty++p6iklHvO+6Gml5zYirP7tGXKV5u4clhn3ktNJ6FlE4YlHvl/3j8hhvP6tePzVTsZ0rklvzu/N+P6tKNzHbpzuTp+SwpmdgGQ6ZxLNbMzj3Ls1UAycEZV+51zU4Gp4Kkp+DhUn7luRCJT523ihXmb+cvF/Y/5/GXb93OwoBgK4NnZG7hn/PE1TUjtvJeaRqOIMMaf3L7G44Z1bVWrD+MT1S2+GS9eP5T563ezcNNusg4WkHWwgN05hSzZuo//pGxnXJ+2/Pa8XnSv5ktHSanjtjeXsjErl1d/MoyucU2PKYamjSK4cFBHUrbu5bVvtnHNqYlVDrWty1amZ/NOaho3nd6NxEr/P7899yTGPTGX332wkgUbd3Pb6CTCwqr+AvDkpIEcKiwhNrp+1ab8WVMYCUwws/FAY6CFmb3mnLu64kFmNha4DzjDOVfgx3j8Lr55Iy4ZnMB7S9L41dk9iT/GRTXmrssizODsPm15ecEWfnxKl3rxzSMUFRSXMGNZBuf0bUeLxqFVIzstKe6I5sNDhSW8tGAzz87ewLgn5vLjUzpz+5gkWle6n+aRz9Yw+/ssHr6wHyN6HH8T5B1je/LB0gwe+WwN064betyvE2qcczz08WpaRkfxy9FHDiPu0aY5VwztxJvfbgfgksHVL8HbKCL8hJr4QpXfOpqdc/c45xKcc4nAJODLKhLCIOB5YIJzrl6M57zp9K4UlZTyytdbjvncr9bvZmCnWB6Y0I/wMOPRz9f6PsAQcqiwhEc/X8vUuRs5VBjYm//+tyaT7ENFdWbd7SZR4fzirB7M+c1ZXDmsE69/s43kP/2XkY98yVUvLOKe91fw+w9W8sK8zVw/IpGrT3BOrrhmjfjFWT3475pMvt6w20elCL7nvtrIN5v38utxPav9MnDH2J40jgxjWNdWDfJLWcDvUzCzB81sgvfpY0Az4B0z+87MZgQ6Hl/rFt+McX3a8uqirVWOY67O/rxClqft5/SkeNrFNOZnZ3TjkxU7SKmw8E99snbnASY8PZ/n5mzkz5+u5YzHZvPqoq3lnbn+9l5qGu1aNGbkCXybDob45o14+ML+zLxjFLePSWJY11YcKiph5qqdvLpoK2N6teF35/um2fGGkYl0jG3Cw5+sOebRSPtyC0ndupf3l6SxMj07JIa7zv4+k8dmfs8FJ7fnqmHV9022bdGYN28azt8vGxDA6EJHQGZJdc7NAeZ4H/+hwvaxgbh+oE0e1Z2Zq3bxdsp2bqjljIrzN+zGOcrHvE8e1Y03v93GQ5+sYfrNI8rbNbPzinhxwWbyCoq5d3zvats7Q5Vzjte+2cbDH6+mRZNIXr1xGI0iwnls5lp+/8FKps7dyK1nJdGrfXNaRkcRGx1Js0YRx9WxO2v1LppGhR/RjJJ1sIA567KYPKob4XXs/69MjzbNuGNsz8O25RYUEx0V7rNO8MaR4fzfeb249c2lvLckrXwVwspKSx1Lt+/ni9U7Sdmyj01ZOezLO3y9kR5tmnHRoI5MHNiBhJaB//a9KSuH295cSq92LfjrpScf9f9oUOfa38Ba32jqbD8Y0qUlQxNb8uycjVw4sCMtazGsb9663TRvHMGABM8KTdFREfzmnF7c9c4yPlqewVm92vDS/M28OG+zpzMaiGkSya1jkvxaFl/KPlTE3e8uY+aqXZzRM56/Xz6gfI6pt392KnPWZfHY599z93vLDzsvMtw4rUccU64ZUus23JyCYm57cynFpaW8csOwwxLDh9+lU1LqamwvrosqjqX3lQtObs9LCzbzt5nfc753vH5+UQk7s/PZvDuX/67ZxazVu8g86JlIclDnWM7t155ucU3pFt+UTq2iWbxlLx8sTeexmd/z2MzvGZbYivH923Fuv/a0i6l5ZJQvHMwvYvKrqUSEGVOvGUJ0lD72auLX+xT8IVTvU6hsZXo2Fz27gDG92vLc1YNr/GbinGPEI18ysFMsz109pHx7aaljwjPz2bE/n+JSR/ahIs7p25Y7xvZkylcbmbEsg1duGHZMd9QGy56cAq5+8Vs2ZB7k7nN6ceNpXaus5ZSWOpanZ5N1sIB9eYXszyskfd8hXlm4lSuHda71qK7/LN7Gb99bQdsWjcgrKOHtn59K7/YtcM5x3lPzaBQZzoe/qPGWGPFK3bqPS577mk6tmpBXUMKe3B+m/4iOCufMk+I5p287zjypTY3DqLfvzeODpenMWJbB+swcAAZ3juW8fu2ZOKgDbZr7PkGUljp+9loqX67N5NUbhzGie91qLvSlULlPocHq1zGGX487iUc+W8u7qWlcVk3VG2BDZg47svO5bczhH+5hYcbvz+/Dj6d9wxk947nz7J706+ipSfzl4v6s3XGQ299ayke3nhaUKnltZR0s4MfTFrF1Tx4vXje0xiQWFmZVLm4U3SiC5+Zs5OSEGK6soT24zNspafRo04xXfjKMS579mhteXsz7t4xgb24ha3ce5KEL6+f4e38Y0qUlt43uwbK0bDrENqZDTBPaxzahY2wTBnWOrdW8T+C5i/rWMUncOiaJDZk5fL5yB5+u2MmfPl3DlK828vINQ6u96Wv22kyWp2Xzy9E9at3kV1rq+Mtna5i1ehd//FGfBp0QjoVqCn5UUuq46oVFrEzP5tPbT6dL66qOOy5mAAAQ6klEQVTHjL84fzMPfbya+b89q8oP9/yikir/8DZl5TDx6QV0i2/K2z8/NSSHx+06kM9VLywiY38+L16ffNx/mCWljutf/pZvNu3lrZ8NZ3ANbb4bMnMY+/hX3Du+F5NHdWftzgNc9txC2sc25uSEWGZ8l8G3942pd+PL66o1Ow7w01dS2JdXyJSrhxz2paG4pJTHZ63j2TmeecVuOr0r953f56iveaiwhLveWcYnK3ZwzfAuPDixb52di8hXaltT0CypfhQeZjx+xUDCwow7//Ndteu7zl2XRbf4ptV+26/um1i3+Gb87fIBLEvL5oGPVvssbl/J2H+IK55fyM7sfF75yYlV3cPDjH9eOYi2MY24+bXUGucIeid1O+FhxkWDPH0Gvdq14Plrh7B5dy7vpqZxdp+2SgghpHf7Frx/ywg6t4rmJ/9azIffedYv2J1TwLUvfcuzczZy5bBOXD28My/M28wb32yr8fV2Zudz+fML+XTlDu4b31sJ4RgpKfhZx9gmPHxhP5Zs21/+baei/KISvtm8h1FJx9cvcE7fdvz8jO688c02Hp+1jqJjWFjcnw4VljBp6iL25BTy7xtP8ckdwbHRUTx/dTLZh4r4xetLqhy+WlxSyvtL0jnrpDaH3Tw4onscf7tsAI0jw/jx8NCcKqUha9uiMW///FSSE1ty+1vf8fDHqzn/H/NI3bqPxy49mb9cfDL3/6gvZ54Uz+8/XMn89VXfO7EiLZuJz8xnU1YO065N5qZR3ZQQjpGSQgBMHOgZivfU/9YfcSNQypZ95BeVMqrn8X+LvmtcTy4a1JF//G89Fz6zgFUZ/pmG+Vh8tS6TbXvzeHLSwGNan+Jo+nRowaOXnMziLft48ONVVVw3i6yDBVyefOTIookDO7L8j+eobTlEtWgcyb9uGMb4/u2YNn8zjSPDef+WEeX9cRHhYfzzykH0iG/Gza+nsiHzIODpO1i8ZS/3TV/BZc9/TURYGO/ePIIxvdsGszh1ljqaA+TBif1YmZ7NdS9/ywMT+nHVKZ5vq3PXZxEZbkfMwngsIsLDeOKKgZzTtx2/+2AlE59ewC1ndueXo5OIighO3p+5ahex0ZF+WbR84sCOrM44wPNzN9GrXYvD7t59O2U7cc2iOKtXmyrPDdb/h9RO48hw/nnlYCYM2MWp3VsfMZqpeeNIXrw+mQuf+Zob/rWYCQM68OF3GaTtO0TjyDDO69eee8f3PuYpZuQH+gsJkJgmkbx/y0hGdI/j3umeKQmKSkqZuy6L5C6tfDJ2+tx+7Zh15ygmDOjAP77cwMRnFpTP9x5IRSWl/G/NLsb0aktEuH9+xe4+txdnnRTP/TNWsXDjHsDTBv2/NZlcPDiBSD9dV/wvPMw4t1+7aoe3JrSMZtp1yWQeKOC5ORvpFt+Mxy8fQMrvzuaJKwYqIZwg/eUEUEyTSF66fig/G9WNVxdt5YrnF7J250Gf3mfQsmkUj18xkOd+PJg1Ow7w0vwtPnvt2lq0aQ8H8os5p6//qu/hYcZTVw4iMa4pt7yeyrY9njHwxaWOy+rIfEZy/AZ2iuWLO0ex6N4x/Psnw7h4cALN/HDzXkOkpBBg4WHGPeN788QVA1iZ4VmY43Q/LKhzXv/2nNO3LS/M28S+3JrXGj5W6fsP8dL8zdXOUzRz1U6aRIb7/aa6Fo0jmXZtMqUObvp3Cm9+u41BnWNJalub5TukruvSuqlfbnhr6JQUguSiQQm89/MR/O783vStYoEXX/j1uJPILSxmig/Xjv5i1U7GPzWPBz9ezX9Sth+xv7TUMWu1ZxqL2t7UdCIS45ryzFWD2ZCVw8asXC4bUv1NgiJydEoKQdQ/IYafnu6/IXM92zbnwoEdeeXrLWQeOLG1fwuKS3jgo1VMfjWVzq2i6dexBc/N3nBEbWFZ2n52HShgnB+bjio7LSmOhyb2o1/HFlwwoOYFc0SkZkoK9dwdY5MoLnE8PXvDcb/G1j25XPrcQl5esIUbRiby7s2n8ptzepGRnc97S9IOO3bmql1EhBljegV2OOBVp3Tm41tPD7kFc0TqGiWFeq5L66ZcPrQTb367je178475/NyCYi6dspCte3J5/poh/PFHfWkUEc6opDgGdorlmdkbym+Yc87xxaqdDO/WmphofTiL1EVKCg3AbaOTMDOe+t/6Yz73tUVbyTpYwMs3DOWcvu3Kt5sZt49JIm3fIaYv8UxLsCEzh027c/066khE/EtJoQFoF9OY607twvtL0tiQeZCcgmLS9x9idcaBGlfFyissZurcTZyeFMeQLkdOU3HmSfGcnBDD097awherdwFwdp92RxwrInWDBvY2EDef2YM3vtnG2MfnHrHv1tE9+PW4k47Y/tqirezJLeSOsVUv5GNm3DY6iZ/+O4UPv8tg5qqdDOgUG5CFU0TEP5QUGohWTaP451WDSNmyj5gmkcQ0iSQ2OpJPV+zk6dkbGNkj7rCpNg4VljB17iZO61F1LaHMmN5t6NuhBX+b+T07D+Rz97lHJhcRqTuUFBqQ0b3aMrrSqKDTk+JZkZ7Nnf/5js9vH1XeQfz6N1vZnVPI7dXUEsqYGbeNSeJnr6YCHNbvICJ1j/oUGrimjSL4x6RB7M4p4J7py3HOcaiwhClfbWJkj9YMTTz6lNdn925L7/YtOKltc7rHNwtA1CLiL6opCP0TYrhr3En85bO1vJ2ynZyCEnbnFPDsmMG1Oj8szPj3T4ZRUlq3VvETkSMpKQgAN53ejbnrs7h/xmqio8I5tVvrY1oYRzNTitQPaj4SwPNt/++XDaRxZBh7co/elyAi9ZPfawpmFg6kAOnOuQsq7WsE/BsYAuwBrnDObfF3TFK1djGNeeHaZFK37juhRX9EpO4KRPPR7cAaoKqpQG8E9jnnepjZJOBR4IoAxCTVSE5sRXItOpdFpH7ya/ORmSUA5wPTqjlkIvCK9/G7wBjTKtsiIkHj7z6FJ4G7gapXY4GOwHYA51wxkA2o3UJEJEj8lhTM7AIg0zmXWtNhVWw7YlyjmU02sxQzS8nKyvJZjCIicjh/1hRGAhPMbAvwFjDazF6rdEwa0AnAzCKAGGBv5Rdyzk11ziU755Lj4/27xKOISEPmt6TgnLvHOZfgnEsEJgFfOueurnTYDOA67+NLvcfoDigRkSAJ+M1rZvYgkOKcmwG8CLxqZhvw1BAmBToeERH5QUCSgnNuDjDH+/gPFbbnA5cFIgYRETk63dEsIiLlrK414ZtZFrC1wqYYPENZK6u8vabn1T2OA3afYMjVxXesx1W1vzbb6lI5ffFeVn5e9tgXZawpxmM5Tr+zNW+rS+WsS7+zXZxzRx+p45yr0z/A1Npsr+l5DY9T/BXfsR5X1f7abKtL5fTFe1ldOX1RxlAqZ6i/l9Xt1+9s6P7Olv3Uh+ajj2q5vabn1T32hdq+3tGOq2p/bbbVpXL64r2s/Ly+ljPUy1jdfv3OHv15sMoJ1MHmo0AysxTnXHKw4/C3hlDOhlBGUDnrk2CVsT7UFPxparADCJCGUM6GUEZQOeuToJRRNQURESmnmoKIiJRrMEnBzF4ys0wzW3kc5w4xsxVmtsHM/lFxem8zu9XMvjezVWb2V99Gfcxx+ryMZna/maWb2Xfen/G+j/yYY/XLe+ndf5eZOTOL813Ex8dP7+dDZrbc+15+YWYdfB/5McXpjzI+ZmZrveWcbmaxvo/8mGP1Rzkv837ulJqZ7/oefDHkqS78AKOAwcDK4zj3W+BUPLO6fgac591+FvBfoJH3eZt6WMb7gbuC/f75u5zefZ2AmXjug4mrj+UEWlQ45jZgSj0s4zggwvv4UeDRevpe9gZOwjNbRLKvYm0wNQXn3FwqzcBqZt3N7HMzSzWzeWbWq/J5ZtYezx/SQud5J/4NXOjdfTPwiHOuwHuNTP+WomZ+KmPI8WM5n8Cz/kdIdLT5o5zOuQMVDm1KkMvqpzJ+4TzrswAsAhL8W4qj81M51zjnvvd1rA0mKVRjKnCrc24IcBfwbBXHdMQzxXeZNO82gJ7A6Wb2jZl9ZWZD/Rrt8TnRMgL80lsVf8nMWvov1BNyQuU0swl41hFf5u9AT9AJv59m9icz2w78GPgDoccXv7NlfoLn23Uo8mU5fSbgs6SGCjNrBowA3qnQrNyoqkOr2Fb27SoCaAkMB4YCb5tZN29GDzoflfE54CHv84eAv+P5QwsZJ1pOM4sG7sPT7BCyfPR+4py7D7jPzO4Bfgn80cehHjdfldH7WvcBxcDrvozRF3xZTl9rsEkBTy1pv3NuYMWNZhYOlK0WNwPPh2LF6mcCkOF9nAa8700C35pZKZ75SkJlebgTLqNzbleF814APvZnwMfpRMvZHegKLPP+gSYAS8xsmHNup59jPxa++J2t6A3gE0IoKeCjMprZdcAFwJhQ+ZJWia/fS98JdgdMIH+ARCp09ABfA5d5HxswoJrzFuOpDZR19Iz3bv858KD3cU88601bPStj+wrH3Am8Fez30R/lrHTMFkKgo9lP72dShWNuBd6th2U8F1gNxAe7bP4sZ4X9c/BhR3PQ/6MC+Ia8CewAivB8w78Rz7fDz4Fl3l+iP1RzbjKwEtgIPF32wQ9EAa959y0BRtfDMr4KrACW4/nm0j5Q5QlkOSsdExJJwU/v53ve7cvxzInTsR6WcQOeL2jfeX+COsLKj+W8yPtaBcAuYKYvYtUdzSIiUq6hjz4SEZEKlBRERKSckoKIiJRTUhARkXJKCiIiUk5JQeoFM8sJ8PWmmVkfH71WiXfW0pVm9tHRZvU0s1gzu8UX1xapTENSpV4wsxznXDMfvl6E+2FSNb+qGLuZvQKsc879qYbjE4GPnXP9AhGfNCyqKUi9ZWbxZvaemS32/oz0bh9mZl+b2VLvvyd5t19vZu+Y2UfAF2Z2ppnNMbN3vfPzv15hLvs5ZXPYm1mOd5K5ZWa2yMzaerd39z5fbGYP1rI2s5AfJulrZmb/M7Ml5plPf6L3mEeA7t7axWPeY3/jvc5yM3vAh/+N0sAoKUh99hTwhHNuKHAJMM27fS0wyjk3CM8soX+ucM6pwHXOudHe54OAO4A+QDdgZBXXaQoscs4NAOYCN1W4/lPe6x91vhrvvDdj8Nw5DpAPXOScG4xn7Y6/e5PS/wEbnXMDnXO/MbNxQBIwDBgIDDGzUUe7nkhVGvKEeFL/jQX6VJiFsoWZNQdigFfMLAnPjJORFc6Z5ZyrOO/9t865NAAz+w7P/DXzK12nkB8mCkwFzvY+PpUf1mt4A/hbNXE2qfDaqcAs73YD/uz9gC/FU4NoW8X547w/S73Pm+FJEnOruZ5ItZQUpD4LA051zh2quNHM/gnMds5d5G2fn1Nhd26l1yio8LiEqv9mitwPnXPVHVOTQ865gWYWgye5/AL4B571DuKBIc65IjPbAjSu4nwD/uKce/4YrytyBDUfSX32BZ71AgAws7JpimOAdO/j6/14/UV4mq0AJh3tYOdcNp4lMu8ys0g8cWZ6E8JZQBfvoQeB5hVOnQn8xDtHP2bW0cza+KgM0sAoKUh9EW1maRV+foXnAzbZ2/m6Gs9U5wB/Bf5iZguAcD/GdAfwKzP7FmgPZB/tBOfcUjyzZk7CszhMspml4Kk1rPUeswdY4B3C+phz7gs8zVMLzWwF8C6HJw2RWtOQVBE/8a7odsg558xsEnClc27i0c4TCSb1KYj4zxDgae+Iof2E2DKmIlVRTUFERMqpT0FERMopKYiISDklBRERKaekICIi5ZQURESknJKCiIiU+382/R8NsQj4IgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossseq2seq_accbleutime
04.0048955.1463600.2975410.23381001:02
14.2659524.8972650.3215180.26921901:03
23.9710664.4025040.3662610.27748601:05
33.2401234.2911710.3789030.28652401:06
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(4, 1e-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "learn.fit_one_cycle(4, 1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So how good is our model? Let's see a few predictions." ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "def get_predictions(learn, ds_type=DatasetType.Valid):\n", " learn.model.eval()\n", " inputs, targets, outputs = [],[],[]\n", " with torch.no_grad():\n", " for xb,yb in progress_bar(learn.dl(ds_type)):\n", " out = learn.model(xb)\n", " for x,y,z in zip(xb,yb,out):\n", " inputs.append(learn.data.train_ds.x.reconstruct(x))\n", " targets.append(learn.data.train_ds.y.reconstruct(y))\n", " outputs.append(learn.data.train_ds.y.reconstruct(z.argmax(1)))\n", " return inputs, targets, outputs" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " 100.00% [151/151 00:24<00:00]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "inputs, targets, outputs = get_predictions(learn)" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Text xxbos quels sont les résultats prévus à court et à long termes de xxunk , et dans quelle mesure ont - ils été obtenus ?,\n", " Text xxbos what are the short and long - term expected outcomes of the ali and to what extent have they been achieved ?,\n", " Text xxbos what were the results , the , , , , , , and and and and and)" ] }, "execution_count": 89, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs[700], targets[700], outputs[700]" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Text xxbos de quel(s ) xxunk ) a - t - on besoin pour xxunk les profits réels de la compagnie pour l'année qui vient ?,\n", " Text xxbos which of the following additional information is necessary to estimate the company 's actual profit for the coming year ?,\n", " Text xxbos what is the the to to to the the ( ( ) ))" ] }, "execution_count": 90, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs[701], targets[701], outputs[701]" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Text xxbos de quelles façons l'expérience et les capacités particulières des agences d'exécution contribuent - elles au projet ?,\n", " Text xxbos what experience and specific capacities do the implementing organizations bring to the project ?,\n", " Text xxbos what are the key and and and and and and of of of of of of ?)" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs[2513], targets[2513], outputs[2513]" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Text xxbos qu'est - ce que la maladie de xxunk - xxunk ( mcj ) ?,\n", " Text xxbos what is xxunk - xxunk disease ( cjd ) ?,\n", " Text xxbos what is the xxunk ( ( ) ))" ] }, "execution_count": 92, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs[4000], targets[4000], outputs[4000]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's usually beginning well, but falls into repeated words at the end of the question." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Teacher forcing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One way to help training is to help the decoder by feeding it the real targets instead of its predictions (if it starts with wrong words, it's very unlikely to give us the right translation). We do that all the time at the beginning, then progressively reduce the amount of teacher forcing." ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [], "source": [ "class TeacherForcing(LearnerCallback):\n", " \n", " def __init__(self, learn, end_epoch):\n", " super().__init__(learn)\n", " self.end_epoch = end_epoch\n", " \n", " def on_batch_begin(self, last_input, last_target, train, **kwargs):\n", " if train: return {'last_input': [last_input, last_target]}\n", " \n", " def on_epoch_begin(self, epoch, **kwargs):\n", " self.learn.model.pr_force = 1 - epoch/self.end_epoch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will add the following code to our `forward` method:\n", "\n", "```\n", " if (targ is not None) and (random.random()=targ.shape[1]: break\n", " dec_inp = targ[:,i]\n", "```\n", "Additionally, `forward` will take an additional argument of `target`." ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqRNN_tf(nn.Module):\n", " def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1):\n", " super().__init__()\n", " self.nl,self.nh,self.out_sl = nl,nh,out_sl\n", " self.bos_idx,self.pad_idx = bos_idx,pad_idx\n", " self.em_sz_enc = emb_enc.embedding_dim\n", " self.em_sz_dec = emb_dec.embedding_dim\n", " self.voc_sz_dec = emb_dec.num_embeddings\n", " \n", " self.emb_enc = emb_enc\n", " self.emb_enc_drop = nn.Dropout(0.15)\n", " self.gru_enc = nn.GRU(self.em_sz_enc, nh, num_layers=nl,\n", " dropout=0.25, batch_first=True)\n", " self.out_enc = nn.Linear(nh, self.em_sz_dec, bias=False)\n", " \n", " self.emb_dec = emb_dec\n", " self.gru_dec = nn.GRU(self.em_sz_dec, self.em_sz_dec, num_layers=nl,\n", " dropout=0.1, batch_first=True)\n", " self.out_drop = nn.Dropout(0.35)\n", " self.out = nn.Linear(self.em_sz_dec, self.voc_sz_dec)\n", " self.out.weight.data = self.emb_dec.weight.data\n", " self.pr_force = 0.\n", " \n", " def encoder(self, bs, inp):\n", " h = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " _, h = self.gru_enc(emb, h)\n", " h = self.out_enc(h)\n", " return h\n", " \n", " def decoder(self, dec_inp, h):\n", " emb = self.emb_dec(dec_inp).unsqueeze(1)\n", " outp, h = self.gru_dec(emb, h)\n", " outp = self.out(self.out_drop(outp[:,0]))\n", " return h, outp\n", " \n", " def forward(self, inp, targ=None):\n", " bs, sl = inp.size()\n", " h = self.encoder(bs, inp)\n", " dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n", " \n", " res = []\n", " for i in range(self.out_sl):\n", " h, outp = self.decoder(dec_inp, h)\n", " res.append(outp)\n", " dec_inp = outp.max(1)[1]\n", " if (dec_inp==self.pad_idx).all(): break\n", " if (targ is not None) and (random.random()=targ.shape[1]: continue\n", " dec_inp = targ[:,i]\n", " return torch.stack(res, dim=1)\n", "\n", " def initHidden(self, bs): return one_param(self).new_zeros(self.nl, bs, self.nh)" ] }, { "cell_type": "code", "execution_count": 90, "metadata": {}, "outputs": [], "source": [ "emb_enc = torch.load(model_path/'fr_emb.pth')\n", "emb_dec = torch.load(model_path/'en_emb.pth')" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [], "source": [ "rnn_tf = Seq2SeqRNN_tf(emb_enc, emb_dec, 256, 30)\n", "\n", "learn = Learner(data, rnn_tf, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))],\n", " callback_fns=partial(TeacherForcing, end_epoch=3))" ] }, { "cell_type": "code", "execution_count": 74, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": 75, "metadata": { "collapsed": true }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XmUXGW57/HvU1U9dzrppDsh6ZDREEVkSouACiigOBwBFa8svaJwZek9guhVj2exjnp04XBwOkfXvV5URK+CS6YjoCg4MBxkCoRACINABjKQ7kzdSXdVdw3P/aN2dYrQne6ka1dV1/591qpVVXvvqvd5U5166n3fvd/X3B0REYmuWKUDEBGRylIiEBGJOCUCEZGIUyIQEYk4JQIRkYhTIhARiTglAhGRiFMiEBGJOCUCEZGIS1Q6gIno6OjwRYsWVToMEZEp5ZFHHtnu7p3jHTclEsGiRYtYuXJlpcMQEZlSzGzDRI5T15CISMQpEYiIRJwSgYhIxCkRiIhEnBKBiEjEKRGIiEScEoGISMQpEYiIVKGBoQxfvXUt67cPhF6WEoGISBW6Y+1LXH3fOnr3DoVelhKBiEgVuunRzcxvb2LFgvbQy1IiEBGpMj39Ke57bjvnHtdFLGahl6dEICJSZW5ZvYWcwznHdZWlPCUCEZEqc/OqzRwzfzpLO1vLUp4SgYhIFXl22x6e3NJfttYAKBGIiFSVm1dtJh4z/uGYeWUrU4lARKRK5HLOb1dt5pRlHXS0NpStXCUCEZEq8eC6nWzpS3Hu8fPLWq4SgYhIlbh51SZa6uOc+Zo5ZS1XiUBEpAqk0lluf+IlzjpqLk318bKWHVoiMLOrzazHzNaMsu9zZuZm1hFW+SIiU8mfntrGnqEM7z2+fGcLFYTZIrgGOGv/jWZ2OHAmsDHEskVEpgx35zcrNzGnrYETl8wqe/mhJQJ3vwfYOcqu7wFfADysskVEppIbH93MPc/28tGTFxMvw5QS+yvrGIGZvQfY7O6ry1muiEi1er53L1/67RpOXDKTi09ZUpEYEuUqyMyagcuBt03w+IuBiwEWLFgQYmQiIpUxlMlyybWraEjE+P5/O64irQEob4tgKbAYWG1m64H5wKNmdthoB7v7Ve7e7e7dnZ2dZQxTRKQ8vvH7p1m7tZ9vn3cMh01vrFgcZWsRuPsTwOzC8yAZdLv79nLFICJSLe5cu41r/raej71xEaeX+bqB/YV5+uh1wP3AcjPbZGYXhVWWiMhU8lJfis/fsJrXzmvji+94daXDCa9F4O7nj7N/UVhli4hUsxsf3cTuwTQ3ffJkGhLlvXhsNLqyWESkzNZtH2BOWwNLyrTewHiUCEREymzDjgEWzmqpdBgjlAhERMps3fZBFisRiIhE096hDNv3DrGwo7nSoYxQIhARKaMNOwYAWKQWgYhINK3fPggoEYiIRNb6oEWwcJa6hkREImnDjgE6pzXQ0lC2iR3GpUQgIlJG66vsjCFQIhARKav1OwaqqlsIlAhERMpmYChDz54hFnWoRSAiEkkbdlTfGUOgRCAiUjYbqvCMIVAiEBEpm3WFi8nUNSQiEk0btg/S0dpAaxWdOgpKBCIiZbN+xwCLqqxbCJQIRETKZn2VTT9doEQgIlIGg8MZtvUPsbiKZh0tUCIQESmDjTvzp46qRSAiElHrt1ff9NMFSgQiImWwPriYrJoWpClQIhARKYP12weY1VJPW2NdpUN5hdASgZldbWY9ZramaNvXzOxxM3vMzO4ws3lhlS8iUk2qcbK5gjBbBNcAZ+237Up3P9rdjwVuA74UYvkiIlVjw47BqruiuCC0RODu9wA799vWX/S0BfCwyhcRqRbJ4Sxb+1JVOVAMUPbrnM3sCuAjQB/wlnKXLyJSbvtOHY1e19Co3P1ydz8c+BXwqbGOM7OLzWylma3s7e0tX4AiIiVWWKd4cdS6hibgWuB9Y+1096vcvdvduzs7O8sYlohIaRWuIVg4U4kAM1tW9PQ9wNPlLF9EpBLW7xikvbmO6c3Vd+oohDhGYGbXAacBHWa2Cfgy8E4zWw7kgA3AJ8IqX0SkWqzfPlC1ZwxBiInA3c8fZfNPwypPRKRabdgxwBuWzKp0GGPSlcUiIiHq2ZNia3+qas8YAiUCEZFQffm3T1IXj/EPx1TvRApKBCIiIbn9ia3cvuYlLjtjGUs7WysdzpiUCEREQrB7cJh/+e2THNXVxsVvXlLpcA6oulZQFhGpEV+9bS27B4f5xYUnkIhX92/u6o5ORGQK+uszPdz06GY+edpSjpzXVulwxqVEICJSQntSaS6/6QmWzW7lU299VaXDmRB1DYmIlNAP/vIcW/tT3PjJk2lIxCsdzoSoRSAiUiLuzm2rt3DGa+Zw/IL2SoczYUoEIiIl8lzPXrb0pXjL8tmVDuWgKBGIiJTIXc/kp8w/dfnUmjFZiUBEpETufraXZbNb6ZrRVOlQDooSgYhICQwMZXho3U5Om2KtAVAiEBEpiQde2MFwNsepR0yt8QFQIhARKYm7numlqS7O6xdPnbOFCpQIREQmyd2569keTl46a8pcO1BMiUBEZJLWbR/gxZ3JKTk+AEoEIiKTdvezwWmjU3B8AJQIREQm7e5ne1nS0cKCKl6F7ECUCEREJiGVznL/8zs45Yip2S0ESgQiIpPy4LqdDGVyU3Z8AJQIREQm5a5nemhIxDhxyaxKh3LIQksEZna1mfWY2ZqibVea2dNm9riZ3WxmM8IqX0SkHO5+tpcTl8yisW7qnTZaEGaL4BrgrP223Qkc5e5HA88C/xxi+SIiodqwY4AXegc4dQqPD0CIicDd7wF27rftDnfPBE8fAOaHVb6ISJjcna/d9hT1iRhve+2cSoczKZUcI7gQuL2C5YuIHLLrV27iT09t4wtvX8789ql52mhBRRKBmV0OZIBfHeCYi81spZmt7O3tLV9wIiLjeHHnIP9665OctGQWF75xcaXDmbSyJwIzuwB4N/Ahd/exjnP3q9y92927Ozundv+biNSObM75X79ZTcyMb3/gGGIxq3RIk1bWxevN7Czgn4BT3X2wnGWLiJTCT+59gYfW7+Q75x0z5RagGUuYp49eB9wPLDezTWZ2EfBDYBpwp5k9ZmY/Cqt8EZFSe2prP9+541nOeu1hvPf4rkqHUzKhtQjc/fxRNv80rPJERML2L/+5hramOr7+3tdhNvW7hAp0ZbGIyAS9sH2As46aw8yW+kqHUlJKBCIiE5QcztJcX9ah1bJQIhARmQB3J5nOTumpJMaiRCAiMgFDmRwATUoEIiLRNDicBaC5XolARCSSkul8IlCLQEQkopJBi6BRLQIRkWhKRb1FYGZLzawheHyamV2qRWVEJErUNQQ3AlkzexX5q4MXA9eGFpWISJUpDBY31ddeR8pEa5QLFpQ5F/i+u38GmBteWCIi1aUwRtBUF90LytJmdj5wAXBbsK0unJBERKrPyBhBhAeLPwacBFzh7uvMbDHwy/DCEhGpLrU8RjChNo67rwUuBTCzdmCau38zzMBERKrJvq6h2ksEEz1r6C4zazOzmcBq4Gdm9t1wQxMRqR6FFkFjhAeLp7t7P/Be4GfuvgI4I7ywRESqS3I4S8ygPh7dRJAws7nAB9g3WCwiEhnJdH4K6lpakKZgoongq8Afgefd/WEzWwL8PbywRESqS61OQQ0THyy+Hri+6PkLwPvCCkpEpNqkhrM1eTEZTHyweL6Z3WxmPWa2zcxuNLP5YQcnIlItkulsTZ4xBBPvGvoZcAswD+gCbg22iYhEghIBdLr7z9w9E9yuATpDjEtEpKoMDtfuGMFEE8F2M/uwmcWD24eBHQd6gZldHXQlrSnadp6ZPWlmOTPrnkzgIiLllEpna3J1Mph4IriQ/KmjLwFbgfeTn3biQK4Bztpv2xry1yLcM/EQRUQqLzmcrcl5hmDiZw1tBN5TvM3MLgO+f4DX3GNmi/bb9lTw2oONU0Skomr59NHJnAv12ZJFISJS5VIaLB5VqD/rzexiM1tpZit7e3vDLEpEZFzJYSWC0XjJohjtzd2vcvdud+/u7NQJSiJSOe7OYDqiYwRmtofRv/ANaAolIhGRKjOUyeFem4vSwDiJwN2nHeobm9l1wGlAh5ltAr4M7AR+QP4ahN+Z2WPu/vZDLUNEpBxSNbwoDUzwrKFD4e7nj7Hr5rDKFBEJQy2vTgaTGyMQEYmEkdXJarRrSIlARGQcI6uTqUUgIhJNtbxeMSgRiIiMq9AiiPpcQyIikVVoEahrSEQkokbOGlKLQEQkmmr9OgIlAhGRcWiwWEQk4gbVNSQiEm2p4Sxm0JCoza/M2qyViEgJFRaur9VFtZQIRETGkazhRWlAiUBEZFzJ4VzNXkMASgQiIuNKpjM1O1AMSgQiIuOq5WUqQYlARGRcyRpephKUCERExpVM59QiEBGJspS6hkREok1dQyIiETc4nNXpoyIiUZbSBWUiItHl7iTT2ZpdnQxCTARmdrWZ9ZjZmqJtM83sTjP7e3DfHlb5IiKlkM462ZxrjOAQXQOctd+2LwJ/dvdlwJ+D5yIiVauwOpnGCA6Bu98D7Nxv89nAz4PHPwfOCat8EZFSqPXVyaD8YwRz3H0rQHA/u8zli4gclMHC6mT1tTukWrU1M7OLzWylma3s7e2tdDgiElG1vkwllD8RbDOzuQDBfc9YB7r7Ve7e7e7dnZ2dZQtQRKRYcmSZykSFIwlPuRPBLcAFweMLgN+WuXwRkYOiMYJJMLPrgPuB5Wa2ycwuAr4JnGlmfwfODJ6LiFStKHQNhdbWcffzx9h1elhlioiU2r6uoaodUp202q2ZiEgJFFoEuo5ARCSikhojEBGJtkIiaNZZQyIi0VToGmpI1O7XZe3WTESkBFLpLI11MWIxq3QooVEiEBE5gGSNr0UASgQiIgc0WOPrFYMSgYjIASXTWRpreC0CUCIQETmglFoEIiLRVuvLVIISgYjIASXT2Zq+qhiUCEREDiipriERkWhLprM1vXA9KBGIiByQWgQiIhGnMQIRkYhL6awhEZHoSmdzpLOuriERkajatzqZEoGISCSlIrA6GSgRiIiMKQqrk4ESgYjImNQ1JCIScYXVyZQIQmBmnzazNWb2pJldVokYRETGo66hkJjZUcDHgROAY4B3m9mycschIjKekRaBEkHJvQZ4wN0H3T0D3A2cW4E4REQOKCpjBIkKlLkGuMLMZgFJ4J3AylAK2tzHxp2DxGNGImbEgnt3yLmTcyebCx7nnKw72ZwHz8HJ73N3cg7u4PjI+7szsi8X3BsQM4jFDDPDADPIP8o/Ln59gVnwOjNiZsRjRmNdjIZEnIa6GI11cepiseC4/P6YQSIeIxEzEnEjEYuN1DNmBMfYy8qEfCyF19v+O0VkRFRaBGVPBO7+lJl9C7gT2AusBjL7H2dmFwMXAyxYsOCQyvr1wxv55QMbDz3YCIjH9iWFAiOfPOoTMRoSQTJKxKhP5BNNPGYjCagunt9en4jREDyuixduNvK4oS5GfdGx9fEYiZH9+fvGuvjLyht5XhejMREnFlPSkvJKpaNxHUElWgS4+0+BnwKY2deBTaMccxVwFUB3d7fvv38iLn3rMv77iYvI5HJkc04ml//FbxD8ajbiwS/mfV+I+Xsjv63wC7xwD1D8dVR4n1jRr/5CayMXtBgKwRdaAI6/ooXgDtmgZeIO6VyOoXSOoUyWVDpHKpMlm823WgqtkHx9cmSy+bplsjkyuaCFUtTCGeXflmyOYH/+NYzElpfLOelsjqFM4ZZlKJ0j6x6Ul2M4k2NgKMNQJsdwNv98OJN/v3RhWzb3spbPZNTFjcZEnIa6OI11MZrq4jQ3JJjRVMf0pjpmNOfvpzUmmNZYR2tDgtbG/P55M5robG1QMpGDUugaqvW5hiqSCMxstrv3mNkC4L3ASWGUM7utkdltjWG8tRyEbM5HksRQNstwJj9/SyZIFOls0f5C4ktnGc7m74cy+ftUUWIcSmdJZbLsSWXYPTjMhh0D7E6m6U+mGSX3AVAfjzF3RiNdM5qYPa2BGc31zGiuoz24n95UR1tTHW2NdbQ1JZjeVEdDora/AOTAksM5QC2CsNwYjBGkgX90910VikPKIB4zmurjwYBbXahl5XLOYDrLnlSavakMe4Yy7BoYZvPuJJt3JdkU3D+ycRe7B9LsGXpFr+TLNCRiTA9aHNOb6pjZUs+s1gY6WutHHs+eFtzaGmmpj2vcpYYMpjPUJ2LEa7wlWamuoTdXolypfbGY5buEGhIwffzj09kcfck0uweH6U9l6AtaFf2pDP3JNH3JNH2DwX0yzcadgzy6cTc7B4ZGbXk01cXpam9iSUcLSzpbWdLZwpKOFuYGXVP1CV3DOZWkIrAoDVSuRSBSFeriMTpaG+hobTio1+Vyzu5kmu17h+jdM0TPnhQ9/UP07Bli485BXtg+wF+f6SGdfXm2mNVSz+y2Rg5ra2DejCa62pvomtHE/PYm5rQ10jmtQd1RVSSZViIQkTHEYsbMlnz30BFzpo16TCabY9OuJOt2DLCtL8W2/iG27UnR059ia1+KVS/uZvdg+hWvm9FcF3Q3NTK/vYnDZzaP3M+d3kh7c33N91lXi2Q6V/PXEIASgUhoEvEYizpaWNTRMuYxe4cybAnGLbb1p+gJWhe9e4Z4qX+IO9duY8fA8Cte11Ifpz1IRJ1Bi6ZzWv42e1q+tTFvRhOzWup1ptQkRGG9YlAiEKmo1oYER8yZNmarAmBgKMOmXUle3DnItj0pdg0Ms3Mgza7BYXYMDLO1L8Xjm/vYsfeV4xaFM6XmTd/XDdXVnu+Kmj+jmcOmN2rc4gBS6axaBCJSeS0NCZYfNo3lh42dLCB/mu6uwWG29afYujvFlr7kyNlSW3YnuffvvfTsGXrFFe2HtTWOjFMsmNnM/JnNLAhuh7U1RrpFMTicobm+9r8ma7+GIhERj9nIwPdr541+ytRwJsfWvpefSrtpV5LNuwdZuWEXt6ze8rJWRWNdjMUd+bOflna2snzONE5YPJPOaQc3uD5VJdM5ZraoRSAiNaQ+EWPhrBYWzhp93CKdzbFld5KNOwfZsGOQddsHeKF3L09s6uP2J7aOJIlls1s5aeksTloyi6WzW2muj9Ncn6C5Pj8tSK1cS6GuIRGJnLr4vkTx5v0mh0+lszzz0h7uf2EHf3t+Bzc8solf3L9hlPcwDm9vZnFHS/7W2cLCmS3Mm9HIvBlNU+qMp/xgce2PoSgRiMiENNbFOebwGRxz+Aw+cepS0tkcj2/qY2tfksGhLAPDGQaHs/QHF96t2z7Afc9vJ5XOvex9ZrbUM29GI4tmtbBs9jSWzWnlVbNbWTSrpeoGrpPprMYIRETGUhePsWJhO9A+5jG5nPNSf4qNOwfZ2pdky+4Um3fnxyVWb9rNbY9vHTm2Ph7j6PnT6V40kxMWt7Ni4UymN4U7Jcl4kunslGrBHColAhEJTSxmI9c0jGZwOMMLvQM817OXtVv7eXj9Tn5y7wv86G7HDFYsaOe87vm86+h5+WlDyqgwWaKuIxARCVFzfYKjuqZzVNd0zjmuC8j3y696cRcPrdvJrau38E83PsFXblnLu46eywe6D6d7YXtZTmndtzpZdXVXhUGJQESqSlN9nJOXdnDy0g4+ffoyVr24m+tXvsitq7dywyObmDu9kXccNZd3vu4wjl8QXlKIyupkoEQgIlXMzDh+QTvHL2jnX959JHc8uY3fPbGVXz64gavvW8ectgbOObaLT5y6lPaW+pKWHZXVyUCJQESmiOb6BOcc18U5x3WxJ5XmL0/38LvHt/Lje1/g2oc2cslbX8VHTlpUsi/ufauT1f7XZO13folIzZnWWMfZx3Zx1Ue6+cNlp9C9sJ2v//5pzvju3dyyegtegvVRB4ejM0ZQ+zUUkZp2xJxp/OxjJ/DLi97AtMY6Lr1uFW/61l+54ndrWbVx1yEnhcIYgbqGRESmiDct6+C2S97EbY9v4T9Xbeaav63nx/euo2tGE+86ei4fesOCMafWGE1hjECDxSIiU0g8Zpx9bBdnH9tF32CaO5/axu+f2MrV/7WOH9/7Ame+Zg4XvWkxJyyeOe58SPtOH1UiEBGZkqY31/H+FfN5/4r5bOtP8Yv71/OrBzdyx9ptHNXVxsffvIR3vW4uifjoPeRROn1UYwQiUvPmtDXy+be/mvu/eDpXnHsUyeEsn/71Y5zx3bv5zcMvMpzJveI1UWoRKBGISGQ01cf50BsWcudnTuVHH15Ba2OCL9z4OKdd+VeuuW8dazb30dOfIpPNRapFUJGuITP7DPA/AAeeAD7m7qlKxCIi0ROLGWcddRhvf+0c7n62lx/+5Tm+cuvakf1m+UnwQGcNhcLMuoBLgSPdPWlmvwE+CFxT7lhEJNrMjNOWz+bUIzp5auseNu4cpHfvEL178rfZ0xqoG2MMoZZUarA4ATSZWRpoBrZUKA4REcyMI+e1ceS8tkqHUhFlT3Xuvhn4NrAR2Ar0ufsd5Y5DRETyyp4IzKwdOBtYDMwDWszsw6Mcd7GZrTSzlb29veUOU0QkMirR+XUGsM7de909DdwEnLz/Qe5+lbt3u3t3Z2dn2YMUEYmKSiSCjcCJZtZs+Uv7TgeeqkAcIiJCZcYIHgRuAB4lf+poDLiq3HGIiEheRc4acvcvA1+uRNkiIvJytX+CrIiIHJASgYhIxFkpVvIJm5n1AhuKNk0H+vY7bCLbip+P9bgD2D7JkEeL5WCPG2vfePWcSJ1LUccDxXgwx020nlP5sxxrf1T/Zg/ls9Xf7MQVv99Cdx//tEt3n3I34KpD2Vb8/ACPV4YR38EeN9a+8eo5kTqXoo7lrudU/iwn8rnVQj1L8VmOVU/9zZa+jsW3qdo1dOshbrt1Ao9LYaLvd6Djxto3Xj0nWudSKGc9p/JnOdb+qP7NHupnWwrVUs9qqOOIKdE1VE5mttLduysdR5iiUEdQPWtJFOoIlavnVG0RhCkK1zREoY6getaSKNQRKlRPtQhERCJOLQIRkYir2URgZlebWY+ZrTmE164wsyfM7Dkz+49gTqTCvkvM7Bkze9LM/q20UR+8MOppZl8xs81m9lhwe2fpIz/oWEP5PIP9nzMzN7OO0kV88EL6LL9mZo8Hn+MdZjav9JEfdKxh1PNKM3s6qOvNZjaj9JEfdKxh1PO84LsnZ2alG0soxSlZ1XgDTgGOB9YcwmsfAk4CDLgdeEew/S3An4CG4PnsGq3nV4DPVbpuYdcz2Hc48Efy16l01FodgbaiYy4FflSLnyXwNiARPP4W8K0aredrgOXAXUB3qWKt2RaBu98D7CzeZmZLzewPZvaImd1rZq/e/3VmNpf8f577Pf8v/wvgnGD3J4FvuvtQUEZPuLUYX0j1rDoh1vN7wBfIr59dUWHU0d37iw5toXbreYe7Z4JDHwDmh1uL8YVUz6fc/ZlSx1qziWAMVwGXuPsK4HPA/x7lmC5gU9HzTcE2gCOAN5vZg2Z2t5m9PtRoD91k6wnwqaCZfbXlFxOqRpOqp5m9B9js7qvDDnQSJv1ZmtkVZvYi8CHgSyHGOhml+JstuJD8r+hqVMp6lkyl1iwuOzNrJb8AzvVFXcQNox06yrbCr6gE0A6cCLwe+I2ZLQmydlUoUT3/D/C14PnXgO+Q/89VNSZbTzNrBi4n36VQlUr0WeLulwOXm9k/A5+iymb+LVU9g/e6HMgAvypljKVQynqWWmQSAfnWz253P7Z4o5nFgUeCp7eQ/xIsblbOB7YEjzcBNwVf/A+ZWY783CDVtJbmpOvp7tuKXvdj4LYwAz5Ek63nUvLLpa4O/lPOBx41sxPc/aWQY5+oUvzNFrsW+B1VlggoUT3N7ALg3cDp1fTjrEipP8/SqfSASpg3YBFFAzXA34DzgscGHDPG6x4m/6u/MFDzzmD7J4CvBo+PAF4kuBajxuo5t+iYzwC/rnQdw6jnfsesp8KDxSF9lsuKjrkEuKHSdQypnmcBa4HOStctzHoW7b+LEg4WV/wfKsQP4DpgK5Am/0v+IvK/AP8ArA7+aL40xmu7gTXA88APC1/2QD3wy2Dfo8Bba7Se/4/86nGPk/+FMrdc9SlnPfc7puKJIKTP8sZg++Pk56DpqsXPEniO/A+zx4JbNZwdFUY9zw3eawjYBvyxFLHqymIRkYiL2llDIiKyHyUCEZGIUyIQEYk4JQIRkYhTIhARiTglApmSzGxvmcv7iZkdWaL3ygazga4xs1vHmynTzGaY2f8sRdkio9HpozIlmdled28t4fslfN+kZaEqjt3Mfg486+5XHOD4RcBt7n5UOeKT6FGLQGqGmXWa2Y1m9nBwe2Ow/QQz+5uZrQrulwfbP2pm15vZrcAdZnaamd1lZjcEc9v/qmge+LsK87+b2d5gIrfVZvaAmc0Jti8Nnj9sZl+dYKvlfvZNgtdqZn82s0ctPxf92cEx3wSWBq2IK4NjPx+U87iZ/WsJ/xklgpQIpJb8O/A9d3898D7gJ8H2p4FT3P048rNvfr3oNScBF7j7W4PnxwGXAUcCS4A3jlJOC/CAux8D3AN8vKj8fw/KH3dumGCOmdPJX70NkALOdffjya998Z0gEX0ReN7dj3X3z5vZ24BlwAnAscAKMztlvPJExhKlSeek9p0BHFk0s2ObmU0DpgM/N7Nl5GdxrCt6zZ3uXjxn/EPuvgnAzB4jP1fMf+1XzjD7JuJ7BDgzeHwS+9Y6uBb49hhxNhW99yPAncF2A74efKnnyLcU5ozy+rcFt1XB81byieGeMcoTOSAlAqklMeAkd08WbzSzHwB/dfdzg/72u4p2D+z3HkNFj7OM/n8k7fsG18Y65kCS7n6smU0nn1D+EfgP8usFdAIr3D1tZuuBxlFeb8A33P3/HmS5IqNS15DUkjvIz7cPgJkVpvudDmwOHn80xPIfIN8lBfDB8Q529z7yy0d+zszqyMfZEySBtwALg0P3ANOKXvpH4MJgfnvMrMvMZpeoDhJBSgQyVTWb2aai22fJf6l2BwOoa8lPGw7wb8A3zOw+IB5iTJcBnzWzh4C5QN94L3D3VeRnovwg+cVUus1sJfnWwdPBMTt9BbZIAAAAcElEQVSA+4LTTa909zvIdz3db2ZPADfw8kQhclB0+qhIiQSrniXd3c3sg8D57n72eK8TqTSNEYiUzgrgh8GZPrupsuU9RcaiFoGISMRpjEBEJOKUCEREIk6JQEQk4pQIREQiTolARCTilAhERCLu/wMXM/+Y9ZMneAAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": 92, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossseq2seq_accbleutime
02.3054735.4018670.1957430.09485501:25
12.6631294.8585450.3726530.33577101:13
23.3372674.3051450.3868220.31958501:07
34.2806784.9378340.3141670.24047801:01
43.4619644.0868160.4011470.30492501:06
53.1545854.0224320.4077920.31071501:07
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(6, 3e-3)" ] }, { "cell_type": "code", "execution_count": 77, "metadata": { "collapsed": true }, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " 100.00% [151/151 00:23<00:00]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "inputs, targets, outputs = get_predictions(learn)" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Text xxbos qui a le pouvoir de modifier le règlement sur les poids et mesures et le règlement sur l'inspection de l'électricité et du gaz ?,\n", " Text xxbos who has the authority to change the electricity and gas inspection regulations and the weights and measures regulations ?,\n", " Text xxbos who has the xxunk and xxunk and xxunk xxunk ?)" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs[700],targets[700],outputs[700]" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Text xxbos quelles sont les deux tendances qui ont nuit à la pêche au saumon dans cette province ?,\n", " Text xxbos what two trends negatively affected the province ’s salmon fishery ?,\n", " Text xxbos what are the main reasons for the xxunk of the xxunk ?)" ] }, "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs[2513], targets[2513], outputs[2513]" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Text xxbos où les aires marines nationales de conservation du canada seront - elles situées ?,\n", " Text xxbos where will national marine conservation areas of canada be located ?,\n", " Text xxbos where are the canadian regulations located in the canadian ?)" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs[4000], targets[4000], outputs[4000]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 2 }