{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional training functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`train`](/train.html#train) provides a number of extension methods that are added to [`Learner`](/basic_train.html#Learner) (see below for a list and details), along with three simple callbacks:\n", "\n", "- [`ShowGraph`](/train.html#ShowGraph)\n", "- [`GradientClipping`](/train.html#GradientClipping)\n", "- [`BnFreeze`](/train.html#BnFreeze)\n", "- [`AccumulateScheduler`](/train.html#AccumulateScheduler)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.train import *\n", "from fastai.vision import *\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`Learner`](/basic_train.html#Learner) extension methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These methods are automatically added to all [`Learner`](/basic_train.html#Learner) objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
fit_one_cycle[source][test]fit_one_cycle(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cyc_len`**:`int`, **`max_lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`moms`**:`Point`=***`(0.95, 0.85)`***, **`div_factor`**:`float`=***`25.0`***, **`pct_start`**:`float`=***`0.3`***, **`final_div`**:`float`=***`None`***, **`wd`**:`float`=***`None`***, **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`tot_epochs`**:`int`=***`None`***, **`start_epoch`**:`int`=***`None`***)\n",
"\n",
"Tests found for fit_one_cycle:
pytest -sv tests/test_train.py::test_fit_one_cycle [source]Some other tests where fit_one_cycle is used:
pytest -sv tests/test_tabular_train.py::test_empty_cont [source]pytest -sv tests/test_text_train.py::test_qrnn_works_if_split_fn_provided [source]pytest -sv tests/test_text_train.py::test_qrnn_works_with_no_split [source]To run tests please refer to this guide.
one_cycle_scheduler[source][test]one_cycle_scheduler(**`lr_max`**:`float`, **\\*\\*`kwargs`**:`Any`) → [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)\n",
"\n",
"No tests found for one_cycle_scheduler. To contribute a test please refer to this guide and this discussion.
lr_find[source][test]lr_find(**`learn`**:[`Learner`](/basic_train.html#Learner), **`start_lr`**:`Floats`=***`1e-07`***, **`end_lr`**:`Floats`=***`10`***, **`num_it`**:`int`=***`100`***, **`stop_div`**:`bool`=***`True`***, **`wd`**:`float`=***`None`***)\n",
"\n",
"to_fp16[source][test]to_fp16(**`learn`**:[`Learner`](/basic_train.html#Learner), **`loss_scale`**:`float`=***`None`***, **`max_noskip`**:`int`=***`1000`***, **`dynamic`**:`bool`=***`True`***, **`clip`**:`float`=***`None`***, **`flat_master`**:`bool`=***`False`***, **`max_scale`**:`float`=***`16777216`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for to_fp16. To contribute a test please refer to this guide and this discussion.
to_fp32[source][test]to_fp32(**`learn`**:[`Learner`](/basic_train.html#Learner))\n",
"\n",
"No tests found for to_fp32. To contribute a test please refer to this guide and this discussion.
mixup[source][test]mixup(**`learn`**:[`Learner`](/basic_train.html#Learner), **`alpha`**:`float`=***`0.4`***, **`stack_x`**:`bool`=***`False`***, **`stack_y`**:`bool`=***`True`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for mixup. To contribute a test please refer to this guide and this discussion.
class Interpretation[source][test]Interpretation(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for Interpretation. To contribute a test please refer to this guide and this discussion.
from_learner[source][test]from_learner(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`Tests found for from_learner:
Some other tests where from_learner is used:
pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]pytest -sv tests/test_vision_train.py::test_interp [source]To run tests please refer to this guide.
top_losses[source][test]top_losses(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n",
"\n",
"class ClassificationInterpretation[source][test]ClassificationInterpretation(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`Tests found for ClassificationInterpretation:
pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]Some other tests where ClassificationInterpretation is used:
pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]pytest -sv tests/test_vision_train.py::test_interp [source]To run tests please refer to this guide.
top_losses[source][test]top_losses(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n",
"\n",
"plot_confusion_matrix[source][test]plot_confusion_matrix(**`normalize`**:`bool`=***`False`***, **`title`**:`str`=***`'Confusion matrix'`***, **`cmap`**:`Any`=***`'Blues'`***, **`slice_size`**:`int`=***`1`***, **`norm_dec`**:`int`=***`2`***, **`plot_txt`**:`bool`=***`True`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot_confusion_matrix. To contribute a test please refer to this guide and this discussion.
confusion_matrix[source][test]confusion_matrix(**`slice_size`**:`int`=***`1`***)\n",
"\n",
"most_confused[source][test]most_confused(**`min_val`**:`int`=***`1`***, **`slice_size`**:`int`=***`1`***) → `Collection`\\[`Tuple`\\[`str`, `str`, `int`\\]\\]\n",
"\n",
"class MultiLabelClassificationInterpretation[source][test]MultiLabelClassificationInterpretation(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`No tests found for MultiLabelClassificationInterpretation. To contribute a test please refer to this guide and this discussion.
class ShowGraph[source][test]ShowGraph(**`learn`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for ShowGraph. To contribute a test please refer to this guide and this discussion.
on_epoch_end[source][test]on_epoch_end(**`n_epochs`**:`int`, **`last_metrics`**:`MetricsList`, **\\*\\*`kwargs`**) → `bool`\n",
"\n",
"No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.
class GradientClipping[source][test]GradientClipping(**`learn`**:[`Learner`](/basic_train.html#Learner), **`clip`**:`float`=***`0.0`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for GradientClipping. To contribute a test please refer to this guide and this discussion.
| epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
|---|---|---|---|---|
| 0 | \n", "0.140471 | \n", "0.079571 | \n", "0.971541 | \n", "00:08 | \n", "
on_backward_end[source][test]on_backward_end(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_backward_end. To contribute a test please refer to this guide and this discussion.
class BnFreeze[source][test]BnFreeze(**`learn`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for BnFreeze. To contribute a test please refer to this guide and this discussion.
| epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
|---|---|---|---|---|
| 0 | \n", "0.147409 | \n", "0.081370 | \n", "0.972031 | \n", "00:05 | \n", "
on_epoch_begin[source][test]on_epoch_begin(**\\*\\*`kwargs`**:`Any`)\n",
"\n",
"No tests found for on_epoch_begin. To contribute a test please refer to this guide and this discussion.
class AccumulateScheduler[source][test]AccumulateScheduler(**`learn`**:[`Learner`](/basic_train.html#Learner), **`n_step`**:`int`=***`1`***, **`drop_last`**:`bool`=***`False`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for AccumulateScheduler. To contribute a test please refer to this guide and this discussion.
_cl_int_plot_top_losses[source][test]_cl_int_plot_top_losses(**`k`**, **`largest`**=***`True`***, **`figsize`**=***`(12, 12)`***, **`heatmap`**:`bool`=***`False`***, **`heatmap_thresh`**:`int`=***`16`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for _cl_int_plot_top_losses. To contribute a test please refer to this guide and this discussion.
_cl_int_from_learner[source][test]_cl_int_from_learner(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`top_losses[source][test]top_losses(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n",
"\n",
"confusion_matrix[source][test]confusion_matrix(**`slice_size`**:`int`=***`1`***)\n",
"\n",
"most_confused[source][test]most_confused(**`min_val`**:`int`=***`1`***, **`slice_size`**:`int`=***`1`***) → `Collection`\\[`Tuple`\\[`str`, `str`, `int`\\]\\]\n",
"\n",
"plot_confusion_matrix[source][test]plot_confusion_matrix(**`normalize`**:`bool`=***`False`***, **`title`**:`str`=***`'Confusion matrix'`***, **`cmap`**:`Any`=***`'Blues'`***, **`slice_size`**:`int`=***`1`***, **`norm_dec`**:`int`=***`2`***, **`plot_txt`**:`bool`=***`True`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for plot_confusion_matrix. To contribute a test please refer to this guide and this discussion.
_cl_int_plot_multi_top_losses[source][test]_cl_int_plot_multi_top_losses(**`samples`**:`int`=***`3`***, **`figsize`**:`Tuple`\\[`int`, `int`\\]=***`(8, 8)`***, **`save_misclassified`**:`bool`=***`False`***)\n",
"\n",
"No tests found for _cl_int_plot_multi_top_losses. To contribute a test please refer to this guide and this discussion.