{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Human numbers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai2.text.all import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs=64" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#2) [Path('/home/jhoward/.fastai/data/human_numbers/train.txt'),Path('/home/jhoward/.fastai/data/human_numbers/valid.txt')]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.HUMAN_NUMBERS)\n", "path.ls()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def readnums(d): return ', '.join(o.strip() for o in open(path/d).readlines())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_txt = readnums('train.txt'); train_txt[:80]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "valid_txt = readnums('valid.txt'); valid_txt[-80:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_tok = tokenize1(train_txt)\n", "valid_tok = tokenize1(valid_txt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dsets = Datasets([train_tok, valid_tok], tfms=Numericalize, dl_type=LMDataLoader, splits=[[0], [1]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = dsets.dataloaders(bs=bs, val_bs=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dsets.show((dsets.train[0][0][:80],))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "13017" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(dsets.valid[0][0])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(dls.valid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(72, 3)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dls.seq_len, len(dls.valid)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2.8248697916666665" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "13017/72/bs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "it = iter(dls.valid)\n", "x1,y1 = next(it)\n", "x2,y2 = next(it)\n", "x3,y3 = next(it)\n", "it.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "12992" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1.numel()+x2.numel()+x3.numel()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the closes multiple of 64 below 13017" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 72]), torch.Size([64, 72]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1.shape,y1.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 72]), torch.Size([64, 72]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x2.shape,y2.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 2, 19, 11, 12, 9, 19, 11, 13, 9, 19, 11, 14, 9, 19, 11, 15, 9, 19,\n", " 11, 16, 9, 19, 11, 17, 9, 19, 11, 18, 9, 19, 11, 19, 9, 19, 11, 20,\n", " 9, 19, 11, 29, 9, 19, 11, 30, 9, 19, 11, 31, 9, 19, 11, 32, 9, 19,\n", " 11, 33, 9, 19, 11, 34, 9, 19, 11, 35, 9, 19, 11, 36, 9, 19, 11, 37],\n", " device='cuda:5')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x1[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([19, 11, 12, 9, 19, 11, 13, 9, 19, 11, 14, 9, 19, 11, 15, 9, 19, 11,\n", " 16, 9, 19, 11, 17, 9, 19, 11, 18, 9, 19, 11, 19, 9, 19, 11, 20, 9,\n", " 19, 11, 29, 9, 19, 11, 30, 9, 19, 11, 31, 9, 19, 11, 32, 9, 19, 11,\n", " 33, 9, 19, 11, 34, 9, 19, 11, 35, 9, 19, 11, 36, 9, 19, 11, 37, 9],\n", " device='cuda:5')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y1[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "v = dls.vocab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand eighteen'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join([v[x] for x in x1[0]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand eighteen ,'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join([v[x] for x in y1[0]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "', eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three , eight thousand twenty four , eight thousand twenty five , eight thousand twenty six , eight thousand twenty seven , eight thousand twenty eight , eight thousand twenty nine , eight thousand thirty , eight thousand thirty one , eight thousand thirty two , eight thousand thirty three'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join([v[x] for x in x2[0]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "', eight thousand thirty four , eight thousand thirty five , eight thousand thirty six , eight thousand thirty seven , eight thousand thirty eight , eight thousand thirty nine , eight thousand forty , eight thousand forty one , eight thousand forty two , eight thousand forty three , eight thousand forty four , eight thousand forty five'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join([v[x] for x in x3[0]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine , eight thousand'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join([v[x] for x in x1[1]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand seventy four , eight'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join([v[x] for x in x2[1]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'thousand seventy five , eight thousand seventy six , eight thousand seventy seven , eight thousand seventy eight , eight thousand seventy nine , eight thousand eighty , eight thousand eighty one , eight thousand eighty two , eight thousand eighty three , eight thousand eighty four , eight thousand eighty five , eight thousand eighty six , eight'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join([v[x] for x in x3[1]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'seven , nine thousand nine hundred eighty eight , nine thousand nine hundred eighty nine , nine thousand nine hundred ninety , nine thousand nine hundred ninety one , nine thousand nine hundred ninety two , nine thousand nine hundred ninety three , nine thousand nine hundred ninety four , nine thousand nine hundred ninety five , nine thousand'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "' '.join([v[x] for x in x3[-1]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Single fully connected model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = dsets.dataloaders(bs=bs, seq_len=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 3]), torch.Size([64, 3]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x,y = dls.one_batch()\n", "x.shape,y.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "40" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nv = len(v); nv" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "nh=64" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def loss4(input,target): return F.cross_entropy(input, target[:,-1])\n", "def acc4 (input,target): return accuracy(input, target[:,-1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model0(Module):\n", " def __init__(self):\n", " self.i_h = nn.Embedding(nv,nh) # green arrow\n", " self.h_h = nn.Linear(nh,nh) # brown arrow\n", " self.h_o = nn.Linear(nh,nv) # blue arrow\n", " self.bn = nn.BatchNorm1d(nh)\n", " \n", " def forward(self, x):\n", " h = self.bn(F.relu(self.h_h(self.i_h(x[:,0]))))\n", " if x.shape[1]>1:\n", " h = h + self.i_h(x[:,1])\n", " h = self.bn(F.relu(self.h_h(h)))\n", " if x.shape[1]>2:\n", " h = h + self.i_h(x[:,2])\n", " h = self.bn(F.relu(self.h_h(h)))\n", " return self.h_o(h)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, Model0(), loss_func=loss4, metrics=acc4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossacc4time
03.4594523.4178390.14421300:10
12.5191202.5692640.45625000:10
22.0313602.1762570.45972200:10
31.8406012.0407400.46365700:10
41.7727402.0009010.46365700:10
51.7586491.9957090.46412000:10
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(6, 1e-4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Same thing with a loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model1(Module):\n", " def __init__(self):\n", " self.i_h = nn.Embedding(nv,nh) # green arrow\n", " self.h_h = nn.Linear(nh,nh) # brown arrow\n", " self.h_o = nn.Linear(nh,nv) # blue arrow\n", " self.bn = nn.BatchNorm1d(nh)\n", " \n", " def forward(self, x):\n", " h = torch.zeros(x.shape[0], nh).to(device=x.device)\n", " for i in range(x.shape[1]):\n", " h = h + self.i_h(x[:,i])\n", " h = self.bn(F.relu(self.h_h(h)))\n", " return self.h_o(h)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, Model1(), loss_func=loss4, metrics=acc4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossacc4time
03.4455853.3836230.19421300:10
12.5682182.7070020.42569400:10
22.0630692.3173260.46018500:10
31.8604972.1523900.46666700:10
41.7873152.1003940.46759300:10
51.7721132.0927690.46759300:10
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(6, 1e-4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Multi fully connected model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = dsets.dataloaders(bs=bs, seq_len=20)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([64, 20]), torch.Size([64, 20]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x,y = dls.one_batch()\n", "x.shape,y.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model2(Module):\n", " def __init__(self):\n", " self.i_h = nn.Embedding(nv,nh)\n", " self.h_h = nn.Linear(nh,nh)\n", " self.h_o = nn.Linear(nh,nv)\n", " self.bn = nn.BatchNorm1d(nh)\n", " \n", " def forward(self, x):\n", " h = torch.zeros(x.shape[0], nh).to(device=x.device)\n", " res = []\n", " for i in range(x.shape[1]):\n", " h = h + self.i_h(x[:,i])\n", " h = F.relu(self.h_h(h))\n", " res.append(self.h_o(self.bn(h)))\n", " return torch.stack(res, dim=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, Model2(), loss_func=CrossEntropyLossFlat(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.7365733.7544800.06356600:02
13.5402613.5233100.12482600:02
23.3007083.3048200.24873500:02
33.0632053.1285780.29977700:02
42.8613453.0091280.33536700:02
52.7054952.9290250.35389400:02
62.5937922.8783350.36783200:02
72.5197322.8507410.37314000:02
82.4755342.8400070.37554600:02
92.4527272.8384130.37591800:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(10, 1e-4, pct_start=0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Maintain state" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model3(Module):\n", " def __init__(self):\n", " self.i_h = nn.Embedding(nv,nh)\n", " self.h_h = nn.Linear(nh,nh)\n", " self.h_o = nn.Linear(nh,nv)\n", " self.bn = nn.BatchNorm1d(nh)\n", " self.h = torch.zeros(bs, nh).cuda()\n", " \n", " def forward(self, x):\n", " res = []\n", " if x.shape[0]!=self.h.shape[0]: self.h = torch.zeros(x.shape[0], nh).cuda()\n", " h = self.h\n", " for i in range(x.shape[1]):\n", " h = h + self.i_h(x[:,i])\n", " h = F.relu(self.h_h(h))\n", " res.append(self.bn(h))\n", " self.h = h.detach()\n", " res = torch.stack(res, dim=1)\n", " res = self.h_o(res)\n", " return res\n", " \n", " def reset(self): self.h = torch.zeros(bs, nh).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, Model3(), metrics=accuracy, loss_func=CrossEntropyLossFlat())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.4823973.4426180.13998000:02
12.8288042.4559080.41778300:02
22.1345922.1537670.31520300:02
31.7635762.0966720.31694000:02
41.5890152.0901710.31708800:02
51.4975012.0579940.33137400:02
61.4143051.8957210.44119500:02
71.3072732.0447910.43787200:02
81.1654291.9916410.46121000:02
91.0333351.7760330.54278300:02
100.9233161.8100160.56450900:02
110.8341171.7622700.56500500:02
120.7589061.7239690.59194000:02
130.6998921.8081630.57894400:02
140.6538391.8028810.59203900:02
150.6205601.7693260.61450900:02
160.5956371.7825740.61666700:02
170.5783591.7724770.62378500:02
180.5672101.7729500.62311500:02
190.5610521.7818800.62175100:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(20, 3e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## nn.RNN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model4(Module):\n", " def __init__(self):\n", " self.i_h = nn.Embedding(nv,nh)\n", " self.rnn = nn.RNN(nh,nh, batch_first=True)\n", " self.h_o = nn.Linear(nh,nv)\n", " self.bn = BatchNorm1dFlat(nh)\n", " self.h = torch.zeros(1, bs, nh).cuda()\n", " \n", " def forward(self, x):\n", " if x.shape[0]!=self.h.shape[1]: self.h = torch.zeros(1, x.shape[0], nh).cuda()\n", " res,h = self.rnn(self.i_h(x), self.h)\n", " self.h = h.detach()\n", " return self.h_o(self.bn(res))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, Model4(), loss_func=CrossEntropyLossFlat(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.4623793.2722400.26574900:01
12.6699842.2546570.46287200:01
22.0269152.1198160.31592300:01
31.7095042.1648390.31696400:01
41.5380792.0371200.38879000:01
51.3763782.2410620.33945900:01
61.1829062.0941070.37142900:01
71.0198521.6148430.47614100:01
80.8716621.5492970.48688000:01
90.7438751.5252400.52286700:01
100.6363711.4349420.55860600:01
110.5495751.3986440.55364600:01
120.4805471.3577810.56441000:01
130.4272231.2909590.58360600:01
140.3881081.2097170.60694400:01
150.3568911.2568060.60972200:01
160.3321501.2690090.61004500:01
170.3151041.2448850.61795600:01
180.3042691.2619090.61550100:01
190.2977691.2797110.61153300:01
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(20, 3e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2-layer GRU" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Model5(Module):\n", " def __init__(self):\n", " self.i_h = nn.Embedding(nv,nh)\n", " self.rnn = nn.GRU(nh, nh, 2, batch_first=True)\n", " self.h_o = nn.Linear(nh,nv)\n", " self.bn = BatchNorm1dFlat(nh)\n", " self.h = torch.zeros(2, bs, nh).cuda()\n", " \n", " def forward(self, x):\n", " if x.shape[0]!=self.h.shape[1]: self.h = torch.zeros(2, x.shape[0], nh).cuda()\n", " res,h = self.rnn(self.i_h(x), self.h)\n", " self.h = h.detach()\n", " return self.h_o(self.bn(res))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(dls, Model5(), loss_func=CrossEntropyLossFlat(), metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
02.6663922.1149010.49759400:01
11.4362921.3572660.62433000:01
20.6788161.0078750.74538700:01
30.3295090.7359180.81321900:01
40.1684630.6339210.83792200:01
50.0898410.6128710.85129000:01
60.0510910.6906960.84097200:01
70.0314490.7065230.83489600:01
80.0206420.6334270.84394800:01
90.0142710.6360020.84407200:01
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(10, 1e-2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## fin" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }