{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.test import *\n", "from local.data.all import *\n", "from local.text.core import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.notebook.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp text.models.awdlstm\n", "#default_cls_lvl 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# AWD-LSTM\n", "\n", "> AWD LSTM from [Smerity et al.](https://arxiv.org/pdf/1708.02182.pdf) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic NLP modules" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On top of the pytorch or the fastai [`layers`](/layers.html#layers), the language models use some custom layers specific to NLP." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def dropout_mask(x, sz, p):\n", " \"Return a dropout mask of the same type as `x`, size `sz`, with probability `p` to cancel an element.\"\n", " return x.new(*sz).bernoulli_(1-p).div_(1-p)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = dropout_mask(torch.randn(3,4), [4,3], 0.25)\n", "test_eq(t.shape, [4,3])\n", "assert ((t == 4/3) + (t==0)).all()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class RNNDropout(Module):\n", " \"Dropout with probability `p` that is consistent on the seq_len dimension.\"\n", " def __init__(self, p=0.5): self.p=p\n", "\n", " def forward(self, x):\n", " if not self.training or self.p == 0.: return x\n", " return x * dropout_mask(x.data, (x.size(0), 1, x.size(2)), self.p)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dp = RNNDropout(0.3)\n", "tst_inp = torch.randn(4,3,7)\n", "tst_out = dp(tst_inp)\n", "for i in range(4):\n", " for j in range(7):\n", " if tst_out[i,0,j] == 0: assert (tst_out[i,:,j] == 0).all()\n", " else: test_close(tst_out[i,:,j], tst_inp[i,:,j]/(1-0.3))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "import warnings" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class WeightDropout(Module):\n", " \"A module that warps another layer in which some weights will be replaced by 0 during training.\"\n", "\n", " def __init__(self, module, weight_p, layer_names='weight_hh_l0'):\n", " self.module,self.weight_p,self.layer_names = module,weight_p,L(layer_names)\n", " for layer in self.layer_names:\n", " #Makes a copy of the weights of the selected layers.\n", " w = getattr(self.module, layer)\n", " self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))\n", " self.module._parameters[layer] = F.dropout(w, p=self.weight_p, training=False)\n", "\n", " def _setweights(self):\n", " \"Apply dropout to the raw weights.\"\n", " for layer in self.layer_names:\n", " raw_w = getattr(self, f'{layer}_raw')\n", " self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=self.training)\n", "\n", " def forward(self, *args):\n", " self._setweights()\n", " with warnings.catch_warnings():\n", " #To avoid the warning that comes because the weights aren't flattened.\n", " warnings.simplefilter(\"ignore\")\n", " return self.module.forward(*args)\n", "\n", " def reset(self):\n", " for layer in self.layer_names:\n", " raw_w = getattr(self, f'{layer}_raw')\n", " self.module._parameters[layer] = F.dropout(raw_w, p=self.weight_p, training=False)\n", " if hasattr(self.module, 'reset'): self.module.reset()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "module = nn.LSTM(5,7).cuda()\n", "dp_module = WeightDropout(module, 0.4)\n", "wgts = getattr(dp_module.module, 'weight_hh_l0')\n", "tst_inp = torch.randn(10,20,5).cuda()\n", "h = torch.zeros(1,20,7).cuda(), torch.zeros(1,20,7).cuda()\n", "x,h = dp_module(tst_inp,h)\n", "new_wgts = getattr(dp_module.module, 'weight_hh_l0')\n", "test_eq(wgts, getattr(dp_module, 'weight_hh_l0_raw'))\n", "assert 0.2 <= (new_wgts==0).sum().float()/new_wgts.numel() <= 0.6" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class EmbeddingDropout(Module):\n", " \"Apply dropout with probabily `embed_p` to an embedding layer `emb`.\"\n", "\n", " def __init__(self, emb, embed_p):\n", " self.emb,self.embed_p = emb,embed_p\n", "\n", " def forward(self, words, scale=None):\n", " if self.training and self.embed_p != 0:\n", " size = (self.emb.weight.size(0),1)\n", " mask = dropout_mask(self.emb.weight.data, size, self.embed_p)\n", " masked_embed = self.emb.weight * mask\n", " else: masked_embed = self.emb.weight\n", " if scale: masked_embed.mul_(scale)\n", " return F.embedding(words, masked_embed, ifnone(self.emb.padding_idx, -1), self.emb.max_norm,\n", " self.emb.norm_type, self.emb.scale_grad_by_freq, self.emb.sparse)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "enc = nn.Embedding(10, 7, padding_idx=1)\n", "enc_dp = EmbeddingDropout(enc, 0.5)\n", "tst_inp = torch.randint(0,10,(8,))\n", "tst_out = enc_dp(tst_inp)\n", "for i in range(8):\n", " assert (tst_out[i]==0).all() or torch.allclose(tst_out[i], 2*enc.weight[tst_inp[i]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class AWD_LSTM(Module):\n", " \"AWD-LSTM inspired by https://arxiv.org/abs/1708.02182\"\n", " initrange=0.1\n", "\n", " def __init__(self, vocab_sz, emb_sz, n_hid, n_layers, pad_token=1, hidden_p=0.2, input_p=0.6, embed_p=0.1,\n", " weight_p=0.5, bidir=False, packed=False):\n", " store_attr(self, 'emb_sz,n_hid,n_layers,pad_token,packed')\n", " self.bs = 1\n", " self.n_dir = 2 if bidir else 1\n", " self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)\n", " self.encoder_dp = EmbeddingDropout(self.encoder, embed_p)\n", " self.rnns = nn.ModuleList([self._one_rnn(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.n_dir,\n", " bidir, weight_p, l) for l in range(n_layers)])\n", " self.encoder.weight.data.uniform_(-self.initrange, self.initrange)\n", " self.input_dp = RNNDropout(input_p)\n", " self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)])\n", "\n", " def forward(self, inp, from_embeds=False):\n", " bs,sl = inp.shape[:2] if from_embeds else inp.shape\n", " if bs!=self.bs:\n", " self.bs=bs\n", " self.reset()\n", " if self.packed: inp,lens = self._pack_sequence(inp, sl)\n", "\n", " raw_output = self.input_dp(inp if from_embeds else self.encoder_dp(inp))\n", " new_hidden,raw_outputs,outputs = [],[],[]\n", " for l, (rnn,hid_dp) in enumerate(zip(self.rnns, self.hidden_dps)):\n", " if self.packed: raw_output = pack_padded_sequence(raw_output, lens, batch_first=True)\n", " raw_output, new_h = rnn(raw_output, self.hidden[l])\n", " if self.packed: raw_output = pad_packed_sequence(raw_output, batch_first=True)[0]\n", " new_hidden.append(new_h)\n", " raw_outputs.append(raw_output)\n", " if l != self.n_layers - 1: raw_output = hid_dp(raw_output)\n", " outputs.append(raw_output)\n", " self.hidden = to_detach(new_hidden, cpu=False, gather=False)\n", " return raw_outputs, outputs\n", "\n", " def _one_rnn(self, n_in, n_out, bidir, weight_p, l):\n", " \"Return one of the inner rnn\"\n", " rnn = nn.LSTM(n_in, n_out, 1, batch_first=True, bidirectional=bidir)\n", " return WeightDropout(rnn, weight_p)\n", "\n", " def _one_hidden(self, l):\n", " \"Return one hidden state\"\n", " nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz) // self.n_dir\n", " return (one_param(self).new_zeros(self.n_dir, self.bs, nh), one_param(self).new_zeros(self.n_dir, self.bs, nh))\n", "\n", " def reset(self):\n", " \"Reset the hidden states\"\n", " [r.reset() for r in self.rnns if hasattr(r, 'reset')]\n", " self.hidden = [self._one_hidden(l) for l in range(self.n_layers)]\n", "\n", " def _pack_sequence(self, inp, sl):\n", " mask = (inp == self.pad_token)\n", " lens = sl - mask.long().sum(1)\n", " n_empty = (lens == 0).sum()\n", " if n_empty > 0:\n", " inp,lens = inp[:-n_empty],lens[:-n_empty]\n", " self.hidden = [(h[0][:,:inp.size(0)], h[1][:,:inp.size(0)]) for h in self.hidden]\n", " return (inp,lens)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the core of an AWD-LSTM model, with embeddings from `vocab_sz` and `emb_sz`, `n_layers` LSTMs potentialy `bidir` stacked, the first one going from `emb_sz` to `n_hid`, the last one from `n_hid` to `emb_sz` and all the inner ones from `n_hid` to `n_hid`. `pad_token` is passed to the PyTorch embedding layer. The dropouts are applied as such:\n", "- the embeddings are wrapped in `EmbeddingDropout` of probability `embed_p`;\n", "- the result of thise embedding layer goes through an `RNNDropout` of probability `input_p`;\n", "- each LSTM has `WeightDropout` applied with probability `weight_p`;\n", "- between two of the inner LSTM, an `RNNDropout` is applied with probabilith `hidden_p`.\n", "\n", "THe module returns two lists: the raw outputs (without being applied the dropout of `hidden_p`) of each inner LSTM and the list of outputs with dropout. Since there is no dropout applied on the last output, those two lists have the same last element, which is the output that should be fed to a decoder (in the case of a language model)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = AWD_LSTM(100, 20, 10, 2)\n", "x = torch.randint(0, 100, (10,5))\n", "r = tst(x)\n", "test_eq(tst.bs, 10)\n", "test_eq(len(tst.hidden), 2)\n", "test_eq([h_.shape for h_ in tst.hidden[0]], [[1,10,10], [1,10,10]])\n", "test_eq([h_.shape for h_ in tst.hidden[1]], [[1,10,20], [1,10,20]])\n", "test_eq(len(r), 2)\n", "test_eq(r[0][-1], r[1][-1]) #No dropout for last output\n", "for i in range(2): test_eq([h_.shape for h_ in r[i]], [[10,5,10], [10,5,20]])\n", "for i in range(2): test_eq(r[0][i][:,-1], tst.hidden[i][0][0]) #hidden state is the last timestep in raw outputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test packed with padding\n", "tst = AWD_LSTM(100, 20, 10, 2, packed=True)\n", "x = torch.randint(2, 100, (10,5))\n", "x[9,3:] = 1\n", "r = tst(x)\n", "test_eq(tst.bs, 10)\n", "test_eq(len(tst.hidden), 2)\n", "test_eq([h_.shape for h_ in tst.hidden[0]], [[1,10,10], [1,10,10]])\n", "test_eq([h_.shape for h_ in tst.hidden[1]], [[1,10,20], [1,10,20]])\n", "test_eq(len(r), 2)\n", "test_eq(r[0][-1], r[1][-1]) #No dropout for last output\n", "for i in range(2): test_eq([h_.shape for h_ in r[i]], [[10,5,10], [10,5,20]])\n", "#hidden state is the last timestep in raw outputs\n", "for i in range(2): test_eq(r[0][i][:,-1][:9], tst.hidden[i][0][0][:9])\n", "for i in range(2): test_eq(r[0][i][:,-3][9], tst.hidden[i][0][0][9])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def awd_lstm_lm_split(model):\n", " \"Split a RNN `model` in groups for differential learning rates.\"\n", " groups = [nn.Sequential(rnn, dp) for rnn, dp in zip(model[0].rnns, model[0].hidden_dps)]\n", " groups = L(groups + [nn.Sequential(model[0].encoder, model[0].encoder_dp, model[1])])\n", " return groups.map(params)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "splits = awd_lstm_lm_split" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "awd_lstm_lm_config = dict(emb_sz=400, n_hid=1152, n_layers=3, pad_token=1, bidir=False, output_p=0.1, packed=False,\n", " hidden_p=0.15, input_p=0.25, embed_p=0.02, weight_p=0.2, tie_weights=True, out_bias=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def awd_lstm_clas_split(model):\n", " \"Split a RNN `model` in groups for differential learning rates.\"\n", " groups = [nn.Sequential(model[0].module.encoder, model[0].module.encoder_dp)]\n", " groups += [nn.Sequential(rnn, dp) for rnn, dp in zip(model[0].module.rnns, model[0].module.hidden_dps)]\n", " groups = L(groups + [model[1]])\n", " return groups.map(params)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "awd_lstm_clas_config = dict(emb_sz=400, n_hid=1152, n_layers=3, pad_token=1, bidir=False, output_p=0.4,\n", " hidden_p=0.3, input_p=0.4, embed_p=0.05, weight_p=0.5, packed=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## QRNN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class AWD_QRNN(AWD_LSTM):\n", " \"Same as an AWD-LSTM, but using QRNNs instead of LSTMs\"\n", " def _one_rnn(self, n_in, n_out, bidir, weight_p, l):\n", " from local.text.models.qrnn import QRNN\n", " rnn = QRNN(n_in, n_out, 1, save_prev_x=True, zoneout=0, window=2 if l == 0 else 1, output_gate=True, bidirectional=bidir)\n", " rnn.layers[0].linear = WeightDropout(rnn.layers[0].linear, weight_p, layer_names='weight')\n", " return rnn\n", "\n", " def _one_hidden(self, l):\n", " \"Return one hidden state\"\n", " nh = (self.n_hid if l != self.n_layers - 1 else self.emb_sz) // self.n_dir\n", " return one_param(self).new_zeros(self.n_dir, self.bs, nh)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "awd_qrnn_lm_config = dict(emb_sz=400, n_hid=1552, n_layers=4, pad_token=1, bidir=False, output_p=0.1,\n", " hidden_p=0.15, input_p=0.25, embed_p=0.02, weight_p=0.2, tie_weights=True, out_bias=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "awd_qrnn_clas_config = dict(emb_sz=400, n_hid=1552, n_layers=4, pad_token=1, bidir=False, output_p=0.4,\n", " hidden_p=0.3, input_p=0.4, embed_p=0.05, weight_p=0.5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_core_foundation.ipynb.\n", "Converted 01a_core_utils.ipynb.\n", "Converted 01b_core_dispatch.ipynb.\n", "Converted 01c_core_transform.ipynb.\n", "Converted 02_core_script.ipynb.\n", "Converted 03_torchcore.ipynb.\n", "Converted 03a_layers.ipynb.\n", "Converted 04_data_load.ipynb.\n", "Converted 05_data_core.ipynb.\n", "Converted 06_data_transforms.ipynb.\n", "Converted 07_data_block.ipynb.\n", "Converted 08_vision_core.ipynb.\n", "Converted 09_vision_augment.ipynb.\n", "Converted 09a_vision_data.ipynb.\n", "Converted 09b_vision_utils.ipynb.\n", "Converted 10_pets_tutorial.ipynb.\n", "Converted 11_vision_models_xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_learner.ipynb.\n", "Converted 13a_metrics.ipynb.\n", "Converted 14_callback_schedule.ipynb.\n", "Converted 14a_callback_data.ipynb.\n", "Converted 15_callback_hook.ipynb.\n", "Converted 15a_vision_models_unet.ipynb.\n", "Converted 16_callback_progress.ipynb.\n", "Converted 17_callback_tracker.ipynb.\n", "Converted 18_callback_fp16.ipynb.\n", "Converted 19_callback_mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision_learner.ipynb.\n", "Converted 22_tutorial_imagenette.ipynb.\n", "Converted 23_tutorial_transfer_learning.ipynb.\n", "Converted 30_text_core.ipynb.\n", "Converted 31_text_data.ipynb.\n", "Converted 32_text_models_awdlstm.ipynb.\n", "Converted 33_text_models_core.ipynb.\n", "Converted 34_callback_rnn.ipynb.\n", "Converted 35_tutorial_wikitext.ipynb.\n", "Converted 36_text_models_qrnn.ipynb.\n", "Converted 37_text_learner.ipynb.\n", "Converted 38_tutorial_ulmfit.ipynb.\n", "Converted 40_tabular_core.ipynb.\n", "Converted 41_tabular_model.ipynb.\n", "Converted 42_tabular_rapids.ipynb.\n", "Converted 50_data_block_examples.ipynb.\n", "Converted 60_medical_imaging.ipynb.\n", "Converted 65_medical_text.ipynb.\n", "Converted 70_callback_wandb.ipynb.\n", "Converted 71_callback_tensorboard.ipynb.\n", "Converted 90_notebook_core.ipynb.\n", "Converted 91_notebook_export.ipynb.\n", "Converted 92_notebook_showdoc.ipynb.\n", "Converted 93_notebook_export2html.ipynb.\n", "Converted 94_notebook_test.ipynb.\n", "Converted 95_index.ipynb.\n", "Converted 96_data_external.ipynb.\n", "Converted 97_utils_test.ipynb.\n", "Converted notebook2jekyll.ipynb.\n", "Converted xse_resnext.ipynb.\n" ] } ], "source": [ "#hide\n", "from local.notebook.export import notebook2script\n", "notebook2script(all_fs=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }