{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Basic training functionality" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.basic_train import *\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`basic_train`](/basic_train.html#basic_train) wraps together the data (in a [`DataBunch`](/basic_data.html#DataBunch) object) with a pytorch model to define a [`Learner`](/basic_train.html#Learner) object. This is where the basic training loop is defined for the [`fit`](/basic_train.html#fit) function. The [`Learner`](/basic_train.html#Learner) object is the entry point of most of the [`Callback`](/callback.html#Callback) functions that will customize this training loop in different ways (and made available through the [`train`](/train.html#train) module), notably:\n", "\n", " - [`Learner.lr_find`](/train.html#lr_find) will launch an LR range test that will help you select a good learning rate\n", " - [`Learner.fit_one_cycle`](/train.html#fit_one_cycle) will launch a training using the 1cycle policy, to help you train your model fast.\n", " - [`Learner.to_fp16`](/train.html#to_fp16) will convert your model in half precision and help you launch a training in mixed precision." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
class Learner[source]Learner(`data`:[`DataBunch`](/basic_data.html#DataBunch), `model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `opt_func`:`Callable`=`'Adam'`, `loss_func`:`Callable`=`None`, `metrics`:`Collection`\\[`Callable`\\]=`None`, `true_wd`:`bool`=`True`, `bn_wd`:`bool`=`True`, `wd`:`Floats`=`0.01`, `train_bn`:`bool`=`True`, `path`:`str`=`None`, `model_dir`:`str`=`'models'`, `callback_fns`:`Collection`\\[`Callable`\\]=`None`, `callbacks`:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=`fit[source]fit(`epochs`:`int`, `lr`:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=`slice(None, 0.003, None)`, `wd`:`Floats`=`None`, `callbacks`:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=`None`)\n",
"\n",
"Fit the model on this learner with `lr` learning rate, `wd` weight decay for `epochs` with `callbacks`. "
],
"text/plain": [
"fit_one_cycle[source]fit_one_cycle(`learn`:[`Learner`](/basic_train.html#Learner), `cyc_len`:`int`, `max_lr`:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=`slice(None, 0.003, None)`, `moms`:`Point`=`(0.95, 0.85)`, `div_factor`:`float`=`25.0`, `pct_start`:`float`=`0.3`, `wd`:`float`=`None`, `callbacks`:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=`None`, `kwargs`)\n",
"\n",
"Fit a model following the 1cycle policy. "
],
"text/plain": [
"lr_find[source]lr_find(`learn`:[`Learner`](/basic_train.html#Learner), `start_lr`:`Floats`=`1e-07`, `end_lr`:`Floats`=`10`, `num_it`:`int`=`100`, `stop_div`:`bool`=`True`, `kwargs`:`Any`)\n",
"\n",
"Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss explodes. "
],
"text/plain": [
"get_preds[source]get_preds(`ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=`validate[source]validate(`dl`=`None`, `callbacks`=`None`, `metrics`=`None`)\n",
"\n",
"Validate on `dl` with potential `callbacks` and `metrics`. "
],
"text/plain": [
"TTA[source]TTA(`learn`:[`Learner`](/basic_train.html#Learner), `beta`:`float`=`0.4`, `scale`:`float`=`1.35`, `ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=`to_fp16[source]to_fp16(`learn`:[`Learner`](/basic_train.html#Learner), `loss_scale`:`float`=`512.0`, `flat_master`:`bool`=`False`) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"Transform `learn` in FP16 precision. "
],
"text/plain": [
"lr_range[source]lr_range(`lr`:`Union`\\[`float`, `slice`\\]) → `ndarray`\n",
"\n",
"Build differential learning rates. "
],
"text/plain": [
"unfreeze[source]unfreeze()\n",
"\n",
"Unfreeze entire model. "
],
"text/plain": [
"freeze[source]freeze()\n",
"\n",
"Freeze up to last layer. "
],
"text/plain": [
"freeze_to[source]freeze_to(`n`:`int`)\n",
"\n",
"Freeze layers up to layer `n`. "
],
"text/plain": [
"split[source]split(`split_on`:`SplitFuncOrIdxList`)\n",
"\n",
"Split the model at `split_on`. "
],
"text/plain": [
"load[source]load(`name`:`PathOrStr`, `device`:[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)=`None`)\n",
"\n",
"Load model `name` from `self.model_dir` using `device`, defaulting to `self.data.device`. "
],
"text/plain": [
"save[source]save(`name`:`PathOrStr`, `return_path`:`bool`=`False`) → `Union`\\[`NoneType`, `str`\\]\n",
"\n",
"Save model with `name` to `self.model_dir`, and return path if `return_path`. "
],
"text/plain": [
"show_results[source]show_results(`ds_type`=`predict[source]predict(`img`:[`ItemBase`](/core.html#ItemBase), `pbar`:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=`None`, `kwargs`)\n",
"\n",
"Return prect class, label and probabilities for `img`. "
],
"text/plain": [
"validate[source]validate(`dl`=`None`, `callbacks`=`None`, `metrics`=`None`)\n",
"\n",
"Validate on `dl` with potential `callbacks` and `metrics`. "
],
"text/plain": [
"Learner_create_unet[source]Learner_create_unet(`data`:[`DataBunch`](/basic_data.html#DataBunch), `arch`:`Callable`, `pretrained`:`bool`=`True`, `split_on`:`Union`\\[`Callable`, `Collection`\\[`ModuleList`\\], `NoneType`\\]=`None`, `kwargs`:`Any`)"
],
"text/plain": [
"init[source]init(`init`)"
],
"text/plain": [
"mixup[source]mixup(`learn`:[`Learner`](/basic_train.html#Learner), `alpha`:`float`=`0.4`, `stack_x`:`bool`=`False`, `stack_y`:`bool`=`True`) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"Add mixup https://arxiv.org/abs/1710.09412 to `learn`. "
],
"text/plain": [
"pred_batch[source]pred_batch(`ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=`create_opt[source]create_opt(`lr`:`Floats`, `wd`:`Floats`=`0.0`)\n",
"\n",
"Create optimizer with `lr` learning rate and `wd` weight decay. "
],
"text/plain": [
"dl[source]dl(`ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=`class Recorder[source]Recorder(`learn`:[`Learner`](/basic_train.html#Learner)) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"A [`LearnerCallback`](/basic_train.html#LearnerCallback) that records epoch, loss, opt and metric data during training. "
],
"text/plain": [
"plot[source]plot(`skip_start`:`int`=`10`, `skip_end`:`int`=`5`)\n",
"\n",
"Plot learning rate and losses, trimmed between `skip_start` and `skip_end`. "
],
"text/plain": [
"plot_losses[source]plot_losses(`last`:`int`=`None`)\n",
"\n",
"Plot training and validation losses. "
],
"text/plain": [
"plot_lr[source]plot_lr(`show_moms`=`False`)\n",
"\n",
"Plot learning rate, `show_moms` to include momentum. "
],
"text/plain": [
"plot_metrics[source]plot_metrics()\n",
"\n",
"Plot metrics collected during training. "
],
"text/plain": [
"on_backward_begin[source]on_backward_begin(`smooth_loss`:`Tensor`, `kwargs`:`Any`)\n",
"\n",
"Record the loss before any other callback has a chance to modify it. "
],
"text/plain": [
"on_batch_begin[source]on_batch_begin(`train`, `kwargs`:`Any`)\n",
"\n",
"Record learning rate and momentum at beginning of batch. "
],
"text/plain": [
"on_epoch_end[source]on_epoch_end(`epoch`:`int`, `num_batch`:`int`, `smooth_loss`:`Tensor`, `last_metrics`=`'Collection'`, `kwargs`:`Any`) → `bool`\n",
"\n",
"Save epoch info: num_batch, smooth_loss, metrics. "
],
"text/plain": [
"on_train_begin[source]on_train_begin(`pbar`:`PBar`, `metrics_names`:`StrList`, `kwargs`:`Any`)\n",
"\n",
"Initialize recording status at beginning of training. "
],
"text/plain": [
"fit[source]fit(`epochs`:`int`, `model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `loss_func`:`LossFunction`, `opt`:[`Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), `data`:[`DataBunch`](/basic_data.html#DataBunch), `callbacks`:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=`None`, `metrics`:`OptMetrics`=`None`)\n",
"\n",
"Fit the `model` on `data` and learn using `loss` and `opt`. "
],
"text/plain": [
"train_epoch[source]train_epoch(`model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `dl`:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), `opt`:[`Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), `loss_func`:`LossFunction`)\n",
"\n",
"Simple training of `model` for 1 epoch of `dl` using optim `opt` and loss function `loss_func`. "
],
"text/plain": [
"validate[source]validate(`model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `dl`:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), `loss_func`:`OptLossFunc`=`None`, `cb_handler`:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=`None`, `pbar`:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=`None`, `average`=`True`, `n_batch`:`Optional`\\[`int`\\]=`None`) → `Iterator`\\[`Tuple`\\[`IntOrTensor`, `Ellipsis`\\]\\]\n",
"\n",
"Calculate loss and metrics for the validation set. "
],
"text/plain": [
"get_preds[source]get_preds(`model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `dl`:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), `pbar`:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=`None`, `cb_handler`:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=`None`, `activ`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=`None`, `loss_func`:`OptLossFunc`=`None`, `n_batch`:`Optional`\\[`int`\\]=`None`) → `List`\\[`Tensor`\\]\n",
"\n",
"Tuple of predictions and targets, and optional losses (if `loss_func`) using `dl`, max batches `n_batch`. "
],
"text/plain": [
"loss_batch[source]loss_batch(`model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `xb`:`Tensor`, `yb`:`Tensor`, `loss_func`:`OptLossFunc`=`None`, `opt`:`OptOptimizer`=`None`, `cb_handler`:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=`None`) → `Tuple`\\[`Union`\\[`Tensor`, `int`, `float`, `str`\\]\\]\n",
"\n",
"Calculate loss and metrics for a batch, call out to callbacks as necessary. "
],
"text/plain": [
"class LearnerCallback[source]LearnerCallback(`learn`:[`Learner`](/basic_train.html#Learner)) :: [`Callback`](/callback.html#Callback)\n",
"\n",
"Base class for creating callbacks for a [`Learner`](/basic_train.html#Learner). "
],
"text/plain": [
"_tta_only[source]_tta_only(`learn`:[`Learner`](/basic_train.html#Learner), `ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=`get_preds[source]get_preds(`ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=`_TTA[source]_TTA(`learn`:[`Learner`](/basic_train.html#Learner), `beta`:`float`=`0.4`, `scale`:`float`=`1.35`, `ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=`format_stats[source]format_stats(`stats`:`MetricsList`)\n",
"\n",
"Format stats before printing. "
],
"text/plain": [
"add_metrics[source]add_metrics(`metrics`)"
],
"text/plain": [
"add_metric_names[source]add_metric_names(`names`)"
],
"text/plain": [
"