{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A language model from scratch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastai2.text.all import *\n",
"path = untar_data(URLs.HUMAN_NUMBERS)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"Path.BASE_PATH = path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#2) [Path('train.txt'),Path('valid.txt')]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path.ls()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#9998) ['one \\n','two \\n','three \\n','four \\n','five \\n','six \\n','seven \\n','eight \\n','nine \\n','ten \\n'...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lines = L()\n",
"with open(path/'train.txt') as f: lines += L(*f.readlines())\n",
"with open(path/'valid.txt') as f: lines += L(*f.readlines())\n",
"lines"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"text = ' . '.join([l.strip() for l in lines])\n",
"text[:100]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokens = text.split(' ')\n",
"tokens[:10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab = L(*tokens).unique()\n",
"vocab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#63095) [0,1,2,1,3,1,4,1,5,1...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"word2idx = {w:i for i,w in enumerate(vocab)}\n",
"nums = L(word2idx[i] for i in tokens)\n",
"nums"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Our first language model from scratch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))\n",
"seqs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bs = 64\n",
"cut = int(len(seqs) * 0.8)\n",
"dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Our language model in PyTorch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel1(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
" \n",
" def forward(self, x):\n",
" h = F.relu(self.h_h(self.i_h(x[:,0])))\n",
" h = h + self.i_h(x[:,1])\n",
" h = F.relu(self.h_h(h))\n",
" h = h + self.i_h(x[:,2])\n",
" h = F.relu(self.h_h(h))\n",
" return self.h_o(h)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1.824297 | \n",
" 1.970941 | \n",
" 0.467554 | \n",
" 00:05 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1.386973 | \n",
" 1.823242 | \n",
" 0.467554 | \n",
" 00:05 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1.417556 | \n",
" 1.654497 | \n",
" 0.494414 | \n",
" 00:05 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.376440 | \n",
" 1.650849 | \n",
" 0.494414 | \n",
" 00:05 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy)\n",
"learn.fit_one_cycle(4, 1e-3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(29), 'thousand', 0.15165200855716662)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n,counts = 0,torch.zeros(len(vocab))\n",
"for x,y in dls.valid:\n",
" n += y.shape[0]\n",
" for i in range_of(vocab): counts[i] += (y==i).long().sum()\n",
"idx = torch.argmax(counts)\n",
"idx, vocab[idx.item()], counts[idx].item()/n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Our first recurrent neural network"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel2(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
" \n",
" def forward(self, x):\n",
" h = 0\n",
" for i in range(3):\n",
" h = h + self.i_h(x[:,i])\n",
" h = F.relu(self.h_h(h))\n",
" return self.h_o(h)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1.816274 | \n",
" 1.964143 | \n",
" 0.460185 | \n",
" 00:04 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1.423805 | \n",
" 1.739964 | \n",
" 0.473259 | \n",
" 00:05 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1.430327 | \n",
" 1.685172 | \n",
" 0.485382 | \n",
" 00:05 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.388390 | \n",
" 1.657033 | \n",
" 0.470406 | \n",
" 00:05 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy)\n",
"learn.fit_one_cycle(4, 1e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Improving the RNN"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Maintaining the state of an RNN"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel3(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
" self.h = 0\n",
" \n",
" def forward(self, x):\n",
" for i in range(3):\n",
" self.h = self.h + self.i_h(x[:,i])\n",
" self.h = F.relu(self.h_h(self.h))\n",
" out = self.h_o(self.h)\n",
" self.h = self.h.detach()\n",
" return out\n",
" \n",
" def reset(self): self.h = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def group_chunks(ds, bs):\n",
" m = len(ds) // bs\n",
" new_ds = L()\n",
" for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n",
" return new_ds"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cut = int(len(seqs) * 0.8)\n",
"dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs), group_chunks(seqs[cut:], bs), bs=bs, drop_last=True, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1.677074 | \n",
" 1.827367 | \n",
" 0.467548 | \n",
" 00:06 | \n",
"
\n",
" \n",
" | 1 | \n",
" 1.282722 | \n",
" 1.870913 | \n",
" 0.388942 | \n",
" 00:06 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1.090705 | \n",
" 1.651794 | \n",
" 0.462500 | \n",
" 00:05 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.005215 | \n",
" 1.615990 | \n",
" 0.515144 | \n",
" 00:06 | \n",
"
\n",
" \n",
" | 4 | \n",
" 0.963020 | \n",
" 1.605894 | \n",
" 0.551202 | \n",
" 00:06 | \n",
"
\n",
" \n",
" | 5 | \n",
" 0.926150 | \n",
" 1.721608 | \n",
" 0.543269 | \n",
" 00:06 | \n",
"
\n",
" \n",
" | 6 | \n",
" 0.901529 | \n",
" 1.650839 | \n",
" 0.559375 | \n",
" 00:05 | \n",
"
\n",
" \n",
" | 7 | \n",
" 0.829993 | \n",
" 1.743913 | \n",
" 0.569952 | \n",
" 00:06 | \n",
"
\n",
" \n",
" | 8 | \n",
" 0.810508 | \n",
" 1.746486 | \n",
" 0.584135 | \n",
" 00:06 | \n",
"
\n",
" \n",
" | 9 | \n",
" 0.795921 | \n",
" 1.756200 | \n",
" 0.582212 | \n",
" 00:04 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy, cbs=ModelReseter)\n",
"learn.fit_one_cycle(10, 3e-3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating more signal"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sl = 16\n",
"seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1])) for i in range(0,len(nums)-sl-1,sl))\n",
"cut = int(len(seqs) * 0.8)\n",
"dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs), group_chunks(seqs[cut:], bs), bs=bs, drop_last=True, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class LMModel4(Module):\n",
" def __init__(self, vocab_sz, n_hidden):\n",
" self.i_h = nn.Embedding(vocab_sz, n_hidden) \n",
" self.h_h = nn.Linear(n_hidden, n_hidden) \n",
" self.h_o = nn.Linear(n_hidden,vocab_sz)\n",
" self.h = 0\n",
" \n",
" def forward(self, x):\n",
" outs = []\n",
" for i in range(sl):\n",
" self.h = self.h + self.i_h(x[:,i])\n",
" self.h = F.relu(self.h_h(self.h))\n",
" outs.append(self.h_o(self.h))\n",
" self.h = self.h.detach()\n",
" return torch.stack(outs, dim=1)\n",
" \n",
" def reset(self): self.h = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def loss_func(inp, targ): return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.285931 | \n",
" 3.072032 | \n",
" 0.212565 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 1 | \n",
" 2.330371 | \n",
" 1.969522 | \n",
" 0.425781 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 2 | \n",
" 1.742317 | \n",
" 1.841378 | \n",
" 0.441488 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.470120 | \n",
" 1.810856 | \n",
" 0.494303 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1.298810 | \n",
" 1.823129 | \n",
" 0.492839 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 5 | \n",
" 1.176840 | \n",
" 1.755435 | \n",
" 0.509033 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 6 | \n",
" 1.070433 | \n",
" 1.689250 | \n",
" 0.517497 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 7 | \n",
" 0.972999 | \n",
" 1.867314 | \n",
" 0.513021 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 8 | \n",
" 0.896505 | \n",
" 1.716296 | \n",
" 0.582682 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 9 | \n",
" 0.835817 | \n",
" 1.673266 | \n",
" 0.592285 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 10 | \n",
" 0.782597 | \n",
" 1.707047 | \n",
" 0.580322 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 11 | \n",
" 0.744230 | \n",
" 1.719031 | \n",
" 0.581299 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 12 | \n",
" 0.710533 | \n",
" 1.790540 | \n",
" 0.593262 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 13 | \n",
" 0.690307 | \n",
" 1.801058 | \n",
" 0.587565 | \n",
" 00:02 | \n",
"
\n",
" \n",
" | 14 | \n",
" 0.678195 | \n",
" 1.765376 | \n",
" 0.600179 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func, metrics=accuracy, cbs=ModelReseter)\n",
"learn.fit_one_cycle(15, 3e-3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}