{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Basic training functionality" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.basic_train import *\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`basic_train`](/basic_train.html#basic_train) wraps together the data (in a [`DataBunch`](/basic_data.html#DataBunch) object) with a pytorch model to define a [`Learner`](/basic_train.html#Learner) object. This is where the basic training loop is defined for the [`fit`](/basic_train.html#fit) function. The [`Learner`](/basic_train.html#Learner) object is the entry point of most of the [`Callback`](/callback.html#Callback) functions that will customize this training loop in different ways (and made available through the [`train`](/train.html#train) module), notably:\n", "\n", " - [`Learner.lr_find`](/train.html#lr_find) will launch an LR range test that will help you select a good learning rate\n", " - [`Learner.fit_one_cycle`](/train.html#fit_one_cycle) will launch a training using the 1cycle policy, to help you train your model fast.\n", " - [`Learner.to_fp16`](/train.html#to_fp16) will convert your model in half precision and help you launch a training in mixed precision." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class Learner[source]

\n", "\n", "> Learner(`data`:[`DataBunch`](/basic_data.html#DataBunch), `model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `opt_func`:`Callable`=`'Adam'`, `loss_func`:`Callable`=`None`, `metrics`:`Collection`\\[`Callable`\\]=`None`, `true_wd`:`bool`=`True`, `bn_wd`:`bool`=`True`, `wd`:`Floats`=`0.01`, `train_bn`:`bool`=`True`, `path`:`str`=`None`, `model_dir`:`str`=`'models'`, `callback_fns`:`Collection`\\[`Callable`\\]=`None`, `callbacks`:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=``, `layer_groups`:`ModuleList`=`None`)\n", "\n", "Train `model` using `data` to minimize `loss_func` with optimizer `opt_func`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner, title_level=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The main purpose of [`Learner`](/basic_train.html#Learner) is to train `model` using [`Learner.fit`](/basic_train.html#Learner.fit). After every epoch, all *metrics* will be printed, and will also be available to callbacks.\n", "\n", "The default weight decay will be `wd`, which will be handled using the method from [Fixing Weight Decay Regularization in Adam](https://arxiv.org/abs/1711.05101) if `true_wd` is set (otherwise it's L2 regularization). If `bn_wd` is False then weight decay will be removed from batchnorm layers, as recommended in [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/abs/1706.02677). You can ensure that batchnorm layer learnable params are trained even for frozen layer groups, by enabling `train_bn`.\n", "\n", "To use [discriminative layer training](#Discriminative-layer-training) pass an [`nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) for each layer group to be optimized with different settings.\n", "\n", "Any model files created will be saved in `path`/`model_dir`.\n", "\n", "You can pass a list of [`callbacks`](/callbacks.html#callbacks) that you have already created, or (more commonly) simply pass a list of callback functions to `callback_fns` and each function will be called (passing `self`) on object initialization, with the results stored as callback objects. For a walk-through, see the [training overview](/training.html) page. You may also want to use an `application` to fit your model, e.g. using the [`create_cnn`](/vision.learner.html#create_cnn) method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [ { "data": { "text/html": [ "Total time: 00:09

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.1425970.0858230.968106
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "learn = create_cnn(data, models.resnet18, metrics=accuracy)\n", "learn.fit(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model fitting methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

fit[source]

\n", "\n", "> fit(`epochs`:`int`, `lr`:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=`slice(None, 0.003, None)`, `wd`:`Floats`=`None`, `callbacks`:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=`None`)\n", "\n", "Fit the model on this learner with `lr` learning rate, `wd` weight decay for `epochs` with `callbacks`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.fit)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Uses [discriminative layer training](#Discriminative-layer-training) if multiple learning rates or weight decay values are passed. To control training behaviour, use the [`callback`](/callback.html#callback) system or one or more of the pre-defined [`callbacks`](/callbacks.html#callbacks)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

fit_one_cycle[source]

\n", "\n", "> fit_one_cycle(`learn`:[`Learner`](/basic_train.html#Learner), `cyc_len`:`int`, `max_lr`:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=`slice(None, 0.003, None)`, `moms`:`Point`=`(0.95, 0.85)`, `div_factor`:`float`=`25.0`, `pct_start`:`float`=`0.3`, `wd`:`float`=`None`, `callbacks`:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=`None`, `kwargs`)\n", "\n", "Fit a model following the 1cycle policy. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.fit_one_cycle)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Uses the [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler) callback." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

lr_find[source]

\n", "\n", "> lr_find(`learn`:[`Learner`](/basic_train.html#Learner), `start_lr`:`Floats`=`1e-07`, `end_lr`:`Floats`=`10`, `num_it`:`int`=`100`, `stop_div`:`bool`=`True`, `kwargs`:`Any`)\n", "\n", "Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss explodes. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.lr_find)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Runs the learning rate finder defined in [`LRFinder`](/callbacks.lr_finder.html#LRFinder), as discussed in [Cyclical Learning Rates for Training Neural Networks](https://arxiv.org/abs/1506.01186). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### See results" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

get_preds[source]

\n", "\n", "> get_preds(`ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=``, `with_loss`:`bool`=`False`, `n_batch`:`Optional`\\[`int`\\]=`None`, `pbar`:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=`None`) → `List`\\[`Tensor`\\]\n", "\n", "Return predictions and targets on the valid, train, or test set, depending on `ds_type`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.get_preds)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

validate[source]

\n", "\n", "> validate(`dl`=`None`, `callbacks`=`None`, `metrics`=`None`)\n", "\n", "Validate on `dl` with potential `callbacks` and `metrics`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.validate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

show_results[source]

\n", "\n", "> show_results(`ds_type`=``, `rows`:`int`=`5`, `kwargs`)\n", "\n", "Show `rows` result of predictions on `ds_type` dataset. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.show_results)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

predict[source]

\n", "\n", "> predict(`img`:[`ItemBase`](/core.html#ItemBase), `kwargs`)\n", "\n", "Return prect class, label and probabilities for `img`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.predict)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

pred_batch[source]

\n", "\n", "> pred_batch(`ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=``, `batch`:`Tuple`=`None`) → `List`\\[`Tensor`\\]\n", "\n", "Return output of the model on one batch from valid, train, or test set, depending on `ds_type`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.pred_batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test time augmentation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

TTA[source]

\n", "\n", "> TTA(`learn`:[`Learner`](/basic_train.html#Learner), `beta`:`float`=`0.4`, `scale`:`float`=`1.35`, `ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=``, `with_loss`:`bool`=`False`) → `Tensors`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.TTA, full_name = 'TTA')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Applies Test Time Augmentation to `learn` on the dataset `ds_type`. We take the average of our regular predictions (with a weight `beta`) with the average of predictions obtained thourh augmented versions of the training set (with a weight `1-beta`). The transforms decided for the training set are applied with a few changes `scale` controls the scale for zoom (which isn't random), the cropping isn't random but we make sure to get the four corners of the image. Flipping isn't random but applied once on each of those corner images (so that makes 8 augmented versions total)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gradient clipping" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

clip_grad[source]

\n", "\n", "> clip_grad(`learn`:[`Learner`](/basic_train.html#Learner), `clip`:`float`=`0.1`) → [`Learner`](/basic_train.html#Learner)\n", "\n", "Gradient clipping during training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.clip_grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Mixed precision training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

to_fp16[source]

\n", "\n", "> to_fp16(`learn`:[`Learner`](/basic_train.html#Learner), `loss_scale`:`float`=`512.0`, `flat_master`:`bool`=`False`) → [`Learner`](/basic_train.html#Learner)\n", "\n", "Transform `learn` in FP16 precision. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.to_fp16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Uses the [`MixedPrecision`](/callbacks.fp16.html#MixedPrecision) callback to train in mixed precision (i.e. forward and backward passes using fp16, with weight updates using fp32), using all [NVIDIA recommendations](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) for ensuring speed and accuracy." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Discriminative layer training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When fitting a model you can pass a list of learning rates (and/or weight decay amounts), which will apply a different rate to each *layer group* (i.e. the parameters of each module in `self.layer_groups`). See the [Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146) paper for details and experimental results in NLP (we also frequently use them successfully in computer vision, but have not published a paper on this topic yet). When working with a [`Learner`](/basic_train.html#Learner) on which you've called `split`, you can set hyperparameters in four ways:\n", "\n", "1. `param = [val1, val2 ..., valn]` (n = number of layer groups)\n", "2. `param = val`\n", "3. `param = slice(start,end)`\n", "4. `param = slice(end)`\n", "\n", "If we chose to set it in way 1, we must specify a number of values exactly equal to the number of layer groups. If we chose to set it in way 2, the chosen value will be repeated for all layer groups. See [`Learner.lr_range`](/basic_train.html#Learner.lr_range) for an explanation of the `slice` syntax).\n", "\n", "Here's an example of how to use discriminative learning rates (note that you don't actually need to manually call [`Learner.split`](/basic_train.html#Learner.split) in this case, since fastai uses this exact function as the default split for `resnet18`; this is just to show how to customize it):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# creates 3 layer groups\n", "learn.split(lambda m: (m[0][6], m[1]))\n", "# only randomly initialized head now trainable\n", "learn.freeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:08\n", "epoch train_loss valid_loss accuracy\n", "1 0.036884 0.023377 0.993621 (00:08)\n", "\n" ] } ], "source": [ "learn.fit_one_cycle(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:11\n", "epoch train_loss valid_loss accuracy\n", "1 0.025823 0.008318 0.997547 (00:11)\n", "\n" ] } ], "source": [ "# all layers now trainable\n", "learn.unfreeze()\n", "# optionally, separate LR and WD for each group\n", "learn.fit_one_cycle(1, max_lr=(1e-4, 1e-3, 1e-2), wd=(1e-4,1e-4,1e-1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

lr_range[source]

\n", "\n", "> lr_range(`lr`:`Union`\\[`float`, `slice`\\]) → `ndarray`\n", "\n", "Build differential learning rates. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.lr_range)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Rather than manually setting an LR for every group, it's often easier to use [`Learner.lr_range`](/basic_train.html#Learner.lr_range). This is a convenience method that returns one learning rate for each layer group. If you pass `slice(start,end)` then the first group's learning rate is `start`, the last is `end`, and the remaining are evenly geometrically spaced.\n", "\n", "If you pass just `slice(end)` then the last group's learning rate is `end`, and all the other groups are `end/10`. For instance (for our learner that has 3 layer groups):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([1.e-05, 1.e-04, 1.e-03]), array([0.0001, 0.0001, 0.001 ]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.lr_range(slice(1e-5,1e-3)), learn.lr_range(slice(1e-3))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

unfreeze[source]

\n", "\n", "> unfreeze()\n", "\n", "Unfreeze entire model. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.unfreeze)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sets every layer group to *trainable* (i.e. `requires_grad=True`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

freeze[source]

\n", "\n", "> freeze()\n", "\n", "Freeze up to last layer. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.freeze)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sets every layer group except the last to *untrainable* (i.e. `requires_grad=False`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

freeze_to[source]

\n", "\n", "> freeze_to(`n`:`int`)\n", "\n", "Freeze layers up to layer `n`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.freeze_to)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

split[source]

\n", "\n", "> split(`split_on`:`SplitFuncOrIdxList`)\n", "\n", "Split the model at `split_on`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.split)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A convenience method that sets `layer_groups` based on the result of [`split_model`](/torch_core.html#split_model). If `split_on` is a function, it calls that function and passes the result to [`split_model`](/torch_core.html#split_model) (see above for example)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Saving and loading models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Simply call [`Learner.save`](/basic_train.html#Learner.save) and [`Learner.load`](/basic_train.html#Learner.load) to save and load models. Only the parameters are saved, not the actual architecture (so you'll need to create your model in the same way before loading weights back in). Models are saved to the `path`/`model_dir` directory." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

load[source]

\n", "\n", "> load(`name`:`PathOrStr`, `device`:[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)=`None`, `strict`:`bool`=`True`)\n", "\n", "Load model `name` from `self.model_dir` using `device`, defaulting to `self.data.device`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.load)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

save[source]

\n", "\n", "> save(`name`:`PathOrStr`, `return_path`:`bool`=`False`) → `Union`\\[`NoneType`, `str`\\]\n", "\n", "Save model with `name` to `self.model_dir`, and return path if `return_path`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.save)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Segmentation model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

unet_learner[source]

\n", "\n", "> unet_learner(`data`:[`DataBunch`](/basic_data.html#DataBunch), `arch`:`Callable`, `pretrained`:`bool`=`True`, `all_wn`:`bool`=`False`, `blur_final`:`bool`=`True`, `split_on`:`Union`\\[`Callable`, `Collection`\\[`ModuleList`\\], `NoneType`\\]=`None`, `blur`:`bool`=`False`, `kwargs`:`Any`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(unet_learner, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Build a Unet [`Learner`](/basic_train.html#Learner) for segmentation tasks from [`data`](/vision.data.html#vision.data), using `arch` that may be `pretrained` if that flag is `True`. `split_on` will overwrite the default way the layers are split for differential learning rates. `kwargs` are passed to the [`Learner`](/basic_train.html#Learner) constructor." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Other methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

init[source]

\n", "\n", "> init(`init`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.init)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initializes all weights (except batchnorm) using function `init`, which will often be from PyTorch's [`nn.init`](https://pytorch.org/docs/stable/nn.html#torch-nn-init) module." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

mixup[source]

\n", "\n", "> mixup(`learn`:[`Learner`](/basic_train.html#Learner), `alpha`:`float`=`0.4`, `stack_x`:`bool`=`False`, `stack_y`:`bool`=`True`) → [`Learner`](/basic_train.html#Learner)\n", "\n", "Add mixup https://arxiv.org/abs/1710.09412 to `learn`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.mixup)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Uses [`MixUpCallback`](/callbacks.mixup.html#MixUpCallback)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

backward[source]

\n", "\n", "> backward(`item`)\n", "\n", "Pass `item` through the model and computes the gradient. Useful if `backward_hooks` are attached. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.backward)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

create_opt[source]

\n", "\n", "> create_opt(`lr`:`Floats`, `wd`:`Floats`=`0.0`)\n", "\n", "Create optimizer with `lr` learning rate and `wd` weight decay. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.create_opt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You generally won't need to call this yourself - it's used to create the [`optim`](https://pytorch.org/docs/stable/optim.html#module-torch.optim) optimizer before fitting the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

dl[source]

\n", "\n", "> dl(`ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=``)\n", "\n", "Return DataLoader for DatasetType `ds_type`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class Recorder[source]

\n", "\n", "> Recorder(`learn`:[`Learner`](/basic_train.html#Learner)) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "A [`LearnerCallback`](/basic_train.html#LearnerCallback) that records epoch, loss, opt and metric data during training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder, title_level=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A [`Learner`](/basic_train.html#Learner) creates a [`Recorder`](/basic_train.html#Recorder) object automatically - you do not need to explicitly pass to `callback_fns` - because other callbacks rely on it being available. It stores the smoothed loss, hyperparameter values, and metrics each batch, and provides plotting methods for each. Note that [`Learner`](/basic_train.html#Learner) automatically sets an attribute with the snake-cased name of each callback, so you can access this through `Learner.recorder`, as shown below." ] }, { "cell_type": "markdown", "metadata": { "hide_input": true }, "source": [ "### Plotting methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot[source]

\n", "\n", "> plot(`skip_start`:`int`=`10`, `skip_end`:`int`=`5`)\n", "\n", "Plot learning rate and losses, trimmed between `skip_start` and `skip_end`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.plot)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is mainly used with the learning rate finder, since it shows a scatterplot of loss vs learning rate." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEKCAYAAAD9xUlFAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xd4VFX+x/H3N52EkBAINQmhdwgQQGygoiK7wtph7bq6q2LXXXWb666ubV11sfxYe8Uu4qLYsSBI6L0jHUJNKOnn90fGGCMkAXJzM5PP63nmeWbunLnzPQyZz9x7zz3XnHOIiIgAhPldgIiI1B0KBRERKaNQEBGRMgoFEREpo1AQEZEyCgURESmjUBARkTIKBRERKeNpKJjZMDNbamYrzOy2Azzfxsw+NbN5ZvaFmaV4WY+IiFTOvDqj2czCgWXAycB6YAYw2jm3qFybN4D3nXPPm9mJwKXOuQsrW2/Tpk1denq6JzWLiISqmTNnbnPOJVfVLsLDGgYAK5xzqwDMbDwwElhUrk034KbA/c+Bd6taaXp6OllZWTVcqohIaDOz76vTzsvdR62BdeUerw8sK28ucGbg/hlAvJk1qbgiM7vSzLLMLCs7O9uTYkVExP8DzbcAg81sNjAY2AAUV2zknBvnnMt0zmUmJ1e59SMiIofJy91HG4DUco9TAsvKOOc2EthSMLOGwFnOuV0e1iQiIpXwckthBtDRzNqaWRQwCnivfAMza2pmP9RwO/CMh/WIiEgVPAsF51wRMAaYDCwGXnfOLTSzu8xsRKDZEGCpmS0DmgN3e1WPiIhUzbMhqV7JzMx0Gn0kInJozGymcy6zqnZ+H2gWEZE6xMsDzXIY8gqLeW3GOpxzNGsUQ7P4aJrFx5Ca1AAz87s8EQlxCgWPPfb5CjbvzqNvm0T6pDamTZPYg365F5c4bhg/hw8Xbv7Zc33TEnngnN60T27odckiUo8pFDw0fdV2Hpi8lMhw48VppScTNomL4qSuzbj9tK40josqa+uc40/vLuDDhZv50y+68qs+rdmak0/2nnxWbN3Do58uZ/gjX3HLKZ257Ni2hIdpq0FEap4ONHukuMQxYuzX7NxbwMc3DWbtjn3MXruLrDU7eG/uRhrHRXH/Wb04oUszAP710VL+89kKrh7Snt8P6/Kz9W3NyeOOdxbwyeIt9GvTmL+N6E73Vo20S0lEqqW6B5oVCh55fcY6fv/WPB4ZlcHIjJ/O7rFw425uem0uS7fkMqp/KulN47j3gyWcl5nKvWf1POgXvXOOCXM28tf3FrJ7fyFpSbEM7dqck7s1p396YyLCNW5ARA5MoeCj3LxCTnhwCmlJDXjrqqMP+CWfX1TMvz9ezrgvV1Li4JRuzXn8/L7V+mLfsbeADxZs4pNFW/hm5XYKikqICDPioiNoGB1BXHQ4DSLDKXaOwiJHYXEJxc6R2SaJczJTGNg2SVsYIvWMQsFH9324hCe+WMmEa46hd2pipW1nfr+DL5Zmc80JHYiJDD/k99qbX8RXy7cxb/0u9uYXsSe/mL35RewvLCYizIgMDyMyIozikhK+XLaNPflFpCXFcna/FEYPSCM5PvpwuykiQUSh4JO12/cx9KEp/LJ3Sx46N8Pvcn5if0ExHy7cxBtZ65m6cjvNG0XzwmUD6dwivkbfp6i4hGmrdvDBgk0s2ZxLWlIs7ZPjaJfckE7NG9I+uaG2VERqmULBJ797cSZTlmXz+S1DaJEQ43c5B7Vw424ufXYG+UUlPHNJJv3aJB3xOuev382L09bw8aIt7NxXSIPIcLq3asSGXfvZtDuvrF3P1glcfmxbhvdsSVTEj7vLnHNszc0noUHkYW01icjBKRRqWWFxCX+buJCXpq3lllM6MebEjn6XVKV1O/Zx4dPT2ZyTxxMX9OOEzs0Oaz3FJY4np6zkoY+X0SAynKFdmzGsR0sGd0qmQVTpl/ve/CJWb9vLrLU7eX7qGlZm76V5o2guGNiGsDBjzrpdzF23i625+bRMiGHchZn0TEmoye6K1GsKhVq0c28BV788i29Xbee3g9vx+1O7BM15BNm5+Vzy7Hcs3ZzLP8/sydn9Ug5p186m3fu58bU5TFu1g1/2asndZ/QkoUFkpa8pKXFMWZ7NM1+v5qvl2wBo1zSO3qmJdG0Zz/NTv2fbnnzuP7vXz0ZuicjhUSjUkmVbcrn8+RlsycnnvrN6ckafFL9LOmQ5eYVc+UIW01btYEB6Enf8oisZ5Q6Qr9+5j6e/Xs3bszbQMDqCtk3jSG8aS9OG0Tz7zZrSraQR3Q85UAA27NpPw6gIEmJ/DJJte/K5+qVZfLdmB78b3J5bT+0cNCErUlcpFDxWWFzCK9PXcv+HS4iNjmDchf3ok9bY77IOW1FxCeNnrOPhT5axbU8Bp/duxTn9Unhr1nren7cJA4b1aEF4mLFm215Wb9tLTl4RvVISeGRUH9o2javRegqKSrhz4kJemb6Wo9s3YcyJHRjUrskBQ2dfQRENIsN18FqkEgoFD32xdCv/+N9iVmzdw9Htm/Cvc3vTMqGBrzXVlD35RfzflJX896tV5BWW0DA6gtEDUrn0mLa0Svyxj845cvKKaBQT4emX8cvTv+fByUvZua+QTs0bctGgdIb3bMnCjbv5avk2vlyWzZLNuWS2aczNp3RmUPufXeJbRFAo1Jj8omI2785j0+48Nu3ez7uzNzJlWTbpTWK5Y3hXTu7WPCR/oW7encf01dsZ0rlZlccIvJZXWMx7czfy/NQ1LNyYU7Y8KjyMzPTG9GydwLtzNrAlJ59jOzTlllM7/2T3l4goFI7Yuh37uPmNuXy3esdPlsfHRHD9SR25aFD6T4ZTivecc8xau5Ovl2+nV2oCA9smERtVOqdjXmExL037nse/WMmOvQWMzGjF30Z0JzE2qoq1itQPCoUKnHPV/kU/eeFmbn1jLs7Bpcekk9YkjpYJMbRIiKF1YgONoa/D9uQXMW7KSh7/YiVNGkZx/9m9Gdwp2e+yRHynUKjgqa9W8dDHy0hsEElCbBSNYyNpHBdFl+bx9EhJoGfrBOJjIvjnpCU8N3UNvVISGDu6L2lNYj3ohXhtwYbd3PjaHJZv3cMFR6Vxx/CuZVsVIvWRQqGCaau288miLezaX8iufQXs2ldI9p58vt++r6xNXFQ4ewuKueyYtvzhtM5ER2iLIJjlFRbz4OSlPP3Nato2iWPsr/vSrVUjv8sS8YVCoZpy8wpZuDGHBRt2s2xLLqd2b8FJXZvX2PrFf1NXbuOG8XPYvb+QO0d0Z1T/1JAcHCBSGYWCSDnb9uRz42tz+Gr5NkZmtOKeM3oSF63dSVJ/VDcUNHxG6oWmDaN5/tIB3HJKJybO3cjpY79mUbnhrSJSSqEg9UZYmDHmxI68/Juj2JNXxK8e/4YXvl1DsG0ti3hJoSD1zqD2Tfjg+uM4pn0T/jJhIb99cSa79hX4XZZInaBQkHqpScNonr64P3/6RVc+X7qV4Y98xdLNuX6XJeI7hYLUW2Fhxm+Oa8dbVx1NUYnjomems27HvqpfKBLCFApS7/VKSeTFyweyv6CYC5+eTnZuvt8lifhGoSACdG4Rz7OXDmBLTulFh3LyCv0uScQXCgWRgH5tGvPEBX1ZujmX3zyfxf6CYr9LEql1CgWRcoZ0bsa/zu3NjDU76PePj7nsuRk8/fVqlm3J1dBVqRd0SqdIBSMzWpMcH80H8zfzzYptfLZkKwB90hIZd2EmyfHRPlco4h1NcyFShQ279vPp4i38c9ISmsZH8dylA2if3NDvskQOiaa5EKkhrRMbcNGgdF777VHsLyjmzMen/uziSyKhQqEgUk29UhJ55+pjaNowiguems47s9frOIOEHIWCyCFITYrlrauOJiMtkRtfm8vpY7/mf/M2UVyicJDQoFAQOUSJsVG8/JuB3H9WL/blF3PNK7MY+tAUxn+3lqLiEr/LEzkiCgWRwxAZHsa5/VP5+KbBPH5+X+Kiw7nt7fmcPvYbZq3d6Xd5IofN01Aws2FmttTMVpjZbQd4Ps3MPjez2WY2z8yGe1mPSE0LDzOG92zJxDHH8vj5fdm5t4CznpjK7W/P18yrEpQ8G5JqZuHAMuBkYD0wAxjtnFtUrs04YLZz7gkz6wZMcs6lV7ZeDUmVumxPfhEPf7yMZ6euIbFBJMN6tKBH6wR6tEqgU4uGuu63+Ka6Q1K9PHltALDCObcqUNB4YCSwqFwbB/xwJfUEYKOH9Yh4rmF0BH/6ZTfO7JvCA5OX8N6cjbw8fS0AEWHG+QPTuHNEd10jWuosL0OhNbCu3OP1wMAKbe4EPjKza4E4YKiH9YjUmm6tGvHspQMoKXGs27mPBRty+HTJFp7/9nsax0Vxw9BOfpcockB+T3MxGnjOOfcvMxsEvGhmPZxzPxnCYWZXAlcCpKWl+VCmyOEJCzPaNImjTZM4hvdsQZgZD3+ynJTGsZzdL8Xv8kR+xssDzRuA1HKPUwLLyrsceB3AOfctEAM0rbgi59w451ymcy4zOTnZo3JFvGVm3HNGT47t0JTb3prH18u3+V2SyM94GQozgI5m1tbMooBRwHsV2qwFTgIws66UhkK2hzWJ+CoqIozHL+hLh2YNueqlmSzZnON3SSI/4VkoOOeKgDHAZGAx8LpzbqGZ3WVmIwLNbgauMLO5wKvAJU7zBkiIaxQTybOX9icuOoLLnp3Bzr0auip1h2ZJFfHJ/PW7OeuJqRzXsSlPXZypEUniKc2SKlLH9UxJ4I7hXfh0yVae+mq13+WIAAoFEV9dfHQ6p3Zvzn0fLtH0GFInKBREfGRm3H9Wb1okxHDtK7PZva/Q75KknlMoiPgsITaS/4zuw5acPG55c66u0SC+UiiI1AF90hpz22ld+HjRFh7/YqXf5Ug9plAQqSMuP7YtIzNa8cDkpXy0cLPf5Ug9pVAQqSPMjPvO6kWvlARufG0OSzfn+l2S1EMKBZE6JCYynHEXZhIbHcFvXtCJbVL7FAoidUyLhBjGXdiPLTn5XP3yLAp1iU+pRQoFkTqoT1pj7j2zJ9+u2s6f312gEUlSa/yeOltEDuLMvimsyt7L2M9XkNK4AWNO7Oh3SVIPKBRE6rCbT+nExl37efCjZbRKbMCZfXUNBvGWQkGkDjMz7j2rF5tz8vj9m/No3iiGYzr87JIjIjVGxxRE6rioiDCevLAf7ZMb8rsXZ7Joo67BIN5RKIgEgfLXYDjnyal8MH+T3yVJiFIoiASJVokNeOeao+nUIp6rXp7FvR8soUjDVaWGKRREgkjLhAaMv/Iozh+YxpNTVnLxs9+xQye4SQ1SKIgEmeiIcO4+oyf3n9WLGWt2MnrcNAqKtMUgNUOhIBKkzu2fyuO/7svSLbn83xTNrCo1Q6EgEsSGdmvOL3q15D+frWBl9h6/y5EQoFAQCXJ/Pb0b0ZFh3PH2fE2HIUdMoSAS5JrFx3DH8K5MX72DN7LW+12OBDmFgkgIOC8zlQHpSdw9aTHZufl+lyNBTKEgEgLCwox7zuzB/oJi7np/kd/lSBBTKIiEiA7N4rnmhA5MnLuRJzUaSQ6TJsQTCSFjTuzA8q253PvBEpJiozi3f6rfJUmQUSiIhJDwMOOhczPYvb+Q296eR0JsJKd2b+F3WRJEtPtIJMRERYTxfxf2o3dqIte+OptvV273uyQJIgoFkRAUGxXBs5f0p01SLFe8kMU3K7b5XZIECYWCSIhKjI3ixcsH0ioxhgufns5/v1ylk9ukSgoFkRDWIiGGd64+hmE9WnD3pMVcN34O+wqK/C5L6jCFgkiIi4uO4LFf9+UPw7rwv3kbOfPxqWzYtd/vsqSOUiiI1ANmxlVD2vPcpQNYv3M/d7w93++SpI5SKIjUI8d3SuaGoR2ZsiybL5dl+12O1EEKBZF65sJBbUhNasA9kxZTXKIDz/JTCgWReiY6Ipw/DOvCks25vD1Ls6rKTykUROqhX/RsSUZqIg9+tJT9BcV+lyN1iEJBpB4yM/70i65sycnnqa9W+V2O1CGehoKZDTOzpWa2wsxuO8Dz/zazOYHbMjPb5WU9IvKjzPQkTuvRgiemrGRrbp7f5Ugd4VkomFk48BhwGtANGG1m3cq3cc7d6JzLcM5lAP8B3vaqHhH5ud8P60JBUQn//ni536VIHeHllsIAYIVzbpVzrgAYD4yspP1o4FUP6xGRCto2jeOiQemMn7GWBRt2+12O1AFehkJrYF25x+sDy37GzNoAbYHPPKxHRA7g+qEdSYqN4s73FmpuJKkzB5pHAW865w44DMLMrjSzLDPLys7WCTciNSmhQSR/GNaFrO93MmHORr/LEZ95GQobgPKXfUoJLDuQUVSy68g5N845l+mcy0xOTq7BEkUE4Ox+KfRKSeCeSYvZk68J8+ozL0NhBtDRzNqaWRSlX/zvVWxkZl2AxsC3HtYiIpUICzP+NqI7W3PzGfvZCr/LER9VKxTMrL2ZRQfuDzGz68wssbLXOOeKgDHAZGAx8LpzbqGZ3WVmI8o1HQWMd9qZKeKrPmmNObtfCk9/vYrV2/b6XY74xKrzXWxmc4BMIB2YBEwAujvnhnta3QFkZma6rKys2n5bkXpha24eJz44hf7pjXn20gF+lyM1yMxmOucyq2pX3d1HJYFf/mcA/3HO3Qq0PJICRaTuaRYfw3UndeDzpdm6hGc9Vd1QKDSz0cDFwPuBZZHelCQifrpoUDqtExtw7wdLKNEsqvVOdUPhUmAQcLdzbrWZtQVe9K4sEfFLTGQ4N5/SifkbdvP+/E1+lyO1rFqh4Jxb5Jy7zjn3qpk1BuKdc/d5XJuI+GRkRmu6tIjnwclLKSgq8bscqUXVHX30hZk1MrMkYBbwXzN7yNvSRMQv4WHGbad1Ye2Ofbw8/Xu/y5FaVN3dRwnOuRzgTOAF59xAYKh3ZYmI3wZ3SmZQuyb857MV5OYV+l2O1JLqhkKEmbUEzuXHA80iEsLMjNuHd2HH3gLGfalrLtQX1Q2Fuyg9CW2lc26GmbUDNNeuSIjrlZLIL3u15KmvVrN5t665UB9U90DzG865Xs65qwKPVznnzvK2NBGpC35/aheKnePuSYv9LkVqQXUPNKeY2TtmtjVwe8vMUrwuTkT8l9YklqsGt2fi3I1M1QltIa+6u4+epXQyu1aB28TAMhGpB64a0p7UpAb85b2FGqIa4qobCsnOuWedc0WB23OA5rAWqSdiIsP524jurNi6h2e+We13OeKh6obCdjO7wMzCA7cLgO1eFiYidcuJXZoztGtzHv10ORt37fe7HPFIdUPhMkqHo24GNgFnA5d4VJOI1FF/Pb0bxSWOu/+ng86hqrqjj753zo1wziU755o5534FaPSRSD2TmhTLmBM68L/5m5iyTJfGDUVHcuW1m2qsChEJGlcc3472yXHc/tY8nekcgo4kFKzGqhCRoBETGc4D5/Rmc04e90xa4nc5UsOOJBQ00bpIPdU3rTFXHNeOV79by1fLtRsplFQaCmaWa2Y5B7jlUnq+gojUUzee3In2yXH84U3tRgollYaCcy7eOdfoALd451xEbRUpInXPT3cjaTRSqDiS3UciUs/9uBtpHV9qNFJIUCiIyBH5YTfSbW/NI0e7kYKeQkFEjkhMZDgP/rAbSSe1BT2FgogcsT5pjbni+HaMn6HdSMFOoSAiNeLGoZ3o0KyhdiMFOYWCiNQI7UYKDQoFEakxGamJXHl8e8bPWKe5kYKUQkFEatQNQzvSsVlDbnljLpt2a4rtYKNQEJEaFRMZzthf92V/QTGXP5fF3vwiv0uSQ6BQEJEa17lFPGN/3Yclm3O4fvxsiks0VVqwUCiIiCeGdG7GnSO688nirfxT02AEDc1fJCKeuWhQOquy9/LU16tpmxzH+QPb+F2SVEFbCiLiqT//shsndE7mLxMWMmvtTr/LkSooFETEU+FhxiOj+9AyIYbrx8/WNNt1nEJBRDzXKCaSh8/LYMPO/fx1wkK/y5FKKBREpFZkpidx7YkdeXv2BibM2eB3OXIQCgURqTXXntiBfm0a86d3FrBuxz6/y5EDUCiISK2JCA/j4fMyALjxtTkUFZf4XJFU5GkomNkwM1tqZivM7LaDtDnXzBaZ2UIze8XLekTEf6lJsfz9Vz3I+n4nj32+0u9ypALPQsHMwoHHgNOAbsBoM+tWoU1H4HbgGOdcd+AGr+oRkbrjV31aMzKjFY9+tpw563b5XY6U4+WWwgBghXNulXOuABgPjKzQ5grgMefcTgDn3FYP6xGROuSukT1oHh/NDeNna36kOsTLUGgNrCv3eH1gWXmdgE5m9o2ZTTOzYR7WIyJ1SEKDSB46L4Pvd+zjH/9b5Hc5EuD3geYIoCMwBBgN/NfMEis2MrMrzSzLzLKyszVHu0ioOKpdE357fHte/W4dHy3c7Hc5grehsAFILfc4JbCsvPXAe865QufcamAZpSHxE865cc65TOdcZnJysmcFi0jtu+nkTnRr2Yjb3p7P1tw8v8up97wMhRlARzNra2ZRwCjgvQpt3qV0KwEza0rp7qRVHtYkInVMVEQYj47OYG9+EWNemU1BkYap+smzUHDOFQFjgMnAYuB159xCM7vLzEYEmk0GtpvZIuBz4Fbn3HavahKRuqlDs3juO6sX363ewd8mahoMP3k6dbZzbhIwqcKyv5S774CbAjcRqcd+1ac1Szbn8uSUlXRpEc+Fg9L9Lqle8vtAs4hImVtP7cxJXZpx58RFTF2xze9y6iWFgojUGeFhxsOjMmjXNI6rX5nF99v3+l1SvaNQEJE6JT4mkqcuzgTgyhdmsr+g2OeK6heFgojUOW2axPHoqD4s25rLnycs8LucekWhICJ10vGdkrn2xI68OXM9r2etq/oFUiMUCiJSZ11/UkeO6dCEP7+7gMWbcvwup15QKIhInRUeZjx8Xh8SGkRyzcuzdH3nWqBQEJE6LTk+mv+M7sOa7Xu5/e35lJ7eJF5RKIhInTewXRNuPqUz78/bxMR5m/wuJ6QpFEQkKPxucHt6pyby1wkL2LYn3+9yQpZCQUSCQniY8eDZvdibX8xfJ2h+JK8oFEQkaHRsHs/1Qzvyv/mb+GC+diN5QaEgIkHlyuPb0aN1I/48YQE79xb4XU7IUSiISFCJDA/jgbN7s2tfoabZ9oBCQUSCTteWjbjmhA68O2cjb85c73c5IUWhICJBacyJHTimQxNuf3se01bp2lw1RaEgIkEpMjyMx3/dj7SkWH774kxWZe/xu6SQoFAQkaCVEBvJs5cMIDzMuOy5GTrwXAMUCiIS1NKaxPLfi/qxcXcev31xJvlFuv7CkVAoiEjQ69cmiQfO7sV3a3bwz0lL/C4nqCkURCQkjMxozcWD2vDc1DV8u1IHng+XQkFEQsYfTutCepNYbn1zLnvyi/wuJygpFEQkZMRGRfDgOb3ZsGs/90xa7Hc5QUmhICIhJTM9iSuOa8cr09cyZVm23+UEHYWCiIScm07uRIdmDfnDm/PYvV9XazsUCgURCTkxkeH865zeZO/J5y8TFuhqbYdAoSAiIal3aiI3nNSRCXM28tL0tX6XEzQUCiISsq45oQNDOidz18SFzFm3y+9ygoJCQURCVliY8fB5GTSLj+Hql2ayQ9NgVEmhICIhLTE2iicv6Me2PQVcP342xSU6vlAZhYKIhLyeKQncOaI7Xy3fxu1vz2Pi3I3MWLODdTv2UVBU4nd5dUqE3wWIiNSG0QNSWbhxNy9PX8vrWT9emCc+JoJnLulP//QkH6urOyzYhmplZma6rKwsv8sQkSC1e18hm3Py2JyTx5bdeTw5ZSXb9uTzxu+OpnOLeL/L84yZzXTOZVbVTruPRKReSYiNpHOLeAZ3Subc/qk8f9kAYiLDueiZ6azfuc/v8nynUBCRei01KZbnLxvAvoJiLnrmu3o/QkmhICL1XteWjXj64v6s37mfS5+bwb6C+jvDqkJBRAQY0DaJsaP7MH/9Lm5+fS4l9XToqkJBRCTglO4tuGN4Vz5YsJlHP1vudzm+0JBUEZFyLj+2LYs35fLwJ8vp0iKeYT1alj1XUuJ4f/4mIsOM03q2rGQtwcvTLQUzG2ZmS81shZnddoDnLzGzbDObE7j9xst6RESqYmbcfUYPMlITufG1uSzamAPAtFXbGfnYN1z36myueWUWWWt2+FypNzw7T8HMwoFlwMnAemAGMNo5t6hcm0uATOfcmOquV+cpiEht2JqTx+ljvyYiLIyuLRvxyeIttEqI4YahnRj7+QqKSxyTrj+OhAaRfpdaLXXhPIUBwArn3CrnXAEwHhjp4fuJiNSYZo1iGHdhJtl78pm2aju3ntqZz24Zwrn9U3lkVAabc/L44zvzQ+5aDV4eU2gNrCv3eD0w8ADtzjKz4yndqrjRObeuYgMzuxK4EiAtLc2DUkVEfq53aiKTbzieRjERNGkYXba8T1pjbhzakQc/WsaQzs04u1+Kj1XWLL9HH00E0p1zvYCPgecP1Mg5N845l+mcy0xOTq7VAkWkfmvbNO4ngfCDq4Z0YEDbJP4yYQFrtu31oTJveBkKG4DUco9TAsvKOOe2O+fyAw+fAvp5WI+ISI0JD1yrISLMuPbV2ezJD40T3rwMhRlARzNra2ZRwCjgvfINzKz8mK4RwGIP6xERqVGtEhvw4Dm9WbQph/Ofms7OEJgiw7NQcM4VAWOAyZR+2b/unFtoZneZ2YhAs+vMbKGZzQWuAy7xqh4RES+c0r0FT5zfl8Wbcjhv3Ldsycnzu6QjoqmzRURqwNSV27ji+SySGkbx0uUDaRYfw+y1O5m2ajuz1+3iwqPacEr3Fr7VV90hqQoFEZEaMnfdLi5+9juKih0FRSUUFJcQZtCoQSRFxY73rz2W9KZxvtSmUBAR8cHyLaVTZKQ0bsBR7ZrQL70xuXlFDH/kK1KTGvDWVUcTHRFe63UpFERE6pCPF23hiheyuOTodO4c0b3W378unNEsIiIBJ3drzqXHpPPc1DVMXrjZ73IOSqEgIlJLbjutCz1bJ3DrG3Pr7KU/FQoiIrUkOiKcsb/uQ4mDc578lrdmrq/WxXxKShwPTl5aK8NdFQoiIrWoTZM4Xrh8AE0bRnPzG3MRGLqBAAAIoklEQVQ5fezXfLNi20HbO+e46/1FjP18BR/Vwm4nhYKISC3rm9aYCdccwyOjMti1r5Dzn5rO5c/NOOCWwGOfr+C5qWv4zbFtueCoNp7XplAQEfFBWJgxMqM1n948mNtP68I3K7dxyr+/5P15G8vavPrdWh78aBln9GnNHcO7Ymae16XLcYqI+CgmMpzfDm7Pyd2ac+Prcxnzymw+XrSF4zsm88d35jO4UzL3n92LsDDvAwF0noKISJ1RVFzC41+s5JFPl1Nc4shITeSVKwYSG3Xkv9+re56CthREROqIiPAwrjupI0M6J/P2rA1cf1LHGgmEQ6qhVt9NRESq1CslkV4pib68tw40i4hIGYWCiIiUUSiIiEgZhYKIiJRRKIiISBmFgoiIlFEoiIhIGYWCiIiUCbppLswsG/i+wuIEYHcVyyp7/MP98suaAgefz7ZqB6rpUNqoT9W7fyR9qk5/KmtXnf5UXFbV/dr4jCprF6x9qo2/pfL3g7FPHZ1zCVVW5pwL+hswrqpllT3+4X6FZVk1XdOhtFGfqn3/sPtUnf5U1q46/TnUPtXGZxSKfaqNv6VQ6lNlt1DZfTSxGssqezzxIG2ORHXWVVkb9al6949EdddzsHbV6U/FZerToasrf0vVraU6/O7TQQXd7qPaYmZZrhozCgYT9anuC7X+gPoUbEJlS8EL4/wuwAPqU90Xav0B9SmoaEtBRETKaEtBRETK1ItQMLNnzGyrmS04jNf2M7P5ZrbCzB61chdJNbNrzWyJmS00s/trtupKa6rx/pjZnWa2wczmBG7Da77ySuvy5DMKPH+zmTkza1pzFVerLi8+p7+b2bzAZ/SRmbWq+corrcuLPj0Q+DuaZ2bvmFmtXkjAoz6dE/heKDGz4Dr2cCTDqoLlBhwP9AUWHMZrvwOOAgz4ADgtsPwE4BMgOvC4WZD3507gllD6jALPpQKTKT23pWmw9wloVK7NdcCTIdCnU4CIwP37gPtCoE9dgc7AF0BmbfbnSG/1YkvBOfclsKP8MjNrb2YfmtlMM/vKzLpUfJ2ZtaT0j3CaK/2kXwB+FXj6KuBe51x+4D22etuLH3nUH1952Kd/A78Hav3gmRd9cs7llGsaRy33y6M+feScKwo0nQakeNuLn/KoT4udc0tro/6aVi9C4SDGAdc65/oBtwCPH6BNa2B9ucfrA8sAOgHHmdl0M5tiZv09rbZqR9ofgDGBTfhnzKyxd6VW2xH1ycxGAhucc3O9LvQQHPHnZGZ3m9k64HzgLx7WWl018X/vB5dR+ovbbzXZp6BSL6/RbGYNgaOBN8rtfo4+xNVEAEmUbjr2B143s3aBXwy1qob68wTwd0p/ef4d+Belf6C+ONI+mVkscAeluybqhBr6nHDO/RH4o5ndDowB/lpjRR6imupTYF1/BIqAl2umusNTk30KRvUyFCjdQtrlnMsov9DMwoGZgYfvUfpFWX5TNgXYELi/Hng7EALfmVkJpfOhZHtZ+EEccX+cc1vKve6/wPteFlwNR9qn9kBbYG7gDzsFmGVmA5xzmz2u/WBq4v9deS8Dk/AxFKihPpnZJcAvgZP8+GFVQU1/TsHF74MatXUD0il3IAmYCpwTuG9A74O8ruKBpOGB5b8D7grc7wSsI3DeR5D2p2W5NjcC44P9M6rQZg21fKDZo8+pY7k21wJvhkCfhgGLgOTa7ovX//cIwgPNvhdQSx/4q8AmoJDSX/iXU/or8kNgbuA/5F8O8tpMYAGwEhj7wxc/EAW8FHhuFnBikPfnRWA+MI/SX0Eta6s/XvWpQptaDwWPPqe3AsvnUTqXTesQ6NMKSn9UzQncantElRd9OiOwrnxgCzC5Nvt0JDed0SwiImXq8+gjERGpQKEgIiJlFAoiIlJGoSAiImUUCiIiUkahIEHPzPbU8vs9ZWbdamhdxYEZTxeY2cSqZgg1s0Qzu7om3lvkQDQkVYKeme1xzjWswfVFuB8naPNU+drN7HlgmXPu7krapwPvO+d61EZ9Uv9oS0FCkpklm9lbZjYjcDsmsHyAmX1rZrPNbKqZdQ4sv8TM3jOzz4BPzWyImX1hZm8G5vp/udxc+V/8MEe+me0JTFA318ymmVnzwPL2gcfzzewf1dya+ZYfJ/NraGafmtmswDpGBtrcC7QPbF08EGh7a6CP88zsbzX4zyj1kEJBQtUjwL+dc/2Bs4CnAsuXAMc55/pQOsPoPeVe0xc42zk3OPC4D3AD0A1oBxxzgPeJA6Y553oDXwJXlHv/R5xzPfnpTJoHFJhX5yRKzyYHyAPOcM71pfTaHf8KhNJtwErnXIZz7lYzOwXoCAwAMoB+ZnZ8Ve8ncjD1dUI8CX1DgW7lZrlsFJj9MgF43sw6UjojbGS513zsnCs/r/53zrn1AGY2h9L5cb6u8D4F/Dh54Ezg5MD9Qfx4XYdXgAcPUmeDwLpbA4uBjwPLDbgn8AVfEni++QFef0rgNjvwuCGlIfHlQd5PpFIKBQlVYcBRzrm88gvNbCzwuXPujMD++S/KPb23wjryy90v5sB/L4XuxwNzB2tTmf3OuYzAVN+TgWuARym9VkIy0M85V2hma4CYA7zegH865/7vEN9X5IC0+0hC1UeUziIKgJn9MA1yAj9Ob3yJh+8/jdLdVgCjqmrsnNtH6eU1bzazCErr3BoIhBOANoGmuUB8uZdOBi4LbAVhZq3NrFkN9UHqIYWChIJYM1tf7nYTpV+wmYGDr4soneoc4H7gn2Y2G2+3lG8AbjKzeUAHYHdVL3DOzaZ09tPRlF4rIdPM5gMXUXosBOfcduCbwBDWB5xzH1G6e+rbQNs3+WloiBwSDUkV8UBgd9B+55wzs1HAaOfcyKpeJ+I3HVMQ8UY/YGxgxNAufLy0qcih0JaCiIiU0TEFEREpo1AQEZEyCgURESmjUBARkTIKBRERKaNQEBGRMv8PqUdaE0MCJTwAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = create_cnn(data, models.resnet18, metrics=accuracy)\n", "learn.lr_find()\n", "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_losses[source]

\n", "\n", "> plot_losses(`last`:`int`=`None`)\n", "\n", "Plot training and validation losses. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.plot_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that validation losses are only calculated once per epoch, whereas training losses are calculated after every batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:15\n", "epoch train_loss valid_loss accuracy\n", "1 0.110780 0.047797 0.981845 (00:07)\n", "2 0.046339 0.038317 0.987733 (00:07)\n", "\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(2)\n", "learn.recorder.plot_losses()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_lr[source]

\n", "\n", "> plot_lr(`show_moms`=`False`)\n", "\n", "Plot learning rate, `show_moms` to include momentum. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.plot_lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr(show_moms=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_metrics[source]

\n", "\n", "> plot_metrics()\n", "\n", "Plot metrics collected during training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.plot_metrics)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that metrics are only collected at the end of each epoch, so you'll need to train at least two epochs to have anything to show here." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYQAAAD8CAYAAAB3u9PLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3Xl4VPX5/vH3Q0hYw74vYd8CRMQAAipKXXAXsK3WutRWbKvfLv6sgFtxK0itrVZbixYrbd0KiEEFRAVBRSWoJCRsYSfsIPua5Pn9MSd2pCgBJjOT5H5dVy7PnPM5M885THLPWebR3B0REZFKsS5ARETigwJBREQABYKIiAQUCCIiAigQREQkoEAQERFAgSAiIoESBYKZDTKzpWaWZ2YjjrG8lZm9a2ZZZjbbzFqELRtrZjlmttjMnjQzC+Zfa2bZwTrTzaxB5DZLRERO1HEDwcwSgKeBi4FU4FozSz1q2GPABHdPAx4ERgfr9gP6A2lAN6AXMMDMKgNPAOcF62QBt0dki0RE5KRULsGY3kCeu68EMLOXgSuB3LAxqcAdwfQsYEow7UBVIAkwIBHYHEwbUMPMtgO1gLzjFdKgQQNv3bp1CUoWEZFiCxYs2ObuDY83riSB0BxYF/Z4PdDnqDELgSGEPvUPBpLNrL67zzOzWcBGQgHwlLsvBjCznwHZwD5gOXDbsV7czIYBwwBSUlLIzMwsQckiIlLMzNaUZFykLirfSehU0OfAACAfKDSz9kAXoAWhYBloZmebWSLwM+B0oBmhU0Yjj/XE7j7O3dPdPb1hw+MGnIiInKSSHCHkAy3DHrcI5n3F3TcQOkLAzGoCQ919p5ndAnzs7nuDZdOAvsDBYL0VwfxXgf+5WC0iItFTkiOE+UAHM2tjZknANUBG+AAza2Bmxc81EhgfTK8luIgcHBUMABYTCpRUMyv+yH9BMF9ERGLkuEcI7l5gZrcDM4AEYLy755jZg0Cmu2cA5wKjzcyBOfz3esBEYCChawUOTHf3qQBm9gAwx8yOAGuAmyK5YSIicmKsLP3/ENLT010XlUVEToyZLXD39OON0zeVRUQEUCCIiEhAgSAiEseWbtrD2OlLiMbp/ZLcdioiIlF2uKCIv8zO4+lZeSRXTeSHZ7aiWZ1qpfqaCgQRkTizcN1O7pqYxdLNe7iyRzPuvyyV+jWrlPrrKhBEROLEgcOFPD5zKX//YBWNkqvy9xvT+U6XxlF7fQWCiEgc+GjFNkZMymbtjv38oE8KIy7uTK2qiVGtQYEgIhJDuw8eYfRbS3jp07W0ql+dl245k77t6sekFgWCiEiMvJO7mXumZLN1zyGGndOWX5/fkWpJCTGrR4EgIhJl2/ce4oGpuWQs3EDnJsmMuz6d01rWiXVZCgQRkWhxdzIWbmBURg57DxVwxwUd+emAdiRVjo+vhCkQRESiYMPOA9w7ZRHvLdlCj5Z1GHt1Gh0bJ8e6rK9RIIiIlKKiIuel+WsZ/dYSCouc+y5L5aZ+rUmoZLEu7X8oEERESsmqbfsYMSmLT1btoH/7+owenEZK/eqxLusbKRBERCKsoLCI8R+u4g9vLyOpciUeHdqd76W3xCz+jgrCKRBERCJo8cbdDJ+URdb6XVyQ2piHr+pG41pVY11WiSgQREQi4FBBIU+/l8dfZq+gTvVEnv5BTy7p3iTujwrCKRBERE7RZ2u/ZPjELJZv2cuQ05tz32Wp1K2RFOuyTpgCQUTkJO0/XMBjM5bx/EeraFqrKs//qBfndWoU67JOmgJBROQkfJi3jRGTs1i34wDXn9mKuwZ1IjnKzegiTYEgInICdh04wu/eXMwrmeto06AGrww7kz5tY9OMLtIUCCIiJfR2zibunbKI7fsO89MB7fjV+R2omhi7ZnSRpkAQETmOrXsOMWpqDm9mbaRL01r8/cZedG9RO9ZlRZwCQUTkG7g7r32ez4Nv5LL/UCG/uagTw85pS2JCfDSji7QSBYKZDQKeABKA59x9zFHLWwHjgYbADuCH7r4+WDYWuBSoBMwEfgnUBOaGPUUL4F/u/qtT2hoRkQjJ33mAe17LZvbSrfRMCTWja98ovprRRdpxA8HMEoCngQuA9cB8M8tw99ywYY8BE9z9BTMbCIwGrjezfkB/IC0Y9wEwwN1nAz3CXmMBMDkC2yMickqKipx/f7KGMdOW4MCoy1O5vm98NqOLtJIcIfQG8tx9JYCZvQxcCYQHQipwRzA9C5gSTDtQFUgCDEgENoc/uZl1BBrx9SMGEZGoW7l1LyMmZfPp6h2c3aEBvxvcnZb14rcZXaSVJBCaA+vCHq8H+hw1ZiEwhNBppcFAspnVd/d5ZjYL2EgoEJ5y98VHrXsN8Iq7+8lsgIjIqSooLOLZuav44zvLqFq5Er+/Oo2rz2hRptpOREKkLirfCTxlZjcBc4B8oNDM2gNdCF0jAJhpZme7e/jRwDXA9d/0xGY2DBgGkJKSEqFyRURCcjbsYvikLBbl72ZQ1yY8eFVXGiWXjWZ0kVaSQMgHWoY9bhHM+4q7byB0hICZ1QSGuvtOM7sF+Njd9wbLpgF9CU4PmdlpQGV3X/BNL+7u44BxAOnp6TqKEJGIOHikkD+/t5xn3l9J3epJ/PW6nlzcvWmsy4qpktw7NR/oYGZtzCyJ0Cf6jPABZtbAzIqfayShO44A1gIDzKyymSUCA4DwU0bXAi+dygaIiJyoBWt2cOmTc3l61gqu6tGcd+44p8KHAZTgCMHdC8zsdmAGodtOx7t7jpk9CGS6ewZwLjDazJzQKaPbgtUnAgOBbEIXmKe7+9Swp/8ecEmkNkZE5NvsO1TA72cs5YV5q2lWuxov3NybAR0bxrqsuGFl6Vpuenq6Z2ZmxroMESmD5izbysjJ2WzYdYAb+7bmzos6UbNKxfhurpktcPf0442rGHtDRCqsnfsP8/Cbi5m4YD1tG9bgP7f2Jb11vViXFZcUCCJSbk3L3sh9r+fw5f7D3HZeO/5vYPlqRhdpCgQRKXe27DnIb1/PYdqiTXRtVosXbu5F12blrxldpCkQRKTccHcmLljPw28u5sCRQu4a1Ilbzi6/zegiTYEgIuXCuh37ufu1bOYu30av1nUZMzSNdg1rxrqsMkWBICJlWlGRM2HeasbOWIoBD13Zlev6tKJSBWhGF2kKBBEps/K27GH4pGwWrPmSAR0b8sjgbrSoW3Ga0UWaAkFEypwjhUWMm7OSJ95ZTvUqCTz+vdMYfHrzCteMLtIUCCJSpizK38VdE7PI3bibS7s3ZdQVXWmYXCXWZZULCgQRKRMOHinkiXeXM27OSurVSOKZH57BoG5NYl1WuaJAEJG4N3/1DoZPzGLltn18L70F91ySSu3qibEuq9xRIIhI3Np7qICx05cwYd4aWtStxr9+3IezOjSIdVnllgJBROLSrKVbuGdyNht3H+Tm/m2486KOVE/Sn6zSpL0rInHly32HeeiNXCZ/nk/7RjWZ+NN+nNGqbqzLqhAUCCISF9ydt7I38duMRezcf4RfDGzPbQPbU6WymtFFiwJBRGJuy+6D3DtlEW/nbqZ789pMuLkPqc1qxbqsCkeBICIx4+78J3M9D72Zy+GCIkZe3Jkfn9WGympGFxMKBBGJibXb9zPytSw+zNtO7zb1eHRoGm0a1Ih1WRWaAkFEoqqwyPnHR6t5bMZSEioZD1/VjR/0TlEzujigQBCRqFm+eQ93Tcri87U7Oa9TQx4Z3J1mdarFuiwJKBBEpNQdLijimfdX8NR7edSoksCfvt+DK3s0UzO6OKNAEJFSlbV+J3dNzGLJpj1cflozfnt5Kg1qqhldPFIgiEipOHC4kD+9s4xn566kYXIVnr0hnQtSG8e6LPkWCgQRibiPV25nxKQsVm/fz7W9WzLyki7UqqpmdPFOgSAiEbPn4BHGTFvCvz9ZS0q96rz4kz70a69mdGVFib79YWaDzGypmeWZ2YhjLG9lZu+aWZaZzTazFmHLxppZjpktNrMnLbiKZGZJZjbOzJaZ2RIzGxq5zRKRaHtvyWYu/OMcXvp0LT85qw0zfnWOwqCMOe4RgpklAE8DFwDrgflmluHuuWHDHgMmuPsLZjYQGA1cb2b9gP5AWjDuA2AAMBu4B9ji7h3NrBJQL0LbJCJRtGPfYR6cmsOULzbQsXFN/nJdP05PUTO6sqgkp4x6A3nuvhLAzF4GrgTCAyEVuCOYngVMCaYdqAokAQYkApuDZTcDnQHcvQjYdtJbISJR5+5MzdrIqIwc9hw8wi+/04HbzmtPUmW1nSirSvIv1xxYF/Z4fTAv3EJgSDA9GEg2s/ruPo9QQGwMfma4+2IzqxOMfcjMPjOz/5jZMW8/MLNhZpZpZplbt24t4WaJSGnatOsgt0xYwC9e+pyWdasx9f/O4tcXdFQYlHGR+te7ExhgZp8TOiWUDxSaWXugC9CCUIgMNLOzCR2ZtAA+cveewDxCp53+h7uPc/d0d09v2LBhhMoVkZPh7rz06VouePx9Psjbyr2XdmHyz/vTuYk6k5YHJTlllA+0DHvcIpj3FXffQHCEYGY1gaHuvtPMbgE+dve9wbJpQF9C1xL2A5ODp/gP8ONT2A4RKWVrtu9jxKRs5q3cTt+29RkztDut6qsZXXlSkiOE+UAHM2tjZknANUBG+AAzaxBcGAYYCYwPptcSOnKobGaJhI4eFru7A1OBc4Nx3+Hr1yREJE4UFjnPzV3JRX+aw6L8XYwe0p0Xb+mjMCiHjnuE4O4FZnY7MANIAMa7e46ZPQhkunsGoT/so83MgTnAbcHqE4GBQDahC8zT3X1qsGw48E8z+xOwFfhR5DZLRCJh6aZQM7qF63ZyfpdGPHxVd5rUrhrrsqSUWOjDetmQnp7umZmZsS5DpNw7XFDE07Py+MvsPJKrJjLqiq5cntZUzejKKDNb4O7pxxunbyqLyNd8sW4nd01cyLLNe7mqRzPuv7wr9WokxbosiQIFgogAoWZ0f3h7KeM/XEXjWlUZf1M6AzurGV1FokAQET5asY0Rk7JZu2M/1/VJYcTFnUlWM7oKR4EgUoHtPniE0W8t5qVP19G6fnVeHnYmZ7atH+uyJEYUCCIV1Du5m7lnSjZb9xzi1nPa8qvzO1ItKSHWZUkMKRBEKphtew/xwNRcpi7cQOcmyTx7QzppLeocf0Up9xQIIhWEu/P6Fxt4YGoOew8VcMcFHfnpgHbqPyRfUSCIVAAbdh7g3imLeG/JFk5PqcPYoWl0aJwc67IkzigQRMqxoiLnxU/XMmbaEgqLnPsvS+XGfq1JqKQvmMn/UiCIlFOrtu1jxKQsPlm1g/7t6zN6cBop9avHuiyJYwoEkXKmoLCIv3+wisdnLiOpciXGDk3ju+kt1HZCjkuBIFKO5G7YzfBJWWTn7+LC1MY8dFU3GtdSMzopGQWCSDlwqKCQp97L46+zV1CneiJP/6Anl3RvoqMCOSEKBJEybsGaLxk+KYu8LXsZ0rM5912aSl01o5OToEAQKaP2Hy7g9zOW8o+PVtO0VlWe/1EvzuvUKNZlSRmmQBApgz5Yvo0Rk7NY/+UBbujbirsGdaZmFf06y6nRO0ikDNm1/wiPvJXLq5nradOgBq/e2pfeberFuiwpJxQIImXE9EWbuO/1RezYd5ifnduOX36nA1UT1YxOIkeBIBLntu45xKiMHN7M3khq01o8f1MvujWvHeuypBxSIIjEKXdn8mf5PPhGLgcOF/Kbizox7Jy2JCaoGZ2UDgWCSBzK33mAuydn8/6yrZzRqi6PDk2jfaOasS5LyjkFgkgcKSpy/vXJGh6dtgQHRl2eyg19W1NJzegkChQIInFixda9jJiUxfzVX3J2hwb8bnB3WtZTMzqJHgWCSIwdKSzi2bkr+dM7y6mWmMBj3z2NoT2bq+2ERF2JAsHMBgFPAAnAc+4+5qjlrYDxQENgB/BDd18fLBsLXApUAmYCv3R3N7PZQFPgQPA0F7r7llPeIpEyZFH+LoZPyiJnw24u7taEB67sSqNkNaOT2DhuIJhZAvA0cAGwHphvZhnunhs27DFggru/YGYDgdHA9WbWD+gPpAXjPgAGALODx9e5e2ZEtkSkDDl4pJA/v7ecZ95fSd3qSfz1up5c3L1prMuSCq4kRwi9gTx3XwlgZi8DVwLhgZAK3BFMzwKmBNMOVAWSAAMSgc2nXrZI2ZW5egd3Tcpi5dZ9XH1GC+69tAt1qqsZncReSQKhObAu7PF6oM9RYxYCQwidVhoMJJtZfXefZ2azgI2EAuEpd18ctt7zZlYITAIednc/ye0QiXv7DoWa0b0wbzXNaldjws29Oadjw1iXJfKVSF1UvhN4ysxuAuYA+UChmbUHugAtgnEzzexsd59L6HRRvpklEwqE64EJRz+xmQ0DhgGkpKREqFyR6Hp/2VbunpzNhl0HuLFva35zUSdqqBmdxJmSvCPzgZZhj1sE877i7hsIHSFgZjWBoe6+08xuAT52973BsmlAX2Cuu+cH6+4xsxcJnZr6n0Bw93HAOID09HQdQUiZsnP/YR56YzGTPltPu4Y1+M+tfUlvrWZ0Ep9K8h34+UAHM2tjZknANUBG+AAza2Bmxc81ktAdRwBrgQFmVtnMEgldUF4cPG4QrJsIXAYsOvXNEYkf07I3cv7jc5jyRT63n9eeN39xtsJA4tpxjxDcvcDMbgdmELrtdLy755jZg0Cmu2cA5wKjzcwJnTK6LVh9IjAQyCZ0gXm6u081sxrAjCAMEoB3gGcju2kisbFl90Hufz2H6Tmb6NqsFi/c3IuuzdSMTuKflaXruOnp6Z6ZqbtUJT65OxMXrOehN3I5WFDEr8/vyC1nt6GymtFJjJnZAndPP944XdUSiYB1O/Zz92vZzF2+jV6t6zJmaBrtGqoZnZQtCgSRU1BY5EyYt5rfz1iKAQ9d2ZXr+rRSMzopkxQIIicpb8sehk/KZsGaLxnQsSG/G9Kd5nWqxboskZOmQBA5QUcKi/jb+yt48t08qldJ4PHvncbg09WMTso+BYLICViUv4vfTMxi8cbdXJrWlFGXd6VhcpVYlyUSEQoEkRI4eKSQP72znGfnrqRejST+dv0ZXNS1SazLEokoBYLIcXyycjsjJmezats+vp/ekrsv6ULt6omxLksk4hQIIt9gz8EjjJ2+lH9+vIaW9arxrx/34awODWJdlkipUSCIHMOspVu4Z3I2G3cf5Ob+bbjzoo5UT9Kvi5RveoeLhPly32EeeiOXyZ/n06FRTSb9rB89U+rGuiyRqFAgiBBqO/Fm9kZ++3oOuw4c4RcD23PbwPZUqZwQ69JEokaBIBXe5t0HuXfKImbmbqZ789r86yd96NK0VqzLEok6BYJUWO7Oq5nrePjNxRwuKGLkxZ358VlqRicVlwJBKqS12/czYnIWH63YTp829RgzNI02DWrEuiyRmFIgSIVSWOT846PVPDZjKQmVjEcGd+PaXilqRieCAkEqkGWb93DXxCy+WLeTgZ0b8cjgbjStrWZ0IsUUCFLuHS4o4q+zV/DUrOXUrFKZJ67pwRWnNVMzOpGjKBCkXFu4bifDJ2WxZNMeLj+tGaMuT6V+TTWjEzkWBYKUSwcOF/LHd5bx3NyVNEyuwrM3pHNBauNYlyUS1xQIUu7MW7GdkZOzWL19P9f2TmHkJZ2pVVXN6ESOR4Eg5cbug0cYM20JL36yllb1q/PiLX3o107N6ERKSoEg5cJ7SzZz9+RFbNlzkFvObsMdF3SiWpLaToicCAWClGnb9x7iwTdyef2LDXRqnMwz159Bj5Z1Yl2WSJmkQJAyyd3JWLiBB6bmsufgEX51fgd+fm57kiqr7YTIyVIgSJmzcdcB7n1tEe8u2cJpLeswdmganZokx7oskTKvRB+nzGyQmS01szwzG3GM5a3M7F0zyzKz2WbWImzZWDPLMbPFZvakHfVtIDPLMLNFp74pUt4VFTkvfrKWCx+fw4crtnHvpV2Y/LN+CgORCDnuEYKZJQBPAxcA64H5Zpbh7rlhwx4DJrj7C2Y2EBgNXG9m/YD+QFow7gNgADA7eO4hwN4IbYuUY6u37WPE5Cw+XrmDvm3rM2Zod1rVVzM6kUgqySmj3kCeu68EMLOXgSuB8EBIBe4IpmcBU4JpB6oCSYABicDm4HlqBusMA149pa2QcqugsIjnP1zNH2YuJbFSJcYM6c73e7VU2wmRUlCSQGgOrAt7vB7oc9SYhcAQ4AlgMJBsZvXdfZ6ZzQI2EgqEp9x9cbDOQ8AfgP2nUL+UY0s27Wb4xCwWrt/F+V0a8fBV3WlSu2qsyxIptyJ1UflO4CkzuwmYA+QDhWbWHugCFF9TmGlmZwN7gHbu/msza/1tT2xmwwgdRZCSkhKhciWeHSoo5OlZK/jLrDxqV0vkz9eezmVpTXVUIFLKShII+UDLsMctgnlfcfcNhI4Qik8FDXX3nWZ2C/Cxu+8Nlk0D+hIKhHQzWx3U0MjMZrv7uUe/uLuPA8YBpKen+wltnZQ5n6/9kuGTsli2eS9X9WjG/Zd3pV6NpFiXJVIhlOQuo/lABzNrY2ZJwDVARvgAM2tgZsXPNRIYH0yvBQaYWWUzSyR0QXmxu//V3Zu5e2vgLGDZscJAKo79hwt46I1chvz1I/YcLGD8Ten86ZrTFQYiUXTcIwR3LzCz24EZQAIw3t1zzOxBINPdM4BzgdFm5oROGd0WrD4RGAhkE7rAPN3dp0Z+M6Qs+yhvGyMmZ7N2x35+eGYKwwd1JlnN6ESiztzLzlmY9PR0z8zMjHUZEiG7Dhxh9FuLeXn+OlrXr86YoWmc2bZ+rMsSKXfMbIG7px9vnL6pLDHxds4m7p2yiG17D3HrgLb8+vyOVE1UMzqRWFIgSFRt23uIURk5vJG1kc5NknnuxnTSWqgZnUg8UCBIVLg7U77I54Gpuew/VMj/u6Ajtw5op2Z0InFEgSClbsPOA9zzWjazlm7l9JRQM7oOjdV/SCTeKBCk1BQVOf/+dC2PTltCYZFz/2Wp3NivNQmV9AUzkXikQJBSsXLrXkZMzubTVTs4q30DRg/pTst61WNdloh8CwWCRFRBYRHPfbCKP85cRlLlSowdmsZ301uo7YRIGaBAkIjJ3bCbuyYtZFH+bi5MbcxDV3WjcS01oxMpKxQIcsoOFRTy1Ht5/HX2CupUT+Qv1/Xk4m5NdFQgUsYoEOSULFgTakaXt2UvQ3o2575LU6mr/kMiZZICQU7KvkMFPPb2Uv7x0Wqa1a7GP37Ui3M7NYp1WSJyChQIcsLmLt/KyMnZrP/yADf0bcVdgzpTs4reSiJlnX6LpcR27T/Cw2/m8p8F62nboAav3tqX3m3qxbosEYkQBYKUyPRFm7jv9UXs2HeYn5/bjl98p4Oa0YmUMwoE+VZb9hxkVEYOb2VvIrVpLZ6/qRfdmteOdVkiUgoUCHJM7s7kz/J58I1cDhwp5DcXdWLYOW1JTFAzOpHySoEg/2P9l/u5+7VFzFm2lTNa1eXRoWm0b1Qz1mWJSClTIMhXioqcf368hkenLwHggSu6cv2ZraikZnQiFYICQQBYsXUvwydmkbnmS87u0IDfDVYzOpGKRoFQwR0pLGLcnJU88e5yqiUm8Nh3T2Noz+ZqOyFSASkQKrBF+bsYPimLnA27uaR7E0Zd0ZVGyWpGJ1JRKRAqoINHCnny3eX8bc5K6lZP4pkf9mRQt6axLktEYkyBUMHMX72D4ROzWLltH989owX3XppK7eqJsS5LROKAAqGC2HuogLHTlzBh3hqa16nGhJt7c07HhrEuS0TiSIkCwcwGAU8ACcBz7j7mqOWtgPFAQ2AH8EN3Xx8sGwtcClQCZgK/dHc3s+lA06CGucBt7l4Yka2Sr3l/2VbunpzNhl0HuKlfa35zUSdqqBmdiBzluF87NbME4GngYiAVuNbMUo8a9hgwwd3TgAeB0cG6/YD+QBrQDegFDAjW+Z67nxbMbwh895S3Rr5m5/7D3PHqF9w4/lOqJlZi4k/7MuqKrgoDETmmkvxl6A3kuftKADN7GbgSyA0bkwrcEUzPAqYE0w5UBZIAAxKBzQDuvjushqRgrETIW9kbuf/1Rezcf4Tbz2vP7QPbqxmdiHyrkjSmaQ6sC3u8PpgXbiEwJJgeDCSbWX13n0coIDYGPzPcfXHxSmY2A9gC7AEmntQWyNds2X2QW/+Zyc///RlNalfl9dv7c+dFnRQGInJckepUdicwwMw+J3RKKB8oNLP2QBegBaEQGWhmZxev5O4XEbqOUAUYeKwnNrNhZpZpZplbt26NULnlj7vzauY6zn/8fWYt3crwQZ2Z8vP+dG2mzqQiUjIlOWWUD7QMe9wimPcVd99AcIRgZjWBoe6+08xuAT52973BsmlAX0IXkYvXPWhmrxM6DTXz6Bd393HAOID09HSdVjqGdTv2M3JyNh/kbaN363qMGdqdtg3VjE5ETkxJjhDmAx3MrI2ZJQHXABnhA8ysgZkVP9dIQnccAawldORQ2cwSCR09LDazmmbWNFi3MqG7kJac+uZULIVFzvMfruLCP87h87Vf8tBV3Xh52JkKAxE5Kcc9QnD3AjO7HZhB6LbT8e6eY2YPApnungGcC4w2MwfmALcFq08kdCoom9BF4+nuPtXMGgMZZlaFUCjNAp6J7KaVb3lb9nDXxCw+W7uTczs15JHB3Wlep1qsyxKRMszcy85ZmPT0dM/MzIx1GTF1pLCIZ2av4M/v5VG9SgK/vTyVq3qoGZ2IfDMzW+Du6ccbpxvSy5Ds9bv4zcSFLNm0h0vTmvLAFV1pULNKrMsSkXJCgVAGHDxSyB/fWcazc1bSoGYV/nb9GVzUtUmsyxKRckaBEOc+WbmdEZOzWbVtH99Pb8ndl3ahdjU1oxORyFMgxKk9B4/w6PQl/OvjtbSsV41//6QP/ds3iHVZIlKOKRDi0KwlW7jntWw27j7Ij89qw/+7sCPVk/RPJSKlS39l4siOfYd56I1cXvs8nw6NajLpZ/3omVI31mWJSAWhQIgD7s4bWRsZlZHDrgNH+MV3OnDbee2oUln9h0QkehQIMbZ590HueW0R7yzeTFqL2vzrJ33o0rRWrMsSkQpIgRAj7s4r89fxyFuLOVy2vXzBAAAK30lEQVRQxN2XdObm/m2onBCpfoMiIidGgRADa7fvZ8TkLD5asZ0+berx6NA0WjeoEeuyRKSCUyBEUXEzusfeXkrlSpX43eDuXNOrJZUqqe2EiMSeAiFKlm7aw/BJWXyxbicDOzfikcHdaFpbzehEJH4oEErZ4YIi/jI7j6dn5ZFcNZEnrunBFac1UzM6EYk7CoRStHDdTu6amMXSzXu44rRm/PbyVOqrGZ2IxCkFQik4cLiQx2cu5e8frKJRclWeuyGd81Mbx7osEZFvpUCIsHkrtjNichZrtu/nB31SGHFxZ2pVVTM6EYl/CoQI2X3wCKPfWsJLn66lVf3qvHhLH/q1UzM6ESk7FAgR8E7uZu6Zks3WPYcYdk5bfn1+R6olqe2EiJQtCoRTsH3vIR6YmkvGwg10apzM365Pp0fLOrEuS0TkpCgQToK7k7FwA6Mycth7qIBfn9+Rn53bjqTKajshImWXAuEEbdx1gHtfW8S7S7bQo2Udxl6dRsfGybEuS0TklCkQSqioyHlp/lpGv7WEgqIi7r20Cz/q34YEtZ0QkXJCgVACq7ftY8TkLD5euYN+7eozZkgaKfWrx7osEZGIUiB8i4LCIsZ/uIo/vL2MpIRKjBnSne/3aqm2EyJSLikQvsHijbsZPimLrPW7OL9LYx6+qhtNaleNdVkiIqWmRLfFmNkgM1tqZnlmNuIYy1uZ2btmlmVms82sRdiysWaWY2aLzexJC6luZm+a2ZJg2ZhIbtSpOFRQyOMzl3H5nz8g/8sD/Pna03n2hjMUBiJS7h33CMHMEoCngQuA9cB8M8tw99ywYY8BE9z9BTMbCIwGrjezfkB/IC0Y9wEwAPgUeMzdZ5lZEvCumV3s7tMitmUn4bO1XzJ8YhbLt+xl8OnNue+yVOrVSIplSSIiUVOSU0a9gTx3XwlgZi8DVwLhgZAK3BFMzwKmBNMOVAWSAAMSgc3uvj8Yh7sfNrPPgBbEyP7DBfzh7WWM/3AVTWpV5fmbenFe50axKkdEJCZKEgjNgXVhj9cDfY4asxAYAjwBDAaSzay+u88zs1nARkKB8JS7Lw5f0czqAJcH60bdh3nbGDE5i3U7DvDDM1MYPqgzyWpGJyIVUKQuKt8JPGVmNwFzgHyg0MzaA13476f/mWZ2trvPBTCzysBLwJPFRyBHM7NhwDCAlJSUCJULuw4c4XdvLuaVzHW0aVCDV4adSZ+29SP2/CIiZU1JAiEfaBn2uEUw7yvuvoHQEQJmVhMY6u47zewW4GN33xssmwb0BeYGq44Dlrv7n77pxd19XDCO9PR0L8lGHc/bOZu4d8oitu09xK0DQs3oqiaqGZ2IVGwluctoPtDBzNoEF4CvATLCB5hZAzMrfq6RwPhgei0wwMwqm1kioQvKi4N1HgZqA7869c0oma17DnHbi58x7J8LqFcjiSm39WfkxV0UBiIilOAIwd0LzOx2YAaQAIx39xwzexDIdPcM4FxgtJk5oVNGtwWrTwQGAtmELjBPd/epwW2p9wBLgM+CL3o95e7PRXTr/rsNTPkinwem5rL/UCF3XtiRWwe0IzFBzehERIqZe0TOwkRFenq6Z2ZmntA6RwqLGDYhk1lLt9IzJdSMrn0jNaMTkYrDzBa4e/rxxpX7byonJlSibcOanNOxITf0ba1mdCIi36DcBwLAfZelxroEEZG4p5PoIiICKBBERCSgQBAREUCBICIiAQWCiIgACgQREQkoEEREBFAgiIhIoEy1rjCzrcCaUnyJBsC2Unz+UxXP9am2k6PaTl481xdvtbVy94bHG1SmAqG0mVlmSfp9xEo816faTo5qO3nxXF881/ZtdMpIREQABYKIiAQUCF83LtYFHEc816faTo5qO3nxXF881/aNdA1BREQAHSGIiEigQgWCmbU0s1lmlmtmOWb2y2B+PTObaWbLg//WDeabmT1pZnlmlmVmPWNQ2+/NbEnw+q+ZWZ1gfmszO2BmXwQ/z8SgtlFmlh9WwyVh64wM9ttSM7soBrW9ElbXajP7Ipgftf0WvF5VM/vUzBYG9T0QzG9jZp8E++iV4P9XjplVCR7nBctbx6C2fwf/bovMbHzw/0PHzM41s11h++7+GNT2DzNbFVZDj2B+NH9Xv6m2uWF1bTCzKcH8qO23U+buFeYHaAr0DKaTgWVAKjAWGBHMHwE8GkxfAkwDDDgT+CQGtV0IVA7mPxpWW2tgUYz32yjgzmOMTwUWAlWANsAKICGatR015g/A/dHeb8HrGVAzmE4EPgneS68C1wTznwF+Fkz/HHgmmL4GeCUGtV0SLDPgpbDazgXeiPF++wdw9THGR/N39Zi1HTVmEnBDtPfbqf5UqCMEd9/o7p8F03uAxUBz4ErghWDYC8BVwfSVwAQP+RioY2ZNo1mbu7/t7gXBsI+BFqXx+idT27esciXwsrsfcvdVQB7QOxa1mZkB3yP0hy3qgvfO3uBhYvDjwEBgYjD/6Pdc8XtxIvCdYBuiVpu7vxUsc+BTYvOe+6b99k2i+bv6rbWZWS1C/75TSuP1S1OFCoRwwaH46YTSvbG7bwwWbQIaB9PNgXVhq63n2/8QlkZt4W4m9CmoWBsz+9zM3jezs0u7rm+o7fbgEH28BafaiK/9djaw2d2Xh82L6n4zs4TglNUWYCahI6adYUEfvn++2nfB8l1A/WjV5u6fhC1LBK4Hpoet0jc4VTLNzLqWVl3Hqe2R4D33RzOrEsyL6nvu2/YboXB/1913h82L2n47FRUyEMysJqFDul8d9Y9G8KkoZrdefVNtZnYPUAD8O5i1EUhx99OBO4AXg08m0aztr0A7oEdQzx9K8/VPsLZi1/L1o4Oo7zd3L3T3HoQ+afcGOpfm652Io2szs25hi/8CzHH3ucHjzwi1QDgN+DOl/An4G2obSWj/9QLqAcNLs4YTrK3Y0e+5qO63U1HhAiH41DMJ+Le7Tw5mby4+vAz+uyWYnw+0DFu9RTAvmrVhZjcBlwHXBYFFcDpmezC9gNCnzo7RrM3dNwe/GEXAs/z3tFC87LfKwBDgleJ50d5v4dx9JzAL6EvolEblYFH4/vlq3wXLawPbo1jboOC1fws0JBSaxWN2F58qcfe3gEQzaxDN2oJThO7uh4DnidF77li1AQT7ozfwZtiYmOy3k1GhAiE4F/t3YLG7Px62KAO4MZi+EXg9bP4NwR0MZwK7wk4tRaU2MxsE3AVc4e77w+Y3NLOEYLot0AFYGeXaws/RDgYWBdMZwDUWumOmTVDbp9GsLXA+sMTd14eNj9p+C3u94jvDqgEXELrOMQu4Ohh29Huu+L14NfBe8YeAKNW2xMx+AlwEXBuEffH4JsXXM8ysN6G/H6USVt9SW/EHNyN0aib8PRet39Vj1hYsvprQBeSDYeOjtt9OmUf5KnYsf4CzCJ0OygK+CH4uIXSO9l1gOfAOUM//ezfB04Q+RWYD6TGoLY/QudHiecV3oAwFcoJ5nwGXx6C2fwb7JYvQL2TTsHXuCfbbUuDiaNcWLPsH8NOjxkdtvwWvlwZ8HtS3iP/e7dSWUEjmAf8BqgTzqwaP84LlbWNQW0Hwb1e8P4vn3x7su4WEbnDoF4Pa3gvec4uAf/Hfu32i+bt6zNqCZbMJHcmEj4/afjvVH31TWUREgAp2ykhERL6ZAkFERAAFgoiIBBQIIiICKBBERCSgQBAREUCBICIiAQWCiIgA8P8BqX9Iu2aZ+RQAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_metrics()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`callback`](/callback.html#callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_backward_begin[source]

\n", "\n", "> on_backward_begin(`smooth_loss`:`Tensor`, `kwargs`:`Any`)\n", "\n", "Record the loss before any other callback has a chance to modify it. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.on_backward_begin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_batch_begin[source]

\n", "\n", "> on_batch_begin(`train`, `kwargs`:`Any`)\n", "\n", "Record learning rate and momentum at beginning of batch. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.on_batch_begin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source]

\n", "\n", "> on_epoch_end(`epoch`:`int`, `num_batch`:`int`, `smooth_loss`:`Tensor`, `last_metrics`=`'Collection'`, `kwargs`:`Any`) → `bool`\n", "\n", "Save epoch info: num_batch, smooth_loss, metrics. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source]

\n", "\n", "> on_train_begin(`pbar`:`PBar`, `metrics_names`:`StrList`, `kwargs`:`Any`)\n", "\n", "Initialize recording status at beginning of training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.on_train_begin)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Module functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generally you'll want to use a [`Learner`](/basic_train.html#Learner) to train your model, since they provide a lot of functionality and make things easier. However, for ultimate flexibility, you can call the same underlying functions that [`Learner`](/basic_train.html#Learner) calls behind the scenes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

fit[source]

\n", "\n", "> fit(`epochs`:`int`, `model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `loss_func`:`LossFunction`, `opt`:[`Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), `data`:[`DataBunch`](/basic_data.html#DataBunch), `callbacks`:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=`None`, `metrics`:`OptMetrics`=`None`)\n", "\n", "Fit the `model` on `data` and learn using `loss` and `opt`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(fit)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that you have to create the `Optimizer` yourself if you call this function, whereas [`Learn.fit`](/basic_train.html#fit) creates it for you automatically." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

train_epoch[source]

\n", "\n", "> train_epoch(`model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `dl`:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), `opt`:[`Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), `loss_func`:`LossFunction`)\n", "\n", "Simple training of `model` for 1 epoch of `dl` using optim `opt` and loss function `loss_func`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(train_epoch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You won't generally need to call this yourself - it's what [`fit`](/basic_train.html#fit) calls for each epoch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

validate[source]

\n", "\n", "> validate(`model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `dl`:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), `loss_func`:`OptLossFunc`=`None`, `cb_handler`:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=`None`, `pbar`:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=`None`, `average`=`True`, `n_batch`:`Optional`\\[`int`\\]=`None`) → `Iterator`\\[`Tuple`\\[`IntOrTensor`, `Ellipsis`\\]\\]\n", "\n", "Calculate loss and metrics for the validation set. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(validate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is what [`fit`](/basic_train.html#fit) calls after each epoch. You can call it if you want to run inference on a [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) manually." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

get_preds[source]

\n", "\n", "> get_preds(`model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `dl`:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), `pbar`:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=`None`, `cb_handler`:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=`None`, `activ`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=`None`, `loss_func`:`OptLossFunc`=`None`, `n_batch`:`Optional`\\[`int`\\]=`None`) → `List`\\[`Tensor`\\]\n", "\n", "Tuple of predictions and targets, and optional losses (if `loss_func`) using `dl`, max batches `n_batch`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(get_preds)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

loss_batch[source]

\n", "\n", "> loss_batch(`model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `xb`:`Tensor`, `yb`:`Tensor`, `loss_func`:`OptLossFunc`=`None`, `opt`:`OptOptimizer`=`None`, `cb_handler`:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=`None`) → `Tuple`\\[`Union`\\[`Tensor`, `int`, `float`, `str`\\]\\]\n", "\n", "Calculate loss and metrics for a batch, call out to callbacks as necessary. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(loss_batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You won't generally need to call this yourself - it's what [`fit`](/basic_train.html#fit) and [`validate`](/basic_train.html#validate) call for each batch. It only does a backward pass if you set `opt`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Other classes" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class LearnerCallback[source]

\n", "\n", "> LearnerCallback(`learn`:[`Learner`](/basic_train.html#Learner)) :: [`Callback`](/callback.html#Callback)\n", "\n", "Base class for creating callbacks for a [`Learner`](/basic_train.html#Learner). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(LearnerCallback, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class RecordOnCPU[source]

\n", "\n", "> RecordOnCPU() :: [`Callback`](/callback.html#Callback)\n", "\n", "Stores the `input` and `target` going through the model on the CPU. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(RecordOnCPU, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_tta_only[source]

\n", "\n", "> _tta_only(`learn`:[`Learner`](/basic_train.html#Learner), `ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=``, `scale`:`float`=`1.35`) → `Iterator`\\[`List`\\[`Tensor`\\]\\]\n", "\n", "Computes the outputs for several augmented inputs for TTA " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.tta_only)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

get_preds[source]

\n", "\n", "> get_preds(`ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=``, `with_loss`:`bool`=`False`, `n_batch`:`Optional`\\[`int`\\]=`None`, `pbar`:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=`None`) → `List`\\[`Tensor`\\]\n", "\n", "Return predictions and targets on the valid, train, or test set, depending on `ds_type`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.get_preds)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_TTA[source]

\n", "\n", "> _TTA(`learn`:[`Learner`](/basic_train.html#Learner), `beta`:`float`=`0.4`, `scale`:`float`=`1.35`, `ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=``, `with_loss`:`bool`=`False`) → `Tensors`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.TTA)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

format_stats[source]

\n", "\n", "> format_stats(`stats`:`MetricsList`)\n", "\n", "Format stats before printing. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.format_stats)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

add_metrics[source]

\n", "\n", "> add_metrics(`metrics`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.add_metrics)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

add_metric_names[source]

\n", "\n", "> add_metric_names(`names`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.add_metric_names)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

on_batch_begin[source]

\n", "\n", "> on_batch_begin(`last_input`, `last_target`, `kwargs`)\n", "\n", "Set HP before the step is done. Returns xb, yb (which can allow us to modify the input at that step if needed). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(RecordOnCPU.on_batch_begin)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Methods - Please document or move to the undocumented section" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Learner class and training loop", "title": "basic_train" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }