{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.text import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = Config().data_path()/'giga-fren'\n", "path.ls()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reduce original dataset to questions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#with open(path/'giga-fren.release2.fixed.fr') as f:\n", "# fr = f.read().split('\\n')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#with open(path/'giga-fren.release2.fixed.en') as f:\n", "# en = f.read().split('\\n')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#re_eq = re.compile('^(Wh[^?.!]+\\?)')\n", "#re_fq = re.compile('^([^?.!]+\\?)')\n", "#en_fname = path/'giga-fren.release2.fixed.en'\n", "#fr_fname = path/'giga-fren.release2.fixed.fr'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#lines = ((re_eq.search(eq), re_fq.search(fq)) \n", "# for eq, fq in zip(open(en_fname, encoding='utf-8'), open(fr_fname, encoding='utf-8')))\n", "#qs = [(e.group(), f.group()) for e,f in lines if e and f]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#qs = [(q1,q2) for q1,q2 in qs]\n", "#df = pd.DataFrame({'fr': [q[1] for q in qs], 'en': [q[0] for q in qs]}, columns = ['en', 'fr'])\n", "#df.to_csv(path/'questions_easy.csv', index=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Put them in a DataBunch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(path/'questions_easy.csv')\n", "df.head()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def seq2seq_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=True, backwards:bool=False) -> Tuple[LongTensor, LongTensor]:\n", " \"Function that collect samples and adds padding. Flips token order if needed\"\n", " samples = to_data(samples)\n", " max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples])\n", " res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx\n", " res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx\n", " if backwards: pad_first = not pad_first\n", " for i,s in enumerate(samples):\n", " if pad_first: \n", " res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])\n", " else: \n", " res_x[i,:len(s[0]):],res_y[i,:len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])\n", " if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1)\n", " return res_x,res_y" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqDataBunch(TextDataBunch):\n", " \"Create a `TextDataBunch` suitable for training an RNN classifier.\"\n", " @classmethod\n", " def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1,\n", " pad_first=False, device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch:\n", " \"Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`\"\n", " datasets = cls._init_ds(train_ds, valid_ds, test_ds)\n", " val_bs = ifnone(val_bs, bs)\n", " collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)\n", " train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2)\n", " train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)\n", " dataloaders = [train_dl]\n", " for ds in datasets[1:]:\n", " lengths = [len(t) for t in ds.x.items]\n", " sampler = SortSampler(ds.x, key=lengths.__getitem__)\n", " dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))\n", " return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqTextList(TextList):\n", " _bunch = Seq2SeqDataBunch\n", " _label_cls = TextList" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "src = Seq2SeqTextList.from_df(df, path = path, cols='en').split_by_rand_pct().label_from_df(cols='fr', label_cls=TextList)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.percentile([len(o) for o in src.train.x.items] + [len(o) for o in src.valid.x.items], 90)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.percentile([len(o) for o in src.train.y.items] + [len(o) for o in src.valid.y.items], 90)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "src = src.filter_by_func(lambda x,y: len(x) > 30 or len(y) > 30)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(src.train) + len(src.valid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = src.databunch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.save('en2fr')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = load_data(path, 'en2fr')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pretrained embeddings for the decoder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To install fastText:\n", "```\n", "$ git clone https://github.com/facebookresearch/fastText.git\n", "$ cd fastText\n", "$ pip install .\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import fastText as ft" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fr_vecs = ft.load_model(str((path/'cc.fr.300.bin')))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create an embedding module with the pretrained vectors and random data for the missing parts." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_emb(vecs, itos, em_sz=300, mult=1.):\n", " emb = nn.Embedding(len(itos), em_sz, padding_idx=1)\n", " wgts = emb.weight.data\n", " vec_dic = {w:vecs.get_word_vector(w) for w in vecs.get_words()}\n", " miss = []\n", " for i,w in enumerate(itos):\n", " try: wgts[i] = tensor(vec_dic[w])\n", " except: miss.append(w)\n", " return emb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "emb_dec = create_emb(fr_vecs, data.y.vocab.itos)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.save(emb_dec, path/'models'/'fr_dec_emb.pth')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "del fr_vecs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Pretrained encoder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.text.models.qrnn import QRNN, QRNNLayer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class SimpleQRNN(nn.Module):\n", " \n", " def __init__(self, vocab_sz, emb_sz=300, n_hid=256, n_layers=2, p_inp=0.1, p_hid=0.1, p_out=0.1):\n", " super().__init__()\n", " self.embed = nn.Embedding(vocab_sz, emb_sz)\n", " self.inp_drop = nn.Dropout(p_inp)\n", " self.encoder = QRNN(emb_sz, n_hid, n_layers=n_layers, dropout=p_hid, save_prev_x=True)\n", " self.out_enc = nn.Linear(n_hid, emb_sz)\n", " self.out_drop = nn.Dropout(p_out)\n", " self.decoder = nn.Linear(emb_sz, vocab_sz)\n", " self.decoder.weight = self.embed.weight\n", " self.n_layers,self.n_hid,self.bs = n_layers,n_hid,1\n", " \n", " def forward(self, inp):\n", " if self.bs != inp.size(0):\n", " self.bs = inp.size(0)\n", " self.init_hidden(inp.size(0))\n", " enc = self.inp_drop(self.embed(inp))\n", " enc, h = self.encoder(enc, self.hidden)\n", " self.hidden = h.detach()\n", " return self.decoder(self.out_drop(self.out_enc(enc)))\n", " \n", " def reset(self):\n", " self.encoder.reset()\n", " \n", " def init_hidden(self, bs):\n", " self.hidden = one_param(self).new_zeros(self.n_layers, bs, self.n_hid)\n", " param = one_param(self)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def convert_weights(wgts:Weights, stoi_wgts:Dict[str,int], itos_new:Collection[str]) -> Weights:\n", " \"Convert the model `wgts` to go with a new vocabulary.\"\n", " dec_bias, enc_wgts = wgts.get('decoder.bias', None), wgts['embed.weight']\n", " wgts_m = enc_wgts.mean(0)\n", " if dec_bias is not None: bias_m = dec_bias.mean(0)\n", " new_w = enc_wgts.new_zeros((len(itos_new),enc_wgts.size(1))).zero_()\n", " if dec_bias is not None: new_b = dec_bias.new_zeros((len(itos_new),)).zero_()\n", " for i,w in enumerate(itos_new):\n", " r = stoi_wgts[w] if w in stoi_wgts else -1\n", " new_w[i] = enc_wgts[r] if r>=0 else wgts_m\n", " if dec_bias is not None: new_b[i] = dec_bias[r] if r>=0 else bias_m\n", " wgts['embed.weight'] = new_w\n", " wgts['decoder.weight'] = new_w.clone()\n", " if dec_bias is not None: wgts['decoder.bias'] = new_b\n", " return wgts" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data_lm = TextList.from_df(df, path = path, cols='en', vocab=data.x.vocab).split_by_rand_pct().label_for_lm().databunch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = SimpleQRNN(len(data.x.vocab.itos))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pretrained_wgts = torch.load(path/'small_qrnn.pth')['model']\n", "pretrained_vocab = Vocab.load(path/'small_qrnn_vocab.pkl')\n", "model.load_state_dict(convert_weights(pretrained_wgts, pretrained_vocab.stoi, data_lm.vocab.itos))\n", "learn = Learner(data_lm, model, metrics=[accuracy, Perplexity()])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(5,1e-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('finetuned', with_opt=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### QRNN seq2seq" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqQRNN(nn.Module):\n", " def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, \n", " p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1):\n", " super().__init__()\n", " self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx\n", " self.emb_enc = emb_enc\n", " self.emb_enc_drop = nn.Dropout(p_inp)\n", " self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc)\n", " self.out_enc = nn.Linear(n_hid, emb_enc.weight.size(1), bias=False)\n", " self.hid_dp = nn.Dropout(p_hid)\n", " self.emb_dec = emb_dec\n", " self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec)\n", " self.out_drop = nn.Dropout(p_out)\n", " self.out = nn.Linear(emb_dec.weight.size(1), emb_dec.weight.size(0))\n", " self.out.weight.data = self.emb_dec.weight.data\n", " \n", " def forward(self, inp):\n", " bs,sl = inp.size()\n", " self.encoder.reset()\n", " self.decoder.reset()\n", " hid = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " enc_out, hid = self.encoder(emb, hid)\n", " hid = self.out_enc(self.hid_dp(hid))\n", "\n", " dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n", " outs = []\n", " for i in range(self.max_len):\n", " emb = self.emb_dec(dec_inp).unsqueeze(1)\n", " out, hid = self.decoder(emb, hid)\n", " out = self.out(self.out_drop(out[:,0]))\n", " outs.append(out)\n", " dec_inp = out.max(1)[1]\n", " if (dec_inp==self.pad_idx).all(): break\n", " return torch.stack(outs, dim=1)\n", " \n", " def initHidden(self, bs): return one_param(self).new_zeros(self.n_layers, bs, self.n_hid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def seq2seq_loss(out, targ, pad_idx=1):\n", " bs,targ_len = targ.size()\n", " _,out_len,vs = out.size()\n", " if targ_len>out_len: out = F.pad(out, (0,0,0,targ_len-out_len,0,0), value=pad_idx)\n", " if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)\n", " return CrossEntropyFlat()(out, targ)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def seq2seq_acc(out, targ, pad_idx=1):\n", " bs,targ_len = targ.size()\n", " _,out_len,vs = out.size()\n", " if targ_len>out_len: out = F.pad(out, (0,0,0,targ_len-out_len,0,0), value=pad_idx)\n", " if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)\n", " out = out.argmax(2)\n", " return (out==targ).float().mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Bleu metric (see dedicated notebook)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class NGram():\n", " def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n\n", " def __eq__(self, other):\n", " if len(self.ngram) != len(other.ngram): return False\n", " return np.all(np.array(self.ngram) == np.array(other.ngram))\n", " def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_grams(x, n, max_n=5000):\n", " return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_correct_ngrams(pred, targ, n, max_n=5000):\n", " pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n)\n", " pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams)\n", " return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CorpusBLEU(Callback):\n", " def __init__(self, vocab_sz):\n", " self.vocab_sz = vocab_sz\n", " self.name = 'bleu'\n", " \n", " def on_epoch_begin(self, **kwargs):\n", " self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4\n", " \n", " def on_batch_end(self, last_output, last_target, **kwargs):\n", " last_output = last_output.argmax(dim=-1)\n", " for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()):\n", " self.pred_len += len(pred)\n", " self.targ_len += len(targ)\n", " for i in range(4):\n", " c,t = get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)\n", " self.corrects[i] += c\n", " self.counts[i] += t\n", " \n", " def on_epoch_end(self, last_metrics, **kwargs):\n", " precs = [c/t for c,t in zip(self.corrects,self.counts)]\n", " len_penalty = exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1\n", " bleu = len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25)\n", " return add_metrics(last_metrics, bleu)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "emb_enc = learn.model.embed\n", "emb_dec = torch.load(path/'models'/'fr_dec_emb.pth')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model1 = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2)\n", "learn1 = Learner(data, model1, loss_func=seq2seq_loss, metrics=seq2seq_acc)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "new_wgts = model.state_dict()\n", "wgts = model1.state_dict()\n", "for k,k1 in zip(wgts.keys(), list(new_wgts.keys())[:-4]): wgts[k].data = new_wgts[k1].data\n", "model1.load_state_dict(wgts)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn1.save('init')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, model1, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))])\n", "learn = learn.load('init')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def seq2seq_split(model):\n", " return [[model.emb_enc, model.emb_enc_drop, model.encoder],\n", " [model.out_enc, model.hid_dp, model.emb_dec, model.decoder, model.out_drop, model.out]]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = learn.split(seq2seq_split)\n", "learn.freeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(8, 1e-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TeacherForcing(LearnerCallback):\n", " \n", " def __init__(self, learn, end_epoch):\n", " super().__init__(learn)\n", " self.end_epoch = end_epoch\n", " \n", " def on_batch_begin(self, last_input, last_target, train, **kwargs):\n", " if train: return {'last_input': [last_input, last_target]}\n", " \n", " def on_epoch_begin(self, epoch, **kwargs):\n", " self.learn.model.pr_force = 1 - 0.5 * epoch/self.end_epoch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqQRNN(nn.Module):\n", " def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, \n", " p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1):\n", " super().__init__()\n", " self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx\n", " self.emb_enc = emb_enc\n", " self.emb_enc_drop = nn.Dropout(p_inp)\n", " self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc)\n", " self.out_enc = nn.Linear(n_hid, emb_enc.weight.size(1), bias=False)\n", " self.hid_dp = nn.Dropout(p_hid)\n", " self.emb_dec = emb_dec\n", " self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec)\n", " self.out_drop = nn.Dropout(p_out)\n", " self.out = nn.Linear(emb_dec.weight.size(1), emb_dec.weight.size(0))\n", " self.out.weight.data = self.emb_dec.weight.data\n", " self.pr_force = 0.\n", " \n", " def forward(self, inp, targ=None):\n", " bs,sl = inp.size()\n", " hid = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " enc_out, hid = self.encoder(emb, hid)\n", " hid = self.out_enc(self.hid_dp(hid))\n", "\n", " dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n", " res = []\n", " for i in range(self.max_len):\n", " emb = self.emb_dec(dec_inp).unsqueeze(1)\n", " outp, hid = self.decoder(emb, hid)\n", " outp = self.out(self.out_drop(outp[:,0]))\n", " res.append(outp)\n", " dec_inp = outp.data.max(1)[1]\n", " if (dec_inp==self.pad_idx).all(): break\n", " if (targ is not None) and (random.random()=targ.shape[1]: break\n", " dec_inp = targ[:,i]\n", " return torch.stack(res, dim=1)\n", " \n", " def initHidden(self, bs): return one_param(self).new_zeros(self.n_layers, bs, self.n_hid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2)\n", "learn = Learner(data, model, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))],\n", " callback_fns=partial(TeacherForcing, end_epoch=8))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = learn.load('init')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(8, 1e-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs, targets, outputs, tensor_inputs = get_predictions(learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[700],targets[700],outputs[700]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[705],targets[705],outputs[705]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.load('init');" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(8, 1e-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs, targets, outputs, tensor_inputs = get_predictions(learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[700],targets[700],outputs[700]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[705],targets[705],outputs[705]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Bidir" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqQRNN(nn.Module):\n", " def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, \n", " p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1):\n", " super().__init__()\n", " self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx\n", " self.emb_enc = emb_enc\n", " self.emb_enc_drop = nn.Dropout(p_inp)\n", " self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc, bidirectional=True)\n", " self.out_enc = nn.Linear(2*n_hid, emb_enc.weight.size(1), bias=False)\n", " self.hid_dp = nn.Dropout(p_hid)\n", " self.emb_dec = emb_dec\n", " self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec)\n", " self.out_drop = nn.Dropout(p_out)\n", " self.out = nn.Linear(emb_dec.weight.size(1), emb_dec.weight.size(0))\n", " self.out.weight.data = self.emb_dec.weight.data\n", " self.pr_force = 0.\n", " \n", " def forward(self, inp, targ=None):\n", " bs,sl = inp.size()\n", " hid = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " enc_out, hid = self.encoder(emb, hid)\n", " \n", " hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous()\n", " hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid))\n", "\n", " dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n", " res = []\n", " for i in range(self.max_len):\n", " emb = self.emb_dec(dec_inp).unsqueeze(1)\n", " outp, hid = self.decoder(emb, hid)\n", " outp = self.out(self.out_drop(outp[:,0]))\n", " res.append(outp)\n", " dec_inp = outp.data.max(1)[1]\n", " if (dec_inp==self.pad_idx).all(): break\n", " if (targ is not None) and (random.random()=targ.shape[1]: break\n", " dec_inp = targ[:,i]\n", " return torch.stack(res, dim=1)\n", " \n", " def initHidden(self, bs): return one_param(self).new_zeros(2*self.n_layers, bs, self.n_hid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "emb_enc = torch.load(path/'models'/'fr_emb.pth')\n", "emb_dec = torch.load(path/'models'/'en_emb.pth')\n", "emb_enc = torch.load(path/'models'/'en_emb2.pth')\n", "emb_dec = torch.load(path/'models'/'fr_emb2.pth')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2)\n", "learn = Learner(data, model, loss_func=seq2seq_loss, metrics=seq2seq_acc,\n", " callback_fns=partial(TeacherForcing, end_epoch=8))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(8, 1e-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs, targets, outputs, input_tensors = get_predictions(learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[700], targets[700], outputs[700]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[701], targets[701], outputs[701]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[4001], targets[4001], outputs[4001]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Attention" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = load_data(path, 'en2fr')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def init_param(*sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqQRNN(nn.Module):\n", " def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, \n", " p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1):\n", " super().__init__()\n", " self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx\n", " self.emb_enc = emb_enc\n", " self.emb_enc_drop = nn.Dropout(p_inp)\n", " self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc)\n", " self.out_enc = nn.Linear(n_hid, emb_enc.weight.size(1), bias=False)\n", " self.hid_dp = nn.Dropout(p_hid)\n", " self.emb_dec = emb_dec\n", " self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec)\n", " self.out_drop = nn.Dropout(p_out)\n", " emb_sz = emb_dec.weight.size(1)\n", " self.out = nn.Linear(emb_sz, emb_dec.weight.size(0))\n", " self.out.weight.data = self.emb_dec.weight.data #Try tying\n", " self.W1 = init_param(n_hid, emb_sz)\n", " self.l2 = nn.Linear(emb_sz, emb_sz)\n", " self.l3 = nn.Linear(emb_sz+n_hid, emb_sz)\n", " self.V = init_param(emb_sz)\n", " self.pr_force = 0.\n", " \n", " def forward(self, inp, targ=None):\n", " bs,sl = inp.size()\n", " hid = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " enc_out, hid = self.encoder(emb, hid)\n", " \n", " hid = self.out_enc(self.hid_dp(hid))\n", "\n", " dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n", " res = []\n", " w1e = enc_out @ self.W1\n", " for i in range(self.max_len):\n", " w2h = self.l2(hid[-1])\n", " u = torch.tanh(w1e + w2h[:,None])\n", " a = F.softmax(u @ self.V, 1)\n", " Xa = (a.unsqueeze(2) * enc_out).sum(1)\n", " emb = self.emb_dec(dec_inp)\n", " wgt_enc = self.l3(torch.cat([emb, Xa], 1))\n", " outp, hid = self.decoder(wgt_enc[:,None], hid)\n", " outp = self.out(self.out_drop(outp[:,0]))\n", " res.append(outp)\n", " dec_inp = outp.data.max(1)[1]\n", " if (dec_inp==self.pad_idx).all(): break\n", " if (targ is not None) and (random.random()=targ.shape[1]: break\n", " dec_inp = targ[:,i]\n", " return torch.stack(res, dim=1)\n", " \n", " def initHidden(self, bs): return one_param(self).new_zeros(self.n_layers, bs, self.n_hid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#emb_enc = torch.load(path/'models'/'fr_emb.pth')\n", "#emb_dec = torch.load(path/'models'/'en_emb.pth')\n", "emb_enc = torch.load(path/'models'/'en_emb2.pth')\n", "emb_dec = torch.load(path/'models'/'fr_emb2.pth')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2)\n", "learn = Learner(data, model, loss_func=seq2seq_loss, metrics=seq2seq_acc,\n", " callback_fns=partial(TeacherForcing, end_epoch=8))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.load('init', strict=False);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(8, 3e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs, targets, outputs, input_tensors = get_predictions(learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[700], targets[700], outputs[700]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[701], targets[701], outputs[701]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[4002], targets[4002], outputs[4002]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = load_data(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def init_param(*sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqQRNN(nn.Module):\n", " def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, \n", " p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1):\n", " super().__init__()\n", " self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx\n", " self.emb_enc = emb_enc\n", " self.emb_enc_drop = nn.Dropout(p_inp)\n", " self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc, bidirectional=True)\n", " self.out_enc = nn.Linear(2*n_hid, emb_enc.weight.size(1), bias=False)\n", " self.hid_dp = nn.Dropout(p_hid)\n", " self.emb_dec = emb_dec\n", " self.decoder = QRNN(emb_dec.weight.size(1), emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec)\n", " self.out_drop = nn.Dropout(p_out)\n", " emb_sz = emb_dec.weight.size(1)\n", " self.out = nn.Linear(emb_sz, emb_dec.weight.size(0))\n", " self.out.weight.data = self.emb_dec.weight.data #Try tying\n", " self.W1 = init_param(2*n_hid, emb_sz)\n", " self.l2 = nn.Linear(emb_sz, emb_sz)\n", " self.l3 = nn.Linear(emb_sz+2*n_hid, emb_sz)\n", " self.V = init_param(emb_sz)\n", " self.pr_force = 0.\n", " \n", " def forward(self, inp, targ=None):\n", " bs,sl = inp.size()\n", " hid = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " enc_out, hid = self.encoder(emb, hid)\n", " \n", " hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous()\n", " hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid))\n", "\n", " dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n", " res = []\n", " w1e = enc_out @ self.W1\n", " for i in range(self.max_len):\n", " w2h = self.l2(hid[-1])\n", " u = torch.tanh(w1e + w2h[:,None])\n", " a = F.softmax(u @ self.V, 1)\n", " Xa = (a.unsqueeze(2) * enc_out).sum(1)\n", " emb = self.emb_dec(dec_inp)\n", " wgt_enc = self.l3(torch.cat([emb, Xa], 1))\n", " outp, hid = self.decoder(wgt_enc[:,None], hid)\n", " outp = self.out(self.out_drop(outp[:,0]))\n", " res.append(outp)\n", " dec_inp = outp.data.max(1)[1]\n", " if (dec_inp==self.pad_idx).all(): break\n", " if (targ is not None) and (random.random()=targ.shape[1]: break\n", " dec_inp = targ[:,i]\n", " return torch.stack(res, dim=1)\n", " \n", " def initHidden(self, bs): return one_param(self).new_zeros(2*self.n_layers, bs, self.n_hid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqQRNN(nn.Module):\n", " def __init__(self, emb_enc, emb_dec, n_hid, max_len, n_layers=2, p_inp:float=0.15, p_enc:float=0.25, \n", " p_dec:float=0.1, p_out:float=0.35, p_hid:float=0.05, bos_idx:int=0, pad_idx:int=1):\n", " super().__init__()\n", " self.n_layers,self.n_hid,self.max_len,self.bos_idx,self.pad_idx = n_layers,n_hid,max_len,bos_idx,pad_idx\n", " self.emb_enc = emb_enc\n", " self.emb_enc_drop = nn.Dropout(p_inp)\n", " self.encoder = QRNN(emb_enc.weight.size(1), n_hid, n_layers=n_layers, dropout=p_enc, bidirectional=True)\n", " self.out_enc = nn.Linear(2*n_hid, emb_enc.weight.size(1), bias=False)\n", " self.hid_dp = nn.Dropout(p_hid)\n", " self.emb_dec = emb_dec\n", " emb_sz = emb_dec.weight.size(1)\n", " self.decoder = QRNN(emb_sz + 2*n_hid, emb_dec.weight.size(1), n_layers=n_layers, dropout=p_dec)\n", " self.out_drop = nn.Dropout(p_out)\n", " self.out = nn.Linear(emb_sz, emb_dec.weight.size(0))\n", " self.out.weight.data = self.emb_dec.weight.data #Try tying\n", " self.enc_att = nn.Linear(2*n_hid, emb_sz, bias=False)\n", " self.hid_att = nn.Linear(emb_sz, emb_sz)\n", " self.V = init_param(emb_sz)\n", " self.pr_force = 0.\n", " \n", " def forward(self, inp, targ=None):\n", " bs,sl = inp.size()\n", " hid = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " enc_out, hid = self.encoder(emb, hid)\n", " \n", " hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous()\n", " hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid))\n", "\n", " dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n", " res = []\n", " enc_att = self.enc_att(enc_out)\n", " for i in range(self.max_len):\n", " hid_att = self.hid_att(hid[-1])\n", " u = torch.tanh(enc_att + hid_att[:,None])\n", " attn_wgts = F.softmax(u @ self.V, 1)\n", " ctx = (attn_wgts[...,None] * enc_out).sum(1)\n", " emb = self.emb_dec(dec_inp)\n", " outp, hid = self.decoder(torch.cat([emb, ctx], 1)[:,None], hid)\n", " outp = self.out(self.out_drop(outp[:,0]))\n", " res.append(outp)\n", " dec_inp = outp.data.max(1)[1]\n", " if (dec_inp==self.pad_idx).all(): break\n", " if (targ is not None) and (random.random()=targ.shape[1]: break\n", " dec_inp = targ[:,i]\n", " return torch.stack(res, dim=1)\n", " \n", " def initHidden(self, bs): return one_param(self).new_zeros(2*self.n_layers, bs, self.n_hid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "emb_enc = torch.load(path/'models'/'fr_emb.pth')\n", "emb_dec = torch.load(path/'models'/'en_emb.pth')\n", "#emb_enc = torch.load(path/'models'/'en_emb2.pth')\n", "#emb_dec = torch.load(path/'models'/'fr_emb2.pth')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = Seq2SeqQRNN(emb_enc, emb_dec, 256, 30, n_layers=2)\n", "learn = Learner(data, model, loss_func=seq2seq_loss, metrics=seq2seq_acc,\n", " callback_fns=partial(TeacherForcing, end_epoch=8))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(8, 3e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs, targets, outputs, input_tensors = get_predictions(learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[700], targets[700], outputs[700]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[701], targets[701], outputs[701]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "inputs[4002], targets[4002], outputs[4002]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }