{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Predicting English word version of numbers using an RNN"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We were using RNNs as part of our language model in the previous lesson. Today, we will dive into more details of what RNNs are and how they work. We will do this using the problem of trying to predict the English word version of numbers.\n",
"\n",
"Let's predict what should come next in this sequence:\n",
"\n",
"*eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve...*\n",
"\n",
"\n",
"Jeremy created this synthetic dataset to have a better way to check if things are working, to debug, and to understand what was going on. When experimenting with new ideas, it can be nice to have a smaller dataset to do so, to quickly get a sense of whether your ideas are promising (for other examples, see [Imagenette and Imagewoof](https://github.com/fastai/imagenette)) This English word numbers will serve as a good dataset for learning about RNNs. Our task today will be to predict which word comes next when counting."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Data"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"from fastai.text import *"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
"bs=64"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[PosixPath('/home/jhoward/.fastai/data/human_numbers/train.txt'),\n",
" PosixPath('/home/jhoward/.fastai/data/human_numbers/valid.txt')]"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = untar_data(URLs.HUMAN_NUMBERS)\n",
"path.ls()"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
"outputs": [],
"source": [
"def readnums(d): return [', '.join(o.strip() for o in open(path/d).readlines())]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"train.txt gives us a sequence of numbers written out as English words:"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'"
]
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_txt = readnums('train.txt'); train_txt[0][:80]"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"valid_txt = readnums('valid.txt'); valid_txt[0][-80:]"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
"train = TextList(train_txt, path=path)\n",
"valid = TextList(valid_txt, path=path)\n",
"\n",
"src = ItemLists(path=path, train=train, valid=valid).label_for_lm()\n",
"data = src.databunch(bs=bs)"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'xxbos one , two , three , four , five , six , seven , eight , nine , ten , eleve'"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train[0].text[:80]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`bptt` stands for *back-propagation through time*. This tells us how many steps of history we are considering."
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(70, 3)"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.bptt, len(data.valid_dl)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have 3 batches in our validation set:\n",
"\n",
"13017 tokens, with about ~70 tokens in about a line of text, and 64 lines of text per batch."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will store each batch in a separate variable, so we can walk through this to understand better what the RNN does at each step:"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
"it = iter(data.valid_dl)\n",
"x1,y1 = next(it)\n",
"x2,y2 = next(it)\n",
"x3,y3 = next(it)\n",
"it.close()"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"v = data.valid_ds.vocab"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
"data = src.databunch(bs=bs, bptt=40)"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([64, 40]), torch.Size([64, 40]))"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x,y = data.one_batch()\n",
"x.shape,y.shape"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"40"
]
},
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nv = len(v.itos); nv"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
"nh=56"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
"def loss4(input,target): return F.cross_entropy(input, target[:,-1])\n",
"def acc4 (input,target): return accuracy(input, target[:,-1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Layer names:\n",
"- `i_h`: input to hidden\n",
"- `h_h`: hidden to hidden\n",
"- `h_o`: hidden to output\n",
"- `bn`: batchnorm"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Adding a GRU"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When you have long time scales and deeper networks, these become impossible to train. One way to address this is to add mini-NN to decide how much of the green arrow and how much of the orange arrow to keep. These mini-NNs can be GRUs or LSTMs."
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
"class Model5(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.i_h = nn.Embedding(nv,nh)\n",
" self.rnn = nn.GRU(nh, nh, 1, batch_first=True)\n",
" self.h_o = nn.Linear(nh,nv)\n",
" self.bn = BatchNorm1dFlat(nh)\n",
" self.h = torch.zeros(1, bs, nh).cuda()\n",
" \n",
" def forward(self, x):\n",
" res,h = self.rnn(self.i_h(x), self.h)\n",
" self.h = h.detach()\n",
" return self.h_o(self.bn(res))"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(40, 56)"
]
},
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"nv, nh"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {},
"outputs": [],
"source": [
"learn = Learner(data, Model5(), metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 135,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.540133 | \n",
" 3.465589 | \n",
" 0.240951 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 1 | \n",
" 2.788598 | \n",
" 2.131198 | \n",
" 0.388867 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 2 | \n",
" 2.147029 | \n",
" 1.868370 | \n",
" 0.441536 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.768475 | \n",
" 1.858901 | \n",
" 0.475521 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1.472611 | \n",
" 1.808398 | \n",
" 0.623893 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 5 | \n",
" 1.204455 | \n",
" 1.676029 | \n",
" 0.621549 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 6 | \n",
" 0.971170 | \n",
" 1.593996 | \n",
" 0.674219 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 7 | \n",
" 0.779145 | \n",
" 1.554770 | \n",
" 0.663021 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 8 | \n",
" 0.633210 | \n",
" 1.524638 | \n",
" 0.700195 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 9 | \n",
" 0.529737 | \n",
" 1.528556 | \n",
" 0.704883 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn.fit_one_cycle(10, 1e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Let's make our own GRU"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using PyTorch's GRUCell"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Axis 0 is the batch dimension, and axis 1 is the time dimension. We want to loop through axis 1:"
]
},
{
"cell_type": "code",
"execution_count": 130,
"metadata": {},
"outputs": [],
"source": [
"def rnn_loop(cell, h, x):\n",
" res = []\n",
" for x_ in x.transpose(0,1):\n",
" h = cell(x_, h)\n",
" res.append(h)\n",
" return torch.stack(res, dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {},
"outputs": [],
"source": [
"class Model6(Model5):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.rnnc = nn.GRUCell(nh, nh)\n",
" self.h = torch.zeros(bs, nh).cuda()\n",
" \n",
" def forward(self, x):\n",
" res = rnn_loop(self.rnnc, self.h, self.i_h(x))\n",
" self.h = res[:,-1].detach()\n",
" return self.h_o(self.bn(res))"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.436362 | \n",
" 3.394695 | \n",
" 0.352865 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 1 | \n",
" 2.739339 | \n",
" 2.135875 | \n",
" 0.464258 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 2 | \n",
" 2.118275 | \n",
" 1.790151 | \n",
" 0.483854 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.757279 | \n",
" 1.725953 | \n",
" 0.506576 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1.468234 | \n",
" 1.623885 | \n",
" 0.566797 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 5 | \n",
" 1.191478 | \n",
" 1.475327 | \n",
" 0.656771 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 6 | \n",
" 0.942030 | \n",
" 1.302567 | \n",
" 0.728320 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 7 | \n",
" 0.739052 | \n",
" 1.320555 | \n",
" 0.759049 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 8 | \n",
" 0.588394 | \n",
" 1.269386 | \n",
" 0.763737 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 9 | \n",
" 0.483452 | \n",
" 1.246448 | \n",
" 0.773112 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(data, Model6(), metrics=accuracy)\n",
"learn.fit_one_cycle(10, 1e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### With a custom GRUCell"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following is based on code from [emadRad](https://github.com/emadRad/lstm-gru-pytorch/blob/master/lstm_gru.ipynb):"
]
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {},
"outputs": [],
"source": [
"class GRUCell(nn.Module):\n",
" def __init__(self, ni, nh):\n",
" super(GRUCell, self).__init__()\n",
" self.ni,self.nh = ni,nh\n",
" self.i2h = nn.Linear(ni, 3*nh)\n",
" self.h2h = nn.Linear(nh, 3*nh)\n",
" \n",
" def forward(self, x, h):\n",
" gate_x = self.i2h(x).squeeze()\n",
" gate_h = self.h2h(h).squeeze()\n",
" i_r,i_u,i_n = gate_x.chunk(3, 1)\n",
" h_r,h_u,h_n = gate_h.chunk(3, 1)\n",
" \n",
" resetgate = torch.sigmoid(i_r + h_r)\n",
" updategate = torch.sigmoid(i_u + h_u)\n",
" newgate = torch.tanh(i_n + (resetgate*h_n))\n",
" return updategate*h + (1-updategate)*newgate"
]
},
{
"cell_type": "code",
"execution_count": 137,
"metadata": {},
"outputs": [],
"source": [
"class Model7(Model6):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.rnnc = GRUCell(nh,nh)"
]
},
{
"cell_type": "code",
"execution_count": 139,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 3.538269 | \n",
" 3.513104 | \n",
" 0.276562 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 1 | \n",
" 2.815921 | \n",
" 2.241536 | \n",
" 0.364844 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 2 | \n",
" 2.140461 | \n",
" 1.959162 | \n",
" 0.418424 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 3 | \n",
" 1.704312 | \n",
" 1.870930 | \n",
" 0.490104 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 4 | \n",
" 1.350341 | \n",
" 1.672874 | \n",
" 0.611784 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 5 | \n",
" 1.040827 | \n",
" 1.563094 | \n",
" 0.664974 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 6 | \n",
" 0.793811 | \n",
" 1.539096 | \n",
" 0.716276 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 7 | \n",
" 0.606719 | \n",
" 1.480668 | \n",
" 0.727604 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 8 | \n",
" 0.472559 | \n",
" 1.456170 | \n",
" 0.731836 | \n",
" 00:01 | \n",
"
\n",
" \n",
" | 9 | \n",
" 0.380289 | \n",
" 1.481691 | \n",
" 0.732617 | \n",
" 00:01 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = Learner(data, Model7(), metrics=accuracy)\n",
"learn.fit_one_cycle(10, 1e-2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Connection to ULMFit"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the previous lesson, we were essentially swapping out `self.h_o` with a classifier in order to do classification on text.\n",
"\n",
"RNNs are just a refactored, fully-connected neural network.\n",
"\n",
"You can use the same approach for any sequence labeling task (part of speech, classifying whether material is sensitive,..)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## fin"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}