{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#skip\n", "! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai.basics import *\n", "from fastai.text.core import *\n", "from fastai.text.data import *\n", "from fastai.text.models.core import *\n", "from fastai.text.models.awdlstm import *\n", "from fastai.callback.rnn import *\n", "from fastai.callback.progress import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp text.learner" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Learner for the text application\n", "\n", "> All the functions necessary to build `Learner` suitable for transfer learning in NLP" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The most important functions of this module are `language_model_learner` and `text_classifier_learner`. They will help you define a `Learner` using a pretrained model. See the [text tutorial](http://docs.fast.ai/tutorial.text) for examples of use." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading a pretrained model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In text, to load a pretrained model, we need to adapt the embeddings of the vocabulary used for the pre-training to the vocabulary of our current corpus." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def match_embeds(old_wgts, old_vocab, new_vocab):\n", " \"Convert the embedding in `old_wgts` to go from `old_vocab` to `new_vocab`.\"\n", " bias, wgts = old_wgts.get('1.decoder.bias', None), old_wgts['0.encoder.weight']\n", " wgts_m = wgts.mean(0)\n", " new_wgts = wgts.new_zeros((len(new_vocab),wgts.size(1)))\n", " if bias is not None:\n", " bias_m = bias.mean(0)\n", " new_bias = bias.new_zeros((len(new_vocab),))\n", " old_o2i = old_vocab.o2i if hasattr(old_vocab, 'o2i') else {w:i for i,w in enumerate(old_vocab)}\n", " for i,w in enumerate(new_vocab):\n", " idx = old_o2i.get(w, -1)\n", " new_wgts[i] = wgts[idx] if idx>=0 else wgts_m\n", " if bias is not None: new_bias[i] = bias[idx] if idx>=0 else bias_m\n", " old_wgts['0.encoder.weight'] = new_wgts\n", " if '0.encoder_dp.emb.weight' in old_wgts: old_wgts['0.encoder_dp.emb.weight'] = new_wgts.clone()\n", " old_wgts['1.decoder.weight'] = new_wgts.clone()\n", " if bias is not None: old_wgts['1.decoder.bias'] = new_bias\n", " return old_wgts" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For words in `new_vocab` that don't have a corresponding match in `old_vocab`, we use the mean of all pretrained embeddings. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "wgts = {'0.encoder.weight': torch.randn(5,3)}\n", "new_wgts = match_embeds(wgts.copy(), ['a', 'b', 'c'], ['a', 'c', 'd', 'b'])\n", "old,new = wgts['0.encoder.weight'],new_wgts['0.encoder.weight']\n", "test_eq(new[0], old[0])\n", "test_eq(new[1], old[2])\n", "test_eq(new[2], old.mean(0))\n", "test_eq(new[3], old[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#With bias\n", "wgts = {'0.encoder.weight': torch.randn(5,3), '1.decoder.bias': torch.randn(5)}\n", "new_wgts = match_embeds(wgts.copy(), ['a', 'b', 'c'], ['a', 'c', 'd', 'b'])\n", "old_w,new_w = wgts['0.encoder.weight'],new_wgts['0.encoder.weight']\n", "old_b,new_b = wgts['1.decoder.bias'], new_wgts['1.decoder.bias']\n", "test_eq(new_w[0], old_w[0])\n", "test_eq(new_w[1], old_w[2])\n", "test_eq(new_w[2], old_w.mean(0))\n", "test_eq(new_w[3], old_w[1])\n", "test_eq(new_b[0], old_b[0])\n", "test_eq(new_b[1], old_b[2])\n", "test_eq(new_b[2], old_b.mean(0))\n", "test_eq(new_b[3], old_b[1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _get_text_vocab(dls):\n", " vocab = dls.vocab\n", " if isinstance(vocab, L): vocab = vocab[0]\n", " return vocab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def load_ignore_keys(model, wgts):\n", " \"Load `wgts` in `model` ignoring the names of the keys, just taking parameters in order\"\n", " sd = model.state_dict()\n", " for k1,k2 in zip(sd.keys(), wgts.keys()): sd[k1].data = wgts[k2].data.clone()\n", " return model.load_state_dict(sd)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _rm_module(n):\n", " t = n.split('.')\n", " for i in range(len(t)-1, -1, -1):\n", " if t[i] == 'module':\n", " t.pop(i)\n", " break\n", " return '.'.join(t)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "#For previous versions compatibility, remove for release\n", "def clean_raw_keys(wgts):\n", " keys = list(wgts.keys())\n", " for k in keys:\n", " t = k.split('.module')\n", " if f'{_rm_module(k)}_raw' in keys: del wgts[k]\n", " return wgts" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "#For previous versions compatibility, remove for release\n", "def load_model_text(file, model, opt, with_opt=None, device=None, strict=True):\n", " \"Load `model` from `file` along with `opt` (if available, and if `with_opt`)\"\n", " distrib_barrier()\n", " if isinstance(device, int): device = torch.device('cuda', device)\n", " elif device is None: device = 'cpu'\n", " state = torch.load(file, map_location=device)\n", " hasopt = set(state)=={'model', 'opt'}\n", " model_state = state['model'] if hasopt else state\n", " get_model(model).load_state_dict(clean_raw_keys(model_state), strict=strict)\n", " if hasopt and ifnone(with_opt,True):\n", " try: opt.load_state_dict(state['opt'])\n", " except:\n", " if with_opt: warn(\"Could not load the optimizer state.\")\n", " elif with_opt: warn(\"Saved filed doesn't contain an optimizer state.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@delegates(Learner.__init__)\n", "class TextLearner(Learner):\n", " \"Basic class for a `Learner` in NLP.\"\n", " def __init__(self, dls, model, alpha=2., beta=1., moms=(0.8,0.7,0.8), **kwargs):\n", " super().__init__(dls, model, moms=moms, **kwargs)\n", " self.add_cbs(rnn_cbs())\n", "\n", " def save_encoder(self, file):\n", " \"Save the encoder to `file` in the model directory\"\n", " if rank_distrib(): return # don't save if child proc\n", " encoder = get_model(self.model)[0]\n", " if hasattr(encoder, 'module'): encoder = encoder.module\n", " torch.save(encoder.state_dict(), join_path_file(file, self.path/self.model_dir, ext='.pth'))\n", "\n", " def load_encoder(self, file, device=None):\n", " \"Load the encoder `file` from the model directory, optionally ensuring it's on `device`\"\n", " encoder = get_model(self.model)[0]\n", " if device is None: device = self.dls.device\n", " if hasattr(encoder, 'module'): encoder = encoder.module\n", " distrib_barrier()\n", " wgts = torch.load(join_path_file(file,self.path/self.model_dir, ext='.pth'), map_location=device)\n", " encoder.load_state_dict(clean_raw_keys(wgts))\n", " self.freeze()\n", " return self\n", "\n", " def load_pretrained(self, wgts_fname, vocab_fname, model=None):\n", " \"Load a pretrained model and adapt it to the data vocabulary.\"\n", " old_vocab = load_pickle(vocab_fname)\n", " new_vocab = _get_text_vocab(self.dls)\n", " distrib_barrier()\n", " wgts = torch.load(wgts_fname, map_location = lambda storage,loc: storage)\n", " if 'model' in wgts: wgts = wgts['model'] #Just in case the pretrained model was saved with an optimizer\n", " wgts = match_embeds(wgts, old_vocab, new_vocab)\n", " load_ignore_keys(self.model if model is None else model, clean_raw_keys(wgts))\n", " self.freeze()\n", " return self\n", "\n", " #For previous versions compatibility. Remove at release\n", " @delegates(load_model_text)\n", " def load(self, file, with_opt=None, device=None, **kwargs):\n", " if device is None: device = self.dls.device\n", " if self.opt is None: self.create_opt()\n", " file = join_path_file(file, self.path/self.model_dir, ext='.pth')\n", " load_model_text(file, self.model, self.opt, device=device, **kwargs)\n", " return self" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adds a `ModelResetter` and an `RNNRegularizer` with `alpha` and `beta` to the callbacks, the rest is the same as `Learner` init. \n", "\n", "This `Learner` adds functionality to the base class:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "
TextLearner.load_pretrained
[source]TextLearner.load_pretrained
(**`wgts_fname`**, **`vocab_fname`**, **`model`**=*`None`*)\n",
"\n",
"Load a pretrained model and adapt it to the data vocabulary."
],
"text/plain": [
"TextLearner.save_encoder
[source]TextLearner.save_encoder
(**`file`**)\n",
"\n",
"Save the encoder to `file` in the model directory"
],
"text/plain": [
"TextLearner.load_encoder
[source]TextLearner.load_encoder
(**`file`**, **`device`**=*`None`*)\n",
"\n",
"Load the encoder `file` from the model directory, optionally ensuring it's on `device`"
],
"text/plain": [
"class
LMLearner
[source]LMLearner
(**`dls`**, **`model`**, **`alpha`**=*`2.0`*, **`beta`**=*`1.0`*, **`moms`**=*`(0.8, 0.7, 0.8)`*, **`loss_func`**=*`None`*, **`opt_func`**=*`Adam`*, **`lr`**=*`0.001`*, **`splitter`**=*`trainable_params`*, **`cbs`**=*`None`*, **`metrics`**=*`None`*, **`path`**=*`None`*, **`model_dir`**=*`'models'`*, **`wd`**=*`None`*, **`wd_bn_bias`**=*`False`*, **`train_bn`**=*`True`*) :: [`TextLearner`](/text.learner.html#TextLearner)\n",
"\n",
"Add functionality to [`TextLearner`](/text.learner.html#TextLearner) when dealing with a language model"
],
"text/plain": [
"LMLearner.predict
[source]LMLearner.predict
(**`text`**, **`n_words`**=*`1`*, **`no_unk`**=*`True`*, **`temperature`**=*`1.0`*, **`min_p`**=*`None`*, **`no_bar`**=*`False`*, **`decoder`**=*`decode_spec_tokens`*, **`only_last_word`**=*`False`*)\n",
"\n",
"Return `text` and the `n_words` that come after"
],
"text/plain": [
"