{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.test import *\n", "from fastai2.data.all import *\n", "from fastai2.text.core import *\n", "from fastai2.text.models.awdlstm import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp text.models.core\n", "#default_cls_lvl 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Core text modules\n", "\n", "> Contain the modules common between different architectures and the generic functions to get models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "_model_meta = {AWD_LSTM: {'hid_name':'emb_sz', 'url':URLs.WT103_FWD, 'url_bwd':URLs.WT103_BWD,\n", " 'config_lm':awd_lstm_lm_config, 'split_lm': awd_lstm_lm_split,\n", " 'config_clas':awd_lstm_clas_config, 'split_clas': awd_lstm_clas_split},\n", " AWD_QRNN: {'hid_name':'emb_sz',\n", " 'config_lm':awd_qrnn_lm_config, 'split_lm': awd_lstm_lm_split,\n", " 'config_clas':awd_qrnn_clas_config, 'split_clas': awd_lstm_clas_split},}\n", " # Transformer: {'hid_name':'d_model', 'url':URLs.OPENAI_TRANSFORMER,\n", " # 'config_lm':tfmer_lm_config, 'split_lm': tfmer_lm_split,\n", " # 'config_clas':tfmer_clas_config, 'split_clas': tfmer_clas_split},\n", " # TransformerXL: {'hid_name':'d_model',\n", " # 'config_lm':tfmerXL_lm_config, 'split_lm': tfmerXL_lm_split,\n", " # 'config_clas':tfmerXL_clas_config, 'split_clas': tfmerXL_clas_split}}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Language models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class LinearDecoder(Module):\n", " \"To go on top of a RNNCore module and create a Language Model.\"\n", " initrange=0.1\n", "\n", " def __init__(self, n_out, n_hid, output_p=0.1, tie_encoder=None, bias=True):\n", " self.decoder = nn.Linear(n_hid, n_out, bias=bias)\n", " self.decoder.weight.data.uniform_(-self.initrange, self.initrange)\n", " self.output_dp = RNNDropout(output_p)\n", " if bias: self.decoder.bias.data.zero_()\n", " if tie_encoder: self.decoder.weight = tie_encoder.weight\n", "\n", " def forward(self, input):\n", " raw_outputs, outputs = input\n", " decoded = self.decoder(self.output_dp(outputs[-1]))\n", " return decoded, raw_outputs, outputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai2.text.models.awdlstm import *\n", "enc = AWD_LSTM(100, 20, 10, 2)\n", "x = torch.randint(0, 100, (10,5))\n", "r = enc(x)\n", "\n", "tst = LinearDecoder(100, 20, 0.1)\n", "y = tst(r)\n", "test_eq(y[1], r[0])\n", "test_eq(y[2], r[1])\n", "test_eq(y[0].shape, [10, 5, 100])\n", "\n", "tst = LinearDecoder(100, 20, 0.1, tie_encoder=enc.encoder)\n", "test_eq(tst.decoder.weight, enc.encoder.weight)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class SequentialRNN(nn.Sequential):\n", " \"A sequential module that passes the reset call to its children.\"\n", " def reset(self):\n", " for c in self.children(): getattr(c, 'reset', noop)()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class _TstMod(Module):\n", " def reset(self): print('reset')\n", "\n", "tst = SequentialRNN(_TstMod(), _TstMod())\n", "test_stdout(tst.reset, 'reset\\nreset')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def get_language_model(arch, vocab_sz, config=None, drop_mult=1.):\n", " \"Create a language model from `arch` and its `config`.\"\n", " meta = _model_meta[arch]\n", " config = ifnone(config, meta['config_lm']).copy()\n", " for k in config.keys():\n", " if k.endswith('_p'): config[k] *= drop_mult\n", " tie_weights,output_p,out_bias = map(config.pop, ['tie_weights', 'output_p', 'out_bias'])\n", " init = config.pop('init') if 'init' in config else None\n", " encoder = arch(vocab_sz, **config)\n", " enc = encoder.encoder if tie_weights else None\n", " decoder = LinearDecoder(vocab_sz, config[meta['hid_name']], output_p, tie_encoder=enc, bias=out_bias)\n", " model = SequentialRNN(encoder, decoder)\n", " return model if init is None else model.apply(init)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The default `config` used can be found in `_model_meta[arch]['config_lm']`. `drop_mult` is applied to all the probabilities of dropout in that config." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config = awd_lstm_lm_config.copy()\n", "config.update({'n_hid':10, 'emb_sz':20})\n", "\n", "tst = get_language_model(AWD_LSTM, 100, config=config)\n", "x = torch.randint(0, 100, (10,5))\n", "y = tst(x)\n", "test_eq(y[0].shape, [10, 5, 100])\n", "test_eq(tst[1].decoder.weight, tst[0].encoder.weight)\n", "for i in range(1,3): test_eq([h_.shape for h_ in y[1]], [[10, 5, 10], [10, 5, 10], [10, 5, 20]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test drop_mult\n", "tst = get_language_model(AWD_LSTM, 100, config=config, drop_mult=0.5)\n", "test_eq(tst[1].output_dp.p, config['output_p']*0.5)\n", "for rnn in tst[0].rnns: test_eq(rnn.weight_p, config['weight_p']*0.5)\n", "for dp in tst[0].hidden_dps: test_eq(dp.p, config['hidden_p']*0.5)\n", "test_eq(tst[0].encoder_dp.embed_p, config['embed_p']*0.5)\n", "test_eq(tst[0].input_dp.p, config['input_p']*0.5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classification models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _pad_tensor(t, bs, val=0.):\n", " if t.size(0) < bs: return torch.cat([t, val + t.new_zeros(bs-t.size(0), *t.shape[1:])])\n", " return t" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class SentenceEncoder(Module):\n", " \"Create an encoder over `module` that can process a full sentence.\"\n", " def __init__(self, bptt, module, pad_idx=1): store_attr(self, 'bptt,module,pad_idx')\n", "\n", " def _concat(self, arrs, bs):\n", " return [torch.cat([_pad_tensor(l[si],bs) for l in arrs], dim=1) for si in range(len(arrs[0]))]\n", "\n", " def reset(self): getattr(self.module, 'reset', noop)()\n", "\n", " def forward(self, input):\n", " bs,sl = input.size()\n", " self.reset()\n", " raw_outputs,outputs,masks = [],[],[]\n", " for i in range(0, sl, self.bptt):\n", " r,o = self.module(input[:,i: min(i+self.bptt, sl)])\n", " masks.append(input[:,i: min(i+self.bptt, sl)] == self.pad_idx)\n", " raw_outputs.append(r)\n", " outputs.append(o)\n", " return self._concat(raw_outputs, bs),self._concat(outputs, bs),torch.cat(masks,dim=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class DoubleEmbedding(nn.Embedding):\n", " def forward(self, x): \n", " y = super().forward(x)\n", " return ([y],[y+1])\n", " \n", "mod = DoubleEmbedding(5, 10,)\n", "tst = SentenceEncoder(5, mod, pad_idx=0)\n", "x = torch.randint(1, 5, (3, 15))\n", "x[2,10:]=0\n", "raw,out,mask = tst(x) \n", "test_eq(raw[0], mod(x)[0][0])\n", "test_eq(out[0], mod(x)[0][0]+1)\n", "test_eq(mask, x==0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class PoolingLinearClassifier(nn.Module):\n", " \"Create a linear classifier with pooling.\"\n", "\n", " def __init__(self, layers, drops):\n", " super().__init__()\n", " mod_layers = []\n", " activs = [nn.ReLU(inplace=True)] * (len(layers) - 2) + [None]\n", " for n_in, n_out, p, actn in zip(layers[:-1], layers[1:], drops, activs):\n", " mod_layers += bn_drop_lin(n_in, n_out, p=p, actn=actn)\n", " self.layers = nn.Sequential(*mod_layers)\n", "\n", " def forward(self, input):\n", " raw_outputs,outputs,mask = input\n", " output = outputs[-1]\n", " lengths = output.size(1) - mask.long().sum(dim=1)\n", " avg_pool = output.masked_fill(mask[:,:,None], 0).sum(dim=1)\n", " avg_pool.div_(lengths.type(avg_pool.dtype)[:,None])\n", " max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]\n", " x = torch.cat([output[torch.arange(0, output.size(0)),lengths-1], max_pool, avg_pool], 1) #Concat pooling.\n", " x = self.layers(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def masked_concat_pool(outputs, mask):\n", " \"Pool `MultiBatchEncoder` outputs into one vector [last_hidden, max_pool, avg_pool]\"\n", " output = outputs[-1]\n", " lens = output.size(1) - mask.long().sum(dim=1)\n", " avg_pool = output.masked_fill(mask[:, :, None], 0).sum(dim=1)\n", " avg_pool.div_(lens.type(avg_pool.dtype)[:,None])\n", " max_pool = output.masked_fill(mask[:,:,None], -float('inf')).max(dim=1)[0]\n", " x = torch.cat([output[torch.arange(0, output.size(0)),lens-1], max_pool, avg_pool], 1) #Concat pooling.\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "out = torch.randn(2,3,5)\n", "mask = tensor([[False,False,True], [False,False,False]])\n", "x = masked_concat_pool([out], mask)\n", "test_close(x[0,:5], out[0,-2])\n", "test_close(x[1,:5], out[1,-1])\n", "test_close(x[0,5:10], out[0,:2].max(dim=0)[0])\n", "test_close(x[1,5:10], out[1].max(dim=0)[0])\n", "test_close(x[0,10:], out[0,:2].mean(dim=0))\n", "test_close(x[1,10:], out[1].mean(dim=0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Test the result is independent of padding\n", "out1 = torch.randn(2,4,5)\n", "out1[:,:-1] = out.clone()\n", "mask1 = tensor([[False,False,True,True], [False,False,False,True]])\n", "x1 = masked_concat_pool([out1], mask1)\n", "test_eq(x, x1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class PoolingLinearClassifier(Module):\n", " \"Create a linear classifier with pooling\"\n", " def __init__(self, dims, ps):\n", " mod_layers = []\n", " if len(ps) != len(dims)-1: raise ValueError(\"Number of layers and dropout values do not match.\")\n", " acts = [nn.ReLU(inplace=True)] * (len(dims) - 2) + [None]\n", " layers = [LinBnDrop(i, o, p=p, act=a) for i,o,p,a in zip(dims[:-1], dims[1:], ps, acts)]\n", " self.layers = nn.Sequential(*layers)\n", "\n", " def forward(self, input):\n", " raw,out,mask = input\n", " x = masked_concat_pool(out, mask)\n", " x = self.layers(x)\n", " return x, raw, out" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mod = DoubleEmbedding(5, 10)\n", "tst = nn.Sequential(SentenceEncoder(5, mod, pad_idx=0), PoolingLinearClassifier([10*3,4], [0.]))\n", "\n", "x = torch.randint(1, 5, (3, 14))\n", "x[2,10:] = 0\n", "res,raw,out = tst(x) \n", "test_eq(raw[0], mod(x)[0][0])\n", "test_eq(out[0], mod(x)[0][0]+1)\n", "test_eq(res.shape, [3,4])\n", "\n", "x1 = torch.cat([x, tensor([0,0,0])[:,None]], dim=1)\n", "res1,raw1,out1 = tst(x1) \n", "test_eq(res, res1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def get_text_classifier(arch, vocab_sz, n_class, bptt=72, config=None, drop_mult=1., lin_ftrs=None,\n", " ps=None, pad_idx=1):\n", " \"Create a text classifier from `arch` and its `config`, maybe `pretrained`\"\n", " meta = _model_meta[arch]\n", " config = ifnone(config, meta['config_clas']).copy()\n", " for k in config.keys():\n", " if k.endswith('_p'): config[k] *= drop_mult\n", " if lin_ftrs is None: lin_ftrs = [50]\n", " if ps is None: ps = [0.1]*len(lin_ftrs)\n", " layers = [config[meta['hid_name']] * 3] + lin_ftrs + [n_class]\n", " ps = [config.pop('output_p')] + ps\n", " init = config.pop('init') if 'init' in config else None\n", " encoder = SentenceEncoder(bptt, arch(vocab_sz, **config), pad_idx=pad_idx)\n", " model = SequentialRNN(encoder, PoolingLinearClassifier(layers, ps))\n", " return model if init is None else model.apply(init)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "config = awd_lstm_clas_config.copy()\n", "config.update({'n_hid':10, 'emb_sz':20})\n", "\n", "tst = get_text_classifier(AWD_LSTM, 100, 3, config=config)\n", "x = torch.randint(2, 100, (10,5))\n", "y = tst(x)\n", "test_eq(y[0].shape, [10, 3])\n", "for i in range(1,3): test_eq([h_.shape for h_ in y[1]], [[10, 5, 10], [10, 5, 10], [10, 5, 20]])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test padding gives same results\n", "tst.eval()\n", "y = tst(x)\n", "x1 = torch.cat([x, tensor([2,1,1,1,1,1,1,1,1,1])[:,None]], dim=1)\n", "y1 = tst(x1)\n", "test_close(y[0][1:],y1[0][1:])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test drop_mult\n", "tst = get_text_classifier(AWD_LSTM, 100, 3, config=config, drop_mult=0.5)\n", "test_eq(tst[1].layers[1][2].p, 0.1)\n", "test_eq(tst[1].layers[0][3].p, config['output_p']*0.5)\n", "for rnn in tst[0].module.rnns: test_eq(rnn.weight_p, config['weight_p']*0.5)\n", "for dp in tst[0].module.hidden_dps: test_eq(dp.p, config['hidden_p']*0.5)\n", "test_eq(tst[0].module.encoder_dp.embed_p, config['embed_p']*0.5)\n", "test_eq(tst[0].module.input_dp.p, config['input_p']*0.5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_core_foundation.ipynb.\n", "Converted 01a_core_utils.ipynb.\n", "Converted 01b_core_dispatch.ipynb.\n", "Converted 01c_core_transform.ipynb.\n", "Converted 02_core_script.ipynb.\n", "Converted 03_torchcore.ipynb.\n", "Converted 03a_layers.ipynb.\n", "Converted 04_data_load.ipynb.\n", "Converted 05_data_core.ipynb.\n", "Converted 06_data_transforms.ipynb.\n", "Converted 07_data_block.ipynb.\n", "Converted 08_vision_core.ipynb.\n", "Converted 09_vision_augment.ipynb.\n", "Converted 09a_vision_data.ipynb.\n", "Converted 10_pets_tutorial.ipynb.\n", "Converted 11_vision_models_xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_learner.ipynb.\n", "Converted 13a_metrics.ipynb.\n", "Converted 14_callback_schedule.ipynb.\n", "Converted 14a_callback_data.ipynb.\n", "Converted 15_callback_hook.ipynb.\n", "Converted 15a_vision_models_unet.ipynb.\n", "Converted 16_callback_progress.ipynb.\n", "Converted 17_callback_tracker.ipynb.\n", "Converted 18_callback_fp16.ipynb.\n", "Converted 19_callback_mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision_learner.ipynb.\n", "Converted 22_tutorial_imagenette.ipynb.\n", "Converted 23_tutorial_transfer_learning.ipynb.\n", "Converted 30_text_core.ipynb.\n", "Converted 31_text_data.ipynb.\n", "Converted 32_text_models_awdlstm.ipynb.\n", "Converted 33_text_models_core.ipynb.\n", "Converted 34_callback_rnn.ipynb.\n", "Converted 35_tutorial_wikitext.ipynb.\n", "Converted 36_text_models_qrnn.ipynb.\n", "Converted 37_text_learner.ipynb.\n", "Converted 38_tutorial_ulmfit.ipynb.\n", "Converted 40_tabular_core.ipynb.\n", "Converted 41_tabular_model.ipynb.\n", "Converted 42_tabular_rapids.ipynb.\n", "Converted 50_data_block_examples.ipynb.\n", "Converted 60_medical_imaging.ipynb.\n", "Converted 65_medical_text.ipynb.\n", "Converted 70_callback_wandb.ipynb.\n", "Converted 71_callback_tensorboard.ipynb.\n", "Converted 90_notebook_core.ipynb.\n", "Converted 91_notebook_export.ipynb.\n", "Converted 92_notebook_showdoc.ipynb.\n", "Converted 93_notebook_export2html.ipynb.\n", "Converted 94_notebook_test.ipynb.\n", "Converted 95_index.ipynb.\n", "Converted 96_data_external.ipynb.\n", "Converted 97_utils_test.ipynb.\n", "Converted notebook2jekyll.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }