{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Classes for callback implementors" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.callback import * \n", "from fastai import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "fastai provides a powerful *callback* system, which is documented on the [`callbacks`](/callbacks.html#callbacks) page; look on that page if you're just looking for how to use existing callbacks. If you want to create your own, you'll need to use the classes discussed below.\n", "\n", "A key motivation for the callback system is that additional functionality can be entirely implemented in a single callback, so that it's easily read. By using this trick, we will have different methods categorized in different callbacks where we will find clearly stated all the interventions the method makes in training. For instance in the [`LRFinder`](/callbacks.lr_finder.html#LRFinder) callback, on top of running the fit function with exponentially growing LRs, it needs to handle some preparation and clean-up, and all this code can be in the same callback so we know exactly what it is doing and where to look if we need to change something.\n", "\n", "In addition, it allows our [`fit`](/basic_train.html#fit) function to be very clean and simple, yet still easily extended. So far in implementing a number of recent papers, we haven't yet come across any situation where we had to modify our training loop source code - we've been able to use callbacks every time." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
class
Callback
[source]Callback
()\n",
"\n",
"Base class for callbacks that want to record values, dynamically change learner params, etc. "
],
"text/plain": [
"on_train_begin
[source]on_train_begin
(`kwargs`:`Any`)\n",
"\n",
"To initialize constants in the callback. "
],
"text/plain": [
"on_epoch_begin
[source]on_epoch_begin
(`kwargs`:`Any`)\n",
"\n",
"At the beginning of each epoch. "
],
"text/plain": [
"on_batch_begin
[source]on_batch_begin
(`kwargs`:`Any`)\n",
"\n",
"Set HP before the step is done. Returns xb, yb (which can allow us to modify the input at that step if needed). "
],
"text/plain": [
"on_loss_begin
[source]on_loss_begin
(`kwargs`:`Any`)\n",
"\n",
"Called after forward pass but before loss has been computed. Returns the output (which can allow us to modify it). "
],
"text/plain": [
"on_backward_begin
[source]on_backward_begin
(`kwargs`:`Any`)\n",
"\n",
"Called after the forward pass and the loss has been computed, but before backprop. Returns the loss (which can allow us to modify it, for instance for reg functions) "
],
"text/plain": [
"on_backward_end
[source]on_backward_end
(`kwargs`:`Any`)\n",
"\n",
"Called after backprop but before optimizer step. Useful for true weight decay in AdamW. "
],
"text/plain": [
"on_step_end
[source]on_step_end
(`kwargs`:`Any`)\n",
"\n",
"Called after the step of the optimizer but before the gradients are zeroed. "
],
"text/plain": [
"on_batch_end
[source]on_batch_end
(`kwargs`:`Any`)\n",
"\n",
"Called at the end of the batch. "
],
"text/plain": [
"on_epoch_end
[source]on_epoch_end
(`kwargs`:`Any`) → `bool`\n",
"\n",
"Called at the end of an epoch. "
],
"text/plain": [
"on_train_end
[source]on_train_end
(`kwargs`:`Any`)\n",
"\n",
"Useful for cleaning up things and saving files/models. "
],
"text/plain": [
"annealing_cos
[source]annealing_cos
(`start`:`Number`, `end`:`Number`, `pct`:`float`) → `Number`\n",
"\n",
"Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0. "
],
"text/plain": [
"annealing_exp
[source]annealing_exp
(`start`:`Number`, `end`:`Number`, `pct`:`float`) → `Number`\n",
"\n",
"Exponentially anneal from `start` to `end` as pct goes from 0.0 to 1.0. "
],
"text/plain": [
"annealing_linear
[source]annealing_linear
(`start`:`Number`, `end`:`Number`, `pct`:`float`) → `Number`\n",
"\n",
"Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0. "
],
"text/plain": [
"annealing_no
[source]annealing_no
(`start`:`Number`, `end`:`Number`, `pct`:`float`) → `Number`\n",
"\n",
"No annealing, always return `start`. "
],
"text/plain": [
"annealing_poly
[source]annealing_poly
(`degree`:`Number`) → `Number`\n",
"\n",
"Anneal polynomically from `start` to `end` as pct goes from 0.0 to 1.0. "
],
"text/plain": [
"class
CallbackHandler
[source]CallbackHandler
(`callbacks`:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=`None`, `metrics`:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=`None`, `beta`:`float`=`0.98`)\n",
"\n",
"Manage all of the registered callback objects, smoothing loss by momentum `beta`. "
],
"text/plain": [
"on_backward_begin
[source]on_backward_begin
(`loss`:`Tensor`)\n",
"\n",
"Handle gradient calculation on `loss`. "
],
"text/plain": [
"on_backward_end
[source]on_backward_end
()\n",
"\n",
"Handle end of gradient calculation. "
],
"text/plain": [
"on_batch_begin
[source]on_batch_begin
(`xb`:`Tensor`, `yb`:`Tensor`, `train`:`bool`=`True`)\n",
"\n",
"Handle new batch `xb`,`yb`. "
],
"text/plain": [
"on_batch_end
[source]on_batch_end
(`loss`:`Tensor`)\n",
"\n",
"Handle end of processing one batch with `loss`. "
],
"text/plain": [
"on_epoch_begin
[source]on_epoch_begin
()\n",
"\n",
"Handle new epoch. "
],
"text/plain": [
"on_epoch_end
[source]on_epoch_end
(`val_loss`:`Tensor`) → `bool`\n",
"\n",
"Epoch is done, process `val_metrics`. "
],
"text/plain": [
"on_loss_begin
[source]on_loss_begin
(`out`:`Tensor`)\n",
"\n",
"Handle start of loss calculation with model output `out`. "
],
"text/plain": [
"on_step_end
[source]on_step_end
()\n",
"\n",
"Handle end of optimization step. "
],
"text/plain": [
"on_train_begin
[source]on_train_begin
(`epochs`:`int`, `pbar`:`PBar`, `metrics`:`MetricFuncList`)\n",
"\n",
"About to start learning. "
],
"text/plain": [
"on_train_end
[source]on_train_end
(`exception`:`Union`\\[`bool`, `Exception`\\])\n",
"\n",
"Handle end of training, `exception` is an `Exception` or False if no exceptions during training. "
],
"text/plain": [
"class
OptimWrapper
[source]OptimWrapper
(`opt`:[`Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), `wd`:`Floats`=`0.0`, `true_wd`:`bool`=`False`, `bn_wd`:`bool`=`True`)\n",
"\n",
"Basic wrapper around an optimizer to simplify HP changes. "
],
"text/plain": [
"create
[source]create
(`opt_func`:`Callable`, `lr`:`Union`\\[`float`, `Tuple`, `List`\\], `layer_groups`:`ModuleList`, `kwargs`:`Any`) → [`Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer)\n",
"\n",
"Create an optim.Optimizer from `opt_func` with `lr`. Set lr on `layer_groups`. "
],
"text/plain": [
"read_defaults
[source]read_defaults
()\n",
"\n",
"Read the values inside the optimizer for the hyper-parameters. "
],
"text/plain": [
"read_val
[source]read_val
(`key`:`str`) → `Union`\\[`List`\\[`float`\\], `Tuple`\\[`List`\\[`float`\\], `List`\\[`float`\\]\\]\\]\n",
"\n",
"Read a hyperparameter key in the optimizer dictionary. "
],
"text/plain": [
"set_val
[source]set_val
(`key`:`str`, `val`:`Any`, `bn_groups`:`bool`=`True`) → `Any`\n",
"\n",
"Set the values inside the optimizer dictionary at the key. "
],
"text/plain": [
"step
[source]step
()\n",
"\n",
"Set weight decay and step optimizer. "
],
"text/plain": [
"zero_grad
[source]zero_grad
()\n",
"\n",
"Clear optimizer gradients. "
],
"text/plain": [
"class
SmoothenValue
[source]SmoothenValue
(`beta`:`float`)\n",
"\n",
"Create a smooth moving average for a value (loss, etc). "
],
"text/plain": [
"add_value
[source]add_value
(`val`:`float`)\n",
"\n",
"Add current value to calculate updated smoothed value. "
],
"text/plain": [
"class
Stepper
[source]Stepper
(`vals`:`StartOptEnd`, `n_iter`:`int`, `func`:`Optional`\\[`AnnealFunc`\\]=`None`)\n",
"\n",
"Used to \"step\" from start,end (`vals`) over `n_iter` iterations on a schedule defined by `func` "
],
"text/plain": [
"step
[source]step
() → `Number`\n",
"\n",
"Return next value along annealed schedule. "
],
"text/plain": [
"class
AverageMetric
[source]AverageMetric
(`func`) :: [`Callback`](/callback.html#Callback)\n",
"\n",
"Wrap a `func` in a callback for metrics computation. "
],
"text/plain": [
"do_annealing_poly
[source]do_annealing_poly
(`start`:`Number`, `end`:`Number`, `pct`:`float`, `degree`:`Number`) → `Number`\n",
"\n",
"Helper function for `anneal_poly`. "
],
"text/plain": [
"on_epoch_begin
[source]on_epoch_begin
(`kwargs`)\n",
"\n",
"At the beginning of each epoch. "
],
"text/plain": [
"on_batch_end
[source]on_batch_end
(`last_output`, `last_target`, `train`, `kwargs`)\n",
"\n",
"Called at the end of the batch. "
],
"text/plain": [
"on_epoch_end
[source]on_epoch_end
(`kwargs`)\n",
"\n",
"Called at the end of an epoch. "
],
"text/plain": [
"