{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Summary of my results:\n",
"\n",
"model | train_loss | valid_loss | seq2seq_acc | bleu\n",
"-------------------|----------|----------|----------|----------\n",
"seq2seq | 3.355085 | 4.272877 | 0.382089 | 0.291899\n",
"\\+ teacher forcing | 3.154585 |\t4.022432 | 0.407792 | 0.310715\n",
"\\+ attention | 1.452292 | 3.420485 | 0.498205 | 0.413232\n",
"transformer | 1.913152 | 2.349686 | 0.781749 | 0.612880"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Translation with an RNN"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook is modified from [this one](https://github.com/fastai/fastai_docs/blob/master/dev_course/dl2/translation.ipynb) created by Sylvain Gugger.\n",
"\n",
"Today we will be tackling the task of translation. We will be translating from French to English, and to keep our task a manageable size, we will limit ourselves to translating questions.\n",
"\n",
"This task is an example of sequence to sequence (seq2seq). Seq2seq can be more challenging than classification, since the output is of variable length (and typically different from the length of the input."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"French/English parallel texts from http://www.statmt.org/wmt15/translation-task.html . It was created by Chris Callison-Burch, who crawled millions of web pages and then used *a set of simple heuristics to transform French URLs onto English URLs (i.e. replacing \"fr\" with \"en\" and about 40 other hand-written rules), and assume that these documents are translations of each other*."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Translation is much tougher in straight PyTorch: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from fastai.text import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download and preprocess our data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will start by reducing the original dataset to questions. You only need to execute this once, uncomment to run. The dataset can be downloaded [here](https://s3.amazonaws.com/fast-ai-nlp/giga-fren.tgz)."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"path = Config().data_path()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# ! wget https://s3.amazonaws.com/fast-ai-nlp/giga-fren.tgz -P {path}"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# ! tar xf {path}/giga-fren.tgz -C {path} "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[PosixPath('/home/racheltho/.fastai/data/giga-fren/models'),\n",
" PosixPath('/home/racheltho/.fastai/data/giga-fren/giga-fren.release2.fixed.fr'),\n",
" PosixPath('/home/racheltho/.fastai/data/giga-fren/cc.en.300.bin'),\n",
" PosixPath('/home/racheltho/.fastai/data/giga-fren/data_save.pkl'),\n",
" PosixPath('/home/racheltho/.fastai/data/giga-fren/giga-fren.release2.fixed.en'),\n",
" PosixPath('/home/racheltho/.fastai/data/giga-fren/cc.fr.300.bin'),\n",
" PosixPath('/home/racheltho/.fastai/data/giga-fren/questions_easy.csv')]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = Config().data_path()/'giga-fren'\n",
"path.ls()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# with open(path/'giga-fren.release2.fixed.fr') as f: fr = f.read().split('\\n')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# with open(path/'giga-fren.release2.fixed.en') as f: en = f.read().split('\\n')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use regex to pick out questions by finding the strings in the English dataset that start with \"Wh\" and end with a question mark. You only need to run these lines once:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# re_eq = re.compile('^(Wh[^?.!]+\\?)')\n",
"# re_fq = re.compile('^([^?.!]+\\?)')\n",
"# en_fname = path/'giga-fren.release2.fixed.en'\n",
"# fr_fname = path/'giga-fren.release2.fixed.fr'"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# lines = ((re_eq.search(eq), re_fq.search(fq)) \n",
"# for eq, fq in zip(open(en_fname, encoding='utf-8'), open(fr_fname, encoding='utf-8')))\n",
"# qs = [(e.group(), f.group()) for e,f in lines if e and f]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# qs = [(q1,q2) for q1,q2 in qs]\n",
"# df = pd.DataFrame({'fr': [q[1] for q in qs], 'en': [q[0] for q in qs]}, columns = ['en', 'fr'])\n",
"# df.to_csv(path/'questions_easy.csv', index=False)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[PosixPath('/home/racheltho/.fastai/data/giga-fren/models'),\n",
" PosixPath('/home/racheltho/.fastai/data/giga-fren/giga-fren.release2.fixed.fr'),\n",
" PosixPath('/home/racheltho/.fastai/data/giga-fren/cc.en.300.bin'),\n",
" PosixPath('/home/racheltho/.fastai/data/giga-fren/data_save.pkl'),\n",
" PosixPath('/home/racheltho/.fastai/data/giga-fren/giga-fren.release2.fixed.en'),\n",
" PosixPath('/home/racheltho/.fastai/data/giga-fren/cc.fr.300.bin'),\n",
" PosixPath('/home/racheltho/.fastai/data/giga-fren/questions_easy.csv')]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path.ls()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load our data into a DataBunch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our questions look like this now:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
"
\n",
"
\n",
"
en
\n",
"
fr
\n",
"
\n",
" \n",
" \n",
"
\n",
"
0
\n",
"
What is light ?
\n",
"
Qu’est-ce que la lumière?
\n",
"
\n",
"
\n",
"
1
\n",
"
Who are we?
\n",
"
Où sommes-nous?
\n",
"
\n",
"
\n",
"
2
\n",
"
Where did we come from?
\n",
"
D'où venons-nous?
\n",
"
\n",
"
\n",
"
3
\n",
"
What would we do without it?
\n",
"
Que ferions-nous sans elle ?
\n",
"
\n",
"
\n",
"
4
\n",
"
What is the absolute location (latitude and lo...
\n",
"
Quelle sont les coordonnées (latitude et longi...
\n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" en \\\n",
"0 What is light ? \n",
"1 Who are we? \n",
"2 Where did we come from? \n",
"3 What would we do without it? \n",
"4 What is the absolute location (latitude and lo... \n",
"\n",
" fr \n",
"0 Qu’est-ce que la lumière? \n",
"1 Où sommes-nous? \n",
"2 D'où venons-nous? \n",
"3 Que ferions-nous sans elle ? \n",
"4 Quelle sont les coordonnées (latitude et longi... "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df = pd.read_csv(path/'questions_easy.csv')\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To make it simple, we lowercase everything."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"df['en'] = df['en'].apply(lambda x:x.lower())\n",
"df['fr'] = df['fr'].apply(lambda x:x.lower())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The first thing is that we will need to collate inputs and targets in a batch: they have different lengths so we need to add padding to make the sequence length the same;"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def seq2seq_collate(samples, pad_idx=1, pad_first=True, backwards=False):\n",
" \"Function that collect samples and adds padding. Flips token order if needed\"\n",
" samples = to_data(samples)\n",
" max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples])\n",
" res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx\n",
" res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx\n",
" if backwards: pad_first = not pad_first\n",
" for i,s in enumerate(samples):\n",
" if pad_first: \n",
" res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])\n",
" else: \n",
" res_x[i,:len(s[0]):],res_y[i,:len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])\n",
" if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1)\n",
" return res_x,res_y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we create a special `DataBunch` that uses this collate function."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"doc(Dataset)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"doc(DataLoader)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"doc(DataBunch)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"class Seq2SeqDataBunch(TextDataBunch):\n",
" \"Create a `TextDataBunch` suitable for training an RNN classifier.\"\n",
" @classmethod\n",
" def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1,\n",
" dl_tfms=None, pad_first=False, device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch:\n",
" \"Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`\"\n",
" datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n",
" val_bs = ifnone(val_bs, bs)\n",
" collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)\n",
" train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2)\n",
" train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)\n",
" dataloaders = [train_dl]\n",
" for ds in datasets[1:]:\n",
" lengths = [len(t) for t in ds.x.items]\n",
" sampler = SortSampler(ds.x, key=lengths.__getitem__)\n",
" dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))\n",
" return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"SortishSampler??"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And a subclass of `TextList` that will use this `DataBunch` class in the call `.databunch` and will use `TextList` to label (since our targets are other texts)."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"class Seq2SeqTextList(TextList):\n",
" _bunch = Seq2SeqDataBunch\n",
" _label_cls = TextList"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Thats all we need to use the data block API!"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"src = Seq2SeqTextList.from_df(df, path = path, cols='fr').split_by_rand_pct(seed=42).label_from_df(cols='en', label_cls=TextList)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"28.0"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.percentile([len(o) for o in src.train.x.items] + [len(o) for o in src.valid.x.items], 90)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"23.0"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.percentile([len(o) for o in src.train.y.items] + [len(o) for o in src.valid.y.items], 90)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We remove the items where one of the target is more than 30 tokens long."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"src = src.filter_by_func(lambda x,y: len(x) > 30 or len(y) > 30)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"48352"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(src.train) + len(src.valid)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"data = src.databunch()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"data.save()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Seq2SeqDataBunch;\n",
"\n",
"Train: LabelList (38706 items)\n",
"x: Seq2SeqTextList\n",
"xxbos qu’est - ce que la lumière ?,xxbos où sommes - nous ?,xxbos d'où venons - nous ?,xxbos que ferions - nous sans elle ?,xxbos quel est le groupe autochtone principal sur l’île de vancouver ?\n",
"y: TextList\n",
"xxbos what is light ?,xxbos who are we ?,xxbos where did we come from ?,xxbos what would we do without it ?,xxbos what is the major aboriginal group on vancouver island ?\n",
"Path: /home/racheltho/.fastai/data/giga-fren;\n",
"\n",
"Valid: LabelList (9646 items)\n",
"x: Seq2SeqTextList\n",
"xxbos quels pourraient être les effets sur l’instrument de xxunk et sur l’aide humanitaire qui ne sont pas co - xxunk ?,xxbos quand la source primaire a - t - elle été créée ?,xxbos pourquoi tant de soldats ont - ils fait xxunk de ne pas voir ce qui s'est passé le 4 et le 16 mars ?,xxbos quels sont les taux d'impôt sur le revenu au canada pour 2007 ?,xxbos pourquoi le programme devrait - il intéresser les employeurs et les fournisseurs de services ?\n",
"y: TextList\n",
"xxbos what would be the resulting effects on the pre - accession instrument and humanitarian aid that are not co - decided ?,xxbos when was the primary source created ?,xxbos why did so many soldiers look the other way in relation to the incidents of march 4th and march xxunk ?,xxbos what are the income tax rates in canada for 2007 ?,xxbos why is the program good for employers and service providers ?\n",
"Path: /home/racheltho/.fastai/data/giga-fren;\n",
"\n",
"Test: None"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PosixPath('/home/racheltho/.fastai/data/giga-fren')"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"data = load_data(path)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
"
\n",
"
text
\n",
"
target
\n",
"
\n",
" \n",
" \n",
"
\n",
"
xxbos quelle position devrait - il défendre pour concilier les objectifs stratégiques des divers traités internationaux sur la propriété intellectuelle , l’environnement , et les droits sociaux et économiques ?
\n",
"
xxbos what position should canada advocate with respect to xxunk the policy objectives of various international treaties on intellectual property , the environment , and social and economic rights ?
\n",
"
\n",
"
\n",
"
xxbos que faire s’il semble que pour sauver un stock local de poisson de fond , il xxunk réduire ou éliminer la prédation par les phoques dans le secteur ?
\n",
"
xxbos what if it appears that in some xxunk , saving a local groundfish stock would require reducing or xxunk seal predation in that area ?
\n",
"
\n",
"
\n",
"
xxbos quels sont les impacts économiques produits par les xxunk millions de dollars dépensés par les résidents du yukon qui ont participé à des activités reliées à la nature ?
\n",
"
xxbos what are the economic impacts that result from participation in nature - related activities by residents of the yukon ?
\n",
"
\n",
"
\n",
"
xxbos quelles pourraient être les raisons pour lesquelles un programme n ' a pas marché aussi bien que prévu , même si les employés ont effectué un travail excellent ?
\n",
"
xxbos what would be some of the reasons why a program could be less than successful , even if staff were excellent ?
\n",
"
\n",
"
\n",
"
xxbos quand les pièces , les feuilles ou les fils métalliques contenant des substances de l’inrp figurant dans les parties 1a et 1b perdent - ils leur statut xxunk ?
\n",
"
xxbos when do metal parts , sheets or xxunk containing npri part xxunk and xxunk substances lose their status as articles ?
\n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data.show_batch()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create our Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pretrained embeddings"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You will need to download the word embeddings (crawl vectors) from the fastText docs. FastText has [pre-trained word vectors](https://fasttext.cc/docs/en/crawl-vectors.html) for 157 languages, trained on Common Crawl and Wikipedia. These models were trained using CBOW.\n",
"\n",
"If you need a refresher on word embeddings, you can check out my gentle intro in [this word embedding workshop](https://www.youtube.com/watch?v=25nC0n9ERq4&list=PLtmWHNX-gukLQlMvtRJ19s7-8MrnRV6h6&index=10&t=0s) with accompanying [github repo](https://github.com/fastai/word-embeddings-workshop). \n",
"\n",
"More reading on CBOW (Continuous Bag of Words vs. Skip-grams):\n",
"\n",
"- [fastText tutorial](https://fasttext.cc/docs/en/unsupervised-tutorial.html#advanced-readers-skipgram-versus-cbow)\n",
"- [StackOverflow](https://stackoverflow.com/questions/38287772/cbow-v-s-skip-gram-why-invert-context-and-target-words)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To install fastText:\n",
"```\n",
"$ git clone https://github.com/facebookresearch/fastText.git\n",
"$ cd fastText\n",
"$ pip install .\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"import fastText as ft"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The lines to download the word vectors only need to be run once:"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"# ! wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz -P {path}\n",
"# ! wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.bin.gz -P {path}"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"# gunzip {path} / cc.en.300.bin.gz\n",
"# gunzip {path} / cc.fr.300.bin.gz"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"fr_vecs = ft.load_model(str((path/'cc.fr.300.bin')))\n",
"en_vecs = ft.load_model(str((path/'cc.en.300.bin')))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We create an embedding module with the pretrained vectors and random data for the missing parts."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"def create_emb(vecs, itos, em_sz=300, mult=1.):\n",
" emb = nn.Embedding(len(itos), em_sz, padding_idx=1)\n",
" wgts = emb.weight.data\n",
" vec_dic = {w:vecs.get_word_vector(w) for w in vecs.get_words()}\n",
" miss = []\n",
" for i,w in enumerate(itos):\n",
" try: wgts[i] = tensor(vec_dic[w])\n",
" except: miss.append(w)\n",
" return emb"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"emb_enc = create_emb(fr_vecs, data.x.vocab.itos)\n",
"emb_dec = create_emb(en_vecs, data.y.vocab.itos)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([11336, 300]), torch.Size([8152, 300]))"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb_enc.weight.size(), emb_dec.weight.size()"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"model_path = Config().model_path()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"torch.save(emb_enc, model_path/'fr_emb.pth')\n",
"torch.save(emb_dec, model_path/'en_emb.pth')"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"emb_enc = torch.load(model_path/'fr_emb.pth')\n",
"emb_dec = torch.load(model_path/'en_emb.pth')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Our Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Review Question: What are the two types of numbers in deep learning?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Encoders & Decoders"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The model in itself consists in an encoder and a decoder\n",
"\n",
"![Seq2seq model](images/seq2seq.png)\n",
"\n",
"
\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"inputs, targets, outputs = get_predictions(learn)"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos quels sont les résultats prévus à court et à long termes de xxunk , et dans quelle mesure ont - ils été obtenus ?,\n",
" Text xxbos what are the short and long - term expected outcomes of the ali and to what extent have they been achieved ?,\n",
" Text xxbos what were the results , the , , , , , , and and and and and)"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[700], targets[700], outputs[700]"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos de quel(s ) xxunk ) a - t - on besoin pour xxunk les profits réels de la compagnie pour l'année qui vient ?,\n",
" Text xxbos which of the following additional information is necessary to estimate the company 's actual profit for the coming year ?,\n",
" Text xxbos what is the the to to to the the ( ( ) ))"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[701], targets[701], outputs[701]"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos de quelles façons l'expérience et les capacités particulières des agences d'exécution contribuent - elles au projet ?,\n",
" Text xxbos what experience and specific capacities do the implementing organizations bring to the project ?,\n",
" Text xxbos what are the key and and and and and and of of of of of of ?)"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[2513], targets[2513], outputs[2513]"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos qu'est - ce que la maladie de xxunk - xxunk ( mcj ) ?,\n",
" Text xxbos what is xxunk - xxunk disease ( cjd ) ?,\n",
" Text xxbos what is the xxunk ( ( ) ))"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[4000], targets[4000], outputs[4000]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's usually beginning well, but falls into repeated words at the end of the question."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Teacher forcing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One way to help training is to help the decoder by feeding it the real targets instead of its predictions (if it starts with wrong words, it's very unlikely to give us the right translation). We do that all the time at the beginning, then progressively reduce the amount of teacher forcing."
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"class TeacherForcing(LearnerCallback):\n",
" \n",
" def __init__(self, learn, end_epoch):\n",
" super().__init__(learn)\n",
" self.end_epoch = end_epoch\n",
" \n",
" def on_batch_begin(self, last_input, last_target, train, **kwargs):\n",
" if train: return {'last_input': [last_input, last_target]}\n",
" \n",
" def on_epoch_begin(self, epoch, **kwargs):\n",
" self.learn.model.pr_force = 1 - epoch/self.end_epoch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will add the following code to our `forward` method:\n",
"\n",
"```\n",
" if (targ is not None) and (random.random()=targ.shape[1]: break\n",
" dec_inp = targ[:,i]\n",
"```\n",
"Additionally, `forward` will take an additional argument of `target`."
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
"class Seq2SeqRNN_tf(nn.Module):\n",
" def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1):\n",
" super().__init__()\n",
" self.nl,self.nh,self.out_sl = nl,nh,out_sl\n",
" self.bos_idx,self.pad_idx = bos_idx,pad_idx\n",
" self.em_sz_enc = emb_enc.embedding_dim\n",
" self.em_sz_dec = emb_dec.embedding_dim\n",
" self.voc_sz_dec = emb_dec.num_embeddings\n",
" \n",
" self.emb_enc = emb_enc\n",
" self.emb_enc_drop = nn.Dropout(0.15)\n",
" self.gru_enc = nn.GRU(self.em_sz_enc, nh, num_layers=nl,\n",
" dropout=0.25, batch_first=True)\n",
" self.out_enc = nn.Linear(nh, self.em_sz_dec, bias=False)\n",
" \n",
" self.emb_dec = emb_dec\n",
" self.gru_dec = nn.GRU(self.em_sz_dec, self.em_sz_dec, num_layers=nl,\n",
" dropout=0.1, batch_first=True)\n",
" self.out_drop = nn.Dropout(0.35)\n",
" self.out = nn.Linear(self.em_sz_dec, self.voc_sz_dec)\n",
" self.out.weight.data = self.emb_dec.weight.data\n",
" self.pr_force = 0.\n",
" \n",
" def encoder(self, bs, inp):\n",
" h = self.initHidden(bs)\n",
" emb = self.emb_enc_drop(self.emb_enc(inp))\n",
" _, h = self.gru_enc(emb, h)\n",
" h = self.out_enc(h)\n",
" return h\n",
" \n",
" def decoder(self, dec_inp, h):\n",
" emb = self.emb_dec(dec_inp).unsqueeze(1)\n",
" outp, h = self.gru_dec(emb, h)\n",
" outp = self.out(self.out_drop(outp[:,0]))\n",
" return h, outp\n",
" \n",
" def forward(self, inp, targ=None):\n",
" bs, sl = inp.size()\n",
" h = self.encoder(bs, inp)\n",
" dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n",
" \n",
" res = []\n",
" for i in range(self.out_sl):\n",
" h, outp = self.decoder(dec_inp, h)\n",
" res.append(outp)\n",
" dec_inp = outp.max(1)[1]\n",
" if (dec_inp==self.pad_idx).all(): break\n",
" if (targ is not None) and (random.random()=targ.shape[1]: continue\n",
" dec_inp = targ[:,i]\n",
" return torch.stack(res, dim=1)\n",
"\n",
" def initHidden(self, bs): return one_param(self).new_zeros(self.nl, bs, self.nh)"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
"emb_enc = torch.load(model_path/'fr_emb.pth')\n",
"emb_dec = torch.load(model_path/'en_emb.pth')"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
"rnn_tf = Seq2SeqRNN_tf(emb_enc, emb_dec, 256, 30)\n",
"\n",
"learn = Learner(data, rnn_tf, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))],\n",
" callback_fns=partial(TeacherForcing, end_epoch=3))"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n"
]
}
],
"source": [
"learn.lr_find()"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"collapsed": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3XmUXGW57/HvU1U9dzrppDsh6ZDREEVkSouACiigOBwBFa8svaJwZek9guhVj2exjnp04XBwOkfXvV5URK+CS6YjoCg4MBxkCoRACINABjKQ7kzdSXdVdw3P/aN2dYrQne6ka1dV1/591qpVVXvvqvd5U5166n3fvd/X3B0REYmuWKUDEBGRylIiEBGJOCUCEZGIUyIQEYk4JQIRkYhTIhARiTglAhGRiFMiEBGJOCUCEZGIS1Q6gIno6OjwRYsWVToMEZEp5ZFHHtnu7p3jHTclEsGiRYtYuXJlpcMQEZlSzGzDRI5T15CISMQpEYiIRJwSgYhIxCkRiIhEnBKBiEjEKRGIiEScEoGISMQpEYiIVKGBoQxfvXUt67cPhF6WEoGISBW6Y+1LXH3fOnr3DoVelhKBiEgVuunRzcxvb2LFgvbQy1IiEBGpMj39Ke57bjvnHtdFLGahl6dEICJSZW5ZvYWcwznHdZWlPCUCEZEqc/OqzRwzfzpLO1vLUp4SgYhIFXl22x6e3NJfttYAKBGIiFSVm1dtJh4z/uGYeWUrU4lARKRK5HLOb1dt5pRlHXS0NpStXCUCEZEq8eC6nWzpS3Hu8fPLWq4SgYhIlbh51SZa6uOc+Zo5ZS1XiUBEpAqk0lluf+IlzjpqLk318bKWHVoiMLOrzazHzNaMsu9zZuZm1hFW+SIiU8mfntrGnqEM7z2+fGcLFYTZIrgGOGv/jWZ2OHAmsDHEskVEpgx35zcrNzGnrYETl8wqe/mhJQJ3vwfYOcqu7wFfADysskVEppIbH93MPc/28tGTFxMvw5QS+yvrGIGZvQfY7O6ry1muiEi1er53L1/67RpOXDKTi09ZUpEYEuUqyMyagcuBt03w+IuBiwEWLFgQYmQiIpUxlMlyybWraEjE+P5/O64irQEob4tgKbAYWG1m64H5wKNmdthoB7v7Ve7e7e7dnZ2dZQxTRKQ8vvH7p1m7tZ9vn3cMh01vrFgcZWsRuPsTwOzC8yAZdLv79nLFICJSLe5cu41r/raej71xEaeX+bqB/YV5+uh1wP3AcjPbZGYXhVWWiMhU8lJfis/fsJrXzmvji+94daXDCa9F4O7nj7N/UVhli4hUsxsf3cTuwTQ3ffJkGhLlvXhsNLqyWESkzNZtH2BOWwNLyrTewHiUCEREymzDjgEWzmqpdBgjlAhERMps3fZBFisRiIhE096hDNv3DrGwo7nSoYxQIhARKaMNOwYAWKQWgYhINK3fPggoEYiIRNb6oEWwcJa6hkREImnDjgE6pzXQ0lC2iR3GpUQgIlJG66vsjCFQIhARKav1OwaqqlsIlAhERMpmYChDz54hFnWoRSAiEkkbdlTfGUOgRCAiUjYbqvCMIVAiEBEpm3WFi8nUNSQiEk0btg/S0dpAaxWdOgpKBCIiZbN+xwCLqqxbCJQIRETKZn2VTT9doEQgIlIGg8MZtvUPsbiKZh0tUCIQESmDjTvzp46qRSAiElHrt1ff9NMFSgQiImWwPriYrJoWpClQIhARKYP12weY1VJPW2NdpUN5hdASgZldbWY9ZramaNvXzOxxM3vMzO4ws3lhlS8iUk2qcbK5gjBbBNcAZ+237Up3P9rdjwVuA74UYvkiIlVjw47BqruiuCC0RODu9wA799vWX/S0BfCwyhcRqRbJ4Sxb+1JVOVAMUPbrnM3sCuAjQB/wlnKXLyJSbvtOHY1e19Co3P1ydz8c+BXwqbGOM7OLzWylma3s7e0tX4AiIiVWWKd4cdS6hibgWuB9Y+1096vcvdvduzs7O8sYlohIaRWuIVg4U4kAM1tW9PQ9wNPlLF9EpBLW7xikvbmO6c3Vd+oohDhGYGbXAacBHWa2Cfgy8E4zWw7kgA3AJ8IqX0SkWqzfPlC1ZwxBiInA3c8fZfNPwypPRKRabdgxwBuWzKp0GGPSlcUiIiHq2ZNia3+qas8YAiUCEZFQffm3T1IXj/EPx1TvRApKBCIiIbn9ia3cvuYlLjtjGUs7WysdzpiUCEREQrB7cJh/+e2THNXVxsVvXlLpcA6oulZQFhGpEV+9bS27B4f5xYUnkIhX92/u6o5ORGQK+uszPdz06GY+edpSjpzXVulwxqVEICJSQntSaS6/6QmWzW7lU299VaXDmRB1DYmIlNAP/vIcW/tT3PjJk2lIxCsdzoSoRSAiUiLuzm2rt3DGa+Zw/IL2SoczYUoEIiIl8lzPXrb0pXjL8tmVDuWgKBGIiJTIXc/kp8w/dfnUmjFZiUBEpETufraXZbNb6ZrRVOlQDooSgYhICQwMZXho3U5Om2KtAVAiEBEpiQde2MFwNsepR0yt8QFQIhARKYm7numlqS7O6xdPnbOFCpQIREQmyd2569keTl46a8pcO1BMiUBEZJLWbR/gxZ3JKTk+AEoEIiKTdvezwWmjU3B8AJQIREQm7e5ne1nS0cKCKl6F7ECUCEREJiGVznL/8zs45Yip2S0ESgQiIpPy4LqdDGVyU3Z8AJQIREQm5a5nemhIxDhxyaxKh3LIQksEZna1mfWY2ZqibVea2dNm9riZ3WxmM8IqX0SkHO5+tpcTl8yisW7qnTZaEGaL4BrgrP223Qkc5e5HA88C/xxi+SIiodqwY4AXegc4dQqPD0CIicDd7wF27rftDnfPBE8fAOaHVb6ISJjcna/d9hT1iRhve+2cSoczKZUcI7gQuL2C5YuIHLLrV27iT09t4wtvX8789ql52mhBRRKBmV0OZIBfHeCYi81spZmt7O3tLV9wIiLjeHHnIP9665OctGQWF75xcaXDmbSyJwIzuwB4N/Ahd/exjnP3q9y92927Ozundv+biNSObM75X79ZTcyMb3/gGGIxq3RIk1bWxevN7Czgn4BT3X2wnGWLiJTCT+59gYfW7+Q75x0z5RagGUuYp49eB9wPLDezTWZ2EfBDYBpwp5k9ZmY/Cqt8EZFSe2prP9+541nOeu1hvPf4rkqHUzKhtQjc/fxRNv80rPJERML2L/+5hramOr7+3tdhNvW7hAp0ZbGIyAS9sH2As46aw8yW+kqHUlJKBCIiE5QcztJcX9ah1bJQIhARmQB3J5nOTumpJMaiRCAiMgFDmRwATUoEIiLRNDicBaC5XolARCSSkul8IlCLQEQkopJBi6BRLQIRkWhKRb1FYGZLzawheHyamV2qRWVEJErUNQQ3AlkzexX5q4MXA9eGFpWISJUpDBY31ddeR8pEa5QLFpQ5F/i+u38GmBteWCIi1aUwRtBUF90LytJmdj5wAXBbsK0unJBERKrPyBhBhAeLPwacBFzh7uvMbDHwy/DCEhGpLrU8RjChNo67rwUuBTCzdmCau38zzMBERKrJvq6h2ksEEz1r6C4zazOzmcBq4Gdm9t1wQxMRqR6FFkFjhAeLp7t7P/Be4GfuvgI4I7ywRESqS3I4S8ygPh7dRJAws7nAB9g3WCwiEhnJdH4K6lpakKZgoongq8Afgefd/WEzWwL8PbywRESqS61OQQ0THyy+Hri+6PkLwPvCCkpEpNqkhrM1eTEZTHyweL6Z3WxmPWa2zcxuNLP5YQcnIlItkulsTZ4xBBPvGvoZcAswD+gCbg22iYhEghIBdLr7z9w9E9yuATpDjEtEpKoMDtfuGMFEE8F2M/uwmcWD24eBHQd6gZldHXQlrSnadp6ZPWlmOTPrnkzgIiLllEpna3J1Mph4IriQ/KmjLwFbgfeTn3biQK4Bztpv2xry1yLcM/EQRUQqLzmcrcl5hmDiZw1tBN5TvM3MLgO+f4DX3GNmi/bb9lTw2oONU0Skomr59NHJnAv12ZJFISJS5VIaLB5VqD/rzexiM1tpZit7e3vDLEpEZFzJYSWC0XjJohjtzd2vcvdud+/u7NQJSiJSOe7OYDqiYwRmtofRv/ANaAolIhGRKjOUyeFem4vSwDiJwN2nHeobm9l1wGlAh5ltAr4M7AR+QP4ahN+Z2WPu/vZDLUNEpBxSNbwoDUzwrKFD4e7nj7Hr5rDKFBEJQy2vTgaTGyMQEYmEkdXJarRrSIlARGQcI6uTqUUgIhJNtbxeMSgRiIiMq9AiiPpcQyIikVVoEahrSEQkokbOGlKLQEQkmmr9OgIlAhGRcWiwWEQk4gbVNSQiEm2p4Sxm0JCoza/M2qyViEgJFRaur9VFtZQIRETGkazhRWlAiUBEZFzJ4VzNXkMASgQiIuNKpjM1O1AMSgQiIuOq5WUqQYlARGRcyRpephKUCERExpVM59QiEBGJspS6hkREok1dQyIiETc4nNXpoyIiUZbSBWUiItHl7iTT2ZpdnQxCTARmdrWZ9ZjZmqJtM83sTjP7e3DfHlb5IiKlkM462ZxrjOAQXQOctd+2LwJ/dvdlwJ+D5yIiVauwOpnGCA6Bu98D7Nxv89nAz4PHPwfOCat8EZFSqPXVyaD8YwRz3H0rQHA/u8zli4gclMHC6mT1tTukWrU1M7OLzWylma3s7e2tdDgiElG1vkwllD8RbDOzuQDBfc9YB7r7Ve7e7e7dnZ2dZQtQRKRYcmSZykSFIwlPuRPBLcAFweMLgN+WuXwRkYOiMYJJMLPrgPuB5Wa2ycwuAr4JnGlmfwfODJ6LiFStKHQNhdbWcffzx9h1elhlioiU2r6uoaodUp202q2ZiEgJFFoEuo5ARCSikhojEBGJtkIiaNZZQyIi0VToGmpI1O7XZe3WTESkBFLpLI11MWIxq3QooVEiEBE5gGSNr0UASgQiIgc0WOPrFYMSgYjIASXTWRpreC0CUCIQETmglFoEIiLRVuvLVIISgYjIASXT2Zq+qhiUCEREDiipriERkWhLprM1vXA9KBGIiByQWgQiIhGnMQIRkYhL6awhEZHoSmdzpLOuriERkajatzqZEoGISCSlIrA6GSgRiIiMKQqrk4ESgYjImNQ1JCIScYXVyZQIQmBmnzazNWb2pJldVokYRETGo66hkJjZUcDHgROAY4B3m9mycschIjKekRaBEkHJvQZ4wN0H3T0D3A2cW4E4REQOKCpjBIkKlLkGuMLMZgFJ4J3AylAK2tzHxp2DxGNGImbEgnt3yLmTcyebCx7nnKw72ZwHz8HJ73N3cg7u4PjI+7szsi8X3BsQM4jFDDPDADPIP8o/Ln59gVnwOjNiZsRjRmNdjIZEnIa6GI11cepiseC4/P6YQSIeIxEzEnEjEYuN1DNmBMfYy8qEfCyF19v+O0VkRFRaBGVPBO7+lJl9C7gT2AusBjL7H2dmFwMXAyxYsOCQyvr1wxv55QMbDz3YCIjH9iWFAiOfPOoTMRoSQTJKxKhP5BNNPGYjCagunt9en4jREDyuixduNvK4oS5GfdGx9fEYiZH9+fvGuvjLyht5XhejMREnFlPSkvJKpaNxHUElWgS4+0+BnwKY2deBTaMccxVwFUB3d7fvv38iLn3rMv77iYvI5HJkc04ml//FbxD8ajbiwS/mfV+I+Xsjv63wC7xwD1D8dVR4n1jRr/5CayMXtBgKwRdaAI6/ooXgDtmgZeIO6VyOoXSOoUyWVDpHKpMlm823WgqtkHx9cmSy+bplsjkyuaCFUtTCGeXflmyOYH/+NYzElpfLOelsjqFM4ZZlKJ0j6x6Ul2M4k2NgKMNQJsdwNv98OJN/v3RhWzb3spbPZNTFjcZEnIa6OI11MZrq4jQ3JJjRVMf0pjpmNOfvpzUmmNZYR2tDgtbG/P55M5robG1QMpGDUugaqvW5hiqSCMxstrv3mNkC4L3ASWGUM7utkdltjWG8tRyEbM5HksRQNstwJj9/SyZIFOls0f5C4ktnGc7m74cy+ftUUWIcSmdJZbLsSWXYPTjMhh0D7E6m6U+mGSX3AVAfjzF3RiNdM5qYPa2BGc31zGiuoz24n95UR1tTHW2NdbQ1JZjeVEdDora/AOTAksM5QC2CsNwYjBGkgX90910VikPKIB4zmurjwYBbXahl5XLOYDrLnlSavakMe4Yy7BoYZvPuJJt3JdkU3D+ycRe7B9LsGXpFr+TLNCRiTA9aHNOb6pjZUs+s1gY6WutHHs+eFtzaGmmpj2vcpYYMpjPUJ2LEa7wlWamuoTdXolypfbGY5buEGhIwffzj09kcfck0uweH6U9l6AtaFf2pDP3JNH3JNH2DwX0yzcadgzy6cTc7B4ZGbXk01cXpam9iSUcLSzpbWdLZwpKOFuYGXVP1CV3DOZWkIrAoDVSuRSBSFeriMTpaG+hobTio1+Vyzu5kmu17h+jdM0TPnhQ9/UP07Bli485BXtg+wF+f6SGdfXm2mNVSz+y2Rg5ra2DejCa62pvomtHE/PYm5rQ10jmtQd1RVSSZViIQkTHEYsbMlnz30BFzpo16TCabY9OuJOt2DLCtL8W2/iG27UnR059ia1+KVS/uZvdg+hWvm9FcF3Q3NTK/vYnDZzaP3M+d3kh7c33N91lXi2Q6V/PXEIASgUhoEvEYizpaWNTRMuYxe4cybAnGLbb1p+gJWhe9e4Z4qX+IO9duY8fA8Cte11Ifpz1IRJ1Bi6ZzWv42e1q+tTFvRhOzWup1ptQkRGG9YlAiEKmo1oYER8yZNmarAmBgKMOmXUle3DnItj0pdg0Ms3Mgza7BYXYMDLO1L8Xjm/vYsfeV4xaFM6XmTd/XDdXVnu+Kmj+jmcOmN2rc4gBS6axaBCJSeS0NCZYfNo3lh42dLCB/mu6uwWG29afYujvFlr7kyNlSW3YnuffvvfTsGXrFFe2HtTWOjFMsmNnM/JnNLAhuh7U1RrpFMTicobm+9r8ma7+GIhERj9nIwPdr541+ytRwJsfWvpefSrtpV5LNuwdZuWEXt6ze8rJWRWNdjMUd+bOflna2snzONE5YPJPOaQc3uD5VJdM5ZraoRSAiNaQ+EWPhrBYWzhp93CKdzbFld5KNOwfZsGOQddsHeKF3L09s6uP2J7aOJIlls1s5aeksTloyi6WzW2muj9Ncn6C5Pj8tSK1cS6GuIRGJnLr4vkTx5v0mh0+lszzz0h7uf2EHf3t+Bzc8solf3L9hlPcwDm9vZnFHS/7W2cLCmS3Mm9HIvBlNU+qMp/xgce2PoSgRiMiENNbFOebwGRxz+Aw+cepS0tkcj2/qY2tfksGhLAPDGQaHs/QHF96t2z7Afc9vJ5XOvex9ZrbUM29GI4tmtbBs9jSWzWnlVbNbWTSrpeoGrpPprMYIRETGUhePsWJhO9A+5jG5nPNSf4qNOwfZ2pdky+4Um3fnxyVWb9rNbY9vHTm2Ph7j6PnT6V40kxMWt7Ni4UymN4U7Jcl4kunslGrBHColAhEJTSxmI9c0jGZwOMMLvQM817OXtVv7eXj9Tn5y7wv86G7HDFYsaOe87vm86+h5+WlDyqgwWaKuIxARCVFzfYKjuqZzVNd0zjmuC8j3y696cRcPrdvJrau38E83PsFXblnLu46eywe6D6d7YXtZTmndtzpZdXVXhUGJQESqSlN9nJOXdnDy0g4+ffoyVr24m+tXvsitq7dywyObmDu9kXccNZd3vu4wjl8QXlKIyupkoEQgIlXMzDh+QTvHL2jnX959JHc8uY3fPbGVXz64gavvW8ectgbOObaLT5y6lPaW+pKWHZXVyUCJQESmiOb6BOcc18U5x3WxJ5XmL0/38LvHt/Lje1/g2oc2cslbX8VHTlpUsi/ufauT1f7XZO13folIzZnWWMfZx3Zx1Ue6+cNlp9C9sJ2v//5pzvju3dyyegtegvVRB4ejM0ZQ+zUUkZp2xJxp/OxjJ/DLi97AtMY6Lr1uFW/61l+54ndrWbVx1yEnhcIYgbqGRESmiDct6+C2S97EbY9v4T9Xbeaav63nx/euo2tGE+86ei4fesOCMafWGE1hjECDxSIiU0g8Zpx9bBdnH9tF32CaO5/axu+f2MrV/7WOH9/7Ame+Zg4XvWkxJyyeOe58SPtOH1UiEBGZkqY31/H+FfN5/4r5bOtP8Yv71/OrBzdyx9ptHNXVxsffvIR3vW4uifjoPeRROn1UYwQiUvPmtDXy+be/mvu/eDpXnHsUyeEsn/71Y5zx3bv5zcMvMpzJveI1UWoRKBGISGQ01cf50BsWcudnTuVHH15Ba2OCL9z4OKdd+VeuuW8dazb30dOfIpPNRapFUJGuITP7DPA/AAeeAD7m7qlKxCIi0ROLGWcddRhvf+0c7n62lx/+5Tm+cuvakf1m+UnwQGcNhcLMuoBLgSPdPWlmvwE+CFxT7lhEJNrMjNOWz+bUIzp5auseNu4cpHfvEL178rfZ0xqoG2MMoZZUarA4ATSZWRpoBrZUKA4REcyMI+e1ceS8tkqHUhFlT3Xuvhn4NrAR2Ar0ufsd5Y5DRETyyp4IzKwdOBtYDMwDWszsw6Mcd7GZrTSzlb29veUOU0QkMirR+XUGsM7de909DdwEnLz/Qe5+lbt3u3t3Z2dn2YMUEYmKSiSCjcCJZtZs+Uv7TgeeqkAcIiJCZcYIHgRuAB4lf+poDLiq3HGIiEheRc4acvcvA1+uRNkiIvJytX+CrIiIHJASgYhIxFkpVvIJm5n1AhuKNk0H+vY7bCLbip+P9bgD2D7JkEeL5WCPG2vfePWcSJ1LUccDxXgwx020nlP5sxxrf1T/Zg/ls9Xf7MQVv99Cdx//tEt3n3I34KpD2Vb8/ACPV4YR38EeN9a+8eo5kTqXoo7lrudU/iwn8rnVQj1L8VmOVU/9zZa+jsW3qdo1dOshbrt1Ao9LYaLvd6Djxto3Xj0nWudSKGc9p/JnOdb+qP7NHupnWwrVUs9qqOOIKdE1VE5mttLduysdR5iiUEdQPWtJFOoIlavnVG0RhCkK1zREoY6getaSKNQRKlRPtQhERCJOLQIRkYir2URgZlebWY+ZrTmE164wsyfM7Dkz+49gTqTCvkvM7Bkze9LM/q20UR+8MOppZl8xs81m9lhwe2fpIz/oWEP5PIP9nzMzN7OO0kV88EL6LL9mZo8Hn+MdZjav9JEfdKxh1PNKM3s6qOvNZjaj9JEfdKxh1PO84LsnZ2alG0soxSlZ1XgDTgGOB9YcwmsfAk4CDLgdeEew/S3An4CG4PnsGq3nV4DPVbpuYdcz2Hc48Efy16l01FodgbaiYy4FflSLnyXwNiARPP4W8K0aredrgOXAXUB3qWKt2RaBu98D7CzeZmZLzewPZvaImd1rZq/e/3VmNpf8f577Pf8v/wvgnGD3J4FvuvtQUEZPuLUYX0j1rDoh1vN7wBfIr59dUWHU0d37iw5toXbreYe7Z4JDHwDmh1uL8YVUz6fc/ZlSx1qziWAMVwGXuPsK4HPA/x7lmC5gU9HzTcE2gCOAN5vZg2Z2t5m9PtRoD91k6wnwqaCZfbXlFxOqRpOqp5m9B9js7qvDDnQSJv1ZmtkVZvYi8CHgSyHGOhml+JstuJD8r+hqVMp6lkyl1iwuOzNrJb8AzvVFXcQNox06yrbCr6gE0A6cCLwe+I2ZLQmydlUoUT3/D/C14PnXgO+Q/89VNSZbTzNrBi4n36VQlUr0WeLulwOXm9k/A5+iymb+LVU9g/e6HMgAvypljKVQynqWWmQSAfnWz253P7Z4o5nFgUeCp7eQ/xIsblbOB7YEjzcBNwVf/A+ZWY783CDVtJbmpOvp7tuKXvdj4LYwAz5Ek63nUvLLpa4O/lPOBx41sxPc/aWQY5+oUvzNFrsW+B1VlggoUT3N7ALg3cDp1fTjrEipP8/SqfSASpg3YBFFAzXA34DzgscGHDPG6x4m/6u/MFDzzmD7J4CvBo+PAF4kuBajxuo5t+iYzwC/rnQdw6jnfsesp8KDxSF9lsuKjrkEuKHSdQypnmcBa4HOStctzHoW7b+LEg4WV/wfKsQP4DpgK5Am/0v+IvK/AP8ArA7+aL40xmu7gTXA88APC1/2QD3wy2Dfo8Bba7Se/4/86nGPk/+FMrdc9SlnPfc7puKJIKTP8sZg++Pk56DpqsXPEniO/A+zx4JbNZwdFUY9zw3eawjYBvyxFLHqymIRkYiL2llDIiKyHyUCEZGIUyIQEYk4JQIRkYhTIhARiTglApmSzGxvmcv7iZkdWaL3ygazga4xs1vHmynTzGaY2f8sRdkio9HpozIlmdled28t4fslfN+kZaEqjt3Mfg486+5XHOD4RcBt7n5UOeKT6FGLQGqGmXWa2Y1m9nBwe2Ow/QQz+5uZrQrulwfbP2pm15vZrcAdZnaamd1lZjcEc9v/qmge+LsK87+b2d5gIrfVZvaAmc0Jti8Nnj9sZl+dYKvlfvZNgtdqZn82s0ctPxf92cEx3wSWBq2IK4NjPx+U87iZ/WsJ/xklgpQIpJb8O/A9d3898D7gJ8H2p4FT3P048rNvfr3oNScBF7j7W4PnxwGXAUcCS4A3jlJOC/CAux8D3AN8vKj8fw/KH3dumGCOmdPJX70NkALOdffjya998Z0gEX0ReN7dj3X3z5vZ24BlwAnAscAKMztlvPJExhKlSeek9p0BHFk0s2ObmU0DpgM/N7Nl5GdxrCt6zZ3uXjxn/EPuvgnAzB4jP1fMf+1XzjD7JuJ7BDgzeHwS+9Y6uBb49hhxNhW99yPAncF2A74efKnnyLcU5ozy+rcFt1XB81byieGeMcoTOSAlAqklMeAkd08WbzSzHwB/dfdzg/72u4p2D+z3HkNFj7OM/n8k7fsG18Y65kCS7n6smU0nn1D+EfgP8usFdAIr3D1tZuuBxlFeb8A33P3/HmS5IqNS15DUkjvIz7cPgJkVpvudDmwOHn80xPIfIN8lBfDB8Q529z7yy0d+zszqyMfZEySBtwALg0P3ANOKXvpH4MJgfnvMrMvMZpeoDhJBSgQyVTWb2aai22fJf6l2BwOoa8lPGw7wb8A3zOw+IB5iTJcBnzWzh4C5QN94L3D3VeRnovwg+cVUus1sJfnWwdPBMTt9BbZIAAAAcElEQVSA+4LTTa909zvIdz3db2ZPADfw8kQhclB0+qhIiQSrniXd3c3sg8D57n72eK8TqTSNEYiUzgrgh8GZPrupsuU9RcaiFoGISMRpjEBEJOKUCEREIk6JQEQk4pQIREQiTolARCTilAhERCLu/wMXM/+Y9ZMneAAAAABJRU5ErkJggg==\n",
"text/plain": [
"
\n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"inputs, targets, outputs = get_predictions(learn)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos qui a le pouvoir de modifier le règlement sur les poids et mesures et le règlement sur l'inspection de l'électricité et du gaz ?,\n",
" Text xxbos who has the authority to change the electricity and gas inspection regulations and the weights and measures regulations ?,\n",
" Text xxbos who has the xxunk and xxunk and xxunk xxunk ?)"
]
},
"execution_count": 78,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[700],targets[700],outputs[700]"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos quelles sont les deux tendances qui ont nuit à la pêche au saumon dans cette province ?,\n",
" Text xxbos what two trends negatively affected the province ’s salmon fishery ?,\n",
" Text xxbos what are the main reasons for the xxunk of the xxunk ?)"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[2513], targets[2513], outputs[2513]"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Text xxbos où les aires marines nationales de conservation du canada seront - elles situées ?,\n",
" Text xxbos where will national marine conservation areas of canada be located ?,\n",
" Text xxbos where are the canadian regulations located in the canadian ?)"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs[4000], targets[4000], outputs[4000]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}