{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline\n", "\n", "from fastai.io import *\n", "from fastai.conv_learner import *\n", "\n", "from fastai.column_data import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We're going to download the collected works of Nietzsche to use as our data for this class." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "PATH='data/nietzsche/'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "corpus length: 600893\n" ] } ], "source": [ "get_data(\"https://s3.amazonaws.com/text-datasets/nietzsche.txt\", f'{PATH}nietzsche.txt')\n", "text = open(f'{PATH}nietzsche.txt').read()\n", "print('corpus length:', len(text))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'PREFACE\\n\\n\\nSUPPOSING that Truth is a woman--what then? Is there not ground\\nfor suspecting that all philosophers, in so far as they have been\\ndogmatists, have failed to understand women--that the terrible\\nseriousness and clumsy importunity with which they have usually paid\\ntheir addresses to Truth, have been unskilled and unseemly methods for\\nwinning a woman? Certainly she has never allowed herself '" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text[:400]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "total chars: 85\n" ] } ], "source": [ "chars = sorted(list(set(text)))\n", "vocab_size = len(chars)+1\n", "print('total chars:', vocab_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sometimes it's useful to have a zero value in the dataset, e.g. for padding" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'\\n !\"\\'(),-.0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxy'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "chars.insert(0, \"\\0\")\n", "\n", "''.join(chars[1:-6])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Map from chars to indices and back again" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [], "source": [ "char_indices = {c: i for i, c in enumerate(chars)}\n", "indices_char = {i: c for i, c in enumerate(chars)}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*idx* will be the data we use from now on - it simply converts all the characters to their index (based on the mapping above)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[40, 42, 29, 30, 25, 27, 29, 1, 1, 1]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "idx = [char_indices[c] for c in text]\n", "\n", "idx[:10]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'PREFACE\\n\\n\\nSUPPOSING that Truth is a woman--what then? Is there not gro'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "''.join(indices_char[i] for i in idx[:70])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Three char model" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Create inputs" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Create a list of every 4th character, starting at the 0th, 1st, 2nd, then 3rd characters" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "cs=3\n", "c1_dat = [idx[i] for i in range(0, len(idx)-cs, cs)]\n", "c2_dat = [idx[i+1] for i in range(0, len(idx)-cs, cs)]\n", "c3_dat = [idx[i+2] for i in range(0, len(idx)-cs, cs)]\n", "c4_dat = [idx[i+3] for i in range(0, len(idx)-cs, cs)]" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Our inputs" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "x1 = np.stack(c1_dat)\n", "x2 = np.stack(c2_dat)\n", "x3 = np.stack(c3_dat)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Our output" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "y = np.stack(c4_dat)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "The first 4 inputs and outputs" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "hidden": true, "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "(array([40, 30, 29, 1]), array([42, 25, 1, 43]), array([29, 27, 1, 45]))" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1[:4], x2[:4], x3[:4]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([30, 29, 1, 40])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y[:4]" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "((200295,), (200295,))" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1.shape, y.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create and train model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pick a size for our hidden state" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "collapsed": true }, "outputs": [], "source": [ "n_hidden = 256" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The number of latent factors to create (i.e. the size of the embedding matrix)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [], "source": [ "n_fac = 42" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class Char3Model(nn.Module):\n", " def __init__(self, vocab_size, n_fac):\n", " super().__init__()\n", " self.e = nn.Embedding(vocab_size, n_fac)\n", "\n", " # The 'green arrow' from our diagram - the layer operation from input to hidden\n", " self.l_in = nn.Linear(n_fac, n_hidden)\n", "\n", " # The 'orange arrow' from our diagram - the layer operation from hidden to hidden\n", " self.l_hidden = nn.Linear(n_hidden, n_hidden)\n", " \n", " # The 'blue arrow' from our diagram - the layer operation from hidden to output\n", " self.l_out = nn.Linear(n_hidden, vocab_size)\n", " \n", " def forward(self, c1, c2, c3):\n", " in1 = F.relu(self.l_in(self.e(c1)))\n", " in2 = F.relu(self.l_in(self.e(c2)))\n", " in3 = F.relu(self.l_in(self.e(c3)))\n", " \n", " h = V(torch.zeros(in1.size()).cuda())\n", " h = F.tanh(self.l_hidden(h+in1))\n", " h = F.tanh(self.l_hidden(h+in2))\n", " h = F.tanh(self.l_hidden(h+in3))\n", " \n", " return F.log_softmax(self.l_out(h))" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": true }, "outputs": [], "source": [ "md = ColumnarModelData.from_arrays('.', [-1], np.stack([x1,x2,x3], axis=1), y, bs=512)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "m = Char3Model(vocab_size, n_fac).cuda()" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": true }, "outputs": [], "source": [ "it = iter(md.trn_dl)\n", "*xs,yt = next(it)\n", "t = m(*V(xs))" ] }, { "cell_type": "code", "execution_count": 191, "metadata": { "collapsed": true }, "outputs": [], "source": [ "opt = optim.Adam(m.parameters(), 1e-2)" ] }, { "cell_type": "code", "execution_count": 192, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "73483a3ac1804c3e81c8de6744d5c4bd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 2.09627 6.52849] \n", "\n" ] } ], "source": [ "fit(m, md, 1, opt, F.nll_loss)" ] }, { "cell_type": "code", "execution_count": 193, "metadata": { "collapsed": true }, "outputs": [], "source": [ "set_lrs(opt, 0.001)" ] }, { "cell_type": "code", "execution_count": 194, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d7278ec0864e451795d91bac0ff944c7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.84525 6.52312] \n", "\n" ] } ], "source": [ "fit(m, md, 1, opt, F.nll_loss)" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Test model" ] }, { "cell_type": "code", "execution_count": 195, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "def get_next(inp):\n", " idxs = T(np.array([char_indices[c] for c in inp]))\n", " p = m(*VV(idxs))\n", " i = np.argmax(to_np(p))\n", " return chars[i]" ] }, { "cell_type": "code", "execution_count": 196, "metadata": { "hidden": true, "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "'T'" ] }, "execution_count": 196, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_next('y. ')" ] }, { "cell_type": "code", "execution_count": 197, "metadata": { "hidden": true, "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "'e'" ] }, "execution_count": 197, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_next('ppl')" ] }, { "cell_type": "code", "execution_count": 198, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "'e'" ] }, "execution_count": 198, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_next(' th')" ] }, { "cell_type": "code", "execution_count": 199, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "' '" ] }, "execution_count": 199, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_next('and')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Our first RNN!" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Create inputs" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "This is the size of our unrolled RNN." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "cs=8" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "For each of 0 through 7, create a list of every 8th character with that starting point. These will be the 8 inputs to our model." ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "c_in_dat = [[idx[i+j] for i in range(cs)] for j in range(len(idx)-cs)]" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Then create a list of the next character in each of these series. This will be the labels for our model." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "c_out_dat = [idx[j+cs] for j in range(len(idx)-cs)]" ] }, { "cell_type": "code", "execution_count": 66, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "xs = np.stack(c_in_dat, axis=0)" ] }, { "cell_type": "code", "execution_count": 67, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(600884, 8)" ] }, "execution_count": 67, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xs.shape" ] }, { "cell_type": "code", "execution_count": 68, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "y = np.stack(c_out_dat)" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "So each column below is one series of 8 characters from the text." ] }, { "cell_type": "code", "execution_count": 69, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[40, 42, 29, 30, 25, 27, 29, 1],\n", " [42, 29, 30, 25, 27, 29, 1, 1],\n", " [29, 30, 25, 27, 29, 1, 1, 1],\n", " [30, 25, 27, 29, 1, 1, 1, 43],\n", " [25, 27, 29, 1, 1, 1, 43, 45],\n", " [27, 29, 1, 1, 1, 43, 45, 40],\n", " [29, 1, 1, 1, 43, 45, 40, 40],\n", " [ 1, 1, 1, 43, 45, 40, 40, 39]])" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xs[:cs,:cs]" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "...and this is the next character after each sequence." ] }, { "cell_type": "code", "execution_count": 70, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([ 1, 1, 43, 45, 40, 40, 39, 43])" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y[:cs]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create and train model" ] }, { "cell_type": "code", "execution_count": 71, "metadata": { "collapsed": true }, "outputs": [], "source": [ "val_idx = get_cv_idxs(len(idx)-cs-1)" ] }, { "cell_type": "code", "execution_count": 72, "metadata": { "collapsed": true }, "outputs": [], "source": [ "md = ColumnarModelData.from_arrays('.', val_idx, xs, y, bs=512)" ] }, { "cell_type": "code", "execution_count": 81, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class CharLoopModel(nn.Module):\n", " # This is an RNN!\n", " def __init__(self, vocab_size, n_fac):\n", " super().__init__()\n", " self.e = nn.Embedding(vocab_size, n_fac)\n", " self.l_in = nn.Linear(n_fac, n_hidden)\n", " self.l_hidden = nn.Linear(n_hidden, n_hidden)\n", " self.l_out = nn.Linear(n_hidden, vocab_size)\n", " \n", " def forward(self, *cs):\n", " bs = cs[0].size(0)\n", " h = V(torch.zeros(bs, n_hidden).cuda())\n", " for c in cs:\n", " inp = F.relu(self.l_in(self.e(c)))\n", " h = F.tanh(self.l_hidden(h+inp))\n", " \n", " return F.log_softmax(self.l_out(h), dim=-1)" ] }, { "cell_type": "code", "execution_count": 82, "metadata": { "collapsed": true }, "outputs": [], "source": [ "m = CharLoopModel(vocab_size, n_fac).cuda()\n", "opt = optim.Adam(m.parameters(), 1e-2)" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6d1c8fb012c74fe191921d467c80b5ea", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 2.02986 1.99268] \n", "\n" ] } ], "source": [ "fit(m, md, 1, opt, F.nll_loss)" ] }, { "cell_type": "code", "execution_count": 84, "metadata": { "collapsed": true }, "outputs": [], "source": [ "set_lrs(opt, 0.001)" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6e4a151c0f274c129e346a22fd4bdece", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.73588 1.75103] \n", "\n" ] } ], "source": [ "fit(m, md, 1, opt, F.nll_loss)" ] }, { "cell_type": "code", "execution_count": 92, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class CharLoopConcatModel(nn.Module):\n", " def __init__(self, vocab_size, n_fac):\n", " super().__init__()\n", " self.e = nn.Embedding(vocab_size, n_fac)\n", " self.l_in = nn.Linear(n_fac+n_hidden, n_hidden)\n", " self.l_hidden = nn.Linear(n_hidden, n_hidden)\n", " self.l_out = nn.Linear(n_hidden, vocab_size)\n", " \n", " def forward(self, *cs):\n", " bs = cs[0].size(0)\n", " h = V(torch.zeros(bs, n_hidden).cuda())\n", " for c in cs:\n", " inp = torch.cat((h, self.e(c)), 1)\n", " inp = F.relu(self.l_in(inp))\n", " h = F.tanh(self.l_hidden(inp))\n", " \n", " return F.log_softmax(self.l_out(h), dim=-1)" ] }, { "cell_type": "code", "execution_count": 93, "metadata": { "collapsed": true }, "outputs": [], "source": [ "m = CharLoopConcatModel(vocab_size, n_fac).cuda()\n", "opt = optim.Adam(m.parameters(), 1e-3)" ] }, { "cell_type": "code", "execution_count": 94, "metadata": { "collapsed": true }, "outputs": [], "source": [ "it = iter(md.trn_dl)\n", "*xs,yt = next(it)\n", "t = m(*V(xs))" ] }, { "cell_type": "code", "execution_count": 95, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d1b3572e787441d8b2e5d80317245596", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.81654 1.78501] \n", "\n" ] } ], "source": [ "fit(m, md, 1, opt, F.nll_loss)" ] }, { "cell_type": "code", "execution_count": 96, "metadata": { "collapsed": true }, "outputs": [], "source": [ "set_lrs(opt, 1e-4)" ] }, { "cell_type": "code", "execution_count": 97, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9aa67fbb4a2f42509dbe7753bc86d9a1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.69008 1.69936] \n", "\n" ] } ], "source": [ "fit(m, md, 1, opt, F.nll_loss)" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Test model" ] }, { "cell_type": "code", "execution_count": 98, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "def get_next(inp):\n", " idxs = T(np.array([char_indices[c] for c in inp]))\n", " p = m(*VV(idxs))\n", " i = np.argmax(to_np(p))\n", " return chars[i]" ] }, { "cell_type": "code", "execution_count": 99, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "'e'" ] }, "execution_count": 99, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_next('for thos')" ] }, { "cell_type": "code", "execution_count": 100, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "'t'" ] }, "execution_count": 100, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_next('part of ')" ] }, { "cell_type": "code", "execution_count": 101, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "'n'" ] }, "execution_count": 101, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_next('queens a')" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## RNN with pytorch" ] }, { "cell_type": "code", "execution_count": 108, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class CharRnn(nn.Module):\n", " def __init__(self, vocab_size, n_fac):\n", " super().__init__()\n", " self.e = nn.Embedding(vocab_size, n_fac)\n", " self.rnn = nn.RNN(n_fac, n_hidden)\n", " self.l_out = nn.Linear(n_hidden, vocab_size)\n", " \n", " def forward(self, *cs):\n", " bs = cs[0].size(0)\n", " h = V(torch.zeros(1, bs, n_hidden))\n", " inp = self.e(torch.stack(cs))\n", " outp,h = self.rnn(inp, h)\n", " \n", " return F.log_softmax(self.l_out(outp[-1]), dim=-1)" ] }, { "cell_type": "code", "execution_count": 109, "metadata": { "collapsed": true }, "outputs": [], "source": [ "m = CharRnn(vocab_size, n_fac).cuda()\n", "opt = optim.Adam(m.parameters(), 1e-3)" ] }, { "cell_type": "code", "execution_count": 110, "metadata": { "collapsed": true }, "outputs": [], "source": [ "it = iter(md.trn_dl)\n", "*xs,yt = next(it)" ] }, { "cell_type": "code", "execution_count": 111, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([8, 512, 42])" ] }, "execution_count": 111, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t = m.e(V(torch.stack(xs)))\n", "t.size()" ] }, { "cell_type": "code", "execution_count": 112, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([8, 512, 256]), torch.Size([1, 512, 256]))" ] }, "execution_count": 112, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ht = V(torch.zeros(1, 512,n_hidden))\n", "outp, hn = m.rnn(t, ht)\n", "outp.size(), hn.size()" ] }, { "cell_type": "code", "execution_count": 113, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([512, 85])" ] }, "execution_count": 113, "metadata": {}, "output_type": "execute_result" } ], "source": [ "t = m(*V(xs)); t.size()" ] }, { "cell_type": "code", "execution_count": 114, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "497078d15ec348149442681039df2e50", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.86065 1.84255] \n", "[ 1. 1.68014 1.67387] \n", "[ 2. 1.58828 1.59169] \n", "[ 3. 1.52989 1.54942] \n", "\n" ] } ], "source": [ "fit(m, md, 4, opt, F.nll_loss)" ] }, { "cell_type": "code", "execution_count": 115, "metadata": { "collapsed": true }, "outputs": [], "source": [ "set_lrs(opt, 1e-4)" ] }, { "cell_type": "code", "execution_count": 116, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "65a2f7bedaa34de2a40296a07387c1c9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.46841 1.50966] \n", "[ 1. 1.46482 1.5039 ] \n", "\n" ] } ], "source": [ "fit(m, md, 2, opt, F.nll_loss)" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Test model" ] }, { "cell_type": "code", "execution_count": 117, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "def get_next(inp):\n", " idxs = T(np.array([char_indices[c] for c in inp]))\n", " p = m(*VV(idxs))\n", " i = np.argmax(to_np(p))\n", " return chars[i]" ] }, { "cell_type": "code", "execution_count": 118, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "'e'" ] }, "execution_count": 118, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_next('for thos')" ] }, { "cell_type": "code", "execution_count": 119, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "def get_next_n(inp, n):\n", " res = inp\n", " for i in range(n):\n", " c = get_next(inp)\n", " res += c\n", " inp = inp[1:]+c\n", " return res" ] }, { "cell_type": "code", "execution_count": 120, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "'for those the same the same the same the same th'" ] }, "execution_count": 120, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_next_n('for thos', 40)" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "## Multi-output model" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Setup" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Let's take non-overlapping sets of characters this time" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "c_in_dat = [[idx[i+j] for i in range(cs)] for j in range(0, len(idx)-cs-1, cs)]" ] }, { "cell_type": "markdown", "metadata": { "hidden": true }, "source": [ "Then create the exact same thing, offset by 1, as our labels" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "c_out_dat = [[idx[i+j] for i in range(cs)] for j in range(1, len(idx)-cs, cs)]" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(75111, 8)" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xs = np.stack(c_in_dat)\n", "xs.shape" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "(75111, 8)" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ys = np.stack(c_out_dat)\n", "ys.shape" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[40, 42, 29, 30, 25, 27, 29, 1],\n", " [ 1, 1, 43, 45, 40, 40, 39, 43],\n", " [33, 38, 31, 2, 73, 61, 54, 73],\n", " [ 2, 44, 71, 74, 73, 61, 2, 62],\n", " [72, 2, 54, 2, 76, 68, 66, 54],\n", " [67, 9, 9, 76, 61, 54, 73, 2],\n", " [73, 61, 58, 67, 24, 2, 33, 72],\n", " [ 2, 73, 61, 58, 71, 58, 2, 67]])" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xs[:cs,:cs]" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "array([[42, 29, 30, 25, 27, 29, 1, 1],\n", " [ 1, 43, 45, 40, 40, 39, 43, 33],\n", " [38, 31, 2, 73, 61, 54, 73, 2],\n", " [44, 71, 74, 73, 61, 2, 62, 72],\n", " [ 2, 54, 2, 76, 68, 66, 54, 67],\n", " [ 9, 9, 76, 61, 54, 73, 2, 73],\n", " [61, 58, 67, 24, 2, 33, 72, 2],\n", " [73, 61, 58, 71, 58, 2, 67, 68]])" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ys[:cs,:cs]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create and train model" ] }, { "cell_type": "code", "execution_count": 127, "metadata": { "collapsed": true }, "outputs": [], "source": [ "val_idx = get_cv_idxs(len(xs)-cs-1)" ] }, { "cell_type": "code", "execution_count": 128, "metadata": { "collapsed": true }, "outputs": [], "source": [ "md = ColumnarModelData.from_arrays('.', val_idx, xs, ys, bs=512)" ] }, { "cell_type": "code", "execution_count": 133, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class CharSeqRnn(nn.Module):\n", " def __init__(self, vocab_size, n_fac):\n", " super().__init__()\n", " self.e = nn.Embedding(vocab_size, n_fac)\n", " self.rnn = nn.RNN(n_fac, n_hidden)\n", " self.l_out = nn.Linear(n_hidden, vocab_size)\n", " \n", " def forward(self, *cs):\n", " bs = cs[0].size(0)\n", " h = V(torch.zeros(1, bs, n_hidden))\n", " inp = self.e(torch.stack(cs))\n", " outp,h = self.rnn(inp, h)\n", " return F.log_softmax(self.l_out(outp), dim=-1)" ] }, { "cell_type": "code", "execution_count": 134, "metadata": { "collapsed": true }, "outputs": [], "source": [ "m = CharSeqRnn(vocab_size, n_fac).cuda()\n", "opt = optim.Adam(m.parameters(), 1e-3)" ] }, { "cell_type": "code", "execution_count": 135, "metadata": { "collapsed": true }, "outputs": [], "source": [ "it = iter(md.trn_dl)\n", "*xst,yt = next(it)" ] }, { "cell_type": "code", "execution_count": 136, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def nll_loss_seq(inp, targ):\n", " sl,bs,nh = inp.size()\n", " targ = targ.transpose(0,1).contiguous().view(-1)\n", " return F.nll_loss(inp.view(-1,nh), targ)" ] }, { "cell_type": "code", "execution_count": 137, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "725ca331d28b482e9c7a4f83f741498e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 2.59241 2.40251] \n", "[ 1. 2.28474 2.19859] \n", "[ 2. 2.13883 2.08836] \n", "[ 3. 2.04892 2.01564] \n", "\n" ] } ], "source": [ "fit(m, md, 4, opt, nll_loss_seq)" ] }, { "cell_type": "code", "execution_count": 138, "metadata": { "collapsed": true }, "outputs": [], "source": [ "set_lrs(opt, 1e-4)" ] }, { "cell_type": "code", "execution_count": 139, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "adb9aa22524d4bfd8b001d2efd10dbc3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.99819 2.00106] \n", "\n" ] } ], "source": [ "fit(m, md, 1, opt, nll_loss_seq)" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### Identity init!" ] }, { "cell_type": "code", "execution_count": 140, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "m = CharSeqRnn(vocab_size, n_fac).cuda()\n", "opt = optim.Adam(m.parameters(), 1e-2)" ] }, { "cell_type": "code", "execution_count": 141, "metadata": { "hidden": true }, "outputs": [ { "data": { "text/plain": [ "\n", " 1 0 0 ... 0 0 0\n", " 0 1 0 ... 0 0 0\n", " 0 0 1 ... 0 0 0\n", " ... ⋱ ... \n", " 0 0 0 ... 1 0 0\n", " 0 0 0 ... 0 1 0\n", " 0 0 0 ... 0 0 1\n", "[torch.cuda.FloatTensor of size 256x256 (GPU 0)]" ] }, "execution_count": 141, "metadata": {}, "output_type": "execute_result" } ], "source": [ "m.rnn.weight_hh_l0.data.copy_(torch.eye(n_hidden))" ] }, { "cell_type": "code", "execution_count": 142, "metadata": { "hidden": true, "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8e141251f24d4083a6e8b2fa15dea724", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 2.39428 2.21111] \n", "[ 1. 2.10381 2.03275] \n", "[ 2. 1.99451 1.96393] \n", "[ 3. 1.93492 1.91763] \n", "\n" ] } ], "source": [ "fit(m, md, 4, opt, nll_loss_seq)" ] }, { "cell_type": "code", "execution_count": 143, "metadata": { "collapsed": true, "hidden": true }, "outputs": [], "source": [ "set_lrs(opt, 1e-3)" ] }, { "cell_type": "code", "execution_count": 144, "metadata": { "hidden": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ddf833e8b7ec4a3aa29dd271911f76ec", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.84035 1.85742] \n", "[ 1. 1.82896 1.84887] \n", "[ 2. 1.81879 1.84281] \n", "[ 3. 1.81337 1.83801] \n", "\n" ] } ], "source": [ "fit(m, md, 4, opt, nll_loss_seq)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Stateful model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[0m\u001b[01;34mmodels\u001b[0m/ nietzsche.txt \u001b[01;34mtrn\u001b[0m/ \u001b[01;34mval\u001b[0m/\r\n" ] } ], "source": [ "from torchtext import vocab, data\n", "\n", "from fastai.nlp import *\n", "from fastai.lm_rnn import *\n", "\n", "PATH='data/nietzsche/'\n", "\n", "TRN_PATH = 'trn/'\n", "VAL_PATH = 'val/'\n", "TRN = f'{PATH}{TRN_PATH}'\n", "VAL = f'{PATH}{VAL_PATH}'\n", "\n", "# Note: The student needs to practice her shell skills and prepare her own dataset before proceeding:\n", "# - trn/trn.txt (first 80% of nietzsche.txt)\n", "# - val/val.txt (last 20% of nietzsche.txt)\n", "\n", "%ls {PATH}" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trn.txt\r\n" ] } ], "source": [ "%ls {PATH}trn" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(963, 56, 1, 493747)" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "TEXT = data.Field(lower=True, tokenize=list)\n", "bs=64; bptt=8; n_fac=42; n_hidden=256\n", "\n", "FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)\n", "md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=3)\n", "\n", "len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### RNN" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class CharSeqStatefulRnn(nn.Module):\n", " def __init__(self, vocab_size, n_fac, bs):\n", " self.vocab_size = vocab_size\n", " super().__init__()\n", " self.e = nn.Embedding(vocab_size, n_fac)\n", " self.rnn = nn.RNN(n_fac, n_hidden)\n", " self.l_out = nn.Linear(n_hidden, vocab_size)\n", " self.init_hidden(bs)\n", " \n", " def forward(self, cs):\n", " bs = cs[0].size(0)\n", " if self.h.size(1) != bs: self.init_hidden(bs)\n", " outp,h = self.rnn(self.e(cs), self.h)\n", " self.h = repackage_var(h)\n", " return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)\n", " \n", " def init_hidden(self, bs): self.h = V(torch.zeros(1, bs, n_hidden))" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": true }, "outputs": [], "source": [ "m = CharSeqStatefulRnn(md.nt, n_fac, 512).cuda()\n", "opt = optim.Adam(m.parameters(), 1e-3)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2a9e0a39ef174c72bac575be7e20579c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.81983 1.81247] \n", "[ 1. 1.63097 1.66228] \n", "[ 2. 1.54433 1.57824] \n", "[ 3. 1.48563 1.54505] \n", "\n" ] } ], "source": [ "fit(m, md, 4, opt, F.nll_loss)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8b15bf8bcc7445e694dbcb3beb658b74", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.4187 1.50374] \n", "[ 1. 1.41492 1.49391] \n", "[ 2. 1.41001 1.49339] \n", "[ 3. 1.40756 1.486 ] \n", "\n" ] } ], "source": [ "set_lrs(opt, 1e-4)\n", "\n", "fit(m, md, 4, opt, F.nll_loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### RNN loop" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# From the pytorch source\n", "\n", "def RNNCell(input, hidden, w_ih, w_hh, b_ih, b_hh):\n", " return F.tanh(F.linear(input, w_ih, b_ih) + F.linear(hidden, w_hh, b_hh))" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class CharSeqStatefulRnn2(nn.Module):\n", " def __init__(self, vocab_size, n_fac, bs):\n", " super().__init__()\n", " self.vocab_size = vocab_size\n", " self.e = nn.Embedding(vocab_size, n_fac)\n", " self.rnn = nn.RNNCell(n_fac, n_hidden)\n", " self.l_out = nn.Linear(n_hidden, vocab_size)\n", " self.init_hidden(bs)\n", " \n", " def forward(self, cs):\n", " bs = cs[0].size(0)\n", " if self.h.size(1) != bs: self.init_hidden(bs)\n", " outp = []\n", " o = self.h\n", " for c in cs: \n", " o = self.rnn(self.e(c), o)\n", " outp.append(o)\n", " outp = self.l_out(torch.stack(outp))\n", " self.h = repackage_var(o)\n", " return F.log_softmax(outp, dim=-1).view(-1, self.vocab_size)\n", " \n", " def init_hidden(self, bs): self.h = V(torch.zeros(1, bs, n_hidden))" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "m = CharSeqStatefulRnn2(md.nt, n_fac, 512).cuda()\n", "opt = optim.Adam(m.parameters(), 1e-3)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8c46f24bfa194e1ba9d73e22283ca6af", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.81013 1.7969 ] \n", "[ 1. 1.62515 1.65346] \n", "[ 2. 1.53913 1.58065] \n", "[ 3. 1.48698 1.54217] \n", "\n" ] } ], "source": [ "fit(m, md, 4, opt, F.nll_loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GRU" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class CharSeqStatefulGRU(nn.Module):\n", " def __init__(self, vocab_size, n_fac, bs):\n", " super().__init__()\n", " self.vocab_size = vocab_size\n", " self.e = nn.Embedding(vocab_size, n_fac)\n", " self.rnn = nn.GRU(n_fac, n_hidden)\n", " self.l_out = nn.Linear(n_hidden, vocab_size)\n", " self.init_hidden(bs)\n", " \n", " def forward(self, cs):\n", " bs = cs[0].size(0)\n", " if self.h.size(1) != bs: self.init_hidden(bs)\n", " outp,h = self.rnn(self.e(cs), self.h)\n", " self.h = repackage_var(h)\n", " return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)\n", " \n", " def init_hidden(self, bs): self.h = V(torch.zeros(1, bs, n_hidden))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# From the pytorch source code - for reference\n", "\n", "def GRUCell(input, hidden, w_ih, w_hh, b_ih, b_hh):\n", " gi = F.linear(input, w_ih, b_ih)\n", " gh = F.linear(hidden, w_hh, b_hh)\n", " i_r, i_i, i_n = gi.chunk(3, 1)\n", " h_r, h_i, h_n = gh.chunk(3, 1)\n", "\n", " resetgate = F.sigmoid(i_r + h_r)\n", " inputgate = F.sigmoid(i_i + h_i)\n", " newgate = F.tanh(i_n + resetgate * h_n)\n", " return newgate + inputgate * (hidden - newgate)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "collapsed": true }, "outputs": [], "source": [ "m = CharSeqStatefulGRU(md.nt, n_fac, 512).cuda()\n", "\n", "opt = optim.Adam(m.parameters(), 1e-3)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e518384d71c345a8b145b35d4ee894fa", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.68409 1.67784] \n", "[ 1. 1.49813 1.52661] \n", "[ 2. 1.41674 1.46769] \n", "[ 3. 1.36359 1.43818] \n", "[ 4. 1.33223 1.41777] \n", "[ 5. 1.30217 1.40511] \n", "\n" ] } ], "source": [ "fit(m, md, 6, opt, F.nll_loss)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": { "collapsed": true }, "outputs": [], "source": [ "set_lrs(opt, 1e-4)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "scrolled": false }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "be385370c27f4b788920caf48f90aeea", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.22708 1.36926] \n", "[ 1. 1.21948 1.3696 ] \n", "[ 2. 1.22541 1.36969] \n", "\n" ] } ], "source": [ "fit(m, md, 3, opt, F.nll_loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Putting it all together: LSTM" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "from fastai import sgdr\n", "\n", "n_hidden=512" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": true }, "outputs": [], "source": [ "class CharSeqStatefulLSTM(nn.Module):\n", " def __init__(self, vocab_size, n_fac, bs, nl):\n", " super().__init__()\n", " self.vocab_size,self.nl = vocab_size,nl\n", " self.e = nn.Embedding(vocab_size, n_fac)\n", " self.rnn = nn.LSTM(n_fac, n_hidden, nl, dropout=0.5)\n", " self.l_out = nn.Linear(n_hidden, vocab_size)\n", " self.init_hidden(bs)\n", " \n", " def forward(self, cs):\n", " bs = cs[0].size(0)\n", " if self.h[0].size(1) != bs: self.init_hidden(bs)\n", " outp,h = self.rnn(self.e(cs), self.h)\n", " self.h = repackage_var(h)\n", " return F.log_softmax(self.l_out(outp), dim=-1).view(-1, self.vocab_size)\n", " \n", " def init_hidden(self, bs):\n", " self.h = (V(torch.zeros(self.nl, bs, n_hidden)),\n", " V(torch.zeros(self.nl, bs, n_hidden)))" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "collapsed": true }, "outputs": [], "source": [ "m = CharSeqStatefulLSTM(md.nt, n_fac, 512, 2).cuda()\n", "lo = LayerOptimizer(optim.Adam, m, 1e-2, 1e-5)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "os.makedirs(f'{PATH}models', exist_ok=True)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6943ca600bbf4a49a0020b2467c2ddb8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.72032 1.64016] \n", "[ 1. 1.62891 1.58176] \n", "\n" ] } ], "source": [ "fit(m, md, 2, lo.opt, F.nll_loss)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "765d0d78da6647d48276a638f70aeec9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.47969 1.4472 ] \n", "[ 1. 1.51411 1.46612] \n", "[ 2. 1.412 1.39909] \n", "[ 3. 1.53689 1.48337] \n", "[ 4. 1.47375 1.43169] \n", "[ 5. 1.39828 1.37963] \n", "[ 6. 1.34546 1.35795] \n", "[ 7. 1.51999 1.47165] \n", "[ 8. 1.48992 1.46146] \n", "[ 9. 1.45492 1.42829] \n", "[ 10. 1.42027 1.39028] \n", "[ 11. 1.3814 1.36539] \n", "[ 12. 1.33895 1.34178] \n", "[ 13. 1.30737 1.32871] \n", "[ 14. 1.28244 1.31518] \n", "\n" ] } ], "source": [ "on_end = lambda sched, cycle: save_model(m, f'{PATH}models/cyc_{cycle}')\n", "cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2, on_cycle_end=on_end)]\n", "fit(m, md, 2**4-1, lo.opt, F.nll_loss, callbacks=cb)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "scrolled": true }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4394818ec37f4b419397628b7cc8b815", "version_major": 2, "version_minor": 0 }, "text/plain": [ "A Jupyter Widget" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "[ 0. 1.46053 1.43462] \n", "[ 1. 1.51537 1.47747] \n", "[ 2. 1.39208 1.38293] \n", "[ 3. 1.53056 1.49371] \n", "[ 4. 1.46812 1.43389] \n", "[ 5. 1.37624 1.37523] \n", "[ 6. 1.3173 1.34022] \n", "[ 7. 1.51783 1.47554] \n", "[ 8. 1.4921 1.45785] \n", "[ 9. 1.44843 1.42215] \n", "[ 10. 1.40948 1.40858] \n", "[ 11. 1.37098 1.36648] \n", "[ 12. 1.32255 1.33842] \n", "[ 13. 1.28243 1.31106] \n", "[ 14. 1.25031 1.2918 ] \n", "[ 15. 1.49236 1.45316] \n", "[ 16. 1.46041 1.43622] \n", "[ 17. 1.45043 1.4498 ] \n", "[ 18. 1.43331 1.41297] \n", "[ 19. 1.43841 1.41704] \n", "[ 20. 1.41536 1.40521] \n", "[ 21. 1.39829 1.37656] \n", "[ 22. 1.37001 1.36891] \n", "[ 23. 1.35469 1.35909] \n", "[ 24. 1.32202 1.34228] \n", "[ 25. 1.29972 1.32256] \n", "[ 26. 1.28007 1.30903] \n", "[ 27. 1.24503 1.29125] \n", "[ 28. 1.22261 1.28316] \n", "[ 29. 1.20563 1.27397] \n", "[ 30. 1.18764 1.27178] \n", "[ 31. 1.18114 1.26694] \n", "[ 32. 1.44344 1.42405] \n", "[ 33. 1.43344 1.41616] \n", "[ 34. 1.4346 1.40442] \n", "[ 35. 1.42152 1.41359] \n", "[ 36. 1.42072 1.40835] \n", "[ 37. 1.41732 1.40498] \n", "[ 38. 1.41268 1.395 ] \n", "[ 39. 1.40725 1.39433] \n", "[ 40. 1.40181 1.39864] \n", "[ 41. 1.38621 1.37549] \n", "[ 42. 1.3838 1.38587] \n", "[ 43. 1.37644 1.37118] \n", "[ 44. 1.36287 1.36211] \n", "[ 45. 1.35942 1.36145] \n", "[ 46. 1.34712 1.34924] \n", "[ 47. 1.32994 1.34884] \n", "[ 48. 1.32788 1.33387] \n", "[ 49. 1.31553 1.342 ] \n", "[ 50. 1.30088 1.32435] \n", "[ 51. 1.28446 1.31166] \n", "[ 52. 1.27058 1.30807] \n", "[ 53. 1.26271 1.29935] \n", "[ 54. 1.24351 1.28942] \n", "[ 55. 1.23119 1.2838 ] \n", "[ 56. 1.2086 1.28364] \n", "[ 57. 1.19742 1.27375] \n", "[ 58. 1.18127 1.26758] \n", "[ 59. 1.17475 1.26858] \n", "[ 60. 1.15349 1.25999] \n", "[ 61. 1.14718 1.25779] \n", "[ 62. 1.13174 1.2524 ] \n", "\n" ] } ], "source": [ "on_end = lambda sched, cycle: save_model(m, f'{PATH}models/cyc_{cycle}')\n", "cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2, on_cycle_end=on_end)]\n", "fit(m, md, 2**6-1, lo.opt, F.nll_loss, callbacks=cb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def get_next(inp):\n", " idxs = TEXT.numericalize(inp)\n", " p = m(VV(idxs.transpose(0,1)))\n", " r = torch.multinomial(p[-1].exp(), 1)\n", " return TEXT.vocab.itos[to_np(r)[0]]" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "'e'" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "get_next('for thos')" ] }, { "cell_type": "code", "execution_count": 47, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def get_next_n(inp, n):\n", " res = inp\n", " for i in range(n):\n", " c = get_next(inp)\n", " res += c\n", " inp = inp[1:]+c\n", " return res" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "for those the skemps), or\n", "imaginates, though they deceives. it should so each ourselvess and new\n", "present, step absolutely for the\n", "science.\" the contradity and\n", "measuring, \n", "the whole!\n", "\n", "293. perhaps, that every life a values of blood\n", "of\n", "intercourse when it senses there is unscrupulus, his very rights, and still impulse, love?\n", "just after that thereby how made with the way anything, and set for harmless philos\n" ] } ], "source": [ "print(get_next_n('for thos', 400))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4" }, "nav_menu": {}, "toc": { "colors": { "hover_highlight": "#DAA520", "navigate_num": "#000000", "navigate_text": "#333333", "running_highlight": "#FF0000", "selected_highlight": "#FFD700", "sidebar_border": "#EEEEEE", "wrapper_background": "#FFFFFF" }, "moveMenuLeft": true, "nav_menu": { "height": "216px", "width": "252px" }, "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 4, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false, "widenNotebook": false } }, "nbformat": 4, "nbformat_minor": 1 }