{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Additional training functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`train`](/train.html#train) provides a number of extension methods that are added to [`Learner`](/basic_train.html#Learner) (see below for a list and details), along with three simple callbacks:\n", "\n", "- [`ShowGraph`](/train.html#ShowGraph)\n", "- [`GradientClipping`](/train.html#GradientClipping)\n", "- [`BnFreeze`](/train.html#BnFreeze)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.train import *\n", "from fastai.vision import *\n", "from fastai import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`Learner`](/basic_train.html#Learner) extension methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These methods are automatically added to all [`Learner`](/basic_train.html#Learner) objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
fit_one_cycle[source]fit_one_cycle(`learn`:[`Learner`](/basic_train.html#Learner), `cyc_len`:`int`, `max_lr`:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=`slice(None, 0.003, None)`, `moms`:`Point`=`(0.95, 0.85)`, `div_factor`:`float`=`25.0`, `pct_start`:`float`=`0.3`, `wd`:`float`=`None`, `callbacks`:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=`None`, `kwargs`)\n",
"\n",
"Fit a model following the 1cycle policy. "
],
"text/plain": [
"lr_find[source]lr_find(`learn`:[`Learner`](/basic_train.html#Learner), `start_lr`:`Floats`=`1e-07`, `end_lr`:`Floats`=`10`, `num_it`:`int`=`100`, `stop_div`:`bool`=`True`, `kwargs`:`Any`)\n",
"\n",
"Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss explodes. "
],
"text/plain": [
"to_fp16[source]to_fp16(`learn`:[`Learner`](/basic_train.html#Learner), `loss_scale`:`float`=`512.0`, `flat_master`:`bool`=`False`) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"Transform `learn` in FP16 precision. "
],
"text/plain": [
"mixup[source]mixup(`learn`:[`Learner`](/basic_train.html#Learner), `alpha`:`float`=`0.4`, `stack_x`:`bool`=`False`, `stack_y`:`bool`=`True`) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"Add mixup https://arxiv.org/abs/1710.09412 to `learn`. "
],
"text/plain": [
"TTA[source]TTA(`learn`:[`Learner`](/basic_train.html#Learner), `beta`:`float`=`0.4`, `scale`:`float`=`1.35`, `is_test`:`bool`=`False`, `with_loss`:`bool`=`False`) → `Tensors`"
],
"text/plain": [
"class ShowGraph[source]ShowGraph(`learn`:[`Learner`](/basic_train.html#Learner)) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"Update a graph of learner stats and metrics after each epoch. "
],
"text/plain": [
"on_epoch_end[source]on_epoch_end(`n_epochs`:`int`, `last_metrics`:`MetricsList`, `kwargs`) → `bool`"
],
"text/plain": [
"class GradientClipping[source]GradientClipping(`learn`:[`Learner`](/basic_train.html#Learner), `clip`:`float`) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"To do gradient clipping during training. "
],
"text/plain": [
"on_backward_end[source]on_backward_end(`kwargs`)"
],
"text/plain": [
"class BnFreeze[source]BnFreeze(`learn`:[`Learner`](/basic_train.html#Learner)) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"Freeze moving average statistics in all non-trainable batchnorm layers. "
],
"text/plain": [
"on_epoch_begin[source]on_epoch_begin(`kwargs`:`Any`)"
],
"text/plain": [
"one_cycle_scheduler[source]one_cycle_scheduler(`lr_max`:`float`, `kwargs`:`Any`) → [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)"
],
"text/plain": [
"