{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic training functionality" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.basic_train import *\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *\n", "from fastai.distributed import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`basic_train`](/basic_train.html#basic_train) wraps together the data (in a [`DataBunch`](/basic_data.html#DataBunch) object) with a PyTorch model to define a [`Learner`](/basic_train.html#Learner) object. Here the basic training loop is defined for the [`fit`](/basic_train.html#fit) method. The [`Learner`](/basic_train.html#Learner) object is the entry point of most of the [`Callback`](/callback.html#Callback) objects that will customize this training loop in different ways. Some of the most commonly used customizations are available through the [`train`](/train.html#train) module, notably:\n", "\n", " - [`Learner.lr_find`](/train.html#lr_find) will launch an LR range test that will help you select a good learning rate.\n", " - [`Learner.fit_one_cycle`](/train.html#fit_one_cycle) will launch a training using the 1cycle policy to help you train your model faster.\n", " - [`Learner.to_fp16`](/train.html#to_fp16) will convert your model to half precision and help you launch a training in mixed precision." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
class
Learner
[source][test]Learner
(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`opt_func`**:`Callable`=***`'Adam'`***, **`loss_func`**:`Callable`=***`None`***, **`metrics`**:`Collection`\\[`Callable`\\]=***`None`***, **`true_wd`**:`bool`=***`True`***, **`bn_wd`**:`bool`=***`True`***, **`wd`**:`Floats`=***`0.01`***, **`train_bn`**:`bool`=***`True`***, **`path`**:`str`=***`None`***, **`model_dir`**:`PathOrStr`=***`'models'`***, **`callback_fns`**:`Collection`\\[`Callable`\\]=***`None`***, **`callbacks`**:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=***`lr_find
[source][test]lr_find
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`start_lr`**:`Floats`=***`1e-07`***, **`end_lr`**:`Floats`=***`10`***, **`num_it`**:`int`=***`100`***, **`stop_div`**:`bool`=***`True`***, **`wd`**:`float`=***`None`***)\n",
"\n",
"fit
[source][test]fit
(**`epochs`**:`int`, **`lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`wd`**:`Floats`=***`None`***, **`callbacks`**:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=***`None`***)\n",
"\n",
"epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.135343 | \n", "0.083190 | \n", "0.972031 | \n", "00:05 | \n", "
fit_one_cycle
[source][test]fit_one_cycle
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cyc_len`**:`int`, **`max_lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`moms`**:`Point`=***`(0.95, 0.85)`***, **`div_factor`**:`float`=***`25.0`***, **`pct_start`**:`float`=***`0.3`***, **`final_div`**:`float`=***`None`***, **`wd`**:`float`=***`None`***, **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`tot_epochs`**:`int`=***`None`***, **`start_epoch`**:`int`=***`None`***)\n",
"\n",
"Tests found for fit_one_cycle
:
pytest -sv tests/test_callback.py::test_callbacks_fit
[source]pytest -sv tests/test_train.py::test_fit_one_cycle
[source]Related tests:
pytest -sv tests/test_text_train.py::test_qrnn_works_with_no_split
[source]pytest -sv tests/test_tabular_train.py::test_empty_cont
[source]pytest -sv tests/test_text_train.py::test_qrnn_works_if_split_fn_provided
[source]To run tests please refer to this guide.
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.075838 | \n", "0.061869 | \n", "0.979882 | \n", "00:05 | \n", "
predict
[source][test]predict
(**`item`**:[`ItemBase`](/core.html#ItemBase), **\\*\\*`kwargs`**)\n",
"\n",
"get_preds
[source][test]get_preds
(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`validate
[source][test]validate
(**`dl`**=***`None`***, **`callbacks`**=***`None`***, **`metrics`**=***`None`***)\n",
"\n",
"show_results
[source][test]show_results
(**`ds_type`**=***`No tests found for show_results
. To contribute a test please refer to this guide and this discussion.
pred_batch
[source][test]pred_batch
(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for pred_batch
. To contribute a test please refer to this guide and this discussion.
interpret
[source][test]interpret
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`model_summary
[source][test]model_summary
(**`m`**:[`Learner`](/basic_train.html#Learner), **`n`**:`int`=***`70`***)\n",
"\n",
"Tests found for model_summary
:
pytest -sv tests/test_callbacks_hooks.py::test_model_summary_vision
[source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_text
[source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_tabular
[source]pytest -sv tests/test_callbacks_hooks.py::test_model_summary_collab
[source]To run tests please refer to this guide.
TTA
[source][test]TTA
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`beta`**:`float`=***`0.4`***, **`scale`**:`float`=***`1.35`***, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for _TTA
. To contribute a test please refer to this guide and this discussion.
clip_grad
[source][test]clip_grad
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`clip`**:`float`=***`0.1`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for clip_grad
. To contribute a test please refer to this guide and this discussion.
to_fp16
[source][test]to_fp16
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`loss_scale`**:`float`=***`None`***, **`max_noskip`**:`int`=***`1000`***, **`dynamic`**:`bool`=***`False`***, **`clip`**:`float`=***`None`***, **`flat_master`**:`bool`=***`False`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for to_fp16
. To contribute a test please refer to this guide and this discussion.
to_fp32
[source][test]to_fp32
(**`learn`**:[`Learner`](/basic_train.html#Learner))\n",
"\n",
"No tests found for to_fp32
. To contribute a test please refer to this guide and this discussion.
distributed
[source][test]distributed
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cuda_id`**:`int`, **`cache_dir`**:`PathOrStr`=***`'tmp'`***)\n",
"\n",
"No tests found for _learner_distributed
. To contribute a test please refer to this guide and this discussion.
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.059613 | \n", "0.054604 | \n", "0.981845 | \n", "00:05 | \n", "
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.026379 | \n", "0.008763 | \n", "0.998037 | \n", "00:07 | \n", "
lr_range
[source][test]lr_range
(**`lr`**:`Union`\\[`float`, `slice`\\]) → `ndarray`\n",
"\n",
"No tests found for lr_range
. To contribute a test please refer to this guide and this discussion.
unfreeze
[source][test]unfreeze
()\n",
"\n",
"freeze
[source][test]freeze
()\n",
"\n",
"freeze_to
[source][test]freeze_to
(**`n`**:`int`)\n",
"\n",
"split
[source][test]split
(**`split_on`**:`SplitFuncOrIdxList`)\n",
"\n",
"No tests found for split
. To contribute a test please refer to this guide and this discussion.
save
[source][test]save
(**`name`**:`PathOrStr`, **`return_path`**:`bool`=***`False`***, **`with_opt`**:`bool`=***`True`***)\n",
"\n",
"load
[source][test]load
(**`name`**:`PathOrStr`, **`device`**:[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)=***`None`***, **`strict`**:`bool`=***`True`***, **`with_opt`**:`bool`=***`None`***, **`purge`**:`bool`=***`True`***, **`remove_module`**:`bool`=***`False`***)\n",
"\n",
"Tests found for load
:
pytest -sv tests/test_basic_train.py::test_save_load
[source]pytest -sv tests/test_basic_train.py::test_memory
[source]pytest -sv tests/test_vision_train.py::test_model_save_load
[source]Direct tests:
pytest -sv tests/test_basic_train.py::test_export_load_learner
[source]To run tests please refer to this guide.
export
[source][test]export
(**`fname`**:`PathOrStr`=***`'export.pkl'`***, **`destroy`**=***`False`***)\n",
"\n",
"load_learner
[source][test]load_learner
(**`path`**:`PathOrStr`, **`fname`**:`PathOrStr`=***`'export.pkl'`***, **`test`**:[`ItemList`](/data_block.html#ItemList)=***`None`***, **\\*\\*`db_kwargs`**)\n",
"\n",
"purge
[source][test]purge
(**`clear_opt`**:`bool`=***`True`***)\n",
"\n",
"destroy
[source][test]destroy
()\n",
"\n",
"init
[source][test]init
(**`init`**)\n",
"\n",
"No tests found for init
. To contribute a test please refer to this guide and this discussion.
mixup
[source][test]mixup
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`alpha`**:`float`=***`0.4`***, **`stack_x`**:`bool`=***`False`***, **`stack_y`**:`bool`=***`True`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for mixup
. To contribute a test please refer to this guide and this discussion.
backward
[source][test]backward
(**`item`**)\n",
"\n",
"No tests found for backward
. To contribute a test please refer to this guide and this discussion.
create_opt
[source][test]create_opt
(**`lr`**:`Floats`, **`wd`**:`Floats`=***`0.0`***)\n",
"\n",
"No tests found for create_opt
. To contribute a test please refer to this guide and this discussion.
dl
[source][test]dl
(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for dl
. To contribute a test please refer to this guide and this discussion.
class
Recorder
[source][test]Recorder
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`add_time`**:`bool`=***`True`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"plot
[source][test]plot
(**`skip_start`**:`int`=***`10`***, **`skip_end`**:`int`=***`5`***, **`suggestion`**:`bool`=***`False`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot
. To contribute a test please refer to this guide and this discussion.
plot_losses
[source][test]plot_losses
(**`skip_start`**:`int`=***`0`***, **`skip_end`**:`int`=***`0`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot_losses
. To contribute a test please refer to this guide and this discussion.
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.228524 | \n", "0.122285 | \n", "0.958783 | \n", "00:05 | \n", "
2 | \n", "0.118838 | \n", "0.075222 | \n", "0.971050 | \n", "00:05 | \n", "
3 | \n", "0.066715 | \n", "0.054920 | \n", "0.981354 | \n", "00:05 | \n", "
4 | \n", "0.048155 | \n", "0.048612 | \n", "0.983317 | \n", "00:05 | \n", "
5 | \n", "0.037535 | \n", "0.046014 | \n", "0.982336 | \n", "00:05 | \n", "