{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from utils import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# A language model from scratch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai2.text.all import *\n", "path = untar_data(URLs.HUMAN_NUMBERS)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "Path.BASE_PATH = path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#2) [Path('train.txt'),Path('valid.txt')]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path.ls()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#9998) ['one \\n','two \\n','three \\n','four \\n','five \\n','six \\n','seven \\n','eight \\n','nine \\n','ten \\n'...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lines = L()\n", "with open(path/'train.txt') as f: lines += L(*f.readlines())\n", "with open(path/'valid.txt') as f: lines += L(*f.readlines())\n", "lines" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'one . two . three . four . five . six . seven . eight . nine . ten . eleven . twelve . thirteen . fo'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "text = ' . '.join([l.strip() for l in lines])\n", "text[:100]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['one', '.', 'two', '.', 'three', '.', 'four', '.', 'five', '.']" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokens = text.split(' ')\n", "tokens[:10]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#30) ['one','.','two','three','four','five','six','seven','eight','nine'...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab = L(*tokens).unique()\n", "vocab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#63095) [0,1,2,1,3,1,4,1,5,1...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "word2idx = {w:i for i,w in enumerate(vocab)}\n", "nums = L(word2idx[i] for i in tokens)\n", "nums" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Our first language model from scratch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#21031) [(['one', '.', 'two'], '.'),(['.', 'three', '.'], 'four'),(['four', '.', 'five'], '.'),(['.', 'six', '.'], 'seven'),(['seven', '.', 'eight'], '.'),(['.', 'nine', '.'], 'ten'),(['ten', '.', 'eleven'], '.'),(['.', 'twelve', '.'], 'thirteen'),(['thirteen', '.', 'fourteen'], '.'),(['.', 'fifteen', '.'], 'sixteen')...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "L((tokens[i:i+3], tokens[i+3]) for i in range(0,len(tokens)-4,3))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#21031) [(tensor([0, 1, 2]), 1),(tensor([1, 3, 1]), 4),(tensor([4, 1, 5]), 1),(tensor([1, 6, 1]), 7),(tensor([7, 1, 8]), 1),(tensor([1, 9, 1]), 10),(tensor([10, 1, 11]), 1),(tensor([ 1, 12, 1]), 13),(tensor([13, 1, 14]), 1),(tensor([ 1, 15, 1]), 16)...]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "seqs = L((tensor(nums[i:i+3]), nums[i+3]) for i in range(0,len(nums)-4,3))\n", "seqs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bs = 64\n", "cut = int(len(seqs) * 0.8)\n", "dls = DataLoaders.from_dsets(seqs[:cut], seqs[cut:], bs=64, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Our language model in PyTorch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LMModel1(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden) \n", " self.h_h = nn.Linear(n_hidden, n_hidden) \n", " self.h_o = nn.Linear(n_hidden,vocab_sz)\n", " \n", " def forward(self, x):\n", " h = F.relu(self.h_h(self.i_h(x[:,0])))\n", " h = h + self.i_h(x[:,1])\n", " h = F.relu(self.h_h(h))\n", " h = h + self.i_h(x[:,2])\n", " h = F.relu(self.h_h(h))\n", " return self.h_o(h)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
01.8242971.9709410.46755400:05
11.3869731.8232420.46755400:05
21.4175561.6544970.49441400:05
31.3764401.6508490.49441400:05
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(dls, LMModel1(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy)\n", "learn.fit_one_cycle(4, 1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor(29), 'thousand', 0.15165200855716662)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "n,counts = 0,torch.zeros(len(vocab))\n", "for x,y in dls.valid:\n", " n += y.shape[0]\n", " for i in range_of(vocab): counts[i] += (y==i).long().sum()\n", "idx = torch.argmax(counts)\n", "idx, vocab[idx.item()], counts[idx].item()/n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Our first recurrent neural network" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LMModel2(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden) \n", " self.h_h = nn.Linear(n_hidden, n_hidden) \n", " self.h_o = nn.Linear(n_hidden,vocab_sz)\n", " \n", " def forward(self, x):\n", " h = 0\n", " for i in range(3):\n", " h = h + self.i_h(x[:,i])\n", " h = F.relu(self.h_h(h))\n", " return self.h_o(h)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
01.8162741.9641430.46018500:04
11.4238051.7399640.47325900:05
21.4303271.6851720.48538200:05
31.3883901.6570330.47040600:05
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(dls, LMModel2(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy)\n", "learn.fit_one_cycle(4, 1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Improving the RNN" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Maintaining the state of an RNN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LMModel3(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden) \n", " self.h_h = nn.Linear(n_hidden, n_hidden) \n", " self.h_o = nn.Linear(n_hidden,vocab_sz)\n", " self.h = 0\n", " \n", " def forward(self, x):\n", " for i in range(3):\n", " self.h = self.h + self.i_h(x[:,i])\n", " self.h = F.relu(self.h_h(self.h))\n", " out = self.h_o(self.h)\n", " self.h = self.h.detach()\n", " return out\n", " \n", " def reset(self): self.h = 0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def group_chunks(ds, bs):\n", " m = len(ds) // bs\n", " new_ds = L()\n", " for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))\n", " return new_ds" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cut = int(len(seqs) * 0.8)\n", "dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs), group_chunks(seqs[cut:], bs), bs=bs, drop_last=True, shuffle=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
01.6770741.8273670.46754800:06
11.2827221.8709130.38894200:06
21.0907051.6517940.46250000:05
31.0052151.6159900.51514400:06
40.9630201.6058940.55120200:06
50.9261501.7216080.54326900:06
60.9015291.6508390.55937500:05
70.8299931.7439130.56995200:06
80.8105081.7464860.58413500:06
90.7959211.7562000.58221200:04
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(dls, LMModel3(len(vocab), 64), loss_func=F.cross_entropy, metrics=accuracy, cbs=ModelReseter)\n", "learn.fit_one_cycle(10, 3e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Creating more signal" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sl = 16\n", "seqs = L((tensor(nums[i:i+sl]), tensor(nums[i+1:i+sl+1])) for i in range(0,len(nums)-sl-1,sl))\n", "cut = int(len(seqs) * 0.8)\n", "dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs), group_chunks(seqs[cut:], bs), bs=bs, drop_last=True, shuffle=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LMModel4(Module):\n", " def __init__(self, vocab_sz, n_hidden):\n", " self.i_h = nn.Embedding(vocab_sz, n_hidden) \n", " self.h_h = nn.Linear(n_hidden, n_hidden) \n", " self.h_o = nn.Linear(n_hidden,vocab_sz)\n", " self.h = 0\n", " \n", " def forward(self, x):\n", " outs = []\n", " for i in range(sl):\n", " self.h = self.h + self.i_h(x[:,i])\n", " self.h = F.relu(self.h_h(self.h))\n", " outs.append(self.h_o(self.h))\n", " self.h = self.h.detach()\n", " return torch.stack(outs, dim=1)\n", " \n", " def reset(self): self.h = 0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def loss_func(inp, targ): return F.cross_entropy(inp.view(-1, len(vocab)), targ.view(-1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
03.2859313.0720320.21256500:02
12.3303711.9695220.42578100:02
21.7423171.8413780.44148800:02
31.4701201.8108560.49430300:02
41.2988101.8231290.49283900:02
51.1768401.7554350.50903300:02
61.0704331.6892500.51749700:02
70.9729991.8673140.51302100:02
80.8965051.7162960.58268200:02
90.8358171.6732660.59228500:02
100.7825971.7070470.58032200:02
110.7442301.7190310.58129900:02
120.7105331.7905400.59326200:02
130.6903071.8010580.58756500:02
140.6781951.7653760.60017900:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(dls, LMModel4(len(vocab), 64), loss_func=loss_func, metrics=accuracy, cbs=ModelReseter)\n", "learn.fit_one_cycle(15, 3e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }