{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic training functionality" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.basic_train import *\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *\n", "from fastai.distributed import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`basic_train`](/basic_train.html#basic_train) wraps together the data (in a [`DataBunch`](/basic_data.html#DataBunch) object) with a PyTorch model to define a [`Learner`](/basic_train.html#Learner) object. Here the basic training loop is defined for the [`fit`](/basic_train.html#fit) method. The [`Learner`](/basic_train.html#Learner) object is the entry point of most of the [`Callback`](/callback.html#Callback) objects that will customize this training loop in different ways. Some of the most commonly used customizations are available through the [`train`](/train.html#train) module, notably:\n", "\n", " - [`Learner.lr_find`](/train.html#lr_find) will launch an LR range test that will help you select a good learning rate.\n", " - [`Learner.fit_one_cycle`](/train.html#fit_one_cycle) will launch a training using the 1cycle policy to help you train your model faster.\n", " - [`Learner.to_fp16`](/train.html#to_fp16) will convert your model to half precision and help you launch a training in mixed precision." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
class Learner[source][test]Learner(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`opt_func`**:`Callable`=***`'Adam'`***, **`loss_func`**:`Callable`=***`None`***, **`metrics`**:`Collection`\\[`Callable`\\]=***`None`***, **`true_wd`**:`bool`=***`True`***, **`bn_wd`**:`bool`=***`True`***, **`wd`**:`Floats`=***`0.01`***, **`train_bn`**:`bool`=***`True`***, **`path`**:`str`=***`None`***, **`model_dir`**:`PathOrStr`=***`'models'`***, **`callback_fns`**:`Collection`\\[`Callable`\\]=***`None`***, **`callbacks`**:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=***`lr_find[source][test]lr_find(**`learn`**:[`Learner`](/basic_train.html#Learner), **`start_lr`**:`Floats`=***`1e-07`***, **`end_lr`**:`Floats`=***`10`***, **`num_it`**:`int`=***`100`***, **`stop_div`**:`bool`=***`True`***, **`wd`**:`float`=***`None`***)\n",
"\n",
"fit[source][test]fit(**`epochs`**:`int`, **`lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`wd`**:`Floats`=***`None`***, **`callbacks`**:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=***`None`***)\n",
"\n",
"| epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
|---|---|---|---|---|
| 1 | \n", "0.135343 | \n", "0.083190 | \n", "0.972031 | \n", "00:05 | \n", "
fit_one_cycle[source][test]fit_one_cycle(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cyc_len`**:`int`, **`max_lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`moms`**:`Point`=***`(0.95, 0.85)`***, **`div_factor`**:`float`=***`25.0`***, **`pct_start`**:`float`=***`0.3`***, **`final_div`**:`float`=***`None`***, **`wd`**:`float`=***`None`***, **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`tot_epochs`**:`int`=***`None`***, **`start_epoch`**:`int`=***`None`***)\n",
"\n",
"Tests found for fit_one_cycle:
pytest -sv tests/test_train.py::test_fit_one_cycle [source]Some other tests where fit_one_cycle is used:
pytest -sv tests/test_tabular_train.py::test_empty_cont [source]pytest -sv tests/test_text_train.py::test_qrnn_works_if_split_fn_provided [source]pytest -sv tests/test_text_train.py::test_qrnn_works_with_no_split [source]To run tests please refer to this guide.
| epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
|---|---|---|---|---|
| 1 | \n", "0.075838 | \n", "0.061869 | \n", "0.979882 | \n", "00:05 | \n", "
predict[source][test]predict(**`item`**:[`ItemBase`](/core.html#ItemBase), **\\*\\*`kwargs`**)\n",
"\n",
"get_preds[source][test]get_preds(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`validate[source][test]validate(**`dl`**=***`None`***, **`callbacks`**=***`None`***, **`metrics`**=***`None`***)\n",
"\n",
"show_results[source][test]show_results(**`ds_type`**=***`No tests found for show_results. To contribute a test please refer to this guide and this discussion.
pred_batch[source][test]pred_batch(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for pred_batch. To contribute a test please refer to this guide and this discussion.
interpret[source][test]interpret(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`model_summary[source][test]model_summary(**`m`**:[`Learner`](/basic_train.html#Learner), **`n`**:`int`=***`70`***)\n",
"\n",
"Tests found for model_summary:
pytest -sv tests/test_basic_train.py::test_export_load_learner [source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_collab [source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_tabular [source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_text [source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_vision [source]To run tests please refer to this guide.
TTA[source][test]TTA(**`learn`**:[`Learner`](/basic_train.html#Learner), **`beta`**:`float`=***`0.4`***, **`scale`**:`float`=***`1.35`***, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for _TTA. To contribute a test please refer to this guide and this discussion.
clip_grad[source][test]clip_grad(**`learn`**:[`Learner`](/basic_train.html#Learner), **`clip`**:`float`=***`0.1`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for clip_grad. To contribute a test please refer to this guide and this discussion.
to_fp16[source][test]to_fp16(**`learn`**:[`Learner`](/basic_train.html#Learner), **`loss_scale`**:`float`=***`None`***, **`max_noskip`**:`int`=***`1000`***, **`dynamic`**:`bool`=***`False`***, **`clip`**:`float`=***`None`***, **`flat_master`**:`bool`=***`False`***, **`max_scale`**:`float`=***`16777216`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for to_fp16. To contribute a test please refer to this guide and this discussion.
to_fp32[source][test]to_fp32(**`learn`**:[`Learner`](/basic_train.html#Learner))\n",
"\n",
"No tests found for to_fp32. To contribute a test please refer to this guide and this discussion.
to_distributed[source][test]to_distributed(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cuda_id`**:`int`, **`cache_dir`**:`PathOrStr`=***`'tmp'`***)\n",
"\n",
"No tests found for _learner_distributed. To contribute a test please refer to this guide and this discussion.
to_parallel[source][test]to_parallel(**`learn`**:[`Learner`](/basic_train.html#Learner))\n",
"\n",
"No tests found for _learner_parallel. To contribute a test please refer to this guide and this discussion.
| epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
|---|---|---|---|---|
| 1 | \n", "0.059613 | \n", "0.054604 | \n", "0.981845 | \n", "00:05 | \n", "
| epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
|---|---|---|---|---|
| 1 | \n", "0.026379 | \n", "0.008763 | \n", "0.998037 | \n", "00:07 | \n", "
lr_range[source][test]lr_range(**`lr`**:`Union`\\[`float`, `slice`\\]) → `ndarray`\n",
"\n",
"No tests found for lr_range. To contribute a test please refer to this guide and this discussion.
unfreeze[source][test]unfreeze()\n",
"\n",
"freeze[source][test]freeze()\n",
"\n",
"freeze_to[source][test]freeze_to(**`n`**:`int`)\n",
"\n",
"split[source][test]split(**`split_on`**:`SplitFuncOrIdxList`)\n",
"\n",
"No tests found for split. To contribute a test please refer to this guide and this discussion.
save[source][test]save(**`name`**:`PathOrStr`, **`return_path`**:`bool`=***`False`***, **`with_opt`**:`bool`=***`True`***)\n",
"\n",
"load[source][test]load(**`name`**:`PathOrStr`, **`device`**:[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)=***`None`***, **`strict`**:`bool`=***`True`***, **`with_opt`**:`bool`=***`None`***, **`purge`**:`bool`=***`True`***, **`remove_module`**:`bool`=***`False`***)\n",
"\n",
"export[source][test]export(**`fname`**:`PathOrStr`=***`'export.pkl'`***, **`destroy`**=***`False`***)\n",
"\n",
"load_learner[source][test]load_learner(**`path`**:`PathOrStr`, **`fname`**:`PathOrStr`=***`'export.pkl'`***, **`test`**:[`ItemList`](/data_block.html#ItemList)=***`None`***, **\\*\\*`db_kwargs`**)\n",
"\n",
"purge[source][test]purge(**`clear_opt`**:`bool`=***`True`***)\n",
"\n",
"destroy[source][test]destroy()\n",
"\n",
"init[source][test]init(**`init`**)\n",
"\n",
"No tests found for init. To contribute a test please refer to this guide and this discussion.
mixup[source][test]mixup(**`learn`**:[`Learner`](/basic_train.html#Learner), **`alpha`**:`float`=***`0.4`***, **`stack_x`**:`bool`=***`False`***, **`stack_y`**:`bool`=***`True`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for mixup. To contribute a test please refer to this guide and this discussion.
backward[source][test]backward(**`item`**)\n",
"\n",
"No tests found for backward. To contribute a test please refer to this guide and this discussion.
create_opt[source][test]create_opt(**`lr`**:`Floats`, **`wd`**:`Floats`=***`0.0`***)\n",
"\n",
"No tests found for create_opt. To contribute a test please refer to this guide and this discussion.
dl[source][test]dl(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for dl. To contribute a test please refer to this guide and this discussion.
class Recorder[source][test]Recorder(**`learn`**:[`Learner`](/basic_train.html#Learner), **`add_time`**:`bool`=***`True`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"plot[source][test]plot(**`skip_start`**:`int`=***`10`***, **`skip_end`**:`int`=***`5`***, **`suggestion`**:`bool`=***`False`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot. To contribute a test please refer to this guide and this discussion.
plot_losses[source][test]plot_losses(**`skip_start`**:`int`=***`0`***, **`skip_end`**:`int`=***`0`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot_losses. To contribute a test please refer to this guide and this discussion.
| epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
|---|---|---|---|---|
| 1 | \n", "0.228524 | \n", "0.122285 | \n", "0.958783 | \n", "00:05 | \n", "
| 2 | \n", "0.118838 | \n", "0.075222 | \n", "0.971050 | \n", "00:05 | \n", "
| 3 | \n", "0.066715 | \n", "0.054920 | \n", "0.981354 | \n", "00:05 | \n", "
| 4 | \n", "0.048155 | \n", "0.048612 | \n", "0.983317 | \n", "00:05 | \n", "
| 5 | \n", "0.037535 | \n", "0.046014 | \n", "0.982336 | \n", "00:05 | \n", "
plot_lr[source][test]plot_lr(**`show_moms`**=***`False`***, **`skip_start`**:`int`=***`0`***, **`skip_end`**:`int`=***`0`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot_lr. To contribute a test please refer to this guide and this discussion.
plot_metrics[source][test]plot_metrics(**`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot_metrics. To contribute a test please refer to this guide and this discussion.
on_backward_begin[source][test]on_backward_begin(**`smooth_loss`**:`Tensor`, **\\*\\*`kwargs`**:`Any`)\n",
"\n",
"No tests found for on_backward_begin. To contribute a test please refer to this guide and this discussion.
on_batch_begin[source][test]on_batch_begin(**`train`**, **\\*\\*`kwargs`**:`Any`)\n",
"\n",
"No tests found for on_batch_begin. To contribute a test please refer to this guide and this discussion.
on_epoch_end[source][test]on_epoch_end(**`epoch`**:`int`, **`num_batch`**:`int`, **`smooth_loss`**:`Tensor`, **`last_metrics`**=***`typing.Collection[typing.Union[torch.Tensor, numbers.Number]]`***, **\\*\\*`kwargs`**:`Any`) → `bool`\n",
"\n",
"No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.
on_train_begin[source][test]on_train_begin(**`pbar`**:`PBar`, **`metrics_names`**:`StrList`, **\\*\\*`kwargs`**:`Any`)\n",
"\n",
"No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.
add_metric_names[source][test]add_metric_names(**`names`**)\n",
"\n",
"No tests found for add_metric_names. To contribute a test please refer to this guide and this discussion.
format_stats[source][test]format_stats(**`stats`**:`MetricsList`)\n",
"\n",
"No tests found for format_stats. To contribute a test please refer to this guide and this discussion.
fit[source][test]fit(**`epochs`**:`int`, **`learn`**:[`BasicLearner`](/basic_train.html#BasicLearner), **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`metrics`**:`OptMetrics`=***`None`***)\n",
"\n",
"train_epoch[source][test]train_epoch(**`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`dl`**:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), **`opt`**:[`Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), **`loss_func`**:`LossFunction`)\n",
"\n",
"No tests found for train_epoch. To contribute a test please refer to this guide and this discussion.
validate[source][test]validate(**`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`dl`**:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), **`loss_func`**:`OptLossFunc`=***`None`***, **`cb_handler`**:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=***`None`***, **`pbar`**:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=***`None`***, **`average`**=***`True`***, **`n_batch`**:`Optional`\\[`int`\\]=***`None`***) → `Iterator`\\[`Tuple`\\[`IntOrTensor`, `Ellipsis`\\]\\]\n",
"\n",
"get_preds[source][test]get_preds(**`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`dl`**:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), **`pbar`**:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=***`None`***, **`cb_handler`**:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=***`None`***, **`activ`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=***`None`***, **`loss_func`**:`OptLossFunc`=***`None`***, **`n_batch`**:`Optional`\\[`int`\\]=***`None`***) → `List`\\[`Tensor`\\]\n",
"\n",
"loss_batch[source][test]loss_batch(**`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`xb`**:`Tensor`, **`yb`**:`Tensor`, **`loss_func`**:`OptLossFunc`=***`None`***, **`opt`**:`OptOptimizer`=***`None`***, **`cb_handler`**:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=***`None`***) → `Tuple`\\[`Union`\\[`Tensor`, `int`, `float`, `str`\\]\\]\n",
"\n",
"No tests found for loss_batch. To contribute a test please refer to this guide and this discussion.
class LearnerCallback[source][test]LearnerCallback(**`learn`**) :: [`Callback`](/callback.html#Callback)\n",
"\n",
"No tests found for LearnerCallback. To contribute a test please refer to this guide and this discussion.
class RecordOnCPU[source][test]RecordOnCPU() :: [`Callback`](/callback.html#Callback)\n",
"\n",
"No tests found for RecordOnCPU. To contribute a test please refer to this guide and this discussion.
_tta_only[source][test]_tta_only(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for _tta_only. To contribute a test please refer to this guide and this discussion.
_TTA[source][test]_TTA(**`learn`**:[`Learner`](/basic_train.html#Learner), **`beta`**:`float`=***`0.4`***, **`scale`**:`float`=***`1.35`***, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for _TTA. To contribute a test please refer to this guide and this discussion.
on_batch_begin[source][test]on_batch_begin(**`last_input`**, **`last_target`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_batch_begin. To contribute a test please refer to this guide and this discussion.