{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Lesson 7: Human numbers"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from fastai.text import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"bs = 64"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[PosixPath('/home/ubuntu/.fastai/data/human_numbers/valid.txt'),\n",
" PosixPath('/home/ubuntu/.fastai/data/human_numbers/train.txt')]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = untar_data(URLs.HUMAN_NUMBERS)\n",
"path.ls()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def readnums(d): return [', '.join(o.strip() for o in open(path / d).readlines())]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_txt = readnums('train.txt')\n",
"train_txt[0][:80]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"valid_txt = readnums('valid.txt')\n",
"valid_txt[0][-80:]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"train = TextList(train_txt, path=path)\n",
"valid = TextList(valid_txt, path=path)\n",
"\n",
"src = ItemLists(path=path, train=train, valid=valid).label_for_lm()\n",
"data = src.databunch(bs=bs)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'xxbos one , two , three , four , five , six , seven , eight , nine , ten , eleve'"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train[0].text[:80]"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"13017"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(data.valid_ds[0][0].data)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(70, 3)"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.bptt, len(data.valid_dl)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2.905580357142857"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"13017/70/bs"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"it = iter(data.valid_dl)\n",
"x1, y1 = next(it)\n",
"x2, y2 = next(it)\n",
"x3, y3 = next(it)\n",
"it.close()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"12928"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1.numel() + x2.numel() + x3.numel()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 95]), torch.Size([64, 95]))"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1.shape, y1.shape"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 76]), torch.Size([64, 76]))"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x2.shape, y2.shape"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 2, 8, 10, 10, 23, 18, 10, 22, 18, 8, 21, 18, 9, 20, 18, 9, 18, 18,\n",
" 9, 15, 18, 9, 8, 18, 9, 8, 8, 9, 8, 9, 9, 8, 19, 19, 26, 10,\n",
" 9, 8, 8, 22, 19, 13, 21, 19, 9, 20, 19, 9, 31, 19, 9, 16, 19, 9,\n",
" 8, 19, 9, 8, 9, 9, 8, 10, 9, 8], device='cuda:0')"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1[:, 0]"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([18, 18, 26, 11, 12, 10, 12, 13, 10, 18, 14, 10, 27, 15, 10, 26, 10, 10,\n",
" 25, 8, 10, 24, 18, 10, 23, 18, 18, 22, 18, 18, 21, 18, 9, 10, 14, 11,\n",
" 23, 19, 19, 11, 10, 9, 12, 10, 27, 13, 10, 26, 8, 10, 25, 9, 10, 24,\n",
" 19, 10, 23, 19, 34, 22, 19, 19, 21, 19], device='cuda:0')"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y1[:, 0]"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"v = data.valid_ds.vocab"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([95])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x1[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand eighteen , eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three'"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x1[0])"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand eighteen , eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three ,'"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(y1[0])"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"', eight thousand twenty four , eight thousand twenty five , eight thousand twenty six , eight thousand twenty seven , eight thousand twenty eight , eight thousand twenty nine , eight thousand thirty , eight thousand thirty one , eight thousand thirty two , eight thousand thirty three , eight thousand thirty four , eight thousand thirty five , eight thousand thirty six , eight thousand thirty seven , eight thousand thirty eight , eight'"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x2[0])"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'thousand thirty nine , eight thousand forty , eight thousand forty one , eight thousand forty two , eight thousand forty three , eight thousand forty four , eight thousand forty'"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x3[0])"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine , eight thousand sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight'"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x1[1])"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand seventy four , eight thousand seventy five , eight thousand seventy six , eight thousand seventy seven , eight thousand seventy eight , eight thousand seventy nine , eight thousand eighty'"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x2[1])"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"', eight thousand eighty one , eight thousand eighty two , eight thousand eighty three , eight thousand eighty four , eight thousand eighty five , eight thousand eighty six ,'"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x3[1])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'one , nine thousand nine hundred ninety two , nine thousand nine hundred ninety three , nine thousand nine hundred ninety four , nine thousand nine hundred ninety five , nine'"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"v.textify(x3[-1])"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" idx | \n",
" text | \n",
"
\n",
" \n",
" 0 | \n",
" xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand | \n",
"
\n",
" \n",
" 1 | \n",
" , eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand | \n",
"
\n",
" \n",
" 2 | \n",
" thousand eighty seven , eight thousand eighty eight , eight thousand eighty nine , eight thousand ninety , eight thousand ninety one , eight thousand ninety two , eight thousand ninety three , eight thousand ninety four , eight thousand ninety five , eight thousand ninety six , eight thousand ninety seven , eight thousand ninety eight , eight thousand ninety nine , eight thousand one hundred | \n",
"
\n",
" \n",
" 3 | \n",
" thousand one hundred twenty three , eight thousand one hundred twenty four , eight thousand one hundred twenty five , eight thousand one hundred twenty six , eight thousand one hundred twenty seven , eight thousand one hundred twenty eight , eight thousand one hundred twenty nine , eight thousand one hundred thirty , eight thousand one hundred thirty one , eight thousand one hundred thirty two | \n",
"
\n",
" \n",
" 4 | \n",
" fifty two , eight thousand one hundred fifty three , eight thousand one hundred fifty four , eight thousand one hundred fifty five , eight thousand one hundred fifty six , eight thousand one hundred fifty seven , eight thousand one hundred fifty eight , eight thousand one hundred fifty nine , eight thousand one hundred sixty , eight thousand one hundred sixty one , eight thousand | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"data.show_batch(ds_type=DatasetType.Valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Single fully connected model"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"data = src.databunch(bs=bs, bptt=3, max_len=0, p_bptt=1.)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 3]), torch.Size([64, 3]))"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x, y = data.one_batch()\n",
"x.shape, y.shape"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"38"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nv = len(v.itos)\n",
"nv"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"nh = 64"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"def loss4(input, target): return F.cross_entropy(input, target[:, -1])\n",
"def acc4(input, target): return accuracy(input, target[:, -1])"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"class Model0(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv, nh) # green arrow\n",
" self.h_h = nn.Linear(nh, nh) # brown arrow\n",
" self.h_o = nn.Linear(nh, nv) # blue arrow\n",
" self.bn = nn.BatchNorm1d(nh)\n",
" \n",
" def forward(self, x):\n",
" h = self.bn(F.relu(self.i_h(x[:,0])))\n",
" if x.shape[1] > 1:\n",
" h = h + self.i_h(x[:,1])\n",
" h = self.bn(F.relu(self.h_h(h)))\n",
" if x.shape[1] > 2:\n",
" h = h + self.i_h(x[:,2])\n",
" h = self.bn(F.relu(self.h_h(h)))\n",
" return self.h_o(h)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model0(), loss_func=loss4, metrics=acc4)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:12 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" acc4 | \n",
"
\n",
" \n",
" 1 | \n",
" 3.609755 | \n",
" 3.622831 | \n",
" 0.046186 | \n",
"
\n",
" \n",
" 2 | \n",
" 3.103226 | \n",
" 3.217704 | \n",
" 0.425781 | \n",
"
\n",
" \n",
" 3 | \n",
" 2.523298 | \n",
" 2.733683 | \n",
" 0.449908 | \n",
"
\n",
" \n",
" 4 | \n",
" 2.192890 | \n",
" 2.447449 | \n",
" 0.452665 | \n",
"
\n",
" \n",
" 5 | \n",
" 2.065991 | \n",
" 2.342171 | \n",
" 0.454274 | \n",
"
\n",
" \n",
" 6 | \n",
" 2.039980 | \n",
" 2.326865 | \n",
" 0.454274 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(6, 1e-4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Same thing with a loop"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"class Model1(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv, nh) # green arrow\n",
" self.h_h = nn.Linear(nh, nh) # brown arrow\n",
" self.h_o = nn.Linear(nh, nv) # blue arrow\n",
" self.bn = nn.BatchNorm1d(nh)\n",
" \n",
" def forward(self, x):\n",
" h = torch.zeros(x.shape[0], nh).to(device=x.device)\n",
" for i in range(x.shape[1]):\n",
" h = h + self.i_h(x[:,i])\n",
" h = self.bn(F.relu(self.h_h(h)))\n",
" return self.h_o(h)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model1(), loss_func=loss4, metrics=acc4)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:13 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" acc4 | \n",
"
\n",
" \n",
" 1 | \n",
" 3.550785 | \n",
" 3.565774 | \n",
" 0.039062 | \n",
"
\n",
" \n",
" 2 | \n",
" 2.994696 | \n",
" 3.056980 | \n",
" 0.434283 | \n",
"
\n",
" \n",
" 3 | \n",
" 2.444730 | \n",
" 2.576163 | \n",
" 0.462546 | \n",
"
\n",
" \n",
" 4 | \n",
" 2.147489 | \n",
" 2.336781 | \n",
" 0.463925 | \n",
"
\n",
" \n",
" 5 | \n",
" 2.030240 | \n",
" 2.252541 | \n",
" 0.465533 | \n",
"
\n",
" \n",
" 6 | \n",
" 2.005649 | \n",
" 2.240313 | \n",
" 0.465763 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(6, 1e-4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multi fully connected model"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"data = src.databunch(bs=bs, bptt=20)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 45]), torch.Size([64, 45]))"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x, y = data.one_batch()\n",
"x.shape, y.shape"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"class Model2(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv, nh) # green arrow\n",
" self.h_h = nn.Linear(nh, nh) # brown arrow\n",
" self.h_o = nn.Linear(nh, nv) # blue arrow\n",
" self.bn = nn.BatchNorm1d(nh)\n",
" \n",
" def forward(self, x):\n",
" h = torch.zeros(x.shape[0], nh).to(device=x.device)\n",
" res = []\n",
" for i in range(x.shape[1]):\n",
" h = h + self.i_h(x[:,i])\n",
" h = F.relu(self.h_h(h))\n",
" res.append(self.h_o(self.bn(h)))\n",
" return torch.stack(res, dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model2(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:08 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" 1 | \n",
" 3.613750 | \n",
" 3.468091 | \n",
" 0.094572 | \n",
"
\n",
" \n",
" 2 | \n",
" 3.507750 | \n",
" 3.366867 | \n",
" 0.201829 | \n",
"
\n",
" \n",
" 3 | \n",
" 3.378657 | \n",
" 3.263256 | \n",
" 0.307511 | \n",
"
\n",
" \n",
" 4 | \n",
" 3.248644 | \n",
" 3.172283 | \n",
" 0.352421 | \n",
"
\n",
" \n",
" 5 | \n",
" 3.129175 | \n",
" 3.094844 | \n",
" 0.377534 | \n",
"
\n",
" \n",
" 6 | \n",
" 3.031686 | \n",
" 3.033968 | \n",
" 0.382643 | \n",
"
\n",
" \n",
" 7 | \n",
" 2.954897 | \n",
" 2.999621 | \n",
" 0.390039 | \n",
"
\n",
" \n",
" 8 | \n",
" 2.903233 | \n",
" 2.998493 | \n",
" 0.382097 | \n",
"
\n",
" \n",
" 9 | \n",
" 2.870547 | \n",
" 2.957848 | \n",
" 0.398152 | \n",
"
\n",
" \n",
" 10 | \n",
" 2.856550 | \n",
" 2.963168 | \n",
" 0.395220 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(10, 1e-4, pct_start=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Maintain state"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"class Model3(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv, nh) # green arrow\n",
" self.h_h = nn.Linear(nh, nh) # brown arrow\n",
" self.h_o = nn.Linear(nh, nv) # blue arrow\n",
" self.bn = nn.BatchNorm1d(nh)\n",
" self.h = torch.zeros(bs, nh).cuda()\n",
"\n",
" def forward(self, x):\n",
" res = []\n",
" h = self.h\n",
" for i in range(x.shape[1]):\n",
" h = h + self.i_h(x[:,i])\n",
" h = F.relu(self.h_h(h))\n",
" res.append(self.bn(h))\n",
" self.h = h.detach()\n",
" res = torch.stack(res, dim=1)\n",
" res = self.h_o(res)\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model3(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:15 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" 1 | \n",
" 3.498832 | \n",
" 3.457415 | \n",
" 0.145312 | \n",
"
\n",
" \n",
" 2 | \n",
" 3.161868 | \n",
" 2.874971 | \n",
" 0.449906 | \n",
"
\n",
" \n",
" 3 | \n",
" 2.515988 | \n",
" 2.037722 | \n",
" 0.467106 | \n",
"
\n",
" \n",
" 4 | \n",
" 1.967317 | \n",
" 2.110590 | \n",
" 0.316073 | \n",
"
\n",
" \n",
" 5 | \n",
" 1.672088 | \n",
" 2.135250 | \n",
" 0.337800 | \n",
"
\n",
" \n",
" 6 | \n",
" 1.493022 | \n",
" 2.155137 | \n",
" 0.344380 | \n",
"
\n",
" \n",
" 7 | \n",
" 1.335257 | \n",
" 2.116041 | \n",
" 0.394331 | \n",
"
\n",
" \n",
" 8 | \n",
" 1.201654 | \n",
" 2.299078 | \n",
" 0.408730 | \n",
"
\n",
" \n",
" 9 | \n",
" 1.090030 | \n",
" 2.624311 | \n",
" 0.427448 | \n",
"
\n",
" \n",
" 10 | \n",
" 1.008497 | \n",
" 2.462456 | \n",
" 0.422197 | \n",
"
\n",
" \n",
" 11 | \n",
" 0.971175 | \n",
" 2.352604 | \n",
" 0.437458 | \n",
"
\n",
" \n",
" 12 | \n",
" 0.906372 | \n",
" 2.458878 | \n",
" 0.453475 | \n",
"
\n",
" \n",
" 13 | \n",
" 0.843947 | \n",
" 2.455768 | \n",
" 0.461694 | \n",
"
\n",
" \n",
" 14 | \n",
" 0.794068 | \n",
" 2.469155 | \n",
" 0.458606 | \n",
"
\n",
" \n",
" 15 | \n",
" 0.754808 | \n",
" 2.490652 | \n",
" 0.453702 | \n",
"
\n",
" \n",
" 16 | \n",
" 0.722530 | \n",
" 2.597134 | \n",
" 0.453636 | \n",
"
\n",
" \n",
" 17 | \n",
" 0.689590 | \n",
" 2.633052 | \n",
" 0.452083 | \n",
"
\n",
" \n",
" 18 | \n",
" 0.670493 | \n",
" 2.525233 | \n",
" 0.467502 | \n",
"
\n",
" \n",
" 19 | \n",
" 0.656720 | \n",
" 2.663035 | \n",
" 0.460008 | \n",
"
\n",
" \n",
" 20 | \n",
" 0.657596 | \n",
" 2.545538 | \n",
" 0.464172 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(20, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## PyTorch `nn.RNN`"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
"class Model4(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv, nh)\n",
" self.rnn = nn.RNN(nh, nh, batch_first=True)\n",
" self.h_o = nn.Linear(nh, nv)\n",
" self.bn = BatchNorm1dFlat(nh)\n",
" self.h = torch.zeros(1, bs, nh).cuda()\n",
"\n",
" def forward(self, x):\n",
" res, h = self.rnn(self.i_h(x), self.h)\n",
" self.h = h.detach()\n",
" return self.h_o(self.bn(res))"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model4(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:09 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" 1 | \n",
" 3.488441 | \n",
" 3.392691 | \n",
" 0.156386 | \n",
"
\n",
" \n",
" 2 | \n",
" 3.016769 | \n",
" 2.516589 | \n",
" 0.461222 | \n",
"
\n",
" \n",
" 3 | \n",
" 2.346680 | \n",
" 1.942615 | \n",
" 0.467149 | \n",
"
\n",
" \n",
" 4 | \n",
" 1.879993 | \n",
" 2.002077 | \n",
" 0.312014 | \n",
"
\n",
" \n",
" 5 | \n",
" 1.631300 | \n",
" 1.904240 | \n",
" 0.432754 | \n",
"
\n",
" \n",
" 6 | \n",
" 1.462227 | \n",
" 1.904621 | \n",
" 0.482044 | \n",
"
\n",
" \n",
" 7 | \n",
" 1.300646 | \n",
" 1.851880 | \n",
" 0.492365 | \n",
"
\n",
" \n",
" 8 | \n",
" 1.153744 | \n",
" 1.653138 | \n",
" 0.492104 | \n",
"
\n",
" \n",
" 9 | \n",
" 1.008373 | \n",
" 1.549363 | \n",
" 0.494715 | \n",
"
\n",
" \n",
" 10 | \n",
" 0.878932 | \n",
" 1.600824 | \n",
" 0.500036 | \n",
"
\n",
" \n",
" 11 | \n",
" 0.777283 | \n",
" 1.508291 | \n",
" 0.520914 | \n",
"
\n",
" \n",
" 12 | \n",
" 0.712406 | \n",
" 1.532992 | \n",
" 0.570639 | \n",
"
\n",
" \n",
" 13 | \n",
" 0.626053 | \n",
" 1.426348 | \n",
" 0.569779 | \n",
"
\n",
" \n",
" 14 | \n",
" 0.555844 | \n",
" 1.715479 | \n",
" 0.545101 | \n",
"
\n",
" \n",
" 15 | \n",
" 0.499043 | \n",
" 1.626162 | \n",
" 0.542316 | \n",
"
\n",
" \n",
" 16 | \n",
" 0.458131 | \n",
" 1.536722 | \n",
" 0.548794 | \n",
"
\n",
" \n",
" 17 | \n",
" 0.438000 | \n",
" 1.548677 | \n",
" 0.545291 | \n",
"
\n",
" \n",
" 18 | \n",
" 0.409515 | \n",
" 1.462034 | \n",
" 0.552396 | \n",
"
\n",
" \n",
" 19 | \n",
" 0.394507 | \n",
" 1.477735 | \n",
" 0.554738 | \n",
"
\n",
" \n",
" 20 | \n",
" 0.390583 | \n",
" 1.518102 | \n",
" 0.549247 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(20, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2-layer GRU"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"class Model5(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv, nh)\n",
" self.rnn = nn.GRU(nh, nh, 2, batch_first=True)\n",
" self.h_o = nn.Linear(nh, nv)\n",
" self.bn = BatchNorm1dFlat(nh)\n",
" self.h = torch.zeros(2, bs, nh).cuda()\n",
"\n",
" def forward(self, x):\n",
" res, h = self.rnn(self.i_h(x), self.h)\n",
" self.h = h.detach()\n",
" return self.h_o(self.bn(res))"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model5(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:05 \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" 1 | \n",
" 2.983626 | \n",
" 2.321548 | \n",
" 0.444593 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.880445 | \n",
" 1.621104 | \n",
" 0.546462 | \n",
"
\n",
" \n",
" 3 | \n",
" 1.015526 | \n",
" 1.040678 | \n",
" 0.796203 | \n",
"
\n",
" \n",
" 4 | \n",
" 0.525928 | \n",
" 0.822173 | \n",
" 0.829538 | \n",
"
\n",
" \n",
" 5 | \n",
" 0.268591 | \n",
" 1.000392 | \n",
" 0.813538 | \n",
"
\n",
" \n",
" 6 | \n",
" 0.140787 | \n",
" 0.820461 | \n",
" 0.842801 | \n",
"
\n",
" \n",
" 7 | \n",
" 0.079602 | \n",
" 0.882789 | \n",
" 0.833222 | \n",
"
\n",
" \n",
" 8 | \n",
" 0.047994 | \n",
" 0.795663 | \n",
" 0.843396 | \n",
"
\n",
" \n",
" 9 | \n",
" 0.037621 | \n",
" 0.872451 | \n",
" 0.833012 | \n",
"
\n",
" \n",
" 10 | \n",
" 0.030443 | \n",
" 0.875417 | \n",
" 0.833148 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(10, 1e-2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}