{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Neural text generation" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from seq2seq import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "path = Config().data_path()/'giga-fren'\n", "data = load_data(path)\n", "model_path = Config().model_path()\n", "emb_enc = torch.load(model_path/'fr_emb.pth')\n", "emb_dec = torch.load(model_path/'en_emb.pth')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class Seq2SeqRNN_attn(nn.Module):\n", " def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1):\n", " super().__init__()\n", " self.nl,self.nh,self.out_sl,self.pr_force = nl,nh,out_sl,1\n", " self.bos_idx,self.pad_idx = bos_idx,pad_idx\n", " self.emb_enc,self.emb_dec = emb_enc,emb_dec\n", " self.emb_sz_enc,self.emb_sz_dec = emb_enc.embedding_dim,emb_dec.embedding_dim\n", " self.voc_sz_dec = emb_dec.num_embeddings\n", " \n", " self.emb_enc_drop = nn.Dropout(0.15)\n", " self.gru_enc = nn.GRU(self.emb_sz_enc, nh, num_layers=nl, dropout=0.25, \n", " batch_first=True, bidirectional=True)\n", " self.out_enc = nn.Linear(2*nh, self.emb_sz_dec, bias=False)\n", " self.gru_dec = nn.GRU(self.emb_sz_dec + 2*nh, self.emb_sz_dec, num_layers=nl,\n", " dropout=0.1, batch_first=True)\n", " self.out_drop = nn.Dropout(0.35)\n", " self.out = nn.Linear(self.emb_sz_dec, self.voc_sz_dec)\n", " self.out.weight.data = self.emb_dec.weight.data\n", " \n", " self.enc_att = nn.Linear(2*nh, self.emb_sz_dec, bias=False)\n", " self.hid_att = nn.Linear(self.emb_sz_dec, self.emb_sz_dec)\n", " self.V = self.init_param(self.emb_sz_dec)\n", " \n", " def encoder(self, bs, inp):\n", " h = self.initHidden(bs)\n", " emb = self.emb_enc_drop(self.emb_enc(inp))\n", " enc_out, hid = self.gru_enc(emb, 2*h)\n", " pre_hid = hid.view(2, self.nl, bs, self.nh).permute(1,2,0,3).contiguous()\n", " pre_hid = pre_hid.view(self.nl, bs, 2*self.nh)\n", " hid = self.out_enc(pre_hid)\n", " return hid,enc_out\n", " \n", " def decoder(self, dec_inp, hid, enc_att, enc_out):\n", " hid_att = self.hid_att(hid[-1])\n", " u = torch.tanh(enc_att + hid_att[:,None])\n", " attn_wgts = F.softmax(u @ self.V, 1)\n", " ctx = (attn_wgts[...,None] * enc_out).sum(1)\n", " emb = self.emb_dec(dec_inp)\n", " outp, hid = self.gru_dec(torch.cat([emb, ctx], 1)[:,None], hid)\n", " outp = self.out(self.out_drop(outp[:,0]))\n", " return hid, outp\n", " \n", " def forward(self, inp, targ=None):\n", " bs, sl = inp.size()\n", " hid,enc_out = self.encoder(bs, inp)\n", " dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n", " enc_att = self.enc_att(enc_out)\n", " \n", " res = []\n", " for i in range(self.out_sl):\n", " hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out)\n", " res.append(outp)\n", " dec_inp = outp.max(1)[1]\n", " if (dec_inp==self.pad_idx).all(): break\n", " if (targ is not None) and (random.random()=targ.shape[1]: continue\n", " dec_inp = targ[:,i]\n", " return torch.stack(res, dim=1)\n", "\n", " def initHidden(self, bs): return one_param(self).new_zeros(2*self.nl, bs, self.nh)\n", " def init_param(self, *sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0]))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [], "source": [ "model = Seq2SeqRNN_attn(emb_enc, emb_dec, 256, 30)\n", "learn = Learner(data, model, loss_func=seq2seq_loss, metrics=seq2seq_acc,\n", " callback_fns=partial(TeacherForcing, end_epoch=30))" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossseq2seq_accbleutime
02.4118523.8323090.4897620.26668902:08
11.9625034.1938160.4905160.39815401:52
21.7055784.4411230.4566390.39044601:51
31.5626414.0488140.4723330.40302301:49
41.3730884.0483310.4779470.41326101:50
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(5, 3e-3)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# learn.save('5')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jhoward/anaconda3/lib/python3.7/site-packages/torch/serialization.py:256: UserWarning: Couldn't retrieve source code for container of type Seq2SeqRNN_attn. It won't be checked for correctness upon loading.\n", " \"type \" + obj.__name__ + \". It won't be checked \"\n" ] } ], "source": [ "learn.load('5');" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def preds_acts(learn, ds_type=DatasetType.Valid):\n", " \"Same as `get_predictions` but also returns non-reconstructed activations\"\n", " learn.model.eval()\n", " ds = learn.data.train_ds\n", " rxs,rys,rzs,xs,ys,zs = [],[],[],[],[],[] # 'r' == 'reconstructed'\n", " with torch.no_grad():\n", " for xb,yb in progress_bar(learn.dl(ds_type)):\n", " out = learn.model(xb)\n", " for x,y,z in zip(xb,yb,out):\n", " rxs.append(ds.x.reconstruct(x))\n", " rys.append(ds.y.reconstruct(y))\n", " preds = z.argmax(1)\n", " rzs.append(ds.y.reconstruct(preds))\n", " for a,b in zip([xs,ys,zs],[x,y,z]): a.append(b)\n", " return rxs,rys,rzs,xs,ys,zs" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " 100.00% [151/151 00:27<00:00]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "rxs,rys,rzs,xs,ys,zs = preds_acts(learn)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Text xxbos quelles sont les lacunes qui existent encore dans notre connaissance du travail autonome et sur lesquelles les recherches devraient se concentrer à l’avenir ?,\n", " Text xxbos what gaps remain in our knowledge of xxunk on which future research should focus ?,\n", " Text xxbos what gaps are needed in our work and what is the research of the work and what research will be in place to future ?)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idx=701\n", "rx,ry,rz = rxs[idx],rys[idx],rzs[idx]\n", "x,y,z = xs[idx],ys[idx],zs[idx]\n", "rx,ry,rz" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def select_topk(outp, k=5):\n", " probs = F.softmax(outp,dim=-1)\n", " vals,idxs = probs.topk(k, dim=-1)\n", " return idxs[torch.randint(k, (1,))]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751)." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "from random import choice\n", "\n", "def select_nucleus(outp, p=0.5):\n", " probs = F.softmax(outp,dim=-1)\n", " idxs = torch.argsort(probs, descending=True)\n", " res,cumsum = [],0.\n", " for idx in idxs:\n", " res.append(idx)\n", " cumsum += probs[idx]\n", " if cumsum>p: return idxs.new_tensor([choice(res)])" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "def decode(self, inp):\n", " inp = inp[None]\n", " bs, sl = inp.size()\n", " hid,enc_out = self.encoder(bs, inp)\n", " dec_inp = inp.new_zeros(bs).long() + self.bos_idx\n", " enc_att = self.enc_att(enc_out)\n", "\n", " res = []\n", " for i in range(self.out_sl):\n", " hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out)\n", " dec_inp = select_nucleus(outp[0], p=0.3)\n", "# dec_inp = select_topk(outp[0], k=2)\n", " res.append(dec_inp)\n", " if (dec_inp==self.pad_idx).all(): break\n", " return torch.cat(res)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def predict_with_decode(learn, x, y):\n", " learn.model.eval()\n", " ds = learn.data.train_ds\n", " with torch.no_grad():\n", " out = decode(learn.model, x)\n", " rx = ds.x.reconstruct(x)\n", " ry = ds.y.reconstruct(y)\n", " rz = ds.y.reconstruct(out)\n", " return rx,ry,rz" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text xxbos what gaps are needed in our understanding of work and security and how research will need to be put in place ?" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rx,ry,rz = predict_with_decode(learn, x, y)\n", "rz" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" } }, "nbformat": 4, "nbformat_minor": 2 }