{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Additional training functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`train`](/train.html#train) provides a number of extension methods that are added to [`Learner`](/basic_train.html#Learner) (see below for a list and details), along with three simple callbacks:\n", "\n", "- [`ShowGraph`](/train.html#ShowGraph)\n", "- [`GradientClipping`](/train.html#GradientClipping)\n", "- [`BnFreeze`](/train.html#BnFreeze)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.train import *\n", "from fastai.vision import *\n", "from fastai import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`Learner`](/basic_train.html#Learner) extension methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These methods are automatically added to all [`Learner`](/basic_train.html#Learner) objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
fit_one_cycle
[source]fit_one_cycle
(`learn`:[`Learner`](/basic_train.html#Learner), `cyc_len`:`int`, `max_lr`:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=`slice(None, 0.003, None)`, `moms`:`Point`=`(0.95, 0.85)`, `div_factor`:`float`=`25.0`, `pct_start`:`float`=`0.3`, `wd`:`float`=`None`, `callbacks`:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=`None`, `kwargs`)\n",
"\n",
"Fit a model following the 1cycle policy. "
],
"text/plain": [
"lr_find
[source]lr_find
(`learn`:[`Learner`](/basic_train.html#Learner), `start_lr`:`Floats`=`1e-07`, `end_lr`:`Floats`=`10`, `num_it`:`int`=`100`, `stop_div`:`bool`=`True`, `kwargs`:`Any`)\n",
"\n",
"Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss explodes. "
],
"text/plain": [
"to_fp16
[source]to_fp16
(`learn`:[`Learner`](/basic_train.html#Learner), `loss_scale`:`float`=`512.0`, `flat_master`:`bool`=`False`) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"Transform `learn` in FP16 precision. "
],
"text/plain": [
"mixup
[source]mixup
(`learn`:[`Learner`](/basic_train.html#Learner), `alpha`:`float`=`0.4`, `stack_x`:`bool`=`False`, `stack_y`:`bool`=`True`) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"Add mixup https://arxiv.org/abs/1710.09412 to `learn`. "
],
"text/plain": [
"TTA
[source]TTA
(`learn`:[`Learner`](/basic_train.html#Learner), `beta`:`float`=`0.4`, `scale`:`float`=`1.35`, `ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=`class
ShowGraph
[source]ShowGraph
(`learn`:[`Learner`](/basic_train.html#Learner)) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"Update a graph of learner stats and metrics after each epoch. "
],
"text/plain": [
"on_epoch_end
[source]on_epoch_end
(`n_epochs`:`int`, `last_metrics`:`MetricsList`, `kwargs`) → `bool`"
],
"text/plain": [
"class
GradientClipping
[source]GradientClipping
(`learn`:[`Learner`](/basic_train.html#Learner), `clip`:`float`) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"To do gradient clipping during training. "
],
"text/plain": [
"on_backward_end
[source]on_backward_end
(`kwargs`)"
],
"text/plain": [
"class
BnFreeze
[source]BnFreeze
(`learn`:[`Learner`](/basic_train.html#Learner)) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"Freeze moving average statistics in all non-trainable batchnorm layers. "
],
"text/plain": [
"on_epoch_begin
[source]on_epoch_begin
(`kwargs`:`Any`)"
],
"text/plain": [
"one_cycle_scheduler
[source]one_cycle_scheduler
(`lr_max`:`float`, `kwargs`:`Any`) → [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)"
],
"text/plain": [
"