{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Human numbers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai2.text.all import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bs=64"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#2) [Path('/home/jhoward/.fastai/data/human_numbers/train.txt'),Path('/home/jhoward/.fastai/data/human_numbers/valid.txt')]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = untar_data(URLs.HUMAN_NUMBERS)\n",
"path.ls()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def readnums(d): return ', '.join(o.strip() for o in open(path/d).readlines())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_txt = readnums('train.txt'); train_txt[:80]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"valid_txt = readnums('valid.txt'); valid_txt[-80:]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_tok = tokenize1(train_txt)\n",
"valid_tok = tokenize1(valid_txt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dsets = Datasets([train_tok, valid_tok], tfms=Numericalize, dl_type=LMDataLoader, splits=[[0], [1]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = dsets.dataloaders(bs=bs, val_bs=bs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dsets.show((dsets.train[0][0][:80],))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"13017"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(dsets.valid[0][0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(dls.valid)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(72, 3)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dls.seq_len, len(dls.valid)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.8248697916666665"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"13017/72/bs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"it = iter(dls.valid)\n",
"x1,y1 = next(it)\n",
"x2,y2 = next(it)\n",
"x3,y3 = next(it)\n",
"it.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"12992"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1.numel()+x2.numel()+x3.numel()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is the closes multiple of 64 below 13017"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 72]), torch.Size([64, 72]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1.shape,y1.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 72]), torch.Size([64, 72]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x2.shape,y2.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 2, 19, 11, 12, 9, 19, 11, 13, 9, 19, 11, 14, 9, 19, 11, 15, 9, 19,\n",
" 11, 16, 9, 19, 11, 17, 9, 19, 11, 18, 9, 19, 11, 19, 9, 19, 11, 20,\n",
" 9, 19, 11, 29, 9, 19, 11, 30, 9, 19, 11, 31, 9, 19, 11, 32, 9, 19,\n",
" 11, 33, 9, 19, 11, 34, 9, 19, 11, 35, 9, 19, 11, 36, 9, 19, 11, 37],\n",
" device='cuda:5')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([19, 11, 12, 9, 19, 11, 13, 9, 19, 11, 14, 9, 19, 11, 15, 9, 19, 11,\n",
" 16, 9, 19, 11, 17, 9, 19, 11, 18, 9, 19, 11, 19, 9, 19, 11, 20, 9,\n",
" 19, 11, 29, 9, 19, 11, 30, 9, 19, 11, 31, 9, 19, 11, 32, 9, 19, 11,\n",
" 33, 9, 19, 11, 34, 9, 19, 11, 35, 9, 19, 11, 36, 9, 19, 11, 37, 9],\n",
" device='cuda:5')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y1[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"v = dls.vocab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand eighteen'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"' '.join([v[x] for x in x1[0]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand eighteen ,'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"' '.join([v[x] for x in y1[0]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"', eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three , eight thousand twenty four , eight thousand twenty five , eight thousand twenty six , eight thousand twenty seven , eight thousand twenty eight , eight thousand twenty nine , eight thousand thirty , eight thousand thirty one , eight thousand thirty two , eight thousand thirty three'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"' '.join([v[x] for x in x2[0]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"', eight thousand thirty four , eight thousand thirty five , eight thousand thirty six , eight thousand thirty seven , eight thousand thirty eight , eight thousand thirty nine , eight thousand forty , eight thousand forty one , eight thousand forty two , eight thousand forty three , eight thousand forty four , eight thousand forty five'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"' '.join([v[x] for x in x3[0]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine , eight thousand'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"' '.join([v[x] for x in x1[1]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand seventy four , eight'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"' '.join([v[x] for x in x2[1]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'thousand seventy five , eight thousand seventy six , eight thousand seventy seven , eight thousand seventy eight , eight thousand seventy nine , eight thousand eighty , eight thousand eighty one , eight thousand eighty two , eight thousand eighty three , eight thousand eighty four , eight thousand eighty five , eight thousand eighty six , eight'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"' '.join([v[x] for x in x3[1]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'seven , nine thousand nine hundred eighty eight , nine thousand nine hundred eighty nine , nine thousand nine hundred ninety , nine thousand nine hundred ninety one , nine thousand nine hundred ninety two , nine thousand nine hundred ninety three , nine thousand nine hundred ninety four , nine thousand nine hundred ninety five , nine thousand'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"' '.join([v[x] for x in x3[-1]])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Single fully connected model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = dsets.dataloaders(bs=bs, seq_len=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 3]), torch.Size([64, 3]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = dls.one_batch()\n",
"x.shape,y.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"40"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nv = len(v); nv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nh=64"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def loss4(input,target): return F.cross_entropy(input, target[:,-1])\n",
"def acc4 (input,target): return accuracy(input, target[:,-1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model0(Module):\n",
" def __init__(self):\n",
" self.i_h = nn.Embedding(nv,nh) # green arrow\n",
" self.h_h = nn.Linear(nh,nh) # brown arrow\n",
" self.h_o = nn.Linear(nh,nv) # blue arrow\n",
" self.bn = nn.BatchNorm1d(nh)\n",
" \n",
" def forward(self, x):\n",
" h = self.bn(F.relu(self.h_h(self.i_h(x[:,0]))))\n",
" if x.shape[1]>1:\n",
" h = h + self.i_h(x[:,1])\n",
" h = self.bn(F.relu(self.h_h(h)))\n",
" if x.shape[1]>2:\n",
" h = h + self.i_h(x[:,2])\n",
" h = self.bn(F.relu(self.h_h(h)))\n",
" return self.h_o(h)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, Model0(), loss_func=loss4, metrics=acc4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" acc4 | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.459452 | \n",
" 3.417839 | \n",
" 0.144213 | \n",
" 00:10 | \n",
"
\n",
" \n",
" | 1 | \n",
" 2.519120 | \n",
" 2.569264 | \n",
" 0.456250 | \n",
" 00:10 | \n",
"
\n",
" \n",
" | 2 | \n",
" 2.031360 | \n",
" 2.176257 | \n",
" 0.459722 | \n",
" 00:10 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.840601 | \n",
" 2.040740 | \n",
" 0.463657 | \n",
" 00:10 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1.772740 | \n",
" 2.000901 | \n",
" 0.463657 | \n",
" 00:10 | \n",
"
\n",
" \n",
" | 5 | \n",
" 1.758649 | \n",
" 1.995709 | \n",
" 0.464120 | \n",
" 00:10 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(6, 1e-4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Same thing with a loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model1(Module):\n",
" def __init__(self):\n",
" self.i_h = nn.Embedding(nv,nh) # green arrow\n",
" self.h_h = nn.Linear(nh,nh) # brown arrow\n",
" self.h_o = nn.Linear(nh,nv) # blue arrow\n",
" self.bn = nn.BatchNorm1d(nh)\n",
" \n",
" def forward(self, x):\n",
" h = torch.zeros(x.shape[0], nh).to(device=x.device)\n",
" for i in range(x.shape[1]):\n",
" h = h + self.i_h(x[:,i])\n",
" h = self.bn(F.relu(self.h_h(h)))\n",
" return self.h_o(h)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, Model1(), loss_func=loss4, metrics=acc4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" acc4 | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.445585 | \n",
" 3.383623 | \n",
" 0.194213 | \n",
" 00:10 | \n",
"
\n",
" \n",
" | 1 | \n",
" 2.568218 | \n",
" 2.707002 | \n",
" 0.425694 | \n",
" 00:10 | \n",
"
\n",
" \n",
" | 2 | \n",
" 2.063069 | \n",
" 2.317326 | \n",
" 0.460185 | \n",
" 00:10 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.860497 | \n",
" 2.152390 | \n",
" 0.466667 | \n",
" 00:10 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1.787315 | \n",
" 2.100394 | \n",
" 0.467593 | \n",
" 00:10 | \n",
"
\n",
" \n",
" | 5 | \n",
" 1.772113 | \n",
" 2.092769 | \n",
" 0.467593 | \n",
" 00:10 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(6, 1e-4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multi fully connected model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dls = dsets.dataloaders(bs=bs, seq_len=20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 20]), torch.Size([64, 20]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = dls.one_batch()\n",
"x.shape,y.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model2(Module):\n",
" def __init__(self):\n",
" self.i_h = nn.Embedding(nv,nh)\n",
" self.h_h = nn.Linear(nh,nh)\n",
" self.h_o = nn.Linear(nh,nv)\n",
" self.bn = nn.BatchNorm1d(nh)\n",
" \n",
" def forward(self, x):\n",
" h = torch.zeros(x.shape[0], nh).to(device=x.device)\n",
" res = []\n",
" for i in range(x.shape[1]):\n",
" h = h + self.i_h(x[:,i])\n",
" h = F.relu(self.h_h(h))\n",
" res.append(self.h_o(self.bn(h)))\n",
" return torch.stack(res, dim=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, Model2(), loss_func=CrossEntropyLossFlat(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.736573 | \n",
" 3.754480 | \n",
" 0.063566 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 1 | \n",
" 3.540261 | \n",
" 3.523310 | \n",
" 0.124826 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 2 | \n",
" 3.300708 | \n",
" 3.304820 | \n",
" 0.248735 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 3 | \n",
" 3.063205 | \n",
" 3.128578 | \n",
" 0.299777 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 4 | \n",
" 2.861345 | \n",
" 3.009128 | \n",
" 0.335367 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 5 | \n",
" 2.705495 | \n",
" 2.929025 | \n",
" 0.353894 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 6 | \n",
" 2.593792 | \n",
" 2.878335 | \n",
" 0.367832 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 7 | \n",
" 2.519732 | \n",
" 2.850741 | \n",
" 0.373140 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 8 | \n",
" 2.475534 | \n",
" 2.840007 | \n",
" 0.375546 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 9 | \n",
" 2.452727 | \n",
" 2.838413 | \n",
" 0.375918 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(10, 1e-4, pct_start=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Maintain state"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model3(Module):\n",
" def __init__(self):\n",
" self.i_h = nn.Embedding(nv,nh)\n",
" self.h_h = nn.Linear(nh,nh)\n",
" self.h_o = nn.Linear(nh,nv)\n",
" self.bn = nn.BatchNorm1d(nh)\n",
" self.h = torch.zeros(bs, nh).cuda()\n",
" \n",
" def forward(self, x):\n",
" res = []\n",
" if x.shape[0]!=self.h.shape[0]: self.h = torch.zeros(x.shape[0], nh).cuda()\n",
" h = self.h\n",
" for i in range(x.shape[1]):\n",
" h = h + self.i_h(x[:,i])\n",
" h = F.relu(self.h_h(h))\n",
" res.append(self.bn(h))\n",
" self.h = h.detach()\n",
" res = torch.stack(res, dim=1)\n",
" res = self.h_o(res)\n",
" return res\n",
" \n",
" def reset(self): self.h = torch.zeros(bs, nh).cuda()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, Model3(), metrics=accuracy, loss_func=CrossEntropyLossFlat())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.482397 | \n",
" 3.442618 | \n",
" 0.139980 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 1 | \n",
" 2.828804 | \n",
" 2.455908 | \n",
" 0.417783 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 2 | \n",
" 2.134592 | \n",
" 2.153767 | \n",
" 0.315203 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.763576 | \n",
" 2.096672 | \n",
" 0.316940 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1.589015 | \n",
" 2.090171 | \n",
" 0.317088 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 5 | \n",
" 1.497501 | \n",
" 2.057994 | \n",
" 0.331374 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 6 | \n",
" 1.414305 | \n",
" 1.895721 | \n",
" 0.441195 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 7 | \n",
" 1.307273 | \n",
" 2.044791 | \n",
" 0.437872 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 8 | \n",
" 1.165429 | \n",
" 1.991641 | \n",
" 0.461210 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 9 | \n",
" 1.033335 | \n",
" 1.776033 | \n",
" 0.542783 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 10 | \n",
" 0.923316 | \n",
" 1.810016 | \n",
" 0.564509 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 11 | \n",
" 0.834117 | \n",
" 1.762270 | \n",
" 0.565005 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 12 | \n",
" 0.758906 | \n",
" 1.723969 | \n",
" 0.591940 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 13 | \n",
" 0.699892 | \n",
" 1.808163 | \n",
" 0.578944 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 14 | \n",
" 0.653839 | \n",
" 1.802881 | \n",
" 0.592039 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 15 | \n",
" 0.620560 | \n",
" 1.769326 | \n",
" 0.614509 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 16 | \n",
" 0.595637 | \n",
" 1.782574 | \n",
" 0.616667 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 17 | \n",
" 0.578359 | \n",
" 1.772477 | \n",
" 0.623785 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 18 | \n",
" 0.567210 | \n",
" 1.772950 | \n",
" 0.623115 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 19 | \n",
" 0.561052 | \n",
" 1.781880 | \n",
" 0.621751 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(20, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## nn.RNN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model4(Module):\n",
" def __init__(self):\n",
" self.i_h = nn.Embedding(nv,nh)\n",
" self.rnn = nn.RNN(nh,nh, batch_first=True)\n",
" self.h_o = nn.Linear(nh,nv)\n",
" self.bn = BatchNorm1dFlat(nh)\n",
" self.h = torch.zeros(1, bs, nh).cuda()\n",
" \n",
" def forward(self, x):\n",
" if x.shape[0]!=self.h.shape[1]: self.h = torch.zeros(1, x.shape[0], nh).cuda()\n",
" res,h = self.rnn(self.i_h(x), self.h)\n",
" self.h = h.detach()\n",
" return self.h_o(self.bn(res))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, Model4(), loss_func=CrossEntropyLossFlat(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.462379 | \n",
" 3.272240 | \n",
" 0.265749 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 1 | \n",
" 2.669984 | \n",
" 2.254657 | \n",
" 0.462872 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 2 | \n",
" 2.026915 | \n",
" 2.119816 | \n",
" 0.315923 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.709504 | \n",
" 2.164839 | \n",
" 0.316964 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1.538079 | \n",
" 2.037120 | \n",
" 0.388790 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 5 | \n",
" 1.376378 | \n",
" 2.241062 | \n",
" 0.339459 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 6 | \n",
" 1.182906 | \n",
" 2.094107 | \n",
" 0.371429 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 7 | \n",
" 1.019852 | \n",
" 1.614843 | \n",
" 0.476141 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 8 | \n",
" 0.871662 | \n",
" 1.549297 | \n",
" 0.486880 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 9 | \n",
" 0.743875 | \n",
" 1.525240 | \n",
" 0.522867 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 10 | \n",
" 0.636371 | \n",
" 1.434942 | \n",
" 0.558606 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 11 | \n",
" 0.549575 | \n",
" 1.398644 | \n",
" 0.553646 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 12 | \n",
" 0.480547 | \n",
" 1.357781 | \n",
" 0.564410 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 13 | \n",
" 0.427223 | \n",
" 1.290959 | \n",
" 0.583606 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 14 | \n",
" 0.388108 | \n",
" 1.209717 | \n",
" 0.606944 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 15 | \n",
" 0.356891 | \n",
" 1.256806 | \n",
" 0.609722 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 16 | \n",
" 0.332150 | \n",
" 1.269009 | \n",
" 0.610045 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 17 | \n",
" 0.315104 | \n",
" 1.244885 | \n",
" 0.617956 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 18 | \n",
" 0.304269 | \n",
" 1.261909 | \n",
" 0.615501 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 19 | \n",
" 0.297769 | \n",
" 1.279711 | \n",
" 0.611533 | \n",
" 00:01 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(20, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2-layer GRU"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model5(Module):\n",
" def __init__(self):\n",
" self.i_h = nn.Embedding(nv,nh)\n",
" self.rnn = nn.GRU(nh, nh, 2, batch_first=True)\n",
" self.h_o = nn.Linear(nh,nv)\n",
" self.bn = BatchNorm1dFlat(nh)\n",
" self.h = torch.zeros(2, bs, nh).cuda()\n",
" \n",
" def forward(self, x):\n",
" if x.shape[0]!=self.h.shape[1]: self.h = torch.zeros(2, x.shape[0], nh).cuda()\n",
" res,h = self.rnn(self.i_h(x), self.h)\n",
" self.h = h.detach()\n",
" return self.h_o(self.bn(res))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(dls, Model5(), loss_func=CrossEntropyLossFlat(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 2.666392 | \n",
" 2.114901 | \n",
" 0.497594 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1.436292 | \n",
" 1.357266 | \n",
" 0.624330 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 2 | \n",
" 0.678816 | \n",
" 1.007875 | \n",
" 0.745387 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 3 | \n",
" 0.329509 | \n",
" 0.735918 | \n",
" 0.813219 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.168463 | \n",
" 0.633921 | \n",
" 0.837922 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 5 | \n",
" 0.089841 | \n",
" 0.612871 | \n",
" 0.851290 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 6 | \n",
" 0.051091 | \n",
" 0.690696 | \n",
" 0.840972 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 7 | \n",
" 0.031449 | \n",
" 0.706523 | \n",
" 0.834896 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 8 | \n",
" 0.020642 | \n",
" 0.633427 | \n",
" 0.843948 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 9 | \n",
" 0.014271 | \n",
" 0.636002 | \n",
" 0.844072 | \n",
" 00:01 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(10, 1e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## fin"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}