{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## NLP model creation and training" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.text import * \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The main thing here is [`RNNLearner`](/text.learner.html#RNNLearner). There are also some utility functions to help create and update text models." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Quickly get a learner" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
language_model_learner
[source][test]language_model_learner
(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`arch`**, **`config`**:`dict`=***`None`***, **`drop_mult`**:`float`=***`1.0`***, **`pretrained`**:`bool`=***`True`***, **`pretrained_fnames`**:`OptStrTuple`=***`None`***, **\\*\\*`learn_kwargs`**) → `LanguageLearner`\n",
"\n",
"AWD_LSTM
](/text.models.html#AWD_LSTM)([Merity et al.](https://arxiv.org/abs/1708.02182))\n",
"- a [Transformer
](/text.models.html#Transformer) decoder ([Vaswani et al.](https://arxiv.org/abs/1706.03762))\n",
"- a [TransformerXL
](/text.models.html#TransformerXL) ([Dai et al.](https://arxiv.org/abs/1901.02860))\n",
"\n",
"They each have a default config for language modelling that is in {lower_case_class_name}\\_lm\\_config
if you want to change the default parameter. At this stage, only the AWD LSTM and Tranformer support `pretrained=True` but we hope to add more pretrained models soon. `drop_mult` is applied to all the dropouts weights of the `config`, `learn_kwargs` are passed to the [`Learner`](/basic_train.html#Learner) initialization.\n",
"\n",
"If your [`data`](/text.data.html#text.data) is backward, the pretrained model downloaded will also be a backward one (only available for [`AWD_LSTM`](/text.models.awd_lstm.html#AWD_LSTM))."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"text_classifier_learner
[source][test]text_classifier_learner
(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`arch`**:`Callable`, **`bptt`**:`int`=***`70`***, **`max_len`**:`int`=***`1400`***, **`config`**:`dict`=***`None`***, **`pretrained`**:`bool`=***`True`***, **`drop_mult`**:`float`=***`1.0`***, **`lin_ftrs`**:`Collection`\\[`int`\\]=***`None`***, **`ps`**:`Collection`\\[`float`\\]=***`None`***, **\\*\\*`learn_kwargs`**) → `TextClassifierLearner`\n",
"\n",
"class
RNNLearner
[source][test]RNNLearner
(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`split_func`**:`OptSplitFunc`=***`None`***, **`clip`**:`float`=***`None`***, **`alpha`**:`float`=***`2.0`***, **`beta`**:`float`=***`1.0`***, **`metrics`**=***`None`***, **\\*\\*`learn_kwargs`**) :: [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for RNNLearner
. To contribute a test please refer to this guide and this discussion.
data
and a `model` with a text data using a certain `bptt`. The `split_func` is used to properly split the model in different groups for gradual unfreezing and differential learning rates. Gradient clipping of `clip` is optionally applied. `alpha` and `beta` are all passed to create an instance of [`RNNTrainer`](/callbacks.rnn.html#RNNTrainer). Can be used for a language model or an RNN classifier. It also handles the conversion of weights from a pretrained model as well as saving or loading the encoder."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"get_preds
[source][test]get_preds
(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for get_preds
. To contribute a test please refer to this guide and this discussion.
class
TextClassificationInterpretation
[source][test]TextClassificationInterpretation
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for TextClassificationInterpretation
. To contribute a test please refer to this guide and this discussion.
Text | \n", "Prediction | \n", "Actual | \n", "Loss | \n", "Probability | \n", "
---|---|---|---|---|
xxbos i have to agree with what many of the other reviewers concluded . a subject which could have been thought - provoking and shed light on a reversed double - standard , failed miserably . \\n \\n xxmaj rape being a crime of violence and forced abusive control , the scenes here were for the most part pathetic . xxmaj it would have been a better idea to | \n", "pos | \n", "neg | \n", "8.25 | \n", "0.00 | \n", "
xxbos xxmaj betty xxmaj sizemore ( xxmaj renee xxmaj zellweger ) lives her life through soap xxmaj opera \" a xxmaj reason to xxmaj love \" as a way to escape her slob husband and dull life . xxmaj after a shocking incident involving two hit men ( xxmaj morgan xxmaj freeman and xxmaj chris xxmaj rock ) , xxmaj betty goes into shock and travels to xxup la , | \n", "pos | \n", "pos | \n", "7.71 | \n", "1.00 | \n", "
xxbos xxmaj when people harp on about how \" they do n't make 'em like they used to \" then just point them towards this fantastically entertaining , and quaint - looking , comedy horror from writer - director xxmaj glenn mcquaid . \\n \\n xxmaj it 's a tale of graverobbers ( played by xxmaj dominic xxmaj monaghan and xxmaj larry xxmaj fessenden ) who end up digging | \n", "pos | \n", "pos | \n", "7.47 | \n", "1.00 | \n", "
xxbos i have to agree with all the previous xxunk -- this is simply the best of all frothy comedies , with xxmaj bardot as sexy as xxmaj marilyn xxmaj monroe ever was , and definitely with a prettier face ( maybe there 's less mystique , but look how xxmaj marilyn paid for that . ) i do n't think i 've ever seen such a succulent - looking | \n", "pos | \n", "pos | \n", "6.55 | \n", "1.00 | \n", "
xxbos i will freely admit that i have n't seen the original movie , but i 've read the play , so i 've some background with the \" original . \" xxmaj if you shuck off the fact that this is a remake of an old classic , this movie is smart , witty , fresh , and hilarious . xxmaj yes , the casting decisions may seem strange | \n", "pos | \n", "pos | \n", "6.38 | \n", "1.00 | \n", "
load_encoder
[source][test]load_encoder
(**`name`**:`str`, **`device`**:[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)=***`None`***)\n",
"\n",
"No tests found for load_encoder
. To contribute a test please refer to this guide and this discussion.
save_encoder
[source][test]save_encoder
(**`name`**:`str`)\n",
"\n",
"No tests found for save_encoder
. To contribute a test please refer to this guide and this discussion.
load_pretrained
[source][test]load_pretrained
(**`wgts_fname`**:`str`, **`itos_fname`**:`str`, **`strict`**:`bool`=***`True`***)\n",
"\n",
"No tests found for load_pretrained
. To contribute a test please refer to this guide and this discussion.
data
. The two files should be in the models directory of the `learner.path`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Utility functions"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"convert_weights
[source][test]convert_weights
(**`wgts`**:`Weights`, **`stoi_wgts`**:`Dict`\\[`str`, `int`\\], **`itos_new`**:`StrList`) → `Weights`\n",
"\n",
"No tests found for convert_weights
. To contribute a test please refer to this guide and this discussion.
class
LanguageLearner
[source][test]LanguageLearner
(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`split_func`**:`OptSplitFunc`=***`None`***, **`clip`**:`float`=***`None`***, **`alpha`**:`float`=***`2.0`***, **`beta`**:`float`=***`1.0`***, **`metrics`**=***`None`***, **\\*\\*`learn_kwargs`**) :: [`RNNLearner`](/text.learner.html#RNNLearner)\n",
"\n",
"No tests found for LanguageLearner
. To contribute a test please refer to this guide and this discussion.
predict
[source][test]predict
(**`text`**:`str`, **`n_words`**:`int`=***`1`***, **`no_unk`**:`bool`=***`True`***, **`temperature`**:`float`=***`1.0`***, **`min_p`**:`float`=***`None`***, **`sep`**:`str`=***`' '`***, **`decoder`**=***`'decode_spec_tokens'`***)\n",
"\n",
"No tests found for predict
. To contribute a test please refer to this guide and this discussion.
beam_search
[source][test]beam_search
(**`text`**:`str`, **`n_words`**:`int`, **`no_unk`**:`bool`=***`True`***, **`top_k`**:`int`=***`10`***, **`beam_sz`**:`int`=***`1000`***, **`temperature`**:`float`=***`1.0`***, **`sep`**:`str`=***`' '`***, **`decoder`**=***`'decode_spec_tokens'`***)\n",
"\n",
"No tests found for beam_search
. To contribute a test please refer to this guide and this discussion.
get_language_model
[source][test]get_language_model
(**`arch`**:`Callable`, **`vocab_sz`**:`int`, **`config`**:`dict`=***`None`***, **`drop_mult`**:`float`=***`1.0`***)\n",
"\n",
"No tests found for get_language_model
. To contribute a test please refer to this guide and this discussion.
get_text_classifier
[source][test]get_text_classifier
(**`arch`**:`Callable`, **`vocab_sz`**:`int`, **`n_class`**:`int`, **`bptt`**:`int`=***`70`***, **`max_len`**:`int`=***`1400`***, **`config`**:`dict`=***`None`***, **`drop_mult`**:`float`=***`1.0`***, **`lin_ftrs`**:`Collection`\\[`int`\\]=***`None`***, **`ps`**:`Collection`\\[`float`\\]=***`None`***, **`pad_idx`**:`int`=***`1`***) → [`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\n",
"\n",
"No tests found for get_text_classifier
. To contribute a test please refer to this guide and this discussion.
forward
[source][test]forward
(**`input`**:`LongTensor`) → `Tuple`\\[`List`\\[`Tensor`\\], `List`\\[`Tensor`\\], `Tensor`\\]\n",
"\n",
"No tests found for forward
. To contribute a test please refer to this guide and this discussion.
show_results
[source][test]show_results
(**`ds_type`**=***`No tests found for show_results
. To contribute a test please refer to this guide and this discussion.
concat
[source][test]concat
(**`arrs`**:`Sequence`\\[`Sequence`\\[`Tensor`\\]\\]) → `List`\\[`Tensor`\\]\n",
"\n",
"No tests found for concat
. To contribute a test please refer to this guide and this discussion.
class
MultiBatchEncoder
[source][test]MultiBatchEncoder
(**`bptt`**:`int`, **`max_len`**:`int`, **`module`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`pad_idx`**:`int`=***`1`***) :: [`PrePostInitMeta`](/core.html#PrePostInitMeta) :: [`Module`](/torch_core.html#Module)\n",
"\n",
"No tests found for MultiBatchEncoder
. To contribute a test please refer to this guide and this discussion.
decode_spec_tokens
[source][test]decode_spec_tokens
(**`tokens`**)\n",
"\n",
"No tests found for decode_spec_tokens
. To contribute a test please refer to this guide and this discussion.
reset
[source][test]reset
()\n",
"\n",
"No tests found for reset
. To contribute a test please refer to this guide and this discussion.