{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Summary of my results:\n", "\n", "model | train_loss | valid_loss | seq2seq_acc | bleu\n", "-------------------|----------|----------|----------|----------\n", "seq2seq | 3.355085 | 4.272877 | 0.382089 | 0.291899\n", "\\+ teacher forcing | 3.154585 |\t4.022432 | 0.407792 | 0.310715\n", "\\+ attention | 1.452292 | 3.420485 | 0.498205 | 0.413232\n", "transformer | 1.913152 | 2.349686 | 0.781749 | 0.612880" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Seq2Seq Translation with Attention" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Attention is a technique that uses the output of our encoder: instead of discarding it entirely, we use it with our hidden state to pay attention to specific words in the input sentence for the predictions in the output sentence. Specifically, we compute attention weights, then add to the input of the decoder the linear combination of the output of the encoder, with those attention weights." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A nice illustration of attention comes from [this blog post](http://jalammar.github.io/illustrated-transformer/) by Jay Alammar (visualization originally from [Tensor2Tensor notebook](https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb)):\n", "\n", "\"attention\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A second things that might help is to use a bidirectional model for the encoder. We set the `bidrectional` parameter to `True` for our GRU encoder, and double the number of inputs to the linear output layer of the encoder.\n", "\n", "Also, we now need to set our hidden state:\n", "\n", "```\n", "hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous()\n", "hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid))\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Code to re-run from start" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from fastai.text import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "path = Config().data_path()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "path = Config().data_path()/'giga-fren'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def seq2seq_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=True, backwards:bool=False) -> Tuple[LongTensor, LongTensor]:\n", " \"Function that collect samples and adds padding. Flips token order if needed\"\n", " samples = to_data(samples)\n", " max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples])\n", " res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx\n", " res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx\n", " if backwards: pad_first = not pad_first\n", " for i,s in enumerate(samples):\n", " if pad_first: \n", " res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])\n", " else: \n", " res_x[i,:len(s[0])],res_y[i,:len(s[1])] = LongTensor(s[0]),LongTensor(s[1])\n", " if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1)\n", " return res_x,res_y" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqDataBunch(TextDataBunch):\n", " \"Create a `TextDataBunch` suitable for training an RNN classifier.\"\n", " @classmethod\n", " def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1,\n", " dl_tfms=None, pad_first=False, device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch:\n", " \"Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`\"\n", " datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n", " val_bs = ifnone(val_bs, bs)\n", " collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)\n", " train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2)\n", " train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)\n", " dataloaders = [train_dl]\n", " for ds in datasets[1:]:\n", " lengths = [len(t) for t in ds.x.items]\n", " sampler = SortSampler(ds.x, key=lengths.__getitem__)\n", " dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))\n", " return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqTextList(TextList):\n", " _bunch = Seq2SeqDataBunch\n", " _label_cls = TextList" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "data = load_data(path)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "model_path = Config().model_path()" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "emb_enc = torch.load(model_path/'fr_emb.pth')\n", "emb_dec = torch.load(model_path/'en_emb.pth')" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def seq2seq_loss(out, targ, pad_idx=1):\n", " bs,targ_len = targ.size()\n", " _,out_len,vs = out.size()\n", " if targ_len>out_len: out = F.pad(out, (0,0,0,targ_len-out_len,0,0), value=pad_idx)\n", " if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)\n", " return CrossEntropyFlat()(out, targ)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def seq2seq_acc(out, targ, pad_idx=1):\n", " bs,targ_len = targ.size()\n", " _,out_len,vs = out.size()\n", " if targ_len>out_len: out = F.pad(out, (0,0,0,targ_len-out_len,0,0), value=pad_idx)\n", " if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)\n", " out = out.argmax(2)\n", " return (out==targ).float().mean()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "class NGram():\n", " def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n\n", " def __eq__(self, other):\n", " if len(self.ngram) != len(other.ngram): return False\n", " return np.all(np.array(self.ngram) == np.array(other.ngram))\n", " def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def get_grams(x, n, max_n=5000):\n", " return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)]" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "def get_correct_ngrams(pred, targ, n, max_n=5000):\n", " pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n)\n", " pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams)\n", " return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "def get_predictions(learn, ds_type=DatasetType.Valid):\n", " learn.model.eval()\n", " inputs, targets, outputs = [],[],[]\n", " with torch.no_grad():\n", " for xb,yb in progress_bar(learn.dl(ds_type)):\n", " out = learn.model(xb)\n", " for x,y,z in zip(xb,yb,out):\n", " inputs.append(learn.data.train_ds.x.reconstruct(x))\n", " targets.append(learn.data.train_ds.y.reconstruct(y))\n", " outputs.append(learn.data.train_ds.y.reconstruct(z.argmax(1)))\n", " return inputs, targets, outputs" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "class CorpusBLEU(Callback):\n", " def __init__(self, vocab_sz):\n", " self.vocab_sz = vocab_sz\n", " self.name = 'bleu'\n", " \n", " def on_epoch_begin(self, **kwargs):\n", " self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4\n", " \n", " def on_batch_end(self, last_output, last_target, **kwargs):\n", " last_output = last_output.argmax(dim=-1)\n", " for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()):\n", " self.pred_len += len(pred)\n", " self.targ_len += len(targ)\n", " for i in range(4):\n", " c,t = get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)\n", " self.corrects[i] += c\n", " self.counts[i] += t\n", " \n", " def on_epoch_end(self, last_metrics, **kwargs):\n", " precs = [c/t for c,t in zip(self.corrects,self.counts)]\n", " len_penalty = exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1\n", " bleu = len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25)\n", " return add_metrics(last_metrics, bleu)" ] }, { "cell_type": "code", "execution_count": 136, "metadata": {}, "outputs": [], "source": [ "class TeacherForcing(LearnerCallback):\n", " def __init__(self, learn, end_epoch):\n", " super().__init__(learn)\n", " self.end_epoch = end_epoch\n", " \n", " def on_batch_begin(self, last_input, last_target, train, **kwargs):\n", " if train: return {'last_input': [last_input, last_target]}\n", " \n", " def on_epoch_begin(self, epoch, **kwargs):\n", " self.learn.model.pr_force = 1 - epoch/self.end_epoch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Implementing attention" ] }, { "cell_type": "code", "execution_count": 137, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqRNN_attn(nn.Module):\n", " def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1):\n", " super().__init__()\n", " self.nl,self.nh,self.out_sl,self.pr_force = nl,nh,out_sl,1\n", " self.bos_idx,self.pad_idx = bos_idx,pad_idx\n", " self.emb_enc,self.emb_dec = emb_enc,emb_dec\n", " self.emb_sz_enc,self.emb_sz_dec = emb_enc.embedding_dim,emb_dec.embedding_dim\n", " self.voc_sz_dec = emb_dec.num_embeddings\n", " \n", " self.emb_enc_drop = nn.Dropout(0.15)\n", " self.gru_enc = nn.GRU(self.emb_sz_enc, nh, num_layers=nl, dropout=0.25, \n", " batch_first=True, bidirectional=True)\n", " self.out_enc = nn.Linear(2*nh, self.emb_sz_dec, bias=False)\n", " \n", " self.gru_dec = nn.GRU(self.emb_sz_dec + 2*nh, self.emb_sz_dec, num_layers=nl,\n", " dropout=0.1, batch_first=True)\n", " self.out_drop = nn.Dropout(0.35)\n", " self.out = nn.Linear(self.emb_sz_dec, self.voc_sz_dec)\n", " self.out.weight.data = self.emb_dec.weight.data\n", " \n", " self.enc_att = nn.Linear(2*nh, self.emb_sz_dec, bias=False)\n", " self.hid_att = nn.Linear(self.emb_sz_dec, self.emb_sz_dec)\n", " self.V = self.init_param(self.emb_sz_dec)\n", " \n", " def encoder(self, bs, inp):\n", " h = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " enc_out, hid = self.gru_enc(emb, 2*h)\n", " \n", " pre_hid = hid.view(2, self.nl, bs, self.nh).permute(1,2,0,3).contiguous()\n", " pre_hid = pre_hid.view(self.nl, bs, 2*self.nh)\n", " hid = self.out_enc(pre_hid)\n", " return hid,enc_out\n", " \n", " def decoder(self, dec_inp, hid, enc_att, enc_out):\n", " hid_att = self.hid_att(hid[-1])\n", " # we have put enc_out and hid through linear layers\n", " u = torch.tanh(enc_att + hid_att[:,None])\n", " # we want to learn the importance of each time step\n", " attn_wgts = F.softmax(u @ self.V, 1)\n", " # weighted average of enc_out (which is the output at every time step)\n", " ctx = (attn_wgts[...,None] * enc_out).sum(1)\n", " emb = self.emb_dec(dec_inp)\n", " # concatenate decoder embedding with context (we could have just\n", " # used the hidden state that came out of the decoder, if we weren't\n", " # using attention)\n", " outp, hid = self.gru_dec(torch.cat([emb, ctx], 1)[:,None], hid)\n", " outp = self.out(self.out_drop(outp[:,0]))\n", " return hid, outp\n", " \n", " def show(self, nm,v):\n", " if False: print(f\"{nm}={v[nm].shape}\")\n", " \n", " def forward(self, inp, targ=None):\n", " bs, sl = inp.size()\n", " hid,enc_out = self.encoder(bs, inp)\n", "# self.show(\"hid\",vars())\n", " dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n", " enc_att = self.enc_att(enc_out)\n", " \n", " res = []\n", " for i in range(self.out_sl):\n", " hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out)\n", " res.append(outp)\n", " dec_inp = outp.max(1)[1]\n", " if (dec_inp==self.pad_idx).all(): break\n", " if (targ is not None) and (random.random()=targ.shape[1]: continue\n", " dec_inp = targ[:,i]\n", " return torch.stack(res, dim=1)\n", "\n", " def initHidden(self, bs): return one_param(self).new_zeros(2*self.nl, bs, self.nh)\n", " def init_param(self, *sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```\n", "hid=torch.Size([2, 64, 300])\n", "dec_inp=torch.Size([64])\n", "enc_att=torch.Size([64, 30, 300])\n", "hid_att=torch.Size([64, 300])\n", "u=torch.Size([64, 30, 300])\n", "attn_wgts=torch.Size([64, 30])\n", "enc_out=torch.Size([64, 30, 512])\n", "ctx=torch.Size([64, 512])\n", "emb=torch.Size([64, 300])\n", "```" ] }, { "cell_type": "code", "execution_count": 142, "metadata": { "scrolled": true }, "outputs": [], "source": [ "model = Seq2SeqRNN_attn(emb_enc, emb_dec, 256, 30)\n", "learn = Learner(data, model, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))],\n", " callback_fns=partial(TeacherForcing, end_epoch=30))" ] }, { "cell_type": "code", "execution_count": 55, "metadata": { "collapsed": true }, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEKCAYAAAAfGVI8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xl8VPW9//HXZyaTPSFAEpawCggiQpBgRVu1Um2v2ipavWqt+/W2drG1y21/vb/eLr/u7e2qtbZVrNdq3WirXitqXVoBJSwCZXMhQCJZWLJnkpnJ9/fHTDDFAAEyc2Z5Px/MIzNnzsz5fJkk75zz/Z7vMeccIiKSuXxeFyAiIt5SEIiIZDgFgYhIhlMQiIhkOAWBiEiGUxCIiGQ4BYGISIZTEIiIZLi4BYGZ3WVmjWa2od+yH5jZZjNbZ2ZLzKwkXtsXEZHBsXidWWxmZwDtwO+cc7Niy84F/uqcC5vZ9wCcc/9xuPcqLS11kyZNikudIiLpatWqVbudc2WHWy8rXgU45140s0kHLFva7+EK4MODea9JkyZRXV09dMWJiGQAM9s+mPW87CO4HnjyYE+a2U1mVm1m1U1NTQksS0Qks3gSBGb2FSAM3HewdZxzdzrnqpxzVWVlh92zERGRoxS3Q0MHY2bXABcAC52mPhUR8VxCg8DMPgD8B3Cmc64zkdsWEZGBxXP46P3AcmC6mdWa2Q3AL4Ai4GkzW2tmd8Rr+yIiMjjxHDV0xQCLfxuv7YmIyNHRmcUiIhlOQSAikoT2dvTw3Sc382ZTe9y3pSAQEUlCr9Y2c8cLb9DY1h33bSkIRESS0PraFszgxLHFcd+WgkBEJAmtr2thcmkBRbmBuG9LQSAikoQ21LVwUsWwhGxLQSAikmSa2rrZ1RJUEIiIZKoNdS0ACgIRkUy1vi7WUawgEBHJTH0dxYU5iZkOTkEgIpJk1te2MDtBewOgIBARSSpNbd3UtwaZpSAQEclMie4oBgWBiEhSSXRHMSgIRESSyrraFo5LYEcxKAhERJJKIs8o7qMgEBFJEo1twYR3FIOCQEQkafR1FM8eV5LQ7SoIRESSxPra1oRNPd2fgkBEJEmsr4t2FBcksKMYFAQiIknDi45iUBCIiCSFvo7ikxLcPwAKAhGRpODFGcV9FAQiIknAq45iUBCIiCQFrzqKQUEgIpIUduztYEpZoSfbVhCIiCSBhtZuRg/L9WTbCgIREY8FQxFaukKMKlYQiIhkpIbWIICCQEQkU9W39AVBjifbVxCIiHisoa0bgNHaIxARyUwNsT2CcgWBiEhmamgNkhfwU5yb+HMIQEEgIuK5+tYgo4pzMDNPth+3IDCzu8ys0cw29Ft2qZn9w8x6zawqXtsWEUklja3dno0YgvjuESwGPnDAsg3AxcCLcdyuiEhKie4ReBcEcTsg5Zx70cwmHbBsE+DZ7o+ISLJxztHQGvTsrGJI4j4CM7vJzKrNrLqpqcnrckRE4qKlK0R3uJfyIm/OIYAkDgLn3J3OuSrnXFVZWZnX5YiIxEVDa+wcAu0RiIhkpnqPp5cABYGIiKf65hny6qxiiO/w0fuB5cB0M6s1sxvMbJGZ1QILgCfM7Kl4bV9EJBX0nVVc5mEfQTxHDV1xkKeWxGubIiKppqEtyPD8ALkBv2c16NCQiIiH6lu8PZkMFAQiIp5qbPP2ZDJQEIiIeKq+JejZdQj6KAhERDwSjvSyu12HhkREMtbu9h56nbfnEICCQETEM15fq7iPgkBExCP1SXAyGSgIREQ809jq7UXr+ygIREQ8Ut8axO8zRhYqCEREMlJDazdlhTn4fd5eo0VBICLikYbWIKM8nH66j4JARMQjDa1BRnk42VwfBYGIiEcaWrs9vSBNHwWBiIgHgqEILV0hz88hAAWBiIgnkuVkMlAQiIh4or4lOc4hAAWBiIgnGtpiF63XHoGISGbqu0RluYJARCQzNbQGyQv4Kc6N2xWDB01BICLigfrW6AVpzLw9qxgUBCIinmhs9f6CNH0UBCIiHojuESgIREQyknOOhtZgUpxVDAoCEZGEa+kK0R3upTwJ5hkCBYGISMI1tMbOIdAegYhIZnqruQtIjuklQEEgIpJwL2xtIjvLxwljir0uBVAQiIgkVKTX8b/rd/He6WUU5nh/MhkoCEREEqq6Zi+Nbd2cP3us16XspyAQEUmgJ9bvIjfgY+GMcq9L2U9BICKSINHDQvWcPaOcgiQ5LAQKAhGRhHll2152t3dz/knJc1gIFAQiIgnzxPq3yAv4ee+MMq9L+SdxCwIzu8vMGs1sQ79lI8zsaTN7LfZ1eLy2LyKSTMKRXp5cX8/ZJ5STn508h4UgvnsEi4EPHLDsS8CzzrlpwLOxxyIiae/lbXvZ09HDBSeN8bqUd4hbEDjnXgT2HrD4QuCe2P17gIvitX0RkWTy+Lpd5Gf7OWt68owW6pPoPoJRzrldALGvyfc/IiIyxMKRXv6yYRcLTxhFXrbf63LeIWk7i83sJjOrNrPqpqYmr8sRETlqy9/cw77OEOcn4WEhSHwQNJjZGIDY18aDreicu9M5V+WcqyorS64edhGRI/HEul0UZPs5a3py/i5LdBD8Gbgmdv8a4E8J3r6ISEJ1hyM8uaGec2aOIjeQfIeFIL7DR+8HlgPTzazWzG4AvgucY2avAefEHouIpK3nNjfS0hXiorkVXpdyUHEbzOqcu+IgTy2M1zZFRJLNo6vrKC3M4d1TS70u5aCStrNYRCTV7evo4bktjVxYOZYsf/L+uk3eykREUtzj63cRijgWJfFhIVAQiIjEzZLVtUwrL+TEsclxJbKDURCIiMRBze4OVu9oZtHJFZiZ1+UckoJARCQO/ri2DjO4qDK5DwuBgkBEZMg551iypo5TJ49kbEme1+UcloJARGSIrd7RzPY9nSw6Ofn3BmCQQWBmU8wsJ3b/LDP7tJmVxLc0EZHUtGRNLTlZPv5l1mivSxmUwe4RPAJEzGwq8FtgMvD7uFUlIpKiesK9PL5uF+eeOJqi3IDX5QzKYIOg1zkXBhYBP3HOfRZIzmn0REQ89OLWJpo7Q1yc5OcO9DfYIAiZ2RVEJ4p7PLYsNaJORCSBNrzVAsCCKSM9rmTwBhsE1wELgG8557aZ2WTgf+JXlohIatq+p5Oxw3KTdqbRgQxq0jnn3Ebg0wCxC84XOec0c6iIyAG27e5gUmmB12UckcGOGnrezIrNbATwKnC3mf13fEsTEUk9NXvSNAiAYc65VuBi4G7n3DzgffErS0Qk9TR39tDcGWLSyHyvSzkigw2CrNilJS/j7c5iERHpp2ZPJwCTRqbnHsE3gKeAN5xzK83sOOC1+JUlIpJ6anZ3ADA5xQ4NDbaz+CHgoX6P3wQuiVdRIiKpaNvuDsxg/Ig0PDRkZuPMbImZNZpZg5k9Ymbj4l2ciEgqqdnTwdhheSk1dBQGf2jobuDPwFigAngstkxERGJqdnek3GEhGHwQlDnn7nbOhWO3xUBZHOsSEUk5NXs6mZhiI4Zg8EGw28yuMjN/7HYVsCeehYmIpJJ9HT20dIXSeo/geqJDR+uBXcCHiU47ISIiwLY90RFDqTZ0FAYZBM65Hc65Dznnypxz5c65i4ieXCYiIrw9dDTVziqGY7tC2a1DVoWISIqr2dOJz2D8iOS/NOWBjiUIbMiqEBFJcTW7OxhbkkdOVmoNHYVjCwI3ZFWIiKS4mj2pOXQUDnNmsZm1MfAvfANSb/9HRCQOnHNs293BRZWpc1Wy/g4ZBM65okQVIiKSqvZ29NAWDKdkRzEc26EhERHh7VlHJ5em3slkoCAQETlmfUNHJ6bgOQSgIBAROWY1ezqiQ0eHa49ARCQjbdvdwbjh+WRnpeav1NSsWkQkiWzf05myHcXgURCY2S1mtsHM/mFmn/GiBhGRoeCci04/nYKzjvZJeBCY2Szg34BTgDnABWY2LR7bqt3XyRtN7bR3h+Px9iIi7Onooa07nLIdxTDIS1UOsROAFc65TgAzewFYBHx/qDf0qxfe5N4V2wHIz/YzqjiXsqIc8rP9ZPl8ZGcZWT4fZtDZE6GrJ0JnT5jOnggAAb+PLL9Fv/oM58DhcLFT7HxmZPkNv8/I8kXfK5DlI9sffe9svw+z6Ewcvc7tf31Olp+8gJ/cgI/cgJ/CnCxK8rMZnh9geEE2JXkB8rL95Ab8ZPls/3uISPJJ1esU9+dFEGwAvmVmI4Eu4DygOh4b+sipE5g3cTgNrUEaWrtpaAvS1NrN3o4eQhFHKNJLONKLA/ICfvKz/eRnZzGiIAczYs9H1+sJ92IGhhH7R6TX0R2OEO51hCOOcG8voYijJ9xLT+w1vc7hM8MsGhwA3aEIXaEIvYOYpMNnkBvwk53lw2cWu0XfqzA3i2F5gf234twsinIDFP3T17fvF+Zk7f+qcBEZGttSeNbRPgkPAufcJjP7HvA00A68Crzj2I2Z3QTcBDBhwoSj2taM0cXMGF189MXGkXOOUMQRDEdoD4bZ19lDc2do/9dgKBK79RIMReiJREOl10VfG444OnrCtHSFaGwL8lpjGy2dIdq7w4cNmIDfGJYX2wPJz2ZEQTalRdmMLMihtCiHssJsRg/LY8ywXEoLc/D7FBoiB7N9Tyd+nzFueOrOuuPFHgHOud8CvwUws28DtQOscydwJ0BVVVXaTXBnZtHDR1k+inMDjC0Zmm8i5xwdPRHagiHagmHagmHau8O0B8O0BUO0BkPs6wzR3NnDvo4Qezt7eKOpnZe3dbOvM/SO9/P7jFFFOYwbns+EkflMHBH9Orm0gKnlheRne/ItJJI0tu3pYPzwPAL+1B2E6clPsZmVO+cazWwC0QvcLPCijnRkZhTmRA//jBl2ZK8NR3rZ29lDY2s3Da1BdrUE2dXSxa7mILX7uvjba0083Nrdb1swYUQ+00cVMWN0EdNHFzN9dBGTRuaTlcI/FCJHomZ3R0ofFgKPggB4JNZHEAI+4Zzb51Ed0k+W30d5US7lRbnMqhg4Rbp6IuzY28m23R1sbWhjS30bm+tbeWZTw/5DUtlZPqaWFTJjdBGzKoYxe9wwZo4t1t6DpJ2WrhCvN7Zz6nEjvS7lmHh1aOg9XmxXjl1etp/po4uYPrqID8wavX95MBTh9cZ2ttS3saWhjc31bfzt9d08uqYOiHZ6Tysv4qRxw6gcX0Ll+BJmjC7SnoOktIdX1dId7mXR3NScfrqP/kSTIZEb8DOrYtg79iQaWoOsr21hXV0L62qb+evmRh5eVRt7jY/ZFSWcM3MUF1aOpbw414vSRY5Kb6/j3uU1zJs4/KB70KlCQSBxNao4l1Ezc3nfzFFAtDN7594u1uzcx9qdzays2cu3/ncT33lyE6dPLeWiygo+MGs0BTn61pTk9sJrTdTs6eTWc6d7XcoxM+eSf0BOVVWVq66Oy6kGkgTeaGrnT2vqWLK2jp17uyjMyeLSqnFcd9pkJqTwafuS3q67+xU2vNXKS/9xdtJONmdmq5xzVYdbT392ieemlBVy67nT+ew5x1O9fR/3rdjOvcu3s3hZDeecMIob3j2ZUyaP0ElwkjRqdnfw/NYmblk4LWlD4EgoCCRpmBnzJ41g/qQRfPm8E/jd8hrue3kHSzc2UDm+hFsWTuOs6WUKBPHcvSu24zfjylOO7mTXZJP6USZpaVRxLl94/wyWf2kh/++iWTS1dXPd4pV86BcvsfQf9aTCIU1JTx3dYR6s3sl5J41JmwEOCgJJannZfq46dSLPf+Esvn/JbFqDIW66dxXn/ezvPKVAEA/8cW0dbcEw15w20etShoyCQFJCwO/jsvnjefbWM/nxv84hGIrw7/eu4oO/+Dt/3dygQJCEcM5xz7IaZlUUc/KE4V6XM2QUBJJSsvw+Fs0dx9OfPYMfXjqH1q4w1y+uZtHty1j+xh6vy5M0t+LNvWxtaOfqBZPSqq9KQSApKcvv48PzxvHs587kuxefRGNrkCt+vYJP/H41bzV3eV2epKmlG+vJDfj40JyxXpcypBQEktICfh+XnzKBv37+LD77vuN5ZmMDZ//oeX7+7GsEQxGvy5M0s3ZnM7MrSsgN+L0uZUgpCCQt5Ab83PK+aTz7uTN57/RyfvT0Vt7/kxfZXN/qdWmSJnrCvfzjrVYqJ5R4XcqQUxBIWhk3PJ9fXjWP+258F8FQhEt/uZy/v7bb67IkDWza1UpPuJfK8QoCkZRw+tRSltx8OhXD87j27ld4cOVOr0uSFLd2ZzOAgkAklYwtyeOhjy1gwZSRfPGRdfxo6RYNM5WjtnZnM+VFOYwZlh4nkfWnIJC0VpQb4K5r53P5/PH8/K+vc+uDrxKK9HpdlqSgNTv2UTm+JK2GjfbRXEOS9gJ+H9+5+CQqSvL40dNbaekKcduVJ5OXnV4jPyR+9nX0ULOnk8vmj/e6lLjQHoFkBDPjUwun8a1Fs3huSyNX/fZlWjpDXpclKWJtbbR/YO749DmbuD8FgWSUj7xrIrddeTLra1u47FfLaWgNel2SpIC1O5rxGcwel9pXIjsYBYFknPNOGsPd182ndl8nl/xyGW80tXtdkiS5tTubOX5UUdpeOU9BIBnp9Kml3H/TqXT1RLj49mW8sm2v1yVJknLO8Wptc1oOG+2jIJCMNXtcCUtuPp2Rhdlc9ZuX+dPaOq9LkiRUs6eT5s6QgkAkXU0Ymc+jHz+Nygkl3PLAWm577nWdayD/ZO3OfQBpObVEHwWBZLyS/GzuveEULqwcyw+e2sLXH9uoMJD91u5opiDbz7TyIq9LiZv07PkQOUI5WX5+fFklIwqyufulGo4rK+DqBZO8LkuSwNqdzZw0bhh+X/qdSNZHewQiMT6f8Z/nz2ThjHK+/thGlr2uyeoyXTAUYeOuVirT9PyBPgoCkX78PuMnl1cypayAm3+/mu17OrwuSTy0cVcroYhL645iUBCIvENRboBfX10FwA33VNMW1BnImWrtjtgZxWncUQwKApEBTRxZwO1Xnsy23R3c8sBaIr3qPM5Ea3c2M2ZYLqOK02/G0f4UBCIHcdrUUr72wZn8dXMjN9+3iq4eXfoy06zdmd4nkvXRqCGRQ/jogkmEex3feHwjV/x6Bb+5porSwpyE1+Gco3ZfF6u272P1jn30hHsZlh9geH42JXkBxg3P5/SpI9NyimSv1DV3sWNvJ1cvmOh1KXGnIBA5jOtOn8zYkjxueWANi25/icXXncKUssK4ba87HGHb7g5eb2zn9cZ2Nu1qZfWOZpraugEoyPZTkJNFc2eInn7XVrj2tEl89YKZ+NJ4mGMiPf2PegAWnjDK40riT0EgMgjvP3E0D9y0gBvvWcnFty/jx/86hzOPLz+mseW9vY665i427Wplc30bm+tb2byrjZo9HfR1SZjBhBH5vHtqKSdPHM68CcOZProIv89wztEVirCvM8Tdf9/Gb/6+ja6eCN+++KS0HvOeKE9vamBqeSGTSwu8LiXuFAQig1Q5voRHP3461y5+hesXV1NamM05M0dx7omjOW3KSHKy/vlCN729jrbuMC2dIZq7etjb0cMbTR1srW9jS0MbrzW00RHrd+j7hT99VBHnzx7D1PJCppYXclxp4UEvoGNm5GdnkZ+dxVfOP4H8nCx+9uxrBMMRfnjpHAJ+dQEerZauEC+/uZd/O+M4r0tJCE+CwMw+C9wIOGA9cJ1zThPDS9KbMDKfJz71Hp7Z1MBT/6jnz2vf4v5XdpKf7Sc/O4tIby/hiCPc6+gORxhosNGIgmymjyri0qrxHD+qiBPGFB3zFMdmxq3nHE9ewM/3/rKZrp4IP79y7jvCSQbn+S2NhHsd58xM/8NC4EEQmFkF8GlgpnOuy8weBC4HFie6FpGjkZft54NzxvLBOWPpDkdY9voeXtjaRE+klyyf4fcZAb+PnCwfw/IClMQ6dEvyA0wqLYhrZ/PHz5pCXsDH1x7byI33VHPHVfPSdg79eFq6sYGyohwqx6X/iCHw7tBQFpBnZiEgH3jLozpEjklOlp/3zijnvTPKvS5lv2tPn0x+ThZfemQdV/56BXddO5+RHox0SlXd4QgvbGnig3PGZEzHe8IPIjrn6oAfAjuAXUCLc25pousQSWeXVY3nVx+tYnN9Gx++Yzk793Z6XVLKWPHmXtq7wxlzWAg8CAIzGw5cCEwGxgIFZnbVAOvdZGbVZlbd1NSU6DJFUt45M0dx343vYk97N5f8chmbdrV6XVJKeHpjPfnZfk6bUup1KQnjxbCC9wHbnHNNzrkQ8Chw2oErOefudM5VOeeqysrKEl6kSDqomjSChz9+Gj4zLvvVcpbGxsbLwJxzPLOxkTOmlZEbyJyOdi+CYAdwqpnlW/Q0yIXAJg/qEMkIx48q4pGbT2PCiHxuuncV//ePGwiGNF3GQNbXtVDfGsyow0LgTR/By8DDwGqiQ0d9wJ2JrkMkk1SU5PHozadx47snc++K7Vz4i5fYUt/mdVlJ5+mNDfh9xtlJ1PmfCJYKl+Srqqpy1dXVXpchkhae39LI5x96lbZgmCtOmcDoYbn7h7cW5QYI9zqCocj+W0d3hOauEM2dPTR3hugKRfj3M46jatIIr5sy5N7/4xcpyQ/wh39f4HUpQ8LMVjnnqg63ngYYi2SYs6aX8+QtZ/DlR9dx38vbCUUO/8egGRTnRsOiPRjmusUreehjC5gxujgBFSfGjj2dbGlo4z/PP8HrUhJOQSCSgcqKcvjNNfP3z1fU3BmiuTNEazBEwG/kZPnJDfjJDfgoyM6iOC+wf/6iuuYuLr79Ja656xUevfl0KkryPG7N0Hgq1pF+7szRHleSeAoCkQzWf76isYP8hV5Rksc915/CpXcs5+rfvszDHzuN4QXZca40voKhCL/9+zbmTRzOhJH5XpeTcJqVSkSO2IzRxfz66ip27u3ixt9Vp/wopHuW1VDfGuSL75/udSme0B6BiByVU48byU8ur+QTv1/NotuXUVqYTVdPhI6eaCdzTpaP4twAxXkBivOymDiigBvfMznp5j5q6Qpx+/NvcNb0Mt513Eivy/GE9ghE5Kidd9IYvnfxbADagmFyAj4qSvKYVTGMCSPyMYv2Kbz85l5+8uxWzv/Z31i7s9njqv/ZHS+8QWswxBffP8PrUjyTXNEsIinnsvnjuWz++MOut+LNPXzuwVe55JfLuGXhNG4+awpZHl8zoaE1yN0vbePCOWOZOTZ9RkAdKe0RiEhCnHrcSP73lvdwwewx/PfTW/nXO1dQu8/byfB++uxrRHodt56TmX0DfRQEIpIww/IC/PTyufz08kq21rdx/eKVdPV409H8ZlM7f1i5kytPmZCRI4X6UxCISMJdWFnBbR85ma0N7Xzj8Y2e1PCjpVvJyfLxybOnebL9ZKIgEBFPnHF8GR87cwr3v7KDx9cl9tpUbza188T6Xdzw7smUFemiPQoCEfHM5849nsrxJXz5kfUJvXjOkjV1+Aw+eurEhG0zmSkIRMQzAb+Pn18xFww+ef8aQpHeuG+zt9fx6Oo6Tp9aSnlxbty3lwoUBCLiqfEj8vnuxbN5dWczP1y6Je7bW1mzl7rmLi45eVzct5UqFAQi4rnzZ4/hyndN4FcvvMmT63fFdVtL1tSRn+3n3BMz6+Izh6IgEJGk8NULZjJ3QgmffXAt62rjc/ZxMBThiXW7+JdZY8jP1vm0fRQEIpIUcgN+7vxoFSMLcrjxnmp2tXQN+Tae2dRAW3eYi0+uGPL3TmUKAhFJGmVFOdx17Xw6eyLcsLiaju7wkL7/o6vrGF2cy6kZOrncwSgIRCSpTB9dxC+unMvm+lZueWANkd6huZzu7vZuXtjaxEVzK/ZfZEeiFAQiknTOml7Of33wRJ7Z1Mh//nE94SEYVvrntW8R6XU6LDQA9ZaISFK65rRJNLYFue25N9ixt5PbrjyZkvyjvxLakjV1zKoo5vhRRUNYZXrQHoGIJK0vvH8GP/jwbFZu28dFt73E643tR/U+rzW0sb6uhUVzde7AQBQEIpLULq0az/03vYv27jCLbnuJ57Y0HvF73P/KTvw+40NzxsahwtSnIBCRpDdv4gj+9Ml3M25EPjcsXsnvX94x6Ne+9Ppu7l62jUVzKzTB3EEoCEQkJVSU5PHIxxdw5vFl/J8l67ntuddx7tAjiupbgtzywBqmlBXy9Q+dmKBKU4+CQERSRn52FndeXcVFlWP5wVNb+Objm+g9yPDSUKSXT92/ms6eCHdcdTIFORobczD6nxGRlBLw+/jvyyoZXpDNXS9tY19nD9//8GwCB1z/+IdPbWFlzT5+enklU8s1UuhQFAQiknJ8PuOrF8xkZEE2P1y6lfV1Lbx7ainzJg5n/qQRrKtt5lcvvslVp07gwkqdN3A4CgIRSUlmxifPnsa44fn8YeVO/rByJ4uX1QDgM5g9bhj/94KZ3haZIhQEIpLSLppbwUVzKwhFetn4Visra/byWkM7n1o4lZwsv9flpQQFgYikhYDfx5zxJcwZX+J1KSlHo4ZERDKcgkBEJMMpCEREMlzCg8DMppvZ2n63VjP7TKLrEBGRqIR3FjvntgCVAGbmB+qAJYmuQ0REorw+NLQQeMM5t93jOkREMpbXQXA5cL/HNYiIZDTPgsDMsoEPAQ8d5PmbzKzazKqbmpoSW5yISAaxw03jGrcNm10IfMI5d+4g1m0C+h8+Gga0HLDaYJb1f3yw+6XA7sPVdBgD1XKk68WrjcncvoGWD/ZxqnyGAy1Pt+/TgZanWxuT5WfxYLX0meicKzvsOzjnPLkBDwDXHeVr7zyaZf0fH+J+9RC07R21HOl68WpjMrfvcO051ONU+QyPtI2p+H2aCW1Mlp/FI2njoW6eHBoys3zgHODRo3yLx45y2WODuD8UBvt+h1ovmdsYr/YNtHywj1PlMxxoebp9nw60PN3amCw/i0Pyfp4dGkpWZlbtnKvyuo54Sff2gdqYLtK9jcnUPq9HDSWjO70uIM7SvX2gNqaLdG9j0rRPewQiIhlOewQiIhkubYPAzO4ys0Yz23AUr51nZuvN7HUz+5mZWb/nPmVmW8zsH2b2/aGt+ojrHPI2mtnXzKyu31xQ5w195UdUZ1w+x9jznzczZ2alQ1fxkYs0gCKhAAAGP0lEQVTT5/hNM1sX+wyXmtnYoa980DXGo30/MLPNsTYuMTNPL0IQpzZeGvs902tm8e1LGIrhS8l4A84ATgY2HMVrXwEWAAY8CfxLbPl7gWeAnNjj8jRs49eAz3v9+cWzjbHnxgNPET0/pTTd2ggU91vn08Adada+c4Gs2P3vAd9Lw8/wBGA68DxQFc/603aPwDn3IrC3/zIzm2JmfzGzVWb2NzObceDrzGwM0R+i5S76afwOuCj29MeB7zrnumPbaIxvKw4tTm1MKnFs44+BLwKed5LFo43OudZ+qxbgYTvj1L6lzrlwbNUVwLj4tuLQ4tTGTS46SWfcpW0QHMSdwKecc/OAzwO3D7BOBVDb73FtbBnA8cB7zOxlM3vBzObHtdqjc6xtBPhkbJf7LjMbHr9Sj9oxtdHMPgTUOedejXehx+CYP0cz+5aZ7QQ+Anw1jrUejaH4Pu1zPdG/pJPNULYxrjLmmsVmVgicBjzU71BxzkCrDrCs76+pLGA4cCowH3jQzI6LJbnnhqiNvwS+GXv8TeBHRH/QksKxttGiJzN+heihhaQ0RJ8jzrmvAF8xsy8DnwT+a4hLPSpD1b7Ye30FCAP3DWWNx2oo25gIGRMERPd+mp1zlf0XWvSaCKtiD/9M9Bdh/93MccBbsfu1wKOxX/yvmFkv0flCkmVWvGNuo3Ouod/rfg08Hs+Cj8KxtnEKMBl4NfYDOg5YbWanOOfq41z7YA3F92p/vweeIEmCgCFqn5ldA1wALEyWP8b6GerPML687GCJ9w2YRL/OG2AZcGnsvgFzDvK6lUT/6u/rvDkvtvxjwDdi948HdhI7FyON2jim3zqfBR5It8/xgHVq8LizOE6f47R+63wKeDjN2vcBYCNQ5vVnF6829nv+eeLcWez5f14cP5T7gV1AiOhf8jcQ/UvwL8CrsW+irx7ktVXABuAN4Bd9v+yBbOB/Ys+tBs5OwzbeC6wH1hH9i2VMotqTqDYesI7nQRCnz/GR2PJ1ROeiqUiz9r1O9A+xtbGbZ6Oi4tjGRbH36gYagKfiVb/OLBYRyXCZNmpIREQOoCAQEclwCgIRkQynIBARyXAKAhGRDKcgkJRkZu0J3t5vzGzmEL1XJDYr6AYze+xwM2eaWYmZ3TwU2xYZiIaPSkoys3bnXOEQvl+We3sSs7jqX7uZ3QNsdc596xDrTwIed87NSkR9knm0RyBpw8zKzOwRM1sZu50eW36KmS0zszWxr9Njy681s4fM7DFgqZmdZWbPm9nDsbnu7+s3N/zzfXPCm1l7bEK3V81shZmNii2fEnu80sy+Mci9luW8PRleoZk9a2arLTo//YWxdb4LTIntRfwgtu4XYttZZ2ZfH8L/RslACgJJJz8Ffuycmw9cAvwmtnwzcIZzbi7RWTi/3e81C4BrnHNnxx7PBT4DzASOA04fYDsFwArn3BzgReDf+m3/p7HtH3a+mNi8MwuJnsENEAQWOedOJnrtix/FguhLwBvOuUrn3BfM7FxgGnAKUAnMM7MzDrc9kYPJpEnnJP29D5jZb7bHYjMrAoYB95jZNKIzOwb6veZp51z/eeRfcc7VApjZWqLzx/z9gO308PZkfKuAc2L3F/D2NQ9+D/zwIHXm9XvvVcDTseUGfDv2S72X6J7CqAFef27stib2uJBoMLx4kO2JHJKCQNKJD1jgnOvqv9DMfg4855xbFDve/ny/pzsOeI/ufvcjDPwzEnJvd64dbJ1D6XLOVZrZMKKB8gngZ0SvG1AGzHPOhcysBsgd4PUGfMc596sj3K7IgHRoSNLJUqLz7gNgZn1TAA8D6mL3r43j9lcQPSQFcPnhVnbOtRC9jOTnzSxAtM7GWAi8F5gYW7UNKOr30qeA62Nz3mNmFWZWPkRtkAykIJBUlW9mtf1utxL9pVoV60DdSHTacIDvA98xs5cAfxxr+gxwq5m9AowBWg73AufcGqKzU15O9OIqVWZWTXTvYHNsnT3AS7Hhpj9wzi0leuhpuZmtBx7mn4NC5Iho+KjIEIld/azLOefM7HLgCufchYd7nYjX1EcgMnTmAb+IjfRpJoku8SlyKNojEBHJcOojEBHJcAoCEZEMpyAQEclwCgIRkQynIBARyXAKAhGRDPf/AeRJPEVzBxEPAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.lr_find()\n", "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": 141, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossseq2seq_accbleutime
01.8876063.7184300.5565080.34130701:25
11.5065913.1229540.5510540.40287101:36
21.5480553.8600260.4734950.41022601:33
31.7354353.2021520.5299490.44150901:36
41.8085993.7114760.4720570.40811101:32
51.8910133.1419250.5236130.43765001:34
61.9522813.4026860.4853070.41382601:29
72.0963823.7903610.4421900.37921801:28
81.9944123.4215500.4746250.39615701:29
92.2077103.4592480.4733420.39282101:26
101.9877393.5384370.4689060.38096301:29
111.8198643.4831370.4794210.39270501:30
121.4121513.5555840.4797950.39634801:27
131.3632413.4244920.4962300.40886501:23
141.4522923.4204850.4982050.41323201:31
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(15, 3e-3)" ] }, { "cell_type": "code", "execution_count": 146, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " 100.00% [151/151 00:34<00:00]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "inputs, targets, outputs = get_predictions(learn)" ] }, { "cell_type": "code", "execution_count": 147, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Text xxbos qui a le pouvoir de modifier le règlement sur les poids et mesures et le règlement sur l'inspection de l'électricité et du gaz ?,\n", " Text xxbos who has the authority to change the electricity and gas inspection regulations and the weights and measures regulations ?,\n", " Text xxbos what do we regulations and and regulations ? ?)" ] }, "execution_count": 147, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs[700], targets[700], outputs[700]" ] }, { "cell_type": "code", "execution_count": 148, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Text xxbos ´ ` ou sont xxunk leurs grandes convictions en ce qui a trait a la ` ` ´ transparence et a la responsabilite ?,\n", " Text xxbos what happened to their great xxunk about transparency and accountability ?,\n", " Text xxbos what are the and and and and and and and and and to to ? ?)" ] }, "execution_count": 148, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs[701], targets[701], outputs[701]" ] }, { "cell_type": "code", "execution_count": 149, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Text xxbos quelles ressources votre communauté possède - t - elle qui favoriseraient la guérison ?,\n", " Text xxbos what resources exist in your community that would promote recovery ?,\n", " Text xxbos what resources would your community community community community community community ?)" ] }, "execution_count": 149, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs[4002], targets[4002], outputs[4002]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 2 }