{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional training functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`train`](/train.html#train) provides a number of extension methods that are added to [`Learner`](/basic_train.html#Learner) (see below for a list and details), along with three simple callbacks:\n", "\n", "- [`ShowGraph`](/train.html#ShowGraph)\n", "- [`GradientClipping`](/train.html#GradientClipping)\n", "- [`BnFreeze`](/train.html#BnFreeze)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.train import *\n", "from fastai.vision import *\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`Learner`](/basic_train.html#Learner) extension methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These methods are automatically added to all [`Learner`](/basic_train.html#Learner) objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

fit_one_cycle[source][test]

\n", "\n", "> fit_one_cycle(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cyc_len`**:`int`, **`max_lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`moms`**:`Point`=***`(0.95, 0.85)`***, **`div_factor`**:`float`=***`25.0`***, **`pct_start`**:`float`=***`0.3`***, **`final_div`**:`float`=***`None`***, **`wd`**:`float`=***`None`***, **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`tot_epochs`**:`int`=***`None`***, **`start_epoch`**:`int`=***`None`***)\n", "\n", "
×

Tests found for fit_one_cycle:

Some other tests where fit_one_cycle is used:

To run tests please refer to this guide.

\n", "\n", "Fit a model following the 1cycle policy. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(fit_one_cycle)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

one_cycle_scheduler[source][test]

\n", "\n", "> one_cycle_scheduler(**`lr_max`**:`float`, **\\*\\*`kwargs`**:`Any`) → [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)\n", "\n", "
×

No tests found for one_cycle_scheduler. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Instantiate a [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler) with `lr_max`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(one_cycle_scheduler)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler) for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

lr_find[source][test]

