{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Human numbers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.text import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=64" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[PosixPath('/home/ubuntu/.fastai/data/human_numbers/train.txt'),\n", " PosixPath('/home/ubuntu/.fastai/data/human_numbers/valid.txt')]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.HUMAN_NUMBERS)\n", "path.ls()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def readnums(d): return [', '.join(o.strip() for o in open(path/d).readlines())]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_txt = readnums('train.txt'); train_txt[0][:80]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "valid_txt = readnums('valid.txt'); valid_txt[0][-80:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train = TextList(train_txt, path=path)\n", "valid = TextList(valid_txt, path=path)\n", "\n", "src = ItemLists(path=path, train=train, valid=valid).label_for_lm()\n", "data = src.databunch(bs=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'xxbos one , two , three , four , five , six , seven , eight , nine , ten , eleve'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train[0].text[:80]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "13017" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(data.valid_ds[0][0].data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(70, 3)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.bptt, len(data.valid_dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2.905580357142857" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "13017/70/bs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "it = iter(data.valid_dl)\n", "x1,y1 = next(it)\n", "x2,y2 = next(it)\n", "x3,y3 = next(it)\n", "it.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "13440" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1.numel()+x2.numel()+x3.numel()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 70]), torch.Size([64, 70]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1.shape,y1.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 70]), torch.Size([64, 70]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x2.shape,y2.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 2, 8, 10, 11, 12, 10, 9, 8, 9, 13, 18, 24, 18, 14, 15, 10, 18, 8,\n", " 9, 8, 18, 24, 18, 10, 18, 10, 9, 8, 18, 19, 10, 25, 19, 22, 19, 19,\n", " 23, 19, 10, 13, 10, 10, 8, 13, 8, 19, 9, 19, 34, 16, 10, 9, 8, 16,\n", " 8, 19, 9, 19, 10, 19, 10, 19, 19, 19], device='cuda:0')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1[:,0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([18, 18, 26, 9, 8, 11, 31, 18, 25, 9, 10, 14, 10, 9, 8, 14, 10, 18,\n", " 25, 18, 10, 17, 10, 17, 8, 17, 20, 18, 9, 9, 19, 8, 10, 15, 10, 10,\n", " 12, 10, 12, 8, 12, 13, 19, 9, 19, 10, 23, 10, 8, 8, 15, 16, 19, 9,\n", " 19, 10, 23, 10, 18, 8, 18, 10, 10, 9], device='cuda:0')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y1[:,0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "v = data.valid_ds.vocab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x1[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(y1[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'thousand eighteen , eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three , eight thousand twenty four , eight thousand twenty five , eight thousand twenty six , eight thousand twenty seven , eight thousand twenty eight , eight thousand twenty nine , eight thousand thirty , eight thousand thirty one , eight thousand thirty two ,'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x2[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'eight thousand thirty three , eight thousand thirty four , eight thousand thirty five , eight thousand thirty six , eight thousand thirty seven , eight thousand thirty eight , eight thousand thirty nine , eight thousand forty , eight thousand forty one , eight thousand forty two , eight thousand forty three , eight thousand forty four , eight thousand forty five , eight thousand forty six , eight'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x3[0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine ,'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x1[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'eight thousand sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x2[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'seventy four , eight thousand seventy five , eight thousand seventy six , eight thousand seventy seven , eight thousand seventy eight , eight thousand seventy nine , eight thousand eighty , eight thousand eighty one , eight thousand eighty two , eight thousand eighty three , eight thousand eighty four , eight thousand eighty five , eight thousand eighty six , eight thousand eighty seven , eight thousand eighty'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x3[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'ninety , nine thousand nine hundred ninety one , nine thousand nine hundred ninety two , nine thousand nine hundred ninety three , nine thousand nine hundred ninety four , nine thousand nine hundred ninety five , nine thousand nine hundred ninety six , nine thousand nine hundred ninety seven , nine thousand nine hundred ninety eight , nine thousand nine hundred ninety nine xxbos eight thousand one , eight'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x3[-1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idxtext
0thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine , eight thousand sixty , eight thousand sixty
1eight , eight thousand eighty nine , eight thousand ninety , eight thousand ninety one , eight thousand ninety two , eight thousand ninety three , eight thousand ninety four , eight thousand ninety five , eight thousand ninety six , eight thousand ninety seven , eight thousand ninety eight , eight thousand ninety nine , eight thousand one hundred , eight thousand one hundred one , eight thousand one
2thousand one hundred twenty four , eight thousand one hundred twenty five , eight thousand one hundred twenty six , eight thousand one hundred twenty seven , eight thousand one hundred twenty eight , eight thousand one hundred twenty nine , eight thousand one hundred thirty , eight thousand one hundred thirty one , eight thousand one hundred thirty two , eight thousand one hundred thirty three , eight thousand
3three , eight thousand one hundred fifty four , eight thousand one hundred fifty five , eight thousand one hundred fifty six , eight thousand one hundred fifty seven , eight thousand one hundred fifty eight , eight thousand one hundred fifty nine , eight thousand one hundred sixty , eight thousand one hundred sixty one , eight thousand one hundred sixty two , eight thousand one hundred sixty three
4thousand one hundred eighty three , eight thousand one hundred eighty four , eight thousand one hundred eighty five , eight thousand one hundred eighty six , eight thousand one hundred eighty seven , eight thousand one hundred eighty eight , eight thousand one hundred eighty nine , eight thousand one hundred ninety , eight thousand one hundred ninety one , eight thousand one hundred ninety two , eight thousand
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data.show_batch(ds_type=DatasetType.Valid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Single fully connected model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = src.databunch(bs=bs, bptt=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 3]), torch.Size([64, 3]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x,y = data.one_batch()\n", "x.shape,y.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "38" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nv = len(v.itos); nv" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nh=64" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def loss4(input,target): return F.cross_entropy(input, target[:,-1])\n", "def acc4 (input,target): return accuracy(input, target[:,-1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model0(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.i_h = nn.Embedding(nv,nh) # green arrow\n", " self.h_h = nn.Linear(nh,nh) # brown arrow\n", " self.h_o = nn.Linear(nh,nv) # blue arrow\n", " self.bn = nn.BatchNorm1d(nh)\n", " \n", " def forward(self, x):\n", " h = self.bn(F.relu(self.h_h(self.i_h(x[:,0]))))\n", " if x.shape[1]>1:\n", " h = h + self.i_h(x[:,1])\n", " h = self.bn(F.relu(self.h_h(h)))\n", " if x.shape[1]>2:\n", " h = h + self.i_h(x[:,2])\n", " h = self.bn(F.relu(self.h_h(h)))\n", " return self.h_o(h)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, Model0(), loss_func=loss4, metrics=acc4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:07

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossacc4
13.5962863.5888690.046645
23.0861003.2057630.274816
32.4944112.7493650.392004
42.1447532.4635370.415671
52.0109152.3528870.409237
61.9839922.3369670.408778
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(6, 1e-4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Same thing with a loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model1(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.i_h = nn.Embedding(nv,nh) # green arrow\n", " self.h_h = nn.Linear(nh,nh) # brown arrow\n", " self.h_o = nn.Linear(nh,nv) # blue arrow\n", " self.bn = nn.BatchNorm1d(nh)\n", " \n", " def forward(self, x):\n", " h = torch.zeros(x.shape[0], nh).to(device=x.device)\n", " for i in range(x.shape[1]):\n", " h = h + self.i_h(x[:,i])\n", " h = self.bn(F.relu(self.h_h(h)))\n", " return self.h_o(h)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, Model1(), loss_func=loss4, metrics=acc4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:07

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossacc4
13.4935253.4202310.156250
22.9876002.9378930.376149
32.4401992.4779950.388787
42.1328372.2565690.391774
52.0113052.1813370.392923
61.9859132.1708740.393153
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(6, 1e-4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multi fully connected model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = src.databunch(bs=bs, bptt=20)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 20]), torch.Size([64, 20]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x,y = data.one_batch()\n", "x.shape,y.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model2(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.i_h = nn.Embedding(nv,nh)\n", " self.h_h = nn.Linear(nh,nh)\n", " self.h_o = nn.Linear(nh,nv)\n", " self.bn = nn.BatchNorm1d(nh)\n", " \n", " def forward(self, x):\n", " h = torch.zeros(x.shape[0], nh).to(device=x.device)\n", " res = []\n", " for i in range(x.shape[1]):\n", " h = h + self.i_h(x[:,i])\n", " h = F.relu(self.h_h(h))\n", " res.append(self.h_o(self.bn(h)))\n", " return torch.stack(res, dim=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, Model2(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:06

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
13.6392853.7092780.058949
23.5511513.5656770.151776
33.4399083.4318500.207741
43.3230833.3142370.283949
53.2134223.2199060.321662
63.1196733.1511620.336790
73.0466453.1066300.341690
82.9953793.0825520.346662
92.9638003.0733270.349645
102.9473123.0719510.349787
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(10, 1e-4, pct_start=0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Maintain state" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model3(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.i_h = nn.Embedding(nv,nh)\n", " self.h_h = nn.Linear(nh,nh)\n", " self.h_o = nn.Linear(nh,nv)\n", " self.bn = nn.BatchNorm1d(nh)\n", " self.h = torch.zeros(bs, nh).cuda()\n", " \n", " def forward(self, x):\n", " res = []\n", " h = self.h\n", " for i in range(x.shape[1]):\n", " h = h + self.i_h(x[:,i])\n", " h = F.relu(self.h_h(h))\n", " res.append(self.bn(h))\n", " self.h = h.detach()\n", " res = torch.stack(res, dim=1)\n", " res = self.h_o(res)\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, Model3(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:11

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
13.5981833.5563620.050710
23.2746162.9756990.401634
32.6242062.0368940.467330
42.0227021.9564390.316193
51.6818131.9349520.336861
61.4530071.9482010.351349
71.2769712.0057760.368679
81.1384992.0812610.360156
91.0292172.1458530.360795
100.9399492.2153880.372230
110.8654412.2404380.401491
120.8053102.1958460.409375
130.7550352.3243730.422727
140.7130732.3055420.449716
150.6773932.3501550.446449
160.6458412.4187380.446591
170.6218092.4569030.446165
180.6053002.5416990.443040
190.5940992.5398240.443040
200.5875632.5514230.442827
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(20, 3e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## nn.RNN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model4(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.i_h = nn.Embedding(nv,nh)\n", " self.rnn = nn.RNN(nh,nh, batch_first=True)\n", " self.h_o = nn.Linear(nh,nv)\n", " self.bn = BatchNorm1dFlat(nh)\n", " self.h = torch.zeros(1, bs, nh).cuda()\n", " \n", " def forward(self, x):\n", " res,h = self.rnn(self.i_h(x), self.h)\n", " self.h = h.detach()\n", " return self.h_o(self.bn(res))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, Model4(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:04

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
13.4514323.2683440.224148
22.9749382.4565690.466051
32.3167321.9469690.465625
41.8661511.9919520.314702
51.6185161.8024030.437216
61.4115171.7311070.436293
71.1719161.6559790.504048
80.9658871.5799630.522088
90.7970461.4798190.565057
100.6593781.4878310.579048
110.5532821.4419220.597798
120.4751671.4981480.600781
130.4161311.5469840.606463
140.3723951.5942610.607386
150.3370931.5783210.613352
160.3113851.5809730.623366
170.2928691.6257450.618253
180.2794861.6239600.626065
190.2700541.6820900.611719
200.2638571.6756760.614702
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(20, 3e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2-layer GRU" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model5(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.i_h = nn.Embedding(nv,nh)\n", " self.rnn = nn.GRU(nh, nh, 2, batch_first=True)\n", " self.h_o = nn.Linear(nh,nv)\n", " self.bn = BatchNorm1dFlat(nh)\n", " self.h = torch.zeros(2, bs, nh).cuda()\n", " \n", " def forward(self, x):\n", " res,h = self.rnn(self.i_h(x), self.h)\n", " self.h = h.detach()\n", " return self.h_o(self.bn(res))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, Model5(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:02

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
12.8648542.3149430.454545
21.7989881.3571160.629688
30.9327291.3074630.796733
40.4519691.3296990.788636
50.2257871.2935700.800142
60.1180851.2659260.803338
70.0653061.2070960.806960
80.0380981.2053610.813920
90.0240691.2394110.807813
100.0170781.2534090.807102
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(10, 1e-2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## fin" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }