{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Human numbers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai.text import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bs=64"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[PosixPath('/home/ubuntu/.fastai/data/human_numbers/train.txt'),\n",
" PosixPath('/home/ubuntu/.fastai/data/human_numbers/valid.txt')]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = untar_data(URLs.HUMAN_NUMBERS)\n",
"path.ls()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def readnums(d): return [', '.join(o.strip() for o in open(path/d).readlines())]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_txt = readnums('train.txt'); train_txt[0][:80]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"valid_txt = readnums('valid.txt'); valid_txt[0][-80:]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train = TextList(train_txt, path=path)\n",
"valid = TextList(valid_txt, path=path)\n",
"\n",
"src = ItemLists(path=path, train=train, valid=valid).label_for_lm()\n",
"data = src.databunch(bs=bs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'xxbos one , two , three , four , five , six , seven , eight , nine , ten , eleve'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train[0].text[:80]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"13017"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(data.valid_ds[0][0].data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(70, 3)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.bptt, len(data.valid_dl)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.905580357142857"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"13017/70/bs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"it = iter(data.valid_dl)\n",
"x1,y1 = next(it)\n",
"x2,y2 = next(it)\n",
"x3,y3 = next(it)\n",
"it.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"13440"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1.numel()+x2.numel()+x3.numel()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 70]), torch.Size([64, 70]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1.shape,y1.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 70]), torch.Size([64, 70]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x2.shape,y2.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 2, 8, 10, 11, 12, 10, 9, 8, 9, 13, 18, 24, 18, 14, 15, 10, 18, 8,\n",
" 9, 8, 18, 24, 18, 10, 18, 10, 9, 8, 18, 19, 10, 25, 19, 22, 19, 19,\n",
" 23, 19, 10, 13, 10, 10, 8, 13, 8, 19, 9, 19, 34, 16, 10, 9, 8, 16,\n",
" 8, 19, 9, 19, 10, 19, 10, 19, 19, 19], device='cuda:0')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1[:,0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([18, 18, 26, 9, 8, 11, 31, 18, 25, 9, 10, 14, 10, 9, 8, 14, 10, 18,\n",
" 25, 18, 10, 17, 10, 17, 8, 17, 20, 18, 9, 9, 19, 8, 10, 15, 10, 10,\n",
" 12, 10, 12, 8, 12, 13, 19, 9, 19, 10, 23, 10, 8, 8, 15, 16, 19, 9,\n",
" 19, 10, 23, 10, 18, 8, 18, 10, 10, 9], device='cuda:0')"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y1[:,0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"v = data.valid_ds.vocab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x1[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(y1[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'thousand eighteen , eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three , eight thousand twenty four , eight thousand twenty five , eight thousand twenty six , eight thousand twenty seven , eight thousand twenty eight , eight thousand twenty nine , eight thousand thirty , eight thousand thirty one , eight thousand thirty two ,'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x2[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'eight thousand thirty three , eight thousand thirty four , eight thousand thirty five , eight thousand thirty six , eight thousand thirty seven , eight thousand thirty eight , eight thousand thirty nine , eight thousand forty , eight thousand forty one , eight thousand forty two , eight thousand forty three , eight thousand forty four , eight thousand forty five , eight thousand forty six , eight'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x3[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine ,'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x1[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'eight thousand sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x2[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'seventy four , eight thousand seventy five , eight thousand seventy six , eight thousand seventy seven , eight thousand seventy eight , eight thousand seventy nine , eight thousand eighty , eight thousand eighty one , eight thousand eighty two , eight thousand eighty three , eight thousand eighty four , eight thousand eighty five , eight thousand eighty six , eight thousand eighty seven , eight thousand eighty'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x3[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'ninety , nine thousand nine hundred ninety one , nine thousand nine hundred ninety two , nine thousand nine hundred ninety three , nine thousand nine hundred ninety four , nine thousand nine hundred ninety five , nine thousand nine hundred ninety six , nine thousand nine hundred ninety seven , nine thousand nine hundred ninety eight , nine thousand nine hundred ninety nine xxbos eight thousand one , eight'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x3[-1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" idx | \n",
" text | \n",
"
\n",
" \n",
" 0 | \n",
" thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine , eight thousand sixty , eight thousand sixty | \n",
"
\n",
" \n",
" 1 | \n",
" eight , eight thousand eighty nine , eight thousand ninety , eight thousand ninety one , eight thousand ninety two , eight thousand ninety three , eight thousand ninety four , eight thousand ninety five , eight thousand ninety six , eight thousand ninety seven , eight thousand ninety eight , eight thousand ninety nine , eight thousand one hundred , eight thousand one hundred one , eight thousand one | \n",
"
\n",
" \n",
" 2 | \n",
" thousand one hundred twenty four , eight thousand one hundred twenty five , eight thousand one hundred twenty six , eight thousand one hundred twenty seven , eight thousand one hundred twenty eight , eight thousand one hundred twenty nine , eight thousand one hundred thirty , eight thousand one hundred thirty one , eight thousand one hundred thirty two , eight thousand one hundred thirty three , eight thousand | \n",
"
\n",
" \n",
" 3 | \n",
" three , eight thousand one hundred fifty four , eight thousand one hundred fifty five , eight thousand one hundred fifty six , eight thousand one hundred fifty seven , eight thousand one hundred fifty eight , eight thousand one hundred fifty nine , eight thousand one hundred sixty , eight thousand one hundred sixty one , eight thousand one hundred sixty two , eight thousand one hundred sixty three | \n",
"
\n",
" \n",
" 4 | \n",
" thousand one hundred eighty three , eight thousand one hundred eighty four , eight thousand one hundred eighty five , eight thousand one hundred eighty six , eight thousand one hundred eighty seven , eight thousand one hundred eighty eight , eight thousand one hundred eighty nine , eight thousand one hundred ninety , eight thousand one hundred ninety one , eight thousand one hundred ninety two , eight thousand | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data.show_batch(ds_type=DatasetType.Valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Single fully connected model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = src.databunch(bs=bs, bptt=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 3]), torch.Size([64, 3]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = data.one_batch()\n",
"x.shape,y.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"38"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nv = len(v.itos); nv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"nh=64"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def loss4(input,target): return F.cross_entropy(input, target[:,-1])\n",
"def acc4 (input,target): return accuracy(input, target[:,-1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model0(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv,nh) # green arrow\n",
" self.h_h = nn.Linear(nh,nh) # brown arrow\n",
" self.h_o = nn.Linear(nh,nv) # blue arrow\n",
" self.bn = nn.BatchNorm1d(nh)\n",
" \n",
" def forward(self, x):\n",
" h = self.bn(F.relu(self.h_h(self.i_h(x[:,0]))))\n",
" if x.shape[1]>1:\n",
" h = h + self.i_h(x[:,1])\n",
" h = self.bn(F.relu(self.h_h(h)))\n",
" if x.shape[1]>2:\n",
" h = h + self.i_h(x[:,2])\n",
" h = self.bn(F.relu(self.h_h(h)))\n",
" return self.h_o(h)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model0(), loss_func=loss4, metrics=acc4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:07 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" acc4 | \n",
"
\n",
" \n",
" 1 | \n",
" 3.596286 | \n",
" 3.588869 | \n",
" 0.046645 | \n",
"
\n",
" \n",
" 2 | \n",
" 3.086100 | \n",
" 3.205763 | \n",
" 0.274816 | \n",
"
\n",
" \n",
" 3 | \n",
" 2.494411 | \n",
" 2.749365 | \n",
" 0.392004 | \n",
"
\n",
" \n",
" 4 | \n",
" 2.144753 | \n",
" 2.463537 | \n",
" 0.415671 | \n",
"
\n",
" \n",
" 5 | \n",
" 2.010915 | \n",
" 2.352887 | \n",
" 0.409237 | \n",
"
\n",
" \n",
" 6 | \n",
" 1.983992 | \n",
" 2.336967 | \n",
" 0.408778 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(6, 1e-4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Same thing with a loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model1(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv,nh) # green arrow\n",
" self.h_h = nn.Linear(nh,nh) # brown arrow\n",
" self.h_o = nn.Linear(nh,nv) # blue arrow\n",
" self.bn = nn.BatchNorm1d(nh)\n",
" \n",
" def forward(self, x):\n",
" h = torch.zeros(x.shape[0], nh).to(device=x.device)\n",
" for i in range(x.shape[1]):\n",
" h = h + self.i_h(x[:,i])\n",
" h = self.bn(F.relu(self.h_h(h)))\n",
" return self.h_o(h)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model1(), loss_func=loss4, metrics=acc4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:07 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" acc4 | \n",
"
\n",
" \n",
" 1 | \n",
" 3.493525 | \n",
" 3.420231 | \n",
" 0.156250 | \n",
"
\n",
" \n",
" 2 | \n",
" 2.987600 | \n",
" 2.937893 | \n",
" 0.376149 | \n",
"
\n",
" \n",
" 3 | \n",
" 2.440199 | \n",
" 2.477995 | \n",
" 0.388787 | \n",
"
\n",
" \n",
" 4 | \n",
" 2.132837 | \n",
" 2.256569 | \n",
" 0.391774 | \n",
"
\n",
" \n",
" 5 | \n",
" 2.011305 | \n",
" 2.181337 | \n",
" 0.392923 | \n",
"
\n",
" \n",
" 6 | \n",
" 1.985913 | \n",
" 2.170874 | \n",
" 0.393153 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(6, 1e-4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multi fully connected model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = src.databunch(bs=bs, bptt=20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 20]), torch.Size([64, 20]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = data.one_batch()\n",
"x.shape,y.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model2(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv,nh)\n",
" self.h_h = nn.Linear(nh,nh)\n",
" self.h_o = nn.Linear(nh,nv)\n",
" self.bn = nn.BatchNorm1d(nh)\n",
" \n",
" def forward(self, x):\n",
" h = torch.zeros(x.shape[0], nh).to(device=x.device)\n",
" res = []\n",
" for i in range(x.shape[1]):\n",
" h = h + self.i_h(x[:,i])\n",
" h = F.relu(self.h_h(h))\n",
" res.append(self.h_o(self.bn(h)))\n",
" return torch.stack(res, dim=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model2(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:06 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" 1 | \n",
" 3.639285 | \n",
" 3.709278 | \n",
" 0.058949 | \n",
"
\n",
" \n",
" 2 | \n",
" 3.551151 | \n",
" 3.565677 | \n",
" 0.151776 | \n",
"
\n",
" \n",
" 3 | \n",
" 3.439908 | \n",
" 3.431850 | \n",
" 0.207741 | \n",
"
\n",
" \n",
" 4 | \n",
" 3.323083 | \n",
" 3.314237 | \n",
" 0.283949 | \n",
"
\n",
" \n",
" 5 | \n",
" 3.213422 | \n",
" 3.219906 | \n",
" 0.321662 | \n",
"
\n",
" \n",
" 6 | \n",
" 3.119673 | \n",
" 3.151162 | \n",
" 0.336790 | \n",
"
\n",
" \n",
" 7 | \n",
" 3.046645 | \n",
" 3.106630 | \n",
" 0.341690 | \n",
"
\n",
" \n",
" 8 | \n",
" 2.995379 | \n",
" 3.082552 | \n",
" 0.346662 | \n",
"
\n",
" \n",
" 9 | \n",
" 2.963800 | \n",
" 3.073327 | \n",
" 0.349645 | \n",
"
\n",
" \n",
" 10 | \n",
" 2.947312 | \n",
" 3.071951 | \n",
" 0.349787 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(10, 1e-4, pct_start=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Maintain state"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model3(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv,nh)\n",
" self.h_h = nn.Linear(nh,nh)\n",
" self.h_o = nn.Linear(nh,nv)\n",
" self.bn = nn.BatchNorm1d(nh)\n",
" self.h = torch.zeros(bs, nh).cuda()\n",
" \n",
" def forward(self, x):\n",
" res = []\n",
" h = self.h\n",
" for i in range(x.shape[1]):\n",
" h = h + self.i_h(x[:,i])\n",
" h = F.relu(self.h_h(h))\n",
" res.append(self.bn(h))\n",
" self.h = h.detach()\n",
" res = torch.stack(res, dim=1)\n",
" res = self.h_o(res)\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model3(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:11 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" 1 | \n",
" 3.598183 | \n",
" 3.556362 | \n",
" 0.050710 | \n",
"
\n",
" \n",
" 2 | \n",
" 3.274616 | \n",
" 2.975699 | \n",
" 0.401634 | \n",
"
\n",
" \n",
" 3 | \n",
" 2.624206 | \n",
" 2.036894 | \n",
" 0.467330 | \n",
"
\n",
" \n",
" 4 | \n",
" 2.022702 | \n",
" 1.956439 | \n",
" 0.316193 | \n",
"
\n",
" \n",
" 5 | \n",
" 1.681813 | \n",
" 1.934952 | \n",
" 0.336861 | \n",
"
\n",
" \n",
" 6 | \n",
" 1.453007 | \n",
" 1.948201 | \n",
" 0.351349 | \n",
"
\n",
" \n",
" 7 | \n",
" 1.276971 | \n",
" 2.005776 | \n",
" 0.368679 | \n",
"
\n",
" \n",
" 8 | \n",
" 1.138499 | \n",
" 2.081261 | \n",
" 0.360156 | \n",
"
\n",
" \n",
" 9 | \n",
" 1.029217 | \n",
" 2.145853 | \n",
" 0.360795 | \n",
"
\n",
" \n",
" 10 | \n",
" 0.939949 | \n",
" 2.215388 | \n",
" 0.372230 | \n",
"
\n",
" \n",
" 11 | \n",
" 0.865441 | \n",
" 2.240438 | \n",
" 0.401491 | \n",
"
\n",
" \n",
" 12 | \n",
" 0.805310 | \n",
" 2.195846 | \n",
" 0.409375 | \n",
"
\n",
" \n",
" 13 | \n",
" 0.755035 | \n",
" 2.324373 | \n",
" 0.422727 | \n",
"
\n",
" \n",
" 14 | \n",
" 0.713073 | \n",
" 2.305542 | \n",
" 0.449716 | \n",
"
\n",
" \n",
" 15 | \n",
" 0.677393 | \n",
" 2.350155 | \n",
" 0.446449 | \n",
"
\n",
" \n",
" 16 | \n",
" 0.645841 | \n",
" 2.418738 | \n",
" 0.446591 | \n",
"
\n",
" \n",
" 17 | \n",
" 0.621809 | \n",
" 2.456903 | \n",
" 0.446165 | \n",
"
\n",
" \n",
" 18 | \n",
" 0.605300 | \n",
" 2.541699 | \n",
" 0.443040 | \n",
"
\n",
" \n",
" 19 | \n",
" 0.594099 | \n",
" 2.539824 | \n",
" 0.443040 | \n",
"
\n",
" \n",
" 20 | \n",
" 0.587563 | \n",
" 2.551423 | \n",
" 0.442827 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(20, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## nn.RNN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model4(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv,nh)\n",
" self.rnn = nn.RNN(nh,nh, batch_first=True)\n",
" self.h_o = nn.Linear(nh,nv)\n",
" self.bn = BatchNorm1dFlat(nh)\n",
" self.h = torch.zeros(1, bs, nh).cuda()\n",
" \n",
" def forward(self, x):\n",
" res,h = self.rnn(self.i_h(x), self.h)\n",
" self.h = h.detach()\n",
" return self.h_o(self.bn(res))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model4(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:04 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" 1 | \n",
" 3.451432 | \n",
" 3.268344 | \n",
" 0.224148 | \n",
"
\n",
" \n",
" 2 | \n",
" 2.974938 | \n",
" 2.456569 | \n",
" 0.466051 | \n",
"
\n",
" \n",
" 3 | \n",
" 2.316732 | \n",
" 1.946969 | \n",
" 0.465625 | \n",
"
\n",
" \n",
" 4 | \n",
" 1.866151 | \n",
" 1.991952 | \n",
" 0.314702 | \n",
"
\n",
" \n",
" 5 | \n",
" 1.618516 | \n",
" 1.802403 | \n",
" 0.437216 | \n",
"
\n",
" \n",
" 6 | \n",
" 1.411517 | \n",
" 1.731107 | \n",
" 0.436293 | \n",
"
\n",
" \n",
" 7 | \n",
" 1.171916 | \n",
" 1.655979 | \n",
" 0.504048 | \n",
"
\n",
" \n",
" 8 | \n",
" 0.965887 | \n",
" 1.579963 | \n",
" 0.522088 | \n",
"
\n",
" \n",
" 9 | \n",
" 0.797046 | \n",
" 1.479819 | \n",
" 0.565057 | \n",
"
\n",
" \n",
" 10 | \n",
" 0.659378 | \n",
" 1.487831 | \n",
" 0.579048 | \n",
"
\n",
" \n",
" 11 | \n",
" 0.553282 | \n",
" 1.441922 | \n",
" 0.597798 | \n",
"
\n",
" \n",
" 12 | \n",
" 0.475167 | \n",
" 1.498148 | \n",
" 0.600781 | \n",
"
\n",
" \n",
" 13 | \n",
" 0.416131 | \n",
" 1.546984 | \n",
" 0.606463 | \n",
"
\n",
" \n",
" 14 | \n",
" 0.372395 | \n",
" 1.594261 | \n",
" 0.607386 | \n",
"
\n",
" \n",
" 15 | \n",
" 0.337093 | \n",
" 1.578321 | \n",
" 0.613352 | \n",
"
\n",
" \n",
" 16 | \n",
" 0.311385 | \n",
" 1.580973 | \n",
" 0.623366 | \n",
"
\n",
" \n",
" 17 | \n",
" 0.292869 | \n",
" 1.625745 | \n",
" 0.618253 | \n",
"
\n",
" \n",
" 18 | \n",
" 0.279486 | \n",
" 1.623960 | \n",
" 0.626065 | \n",
"
\n",
" \n",
" 19 | \n",
" 0.270054 | \n",
" 1.682090 | \n",
" 0.611719 | \n",
"
\n",
" \n",
" 20 | \n",
" 0.263857 | \n",
" 1.675676 | \n",
" 0.614702 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(20, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2-layer GRU"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class Model5(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv,nh)\n",
" self.rnn = nn.GRU(nh, nh, 2, batch_first=True)\n",
" self.h_o = nn.Linear(nh,nv)\n",
" self.bn = BatchNorm1dFlat(nh)\n",
" self.h = torch.zeros(2, bs, nh).cuda()\n",
" \n",
" def forward(self, x):\n",
" res,h = self.rnn(self.i_h(x), self.h)\n",
" self.h = h.detach()\n",
" return self.h_o(self.bn(res))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model5(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:02 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" 1 | \n",
" 2.864854 | \n",
" 2.314943 | \n",
" 0.454545 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.798988 | \n",
" 1.357116 | \n",
" 0.629688 | \n",
"
\n",
" \n",
" 3 | \n",
" 0.932729 | \n",
" 1.307463 | \n",
" 0.796733 | \n",
"
\n",
" \n",
" 4 | \n",
" 0.451969 | \n",
" 1.329699 | \n",
" 0.788636 | \n",
"
\n",
" \n",
" 5 | \n",
" 0.225787 | \n",
" 1.293570 | \n",
" 0.800142 | \n",
"
\n",
" \n",
" 6 | \n",
" 0.118085 | \n",
" 1.265926 | \n",
" 0.803338 | \n",
"
\n",
" \n",
" 7 | \n",
" 0.065306 | \n",
" 1.207096 | \n",
" 0.806960 | \n",
"
\n",
" \n",
" 8 | \n",
" 0.038098 | \n",
" 1.205361 | \n",
" 0.813920 | \n",
"
\n",
" \n",
" 9 | \n",
" 0.024069 | \n",
" 1.239411 | \n",
" 0.807813 | \n",
"
\n",
" \n",
" 10 | \n",
" 0.017078 | \n",
" 1.253409 | \n",
" 0.807102 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(10, 1e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## fin"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}