\n", "\n", "> lr_find(**`learn`**:[`Learner`](/basic_train.html#Learner), **`start_lr`**:`Floats`=***`1e-07`***, **`end_lr`**:`Floats`=***`10`***, **`num_it`**:`int`=***`100`***, **`stop_div`**:`bool`=***`True`***, **`wd`**:`float`=***`None`***)\n", "\n", "
×

Tests found for lr_find:

  • pytest -sv tests/test_train.py::test_lr_find [source]
  • pytest -sv tests/test_vision_train.py::test_lrfind [source]

To run tests please refer to this guide.

\n", "\n", "Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss diverges. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(lr_find)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`LRFinder`](/callbacks.lr_finder.html#LRFinder) for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

to_fp16[source][test]

\n", "\n", "> to_fp16(**`learn`**:[`Learner`](/basic_train.html#Learner), **`loss_scale`**:`float`=***`None`***, **`max_noskip`**:`int`=***`1000`***, **`dynamic`**:`bool`=***`True`***, **`clip`**:`float`=***`None`***, **`flat_master`**:`bool`=***`False`***, **`max_scale`**:`float`=***`16777216`***) → [`Learner`](/basic_train.html#Learner)\n", "\n", "
×

No tests found for to_fp16. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Put `learn` in FP16 precision mode. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(to_fp16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`MixedPrecision`](/callbacks.fp16.html#MixedPrecision) for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

to_fp32[source][test]

\n", "\n", "> to_fp32(**`learn`**:[`Learner`](/basic_train.html#Learner))\n", "\n", "
×

No tests found for to_fp32. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Put `learn` back to FP32 precision mode. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(to_fp32)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

mixup[source][test]

\n", "\n", "> mixup(**`learn`**:[`Learner`](/basic_train.html#Learner), **`alpha`**:`float`=***`0.4`***, **`stack_x`**:`bool`=***`False`***, **`stack_y`**:`bool`=***`True`***) → [`Learner`](/basic_train.html#Learner)\n", "\n", "
×

No tests found for mixup. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Add mixup https://arxiv.org/abs/1710.09412 to `learn`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(mixup)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class ClassificationInterpretation[source][test]

\n", "\n", "> ClassificationInterpretation(**`learn`**:[`Learner`](/basic_train.html#Learner), **`probs`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***)\n", "\n", "
×

Tests found for ClassificationInterpretation:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

Some other tests where ClassificationInterpretation is used:

  • pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]
  • pytest -sv tests/test_vision_train.py::test_interp [source]

To run tests please refer to this guide.

\n", "\n", "Interpretation methods for classification models. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`MixUpCallback`](/callbacks.mixup.html#MixUpCallback) for more details." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Additional callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll show examples below using our MNIST sample. As usual the `on_something` methods are directly called by the fastai library, no need to call them yourself." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class ShowGraph[source][test]

\n", "\n", "> ShowGraph(**`learn`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for ShowGraph. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Update a graph of learner stats and metrics after each epoch. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ShowGraph, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=ShowGraph)\n", "learn.fit(3)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Training graph](imgs/train_graph.gif)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source][test]

\n", "\n", "> on_epoch_end(**`n_epochs`**:`int`, **`last_metrics`**:`MetricsList`, **\\*\\*`kwargs`**) → `bool`\n", "\n", "
×

No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "If we have `last_metrics` plot them in our pbar graph " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ShowGraph.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class GradientClipping[source][test]

\n", "\n", "> GradientClipping(**`learn`**:[`Learner`](/basic_train.html#Learner), **`clip`**:`float`=***`0.0`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for GradientClipping. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Gradient clipping during training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GradientClipping)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:11

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.1311330.0781900.973013
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = cnn_learner(data, models.resnet18, metrics=accuracy,\n", " callback_fns=partial(GradientClipping, clip=0.1))\n", "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_backward_end[source][test]

\n", "\n", "> on_backward_end(**\\*\\*`kwargs`**)\n", "\n", "
×

No tests found for on_backward_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Clip the gradient before the optimizer step. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GradientClipping.on_backward_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class BnFreeze[source][test]

\n", "\n", "> BnFreeze(**`learn`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for BnFreeze. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Freeze moving average statistics in all non-trainable batchnorm layers. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(BnFreeze)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For batchnorm layers where `requires_grad==False`, you generally don't want to update their moving average statistics, in order to avoid the model's statistics getting out of sync with its pre-trained weights. You can add this callback to automate this freezing of statistics (internally, it calls `eval` on these layers)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:07

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.1325640.0789100.972031
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = cnn_learner(data, models.resnet18, metrics=accuracy, callback_fns=BnFreeze)\n", "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_begin[source][test]

\n", "\n", "> on_epoch_begin(**\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_epoch_begin. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Put bn layers in eval mode just after `model.train()`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(BnFreeze.on_epoch_begin)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Methods - Please document or move to the undocumented section" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_cl_int_plot_top_losses[source][test]

\n", "\n", "> _cl_int_plot_top_losses(**`k`**, **`largest`**=***`True`***, **`figsize`**=***`(12, 12)`***, **`heatmap`**:`bool`=***`True`***, **`heatmap_thresh`**:`int`=***`16`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n", "\n", "
×

No tests found for _cl_int_plot_top_losses. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Show images in `top_losses` along with their prediction, actual, loss, and probability of actual class. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_cl_int_from_learner[source][test]

\n", "\n", "> _cl_int_from_learner(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`tta`**=***`False`***)\n", "\n", "
×

Tests found for _cl_int_from_learner:

  • pytest -sv tests/test_vision_train.py::test_interp [source]

To run tests please refer to this guide.

\n", "\n", "Create an instance of [`ClassificationInterpretation`](/train.html#ClassificationInterpretation). `tta` indicates if we want to use Test Time Augmentation. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.from_learner)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

top_losses[source][test]

\n", "\n", "> top_losses(**`k`**:`int`=***`None`***, **`largest`**=***`True`***)\n", "\n", "
×

Tests found for top_losses:

Some other tests where top_losses is used:

  • pytest -sv tests/test_vision_train.py::test_interp [source]
  • pytest -sv tests/test_vision_train.py::test_interp_shortcut [source]

To run tests please refer to this guide.

\n", "\n", "`k` largest(/smallest) losses and indexes, defaulting to all losses (sorted by `largest`). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

confusion_matrix[source][test]

\n", "\n", "> confusion_matrix(**`slice_size`**:`int`=***`1`***)\n", "\n", "
×

Tests found for confusion_matrix:

  • pytest -sv tests/test_tabular_train.py::test_confusion_tabular [source]

Some other tests where confusion_matrix is used:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

To run tests please refer to this guide.

\n", "\n", "Confusion matrix as an `np.ndarray`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.confusion_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

most_confused[source][test]

\n", "\n", "> most_confused(**`min_val`**:`int`=***`1`***, **`slice_size`**:`int`=***`1`***) → `Collection`\\[`Tuple`\\[`str`, `str`, `int`\\]\\]\n", "\n", "
×

Tests found for most_confused:

Some other tests where most_confused is used:

  • pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation [source]

To run tests please refer to this guide.

\n", "\n", "Sorted descending list of largest non-diagonal entries of confusion matrix, presented as actual, predicted, number of occurrences. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.most_confused)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_confusion_matrix[source][test]

\n", "\n", "> plot_confusion_matrix(**`normalize`**:`bool`=***`False`***, **`title`**:`str`=***`'Confusion matrix'`***, **`cmap`**:`Any`=***`'Blues'`***, **`slice_size`**:`int`=***`1`***, **`norm_dec`**:`int`=***`2`***, **`plot_txt`**:`bool`=***`True`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n", "\n", "
×

No tests found for plot_confusion_matrix. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Plot the confusion matrix, with `title` and using `cmap`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_confusion_matrix)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_cl_int_plot_multi_top_losses[source][test]

\n", "\n", "> _cl_int_plot_multi_top_losses(**`samples`**:`int`=***`3`***, **`figsize`**:`Tuple`\\[`int`, `int`\\]=***`(8, 8)`***, **`save_misclassified`**:`bool`=***`False`***)\n", "\n", "
×

No tests found for _cl_int_plot_multi_top_losses. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Show images in `top_losses` along with their prediction, actual, loss, and probability of predicted class in a multilabeled dataset. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ClassificationInterpretation.plot_multi_top_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Extensions to Learner that easily implement Callback", "title": "train" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }