{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Basic training functionality" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.basic_train import *\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`basic_train`](/basic_train.html#basic_train) wraps together the data (in a [`DataBunch`](/basic_data.html#DataBunch) object) with a pytorch model to define a [`Learner`](/basic_train.html#Learner) object. This is where the basic training loop is defined for the [`fit`](/basic_train.html#fit) function. The [`Learner`](/basic_train.html#Learner) object is the entry point of most of the [`Callback`](/callback.html#Callback) functions that will customize this training loop in different ways (and made available through the [`train`](/train.html#train) module), notably:\n", "\n", " - [`Learner.lr_find`](/train.html#lr_find) will launch an LR range test that will help you select a good learning rate\n", " - [`Learner.fit_one_cycle`](/train.html#fit_one_cycle) will launch a training using the 1cycle policy, to help you train your model fast.\n", " - [`Learner.to_fp16`](/train.html#to_fp16) will convert your model in half precision and help you launch a training in mixed precision." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
class
Learner
[source]Learner
(`data`:[`DataBunch`](/basic_data.html#DataBunch), `model`:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), `opt_func`:`Callable`=`'Adam'`, `loss_func`:`Callable`=`None`, `metrics`:`Collection`\\[`Callable`\\]=`None`, `true_wd`:`bool`=`True`, `bn_wd`:`bool`=`True`, `wd`:`Floats`=`0.01`, `train_bn`:`bool`=`True`, `path`:`str`=`None`, `model_dir`:`str`=`'models'`, `callback_fns`:`Collection`\\[`Callable`\\]=`None`, `callbacks`:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=`epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "
---|---|---|---|
1 | \n", "0.142597 | \n", "0.085823 | \n", "0.968106 | \n", "
fit
[source]fit
(`epochs`:`int`, `lr`:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=`slice(None, 0.003, None)`, `wd`:`Floats`=`None`, `callbacks`:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=`None`)\n",
"\n",
"Fit the model on this learner with `lr` learning rate, `wd` weight decay for `epochs` with `callbacks`. "
],
"text/plain": [
"fit_one_cycle
[source]fit_one_cycle
(`learn`:[`Learner`](/basic_train.html#Learner), `cyc_len`:`int`, `max_lr`:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=`slice(None, 0.003, None)`, `moms`:`Point`=`(0.95, 0.85)`, `div_factor`:`float`=`25.0`, `pct_start`:`float`=`0.3`, `wd`:`float`=`None`, `callbacks`:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=`None`, `kwargs`)\n",
"\n",
"Fit a model following the 1cycle policy. "
],
"text/plain": [
"lr_find
[source]lr_find
(`learn`:[`Learner`](/basic_train.html#Learner), `start_lr`:`Floats`=`1e-07`, `end_lr`:`Floats`=`10`, `num_it`:`int`=`100`, `stop_div`:`bool`=`True`, `kwargs`:`Any`)\n",
"\n",
"Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss explodes. "
],
"text/plain": [
"get_preds
[source]get_preds
(`ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=`validate
[source]validate
(`dl`=`None`, `callbacks`=`None`, `metrics`=`None`)\n",
"\n",
"Validate on `dl` with potential `callbacks` and `metrics`. "
],
"text/plain": [
"show_results
[source]show_results
(`ds_type`=`predict
[source]predict
(`img`:[`ItemBase`](/core.html#ItemBase), `kwargs`)\n",
"\n",
"Return prect class, label and probabilities for `img`. "
],
"text/plain": [
"pred_batch
[source]pred_batch
(`ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=`TTA
[source]TTA
(`learn`:[`Learner`](/basic_train.html#Learner), `beta`:`float`=`0.4`, `scale`:`float`=`1.35`, `ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=`clip_grad
[source]clip_grad
(`learn`:[`Learner`](/basic_train.html#Learner), `clip`:`float`=`0.1`) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"Gradient clipping during training. "
],
"text/plain": [
"to_fp16
[source]to_fp16
(`learn`:[`Learner`](/basic_train.html#Learner), `loss_scale`:`float`=`512.0`, `flat_master`:`bool`=`False`) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"Transform `learn` in FP16 precision. "
],
"text/plain": [
"lr_range
[source]lr_range
(`lr`:`Union`\\[`float`, `slice`\\]) → `ndarray`\n",
"\n",
"Build differential learning rates. "
],
"text/plain": [
"unfreeze
[source]unfreeze
()\n",
"\n",
"Unfreeze entire model. "
],
"text/plain": [
"freeze
[source]freeze
()\n",
"\n",
"Freeze up to last layer. "
],
"text/plain": [
"freeze_to
[source]freeze_to
(`n`:`int`)\n",
"\n",
"Freeze layers up to layer `n`. "
],
"text/plain": [
"split
[source]split
(`split_on`:`SplitFuncOrIdxList`)\n",
"\n",
"Split the model at `split_on`. "
],
"text/plain": [
"load
[source]load
(`name`:`PathOrStr`, `device`:[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)=`None`, `strict`:`bool`=`True`)\n",
"\n",
"Load model `name` from `self.model_dir` using `device`, defaulting to `self.data.device`. "
],
"text/plain": [
"save
[source]save
(`name`:`PathOrStr`, `return_path`:`bool`=`False`) → `Union`\\[`NoneType`, `str`\\]\n",
"\n",
"Save model with `name` to `self.model_dir`, and return path if `return_path`. "
],
"text/plain": [
"unet_learner
[source]unet_learner
(`data`:[`DataBunch`](/basic_data.html#DataBunch), `arch`:`Callable`, `pretrained`:`bool`=`True`, `all_wn`:`bool`=`False`, `blur_final`:`bool`=`True`, `split_on`:`Union`\\[`Callable`, `Collection`\\[`ModuleList`\\], `NoneType`\\]=`None`, `blur`:`bool`=`False`, `kwargs`:`Any`)"
],
"text/plain": [
"init
[source]init
(`init`)"
],
"text/plain": [
"mixup
[source]mixup
(`learn`:[`Learner`](/basic_train.html#Learner), `alpha`:`float`=`0.4`, `stack_x`:`bool`=`False`, `stack_y`:`bool`=`True`) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"Add mixup https://arxiv.org/abs/1710.09412 to `learn`. "
],
"text/plain": [
"backward
[source]backward
(`item`)\n",
"\n",
"Pass `item` through the model and computes the gradient. Useful if `backward_hooks` are attached. "
],
"text/plain": [
"create_opt
[source]create_opt
(`lr`:`Floats`, `wd`:`Floats`=`0.0`)\n",
"\n",
"Create optimizer with `lr` learning rate and `wd` weight decay. "
],
"text/plain": [
"dl
[source]dl
(`ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=`class
Recorder
[source]Recorder
(`learn`:[`Learner`](/basic_train.html#Learner)) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"A [`LearnerCallback`](/basic_train.html#LearnerCallback) that records epoch, loss, opt and metric data during training. "
],
"text/plain": [
"plot
[source]plot
(`skip_start`:`int`=`10`, `skip_end`:`int`=`5`)\n",
"\n",
"Plot learning rate and losses, trimmed between `skip_start` and `skip_end`. "
],
"text/plain": [
"