{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Additional training functions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[`train`](/train.html#train) provides a number of extension methods that are added to [`Learner`](/basic_train.html#Learner) (see below for a list and details), along with three simple callbacks:\n",
"\n",
"- [`ShowGraph`](/train.html#ShowGraph)\n",
"- [`GradientClipping`](/train.html#GradientClipping)\n",
"- [`BnFreeze`](/train.html#BnFreeze)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [],
"source": [
"from fastai.gen_doc.nbdoc import *\n",
"from fastai.train import *\n",
"from fastai.vision import *\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [`Learner`](/basic_train.html#Learner) extension methods"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These methods are automatically added to all [`Learner`](/basic_train.html#Learner) objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"
\n",
"\n",
"> fit_one_cycle(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cyc_len`**:`int`, **`max_lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`moms`**:`Point`=***`(0.95, 0.85)`***, **`div_factor`**:`float`=***`25.0`***, **`pct_start`**:`float`=***`0.3`***, **`wd`**:`float`=***`None`***, **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`tot_epochs`**:`int`=***`None`***, **`start_epoch`**:`int`=***`1`***)\n",
"\n",
"Fit a model following the 1cycle policy. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(fit_one_cycle)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"one_cycle_scheduler[source]
\n",
"\n",
"> one_cycle_scheduler(**`lr_max`**:`float`, **\\*\\*`kwargs`**:`Any`) → [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)\n",
"\n",
"Instantiate a [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler) with `lr_max`. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(one_cycle_scheduler)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler) for details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> lr_find(**`learn`**:[`Learner`](/basic_train.html#Learner), **`start_lr`**:`Floats`=***`1e-07`***, **`end_lr`**:`Floats`=***`10`***, **`num_it`**:`int`=***`100`***, **`stop_div`**:`bool`=***`True`***, **`wd`**:`float`=***`None`***)\n",
"\n",
"Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss diverges. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(lr_find)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See [`LRFinder`](/callbacks.lr_finder.html#LRFinder) for details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> to_fp16(**`learn`**:[`Learner`](/basic_train.html#Learner), **`loss_scale`**:`float`=***`None`***, **`max_noskip`**:`int`=***`1000`***, **`dynamic`**:`bool`=***`False`***, **`clip`**:`float`=***`None`***, **`flat_master`**:`bool`=***`False`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"Put `learn` in FP16 precision mode. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(to_fp16)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See [`MixedPrecision`](/callbacks.fp16.html#MixedPrecision) for details."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> to_fp32(**`learn`**:[`Learner`](/basic_train.html#Learner))\n",
"\n",
"Put `learn` back to FP32 precision mode. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(to_fp32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> mixup(**`learn`**:[`Learner`](/basic_train.html#Learner), **`alpha`**:`float`=***`0.4`***, **`stack_x`**:`bool`=***`False`***, **`stack_y`**:`bool`=***`True`***) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"Add mixup https://arxiv.org/abs/1710.09412 to `learn`. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(mixup)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"class ClassificationInterpretation[source]
\n",
"\n",
"> ClassificationInterpretation(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`probs`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***)\n",
"\n",
"Interpretation methods for classification models. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(ClassificationInterpretation)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"See [`MixUpCallback`](/callbacks.mixup.html#MixUpCallback) for more details."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Additional callbacks"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll show examples below using our MNIST sample. As usual the `on_something` methods are directly called by the fastai library, no need to call them yourself."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.MNIST_SAMPLE)\n",
"data = ImageDataBunch.from_folder(path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> ShowGraph(**`learn`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"Update a graph of learner stats and metrics after each epoch. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(ShowGraph, title_level=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"learn = create_cnn(data, models.resnet18, metrics=accuracy, callback_fns=ShowGraph)\n",
"learn.fit(3)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> on_epoch_end(**`n_epochs`**:`int`, **`last_metrics`**:`MetricsList`, **\\*\\*`kwargs`**) → `bool`\n",
"\n",
"If we have `last_metrics` plot them in our pbar graph "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(ShowGraph.on_epoch_end)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"class GradientClipping[source]
\n",
"\n",
"> GradientClipping(**`learn`**:[`Learner`](/basic_train.html#Learner), **`clip`**:`float`=***`0.0`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"Gradient clipping during training. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(GradientClipping)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:11 \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.131133 | \n",
" 0.078190 | \n",
" 0.973013 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = create_cnn(data, models.resnet18, metrics=accuracy,\n",
" callback_fns=partial(GradientClipping, clip=0.1))\n",
"learn.fit(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> on_backward_end(**\\*\\*`kwargs`**)\n",
"\n",
"Clip the gradient before the optimizer step. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(GradientClipping.on_backward_end)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> BnFreeze(**`learn`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"Freeze moving average statistics in all non-trainable batchnorm layers. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(BnFreeze)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For batchnorm layers where `requires_grad==False`, you generally don't want to update their moving average statistics, in order to avoid the model's statistics getting out of sync with its pre-trained weights. You can add this callback to automate this freezing of statistics (internally, it calls `eval` on these layers)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:07 \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" accuracy | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.132564 | \n",
" 0.078910 | \n",
" 0.972031 | \n",
"
\n",
"
\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = create_cnn(data, models.resnet18, metrics=accuracy, callback_fns=BnFreeze)\n",
"learn.fit(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"\n",
"\n",
"> on_epoch_begin(**\\*\\*`kwargs`**:`Any`)\n",
"\n",
"Put bn layers in eval mode just after `model.train()`. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(BnFreeze.on_epoch_begin)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Undocumented Methods - Methods moved below this line will intentionally be hidden"
]
}
],
"metadata": {
"jekyll": {
"keywords": "fastai",
"summary": "Extensions to Learner that easily implement Callback",
"title": "train"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}