{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic training functionality" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.basic_train import *\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *\n", "from fastai.distributed import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`basic_train`](/basic_train.html#basic_train) wraps together the data (in a [`DataBunch`](/basic_data.html#DataBunch) object) with a PyTorch model to define a [`Learner`](/basic_train.html#Learner) object. Here the basic training loop is defined for the [`fit`](/basic_train.html#fit) method. The [`Learner`](/basic_train.html#Learner) object is the entry point of most of the [`Callback`](/callback.html#Callback) objects that will customize this training loop in different ways. Some of the most commonly used customizations are available through the [`train`](/train.html#train) module, notably:\n", "\n", " - [`Learner.lr_find`](/train.html#lr_find) will launch an LR range test that will help you select a good learning rate.\n", " - [`Learner.fit_one_cycle`](/train.html#fit_one_cycle) will launch a training using the 1cycle policy to help you train your model faster.\n", " - [`Learner.to_fp16`](/train.html#to_fp16) will convert your model to half precision and help you launch a training in mixed precision." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
class
Learner
[source][test]Learner
(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`opt_func`**:`Callable`=***`'Adam'`***, **`loss_func`**:`Callable`=***`None`***, **`metrics`**:`Collection`\\[`Callable`\\]=***`None`***, **`true_wd`**:`bool`=***`True`***, **`bn_wd`**:`bool`=***`True`***, **`wd`**:`Floats`=***`0.01`***, **`train_bn`**:`bool`=***`True`***, **`path`**:`str`=***`None`***, **`model_dir`**:`PathOrStr`=***`'models'`***, **`callback_fns`**:`Collection`\\[`Callable`\\]=***`None`***, **`callbacks`**:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=***`lr_find
[source][test]lr_find
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`start_lr`**:`Floats`=***`1e-07`***, **`end_lr`**:`Floats`=***`10`***, **`num_it`**:`int`=***`100`***, **`stop_div`**:`bool`=***`True`***, **`wd`**:`float`=***`None`***)\n",
"\n",
"fit
[source][test]fit
(**`epochs`**:`int`, **`lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`wd`**:`Floats`=***`None`***, **`callbacks`**:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=***`None`***)\n",
"\n",
"epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.135343 | \n", "0.083190 | \n", "0.972031 | \n", "00:05 | \n", "
fit_one_cycle
[source][test]fit_one_cycle
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cyc_len`**:`int`, **`max_lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`moms`**:`Point`=***`(0.95, 0.85)`***, **`div_factor`**:`float`=***`25.0`***, **`pct_start`**:`float`=***`0.3`***, **`final_div`**:`float`=***`None`***, **`wd`**:`float`=***`None`***, **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`tot_epochs`**:`int`=***`None`***, **`start_epoch`**:`int`=***`None`***)\n",
"\n",
"Tests found for fit_one_cycle
:
pytest -sv tests/test_train.py::test_fit_one_cycle
[source]Some other tests where fit_one_cycle
is used:
pytest -sv tests/test_tabular_train.py::test_empty_cont
[source]pytest -sv tests/test_text_train.py::test_qrnn_works_if_split_fn_provided
[source]pytest -sv tests/test_text_train.py::test_qrnn_works_with_no_split
[source]To run tests please refer to this guide.
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.075838 | \n", "0.061869 | \n", "0.979882 | \n", "00:05 | \n", "
predict
[source][test]predict
(**`item`**:[`ItemBase`](/core.html#ItemBase), **`return_x`**:`bool`=***`False`***, **`batch_first`**:`bool`=***`True`***, **`with_dropout`**:`bool`=***`False`***, **\\*\\*`kwargs`**)\n",
"\n",
"get_preds
[source][test]get_preds
(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`validate
[source][test]validate
(**`dl`**=***`None`***, **`callbacks`**=***`None`***, **`metrics`**=***`None`***)\n",
"\n",
"show_results
[source][test]show_results
(**`ds_type`**=***`No tests found for show_results
. To contribute a test please refer to this guide and this discussion.
pred_batch
[source][test]pred_batch
(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for pred_batch
. To contribute a test please refer to this guide and this discussion.
interpret
[source][test]interpret
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`model_summary
[source][test]model_summary
(**`m`**:[`Learner`](/basic_train.html#Learner), **`n`**:`int`=***`70`***)\n",
"\n",
"Tests found for model_summary
:
pytest -sv tests/test_basic_train.py::test_export_load_learner
[source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_collab
[source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_tabular
[source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_text
[source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_vision
[source]To run tests please refer to this guide.
TTA
[source][test]TTA
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`beta`**:`float`=***`0.4`***, **`scale`**:`float`=***`1.35`***, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for _TTA
. To contribute a test please refer to this guide and this discussion.
clip_grad
[source][test]clip_grad
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`clip`**:`float`=***`0.1`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for clip_grad
. To contribute a test please refer to this guide and this discussion.
to_fp16
[source][test]to_fp16
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`loss_scale`**:`float`=***`None`***, **`max_noskip`**:`int`=***`1000`***, **`dynamic`**:`bool`=***`True`***, **`clip`**:`float`=***`None`***, **`flat_master`**:`bool`=***`False`***, **`max_scale`**:`float`=***`16777216`***, **`loss_fp32`**:`bool`=***`True`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for to_fp16
. To contribute a test please refer to this guide and this discussion.
to_fp32
[source][test]to_fp32
(**`learn`**:[`Learner`](/basic_train.html#Learner))\n",
"\n",
"No tests found for to_fp32
. To contribute a test please refer to this guide and this discussion.
to_distributed
[source][test]to_distributed
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cuda_id`**:`int`, **`cache_dir`**:`PathOrStr`=***`'tmp'`***)\n",
"\n",
"No tests found for _learner_distributed
. To contribute a test please refer to this guide and this discussion.
to_parallel
[source][test]to_parallel
(**`learn`**:[`Learner`](/basic_train.html#Learner))\n",
"\n",
"No tests found for _learner_parallel
. To contribute a test please refer to this guide and this discussion.
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.059613 | \n", "0.054604 | \n", "0.981845 | \n", "00:05 | \n", "
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.026379 | \n", "0.008763 | \n", "0.998037 | \n", "00:07 | \n", "
lr_range
[source][test]lr_range
(**`lr`**:`Union`\\[`float`, `slice`\\]) → `ndarray`\n",
"\n",
"No tests found for lr_range
. To contribute a test please refer to this guide and this discussion.
unfreeze
[source][test]unfreeze
()\n",
"\n",
"freeze
[source][test]freeze
()\n",
"\n",
"freeze_to
[source][test]freeze_to
(**`n`**:`int`)\n",
"\n",
"split
[source][test]split
(**`split_on`**:`SplitFuncOrIdxList`)\n",
"\n",
"No tests found for split
. To contribute a test please refer to this guide and this discussion.
save
[source][test]save
(**`file`**:`PathLikeOrBinaryStream`=***`None`***, **`return_path`**:`bool`=***`False`***, **`with_opt`**:`bool`=***`True`***)\n",
"\n",
"load
[source][test]load
(**`file`**:`PathLikeOrBinaryStream`=***`None`***, **`device`**:[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)=***`None`***, **`strict`**:`bool`=***`True`***, **`with_opt`**:`bool`=***`None`***, **`purge`**:`bool`=***`False`***, **`remove_module`**:`bool`=***`False`***) → `Learner`\n",
"\n",
"export
[source][test]export
(**`file`**:`PathLikeOrBinaryStream`=***`'export.pkl'`***, **`destroy`**=***`False`***)\n",
"\n",
"load_learner
[source][test]load_learner
(**`path`**:`PathOrStr`, **`file`**:`PathLikeOrBinaryStream`=***`'export.pkl'`***, **`test`**:[`ItemList`](/data_block.html#ItemList)=***`None`***, **`tfm_y`**=***`None`***, **\\*\\*`db_kwargs`**)\n",
"\n",
"purge
[source][test]purge
(**`clear_opt`**:`bool`=***`True`***)\n",
"\n",
"destroy
[source][test]destroy
()\n",
"\n",
"init
[source][test]init
(**`init`**)\n",
"\n",
"No tests found for init
. To contribute a test please refer to this guide and this discussion.
mixup
[source][test]mixup
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`alpha`**:`float`=***`0.4`***, **`stack_x`**:`bool`=***`False`***, **`stack_y`**:`bool`=***`True`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for mixup
. To contribute a test please refer to this guide and this discussion.
backward
[source][test]backward
(**`item`**)\n",
"\n",
"No tests found for backward
. To contribute a test please refer to this guide and this discussion.
create_opt
[source][test]create_opt
(**`lr`**:`Floats`, **`wd`**:`Floats`=***`0.0`***)\n",
"\n",
"No tests found for create_opt
. To contribute a test please refer to this guide and this discussion.
dl
[source][test]dl
(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for dl
. To contribute a test please refer to this guide and this discussion.
class
Recorder
[source][test]Recorder
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`add_time`**:`bool`=***`True`***, **`silent`**:`bool`=***`False`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"plot
[source][test]plot
(**`skip_start`**:`int`=***`10`***, **`skip_end`**:`int`=***`5`***, **`suggestion`**:`bool`=***`False`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot
. To contribute a test please refer to this guide and this discussion.
plot_losses
[source][test]plot_losses
(**`skip_start`**:`int`=***`0`***, **`skip_end`**:`int`=***`0`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot_losses
. To contribute a test please refer to this guide and this discussion.
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.228524 | \n", "0.122285 | \n", "0.958783 | \n", "00:05 | \n", "
2 | \n", "0.118838 | \n", "0.075222 | \n", "0.971050 | \n", "00:05 | \n", "
3 | \n", "0.066715 | \n", "0.054920 | \n", "0.981354 | \n", "00:05 | \n", "
4 | \n", "0.048155 | \n", "0.048612 | \n", "0.983317 | \n", "00:05 | \n", "
5 | \n", "0.037535 | \n", "0.046014 | \n", "0.982336 | \n", "00:05 | \n", "
plot_lr
[source][test]plot_lr
(**`show_moms`**=***`False`***, **`skip_start`**:`int`=***`0`***, **`skip_end`**:`int`=***`0`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot_lr
. To contribute a test please refer to this guide and this discussion.
plot_metrics
[source][test]plot_metrics
(**`skip_start`**:`int`=***`0`***, **`skip_end`**:`int`=***`0`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot_metrics
. To contribute a test please refer to this guide and this discussion.
on_backward_begin
[source][test]on_backward_begin
(**`smooth_loss`**:`Tensor`, **\\*\\*`kwargs`**:`Any`)\n",
"\n",
"No tests found for on_backward_begin
. To contribute a test please refer to this guide and this discussion.
on_batch_begin
[source][test]on_batch_begin
(**`train`**, **\\*\\*`kwargs`**:`Any`)\n",
"\n",
"No tests found for on_batch_begin
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**`epoch`**:`int`, **`num_batch`**:`int`, **`smooth_loss`**:`Tensor`, **`last_metrics`**:`MetricsList`, **\\*\\*`kwargs`**:`Any`) → `bool`\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
on_train_begin
[source][test]on_train_begin
(**`pbar`**:`PBar`, **`metrics_names`**:`StrList`, **\\*\\*`kwargs`**:`Any`)\n",
"\n",
"No tests found for on_train_begin
. To contribute a test please refer to this guide and this discussion.
add_metric_names
[source][test]add_metric_names
(**`names`**)\n",
"\n",
"No tests found for add_metric_names
. To contribute a test please refer to this guide and this discussion.
format_stats
[source][test]format_stats
(**`stats`**:`MetricsList`)\n",
"\n",
"No tests found for format_stats
. To contribute a test please refer to this guide and this discussion.
fit
[source][test]fit
(**`epochs`**:`int`, **`learn`**:[`BasicLearner`](/basic_train.html#BasicLearner), **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`metrics`**:`OptMetrics`=***`None`***)\n",
"\n",
"train_epoch
[source][test]train_epoch
(**`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`dl`**:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), **`opt`**:[`Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), **`loss_func`**:`LossFunction`)\n",
"\n",
"No tests found for train_epoch
. To contribute a test please refer to this guide and this discussion.
validate
[source][test]validate
(**`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`dl`**:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), **`loss_func`**:`OptLossFunc`=***`None`***, **`cb_handler`**:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=***`None`***, **`pbar`**:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=***`None`***, **`average`**=***`True`***, **`n_batch`**:`Optional`\\[`int`\\]=***`None`***) → `Iterator`\\[`Tuple`\\[`IntOrTensor`, `Ellipsis`\\]\\]\n",
"\n",
"get_preds
[source][test]get_preds
(**`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`dl`**:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), **`pbar`**:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=***`None`***, **`cb_handler`**:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=***`None`***, **`activ`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=***`None`***, **`loss_func`**:`OptLossFunc`=***`None`***, **`n_batch`**:`Optional`\\[`int`\\]=***`None`***) → `List`\\[`Tensor`\\]\n",
"\n",
"loss_batch
[source][test]loss_batch
(**`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`xb`**:`Tensor`, **`yb`**:`Tensor`, **`loss_func`**:`OptLossFunc`=***`None`***, **`opt`**:`OptOptimizer`=***`None`***, **`cb_handler`**:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=***`None`***) → `Tuple`\\[`Union`\\[`Tensor`, `int`, `float`, `str`\\]\\]\n",
"\n",
"No tests found for loss_batch
. To contribute a test please refer to this guide and this discussion.
class
LearnerCallback
[source][test]LearnerCallback
(**`learn`**) :: [`Callback`](/callback.html#Callback)\n",
"\n",
"No tests found for LearnerCallback
. To contribute a test please refer to this guide and this discussion.
class
RecordOnCPU
[source][test]RecordOnCPU
() :: [`Callback`](/callback.html#Callback)\n",
"\n",
"No tests found for RecordOnCPU
. To contribute a test please refer to this guide and this discussion.
_tta_only
[source][test]_tta_only
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for _tta_only
. To contribute a test please refer to this guide and this discussion.
_TTA
[source][test]_TTA
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`beta`**:`float`=***`0.4`***, **`scale`**:`float`=***`1.35`***, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for _TTA
. To contribute a test please refer to this guide and this discussion.
on_batch_begin
[source][test]on_batch_begin
(**`last_input`**, **`last_target`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_batch_begin
. To contribute a test please refer to this guide and this discussion.