{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Lesson 7: Human numbers" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from fastai.text import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "bs = 64" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[PosixPath('/home/ubuntu/.fastai/data/human_numbers/valid.txt'),\n", " PosixPath('/home/ubuntu/.fastai/data/human_numbers/train.txt')]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.HUMAN_NUMBERS)\n", "path.ls()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def readnums(d): return [', '.join(o.strip() for o in open(path / d).readlines())]" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_txt = readnums('train.txt')\n", "train_txt[0][:80]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "valid_txt = readnums('valid.txt')\n", "valid_txt[0][-80:]" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "train = TextList(train_txt, path=path)\n", "valid = TextList(valid_txt, path=path)\n", "\n", "src = ItemLists(path=path, train=train, valid=valid).label_for_lm()\n", "data = src.databunch(bs=bs)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'xxbos one , two , three , four , five , six , seven , eight , nine , ten , eleve'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train[0].text[:80]" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "13017" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(data.valid_ds[0][0].data)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(70, 3)" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data.bptt, len(data.valid_dl)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2.905580357142857" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "13017/70/bs" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "it = iter(data.valid_dl)\n", "x1, y1 = next(it)\n", "x2, y2 = next(it)\n", "x3, y3 = next(it)\n", "it.close()" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "12928" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1.numel() + x2.numel() + x3.numel()" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 95]), torch.Size([64, 95]))" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1.shape, y1.shape" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 76]), torch.Size([64, 76]))" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x2.shape, y2.shape" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 2, 8, 10, 10, 23, 18, 10, 22, 18, 8, 21, 18, 9, 20, 18, 9, 18, 18,\n", " 9, 15, 18, 9, 8, 18, 9, 8, 8, 9, 8, 9, 9, 8, 19, 19, 26, 10,\n", " 9, 8, 8, 22, 19, 13, 21, 19, 9, 20, 19, 9, 31, 19, 9, 16, 19, 9,\n", " 8, 19, 9, 8, 9, 9, 8, 10, 9, 8], device='cuda:0')" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1[:, 0]" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([18, 18, 26, 11, 12, 10, 12, 13, 10, 18, 14, 10, 27, 15, 10, 26, 10, 10,\n", " 25, 8, 10, 24, 18, 10, 23, 18, 18, 22, 18, 18, 21, 18, 9, 10, 14, 11,\n", " 23, 19, 19, 11, 10, 9, 12, 10, 27, 13, 10, 26, 8, 10, 25, 9, 10, 24,\n", " 19, 10, 23, 19, 34, 22, 19, 19, 21, 19], device='cuda:0')" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y1[:, 0]" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "v = data.valid_ds.vocab" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([95])" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1[0].shape" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand eighteen , eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three'" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x1[0])" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand eighteen , eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three ,'" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(y1[0])" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "', eight thousand twenty four , eight thousand twenty five , eight thousand twenty six , eight thousand twenty seven , eight thousand twenty eight , eight thousand twenty nine , eight thousand thirty , eight thousand thirty one , eight thousand thirty two , eight thousand thirty three , eight thousand thirty four , eight thousand thirty five , eight thousand thirty six , eight thousand thirty seven , eight thousand thirty eight , eight'" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x2[0])" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'thousand thirty nine , eight thousand forty , eight thousand forty one , eight thousand forty two , eight thousand forty three , eight thousand forty four , eight thousand forty'" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x3[0])" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine , eight thousand sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight'" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x1[1])" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand seventy four , eight thousand seventy five , eight thousand seventy six , eight thousand seventy seven , eight thousand seventy eight , eight thousand seventy nine , eight thousand eighty'" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x2[1])" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "', eight thousand eighty one , eight thousand eighty two , eight thousand eighty three , eight thousand eighty four , eight thousand eighty five , eight thousand eighty six ,'" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x3[1])" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'one , nine thousand nine hundred ninety two , nine thousand nine hundred ninety three , nine thousand nine hundred ninety four , nine thousand nine hundred ninety five , nine'" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "v.textify(x3[-1])" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
idxtext
0xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand
1, eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand
2thousand eighty seven , eight thousand eighty eight , eight thousand eighty nine , eight thousand ninety , eight thousand ninety one , eight thousand ninety two , eight thousand ninety three , eight thousand ninety four , eight thousand ninety five , eight thousand ninety six , eight thousand ninety seven , eight thousand ninety eight , eight thousand ninety nine , eight thousand one hundred
3thousand one hundred twenty three , eight thousand one hundred twenty four , eight thousand one hundred twenty five , eight thousand one hundred twenty six , eight thousand one hundred twenty seven , eight thousand one hundred twenty eight , eight thousand one hundred twenty nine , eight thousand one hundred thirty , eight thousand one hundred thirty one , eight thousand one hundred thirty two
4fifty two , eight thousand one hundred fifty three , eight thousand one hundred fifty four , eight thousand one hundred fifty five , eight thousand one hundred fifty six , eight thousand one hundred fifty seven , eight thousand one hundred fifty eight , eight thousand one hundred fifty nine , eight thousand one hundred sixty , eight thousand one hundred sixty one , eight thousand
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "data.show_batch(ds_type=DatasetType.Valid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Single fully connected model" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "data = src.databunch(bs=bs, bptt=3, max_len=0, p_bptt=1.)" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 3]), torch.Size([64, 3]))" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x, y = data.one_batch()\n", "x.shape, y.shape" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "38" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nv = len(v.itos)\n", "nv" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "nh = 64" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "def loss4(input, target): return F.cross_entropy(input, target[:, -1])\n", "def acc4(input, target): return accuracy(input, target[:, -1])" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "class Model0(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.i_h = nn.Embedding(nv, nh) # green arrow\n", " self.h_h = nn.Linear(nh, nh) # brown arrow\n", " self.h_o = nn.Linear(nh, nv) # blue arrow\n", " self.bn = nn.BatchNorm1d(nh)\n", " \n", " def forward(self, x):\n", " h = self.bn(F.relu(self.i_h(x[:,0])))\n", " if x.shape[1] > 1:\n", " h = h + self.i_h(x[:,1])\n", " h = self.bn(F.relu(self.h_h(h)))\n", " if x.shape[1] > 2:\n", " h = h + self.i_h(x[:,2])\n", " h = self.bn(F.relu(self.h_h(h)))\n", " return self.h_o(h)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, Model0(), loss_func=loss4, metrics=acc4)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:12

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossacc4
13.6097553.6228310.046186
23.1032263.2177040.425781
32.5232982.7336830.449908
42.1928902.4474490.452665
52.0659912.3421710.454274
62.0399802.3268650.454274
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(6, 1e-4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Same thing with a loop" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "class Model1(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.i_h = nn.Embedding(nv, nh) # green arrow\n", " self.h_h = nn.Linear(nh, nh) # brown arrow\n", " self.h_o = nn.Linear(nh, nv) # blue arrow\n", " self.bn = nn.BatchNorm1d(nh)\n", " \n", " def forward(self, x):\n", " h = torch.zeros(x.shape[0], nh).to(device=x.device)\n", " for i in range(x.shape[1]):\n", " h = h + self.i_h(x[:,i])\n", " h = self.bn(F.relu(self.h_h(h)))\n", " return self.h_o(h)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, Model1(), loss_func=loss4, metrics=acc4)" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:13

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossacc4
13.5507853.5657740.039062
22.9946963.0569800.434283
32.4447302.5761630.462546
42.1474892.3367810.463925
52.0302402.2525410.465533
62.0056492.2403130.465763
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(6, 1e-4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multi fully connected model" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "data = src.databunch(bs=bs, bptt=20)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 45]), torch.Size([64, 45]))" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x, y = data.one_batch()\n", "x.shape, y.shape" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [], "source": [ "class Model2(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.i_h = nn.Embedding(nv, nh) # green arrow\n", " self.h_h = nn.Linear(nh, nh) # brown arrow\n", " self.h_o = nn.Linear(nh, nv) # blue arrow\n", " self.bn = nn.BatchNorm1d(nh)\n", " \n", " def forward(self, x):\n", " h = torch.zeros(x.shape[0], nh).to(device=x.device)\n", " res = []\n", " for i in range(x.shape[1]):\n", " h = h + self.i_h(x[:,i])\n", " h = F.relu(self.h_h(h))\n", " res.append(self.h_o(self.bn(h)))\n", " return torch.stack(res, dim=1)" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, Model2(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:08

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
13.6137503.4680910.094572
23.5077503.3668670.201829
33.3786573.2632560.307511
43.2486443.1722830.352421
53.1291753.0948440.377534
63.0316863.0339680.382643
72.9548972.9996210.390039
82.9032332.9984930.382097
92.8705472.9578480.398152
102.8565502.9631680.395220
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(10, 1e-4, pct_start=0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Maintain state" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [], "source": [ "class Model3(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.i_h = nn.Embedding(nv, nh) # green arrow\n", " self.h_h = nn.Linear(nh, nh) # brown arrow\n", " self.h_o = nn.Linear(nh, nv) # blue arrow\n", " self.bn = nn.BatchNorm1d(nh)\n", " self.h = torch.zeros(bs, nh).cuda()\n", "\n", " def forward(self, x):\n", " res = []\n", " h = self.h\n", " for i in range(x.shape[1]):\n", " h = h + self.i_h(x[:,i])\n", " h = F.relu(self.h_h(h))\n", " res.append(self.bn(h))\n", " self.h = h.detach()\n", " res = torch.stack(res, dim=1)\n", " res = self.h_o(res)\n", " return res" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, Model3(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": 66, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "Total time: 00:15

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
13.4988323.4574150.145312
23.1618682.8749710.449906
32.5159882.0377220.467106
41.9673172.1105900.316073
51.6720882.1352500.337800
61.4930222.1551370.344380
71.3352572.1160410.394331
81.2016542.2990780.408730
91.0900302.6243110.427448
101.0084972.4624560.422197
110.9711752.3526040.437458
120.9063722.4588780.453475
130.8439472.4557680.461694
140.7940682.4691550.458606
150.7548082.4906520.453702
160.7225302.5971340.453636
170.6895902.6330520.452083
180.6704932.5252330.467502
190.6567202.6630350.460008
200.6575962.5455380.464172
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(20, 3e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## PyTorch `nn.RNN`" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [], "source": [ "class Model4(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.i_h = nn.Embedding(nv, nh)\n", " self.rnn = nn.RNN(nh, nh, batch_first=True)\n", " self.h_o = nn.Linear(nh, nv)\n", " self.bn = BatchNorm1dFlat(nh)\n", " self.h = torch.zeros(1, bs, nh).cuda()\n", "\n", " def forward(self, x):\n", " res, h = self.rnn(self.i_h(x), self.h)\n", " self.h = h.detach()\n", " return self.h_o(self.bn(res))" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, Model4(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": 69, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "Total time: 00:09

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
13.4884413.3926910.156386
23.0167692.5165890.461222
32.3466801.9426150.467149
41.8799932.0020770.312014
51.6313001.9042400.432754
61.4622271.9046210.482044
71.3006461.8518800.492365
81.1537441.6531380.492104
91.0083731.5493630.494715
100.8789321.6008240.500036
110.7772831.5082910.520914
120.7124061.5329920.570639
130.6260531.4263480.569779
140.5558441.7154790.545101
150.4990431.6261620.542316
160.4581311.5367220.548794
170.4380001.5486770.545291
180.4095151.4620340.552396
190.3945071.4777350.554738
200.3905831.5181020.549247
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(20, 3e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2-layer GRU" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [], "source": [ "class Model5(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.i_h = nn.Embedding(nv, nh)\n", " self.rnn = nn.GRU(nh, nh, 2, batch_first=True)\n", " self.h_o = nn.Linear(nh, nv)\n", " self.bn = BatchNorm1dFlat(nh)\n", " self.h = torch.zeros(2, bs, nh).cuda()\n", "\n", " def forward(self, x):\n", " res, h = self.rnn(self.i_h(x), self.h)\n", " self.h = h.detach()\n", " return self.h_o(self.bn(res))" ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, Model5(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:05

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
12.9836262.3215480.444593
21.8804451.6211040.546462
31.0155261.0406780.796203
40.5259280.8221730.829538
50.2685911.0003920.813538
60.1407870.8204610.842801
70.0796020.8827890.833222
80.0479940.7956630.843396
90.0376210.8724510.833012
100.0304430.8754170.833148
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(10, 1e-2)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }