{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional training functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`train`](/train.html#train) provides a number of extension methods that are added to [`Learner`](/basic_train.html#Learner) (see below for a list and details), along with three simple callbacks:\n", "\n", "- [`ShowGraph`](/train.html#ShowGraph)\n", "- [`GradientClipping`](/train.html#GradientClipping)\n", "- [`BnFreeze`](/train.html#BnFreeze)\n", "- [`AccumulateScheduler`](/train.html#AccumulateScheduler)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.train import *\n", "from fastai.vision import *\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`Learner`](/basic_train.html#Learner) extension methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These methods are automatically added to all [`Learner`](/basic_train.html#Learner) objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
fit_one_cycle
[source][test]fit_one_cycle
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cyc_len`**:`int`, **`max_lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`moms`**:`Point`=***`(0.95, 0.85)`***, **`div_factor`**:`float`=***`25.0`***, **`pct_start`**:`float`=***`0.3`***, **`final_div`**:`float`=***`None`***, **`wd`**:`float`=***`None`***, **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`tot_epochs`**:`int`=***`None`***, **`start_epoch`**:`int`=***`None`***)\n",
"\n",
"Tests found for fit_one_cycle
:
pytest -sv tests/test_train.py::test_fit_one_cycle
[source]Some other tests where fit_one_cycle
is used:
pytest -sv tests/test_tabular_train.py::test_empty_cont
[source]pytest -sv tests/test_text_train.py::test_qrnn_works_if_split_fn_provided
[source]pytest -sv tests/test_text_train.py::test_qrnn_works_with_no_split
[source]To run tests please refer to this guide.
one_cycle_scheduler
[source][test]one_cycle_scheduler
(**`lr_max`**:`float`, **\\*\\*`kwargs`**:`Any`) → [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)\n",
"\n",
"No tests found for one_cycle_scheduler
. To contribute a test please refer to this guide and this discussion.
lr_find
[source][test]lr_find
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`start_lr`**:`Floats`=***`1e-07`***, **`end_lr`**:`Floats`=***`10`***, **`num_it`**:`int`=***`100`***, **`stop_div`**:`bool`=***`True`***, **`wd`**:`float`=***`None`***)\n",
"\n",
"to_fp16
[source][test]to_fp16
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`loss_scale`**:`float`=***`None`***, **`max_noskip`**:`int`=***`1000`***, **`dynamic`**:`bool`=***`True`***, **`clip`**:`float`=***`None`***, **`flat_master`**:`bool`=***`False`***, **`max_scale`**:`float`=***`16777216`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for to_fp16
. To contribute a test please refer to this guide and this discussion.
to_fp32
[source][test]to_fp32
(**`learn`**:[`Learner`](/basic_train.html#Learner))\n",
"\n",
"No tests found for to_fp32
. To contribute a test please refer to this guide and this discussion.
mixup
[source][test]mixup
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`alpha`**:`float`=***`0.4`***, **`stack_x`**:`bool`=***`False`***, **`stack_y`**:`bool`=***`True`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for mixup
. To contribute a test please refer to this guide and this discussion.
class
Interpretation
[source][test]Interpretation
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for Interpretation
. To contribute a test please refer to this guide and this discussion.
from_learner
[source][test]from_learner
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`Tests found for from_learner
:
Some other tests where from_learner
is used:
pytest -sv tests/test_tabular_train.py::test_confusion_tabular
[source]pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation
[source]pytest -sv tests/test_vision_train.py::test_interp
[source]To run tests please refer to this guide.
top_losses
[source][test]top_losses
(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n",
"\n",
"class
ClassificationInterpretation
[source][test]ClassificationInterpretation
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`Tests found for ClassificationInterpretation
:
pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation
[source]Some other tests where ClassificationInterpretation
is used:
pytest -sv tests/test_tabular_train.py::test_confusion_tabular
[source]pytest -sv tests/test_vision_train.py::test_interp
[source]To run tests please refer to this guide.
top_losses
[source][test]top_losses
(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n",
"\n",
"plot_confusion_matrix
[source][test]plot_confusion_matrix
(**`normalize`**:`bool`=***`False`***, **`title`**:`str`=***`'Confusion matrix'`***, **`cmap`**:`Any`=***`'Blues'`***, **`slice_size`**:`int`=***`1`***, **`norm_dec`**:`int`=***`2`***, **`plot_txt`**:`bool`=***`True`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot_confusion_matrix
. To contribute a test please refer to this guide and this discussion.
confusion_matrix
[source][test]confusion_matrix
(**`slice_size`**:`int`=***`1`***)\n",
"\n",
"most_confused
[source][test]most_confused
(**`min_val`**:`int`=***`1`***, **`slice_size`**:`int`=***`1`***) → `Collection`\\[`Tuple`\\[`str`, `str`, `int`\\]\\]\n",
"\n",
"class
MultiLabelClassificationInterpretation
[source][test]MultiLabelClassificationInterpretation
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for MultiLabelClassificationInterpretation
. To contribute a test please refer to this guide and this discussion.
class
ShowGraph
[source][test]ShowGraph
(**`learn`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for ShowGraph
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**`n_epochs`**:`int`, **`last_metrics`**:`MetricsList`, **\\*\\*`kwargs`**) → `bool`\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
class
GradientClipping
[source][test]GradientClipping
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`clip`**:`float`=***`0.0`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for GradientClipping
. To contribute a test please refer to this guide and this discussion.
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
0 | \n", "0.140471 | \n", "0.079571 | \n", "0.971541 | \n", "00:08 | \n", "
on_backward_end
[source][test]on_backward_end
(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_backward_end
. To contribute a test please refer to this guide and this discussion.
class
BnFreeze
[source][test]BnFreeze
(**`learn`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for BnFreeze
. To contribute a test please refer to this guide and this discussion.
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
0 | \n", "0.147409 | \n", "0.081370 | \n", "0.972031 | \n", "00:05 | \n", "
on_epoch_begin
[source][test]on_epoch_begin
(**\\*\\*`kwargs`**:`Any`)\n",
"\n",
"No tests found for on_epoch_begin
. To contribute a test please refer to this guide and this discussion.
class
AccumulateScheduler
[source][test]AccumulateScheduler
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`n_step`**:`int`=***`1`***, **`drop_last`**:`bool`=***`False`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for AccumulateScheduler
. To contribute a test please refer to this guide and this discussion.
_cl_int_plot_top_losses
[source][test]_cl_int_plot_top_losses
(**`k`**, **`largest`**=***`True`***, **`figsize`**=***`(12, 12)`***, **`heatmap`**:`bool`=***`None`***, **`heatmap_thresh`**:`int`=***`16`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for _cl_int_plot_top_losses
. To contribute a test please refer to this guide and this discussion.
_cl_int_from_learner
[source][test]_cl_int_from_learner
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`top_losses
[source][test]top_losses
(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n",
"\n",
"confusion_matrix
[source][test]confusion_matrix
(**`slice_size`**:`int`=***`1`***)\n",
"\n",
"most_confused
[source][test]most_confused
(**`min_val`**:`int`=***`1`***, **`slice_size`**:`int`=***`1`***) → `Collection`\\[`Tuple`\\[`str`, `str`, `int`\\]\\]\n",
"\n",
"plot_confusion_matrix
[source][test]plot_confusion_matrix
(**`normalize`**:`bool`=***`False`***, **`title`**:`str`=***`'Confusion matrix'`***, **`cmap`**:`Any`=***`'Blues'`***, **`slice_size`**:`int`=***`1`***, **`norm_dec`**:`int`=***`2`***, **`plot_txt`**:`bool`=***`True`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot_confusion_matrix
. To contribute a test please refer to this guide and this discussion.
_cl_int_plot_multi_top_losses
[source][test]_cl_int_plot_multi_top_losses
(**`samples`**:`int`=***`3`***, **`figsize`**:`Tuple`\\[`int`, `int`\\]=***`(8, 8)`***, **`save_misclassified`**:`bool`=***`False`***)\n",
"\n",
"No tests found for _cl_int_plot_multi_top_losses
. To contribute a test please refer to this guide and this discussion.