{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Additional training functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`train`](/train.html#train) provides a number of extension methods that are added to [`Learner`](/basic_train.html#Learner) (see below for a list and details), along with three simple callbacks:\n", "\n", "- [`ShowGraph`](/train.html#ShowGraph)\n", "- [`GradientClipping`](/train.html#GradientClipping)\n", "- [`BnFreeze`](/train.html#BnFreeze)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.train import *\n", "from fastai.vision import *\n", "from fastai import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`Learner`](/basic_train.html#Learner) extension methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These methods are automatically added to all [`Learner`](/basic_train.html#Learner) objects created after importing this module. They provide convenient access to a number of callbacks, without requiring them to be manually created." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

fit_one_cycle[source]

\n", "\n", "> fit_one_cycle(`learn`:[`Learner`](/basic_train.html#Learner), `cyc_len`:`int`, `max_lr`:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=`slice(None, 0.003, None)`, `moms`:`Point`=`(0.95, 0.85)`, `div_factor`:`float`=`25.0`, `pct_start`:`float`=`0.3`, `wd`:`float`=`None`, `callbacks`:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=`None`, `kwargs`)\n", "\n", "Fit a model following the 1cycle policy. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(fit_one_cycle)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fit a model with 1cycle training. See [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler) for details." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

lr_find[source]

\n", "\n", "> lr_find(`learn`:[`Learner`](/basic_train.html#Learner), `start_lr`:`Floats`=`1e-07`, `end_lr`:`Floats`=`10`, `num_it`:`int`=`100`, `stop_div`:`bool`=`True`, `kwargs`:`Any`)\n", "\n", "Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss explodes. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(lr_find)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`LRFinder`](/callbacks.lr_finder.html#LRFinder) for details." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

to_fp16[source]

\n", "\n", "> to_fp16(`learn`:[`Learner`](/basic_train.html#Learner), `loss_scale`:`float`=`512.0`, `flat_master`:`bool`=`False`) → [`Learner`](/basic_train.html#Learner)\n", "\n", "Transform `learn` in FP16 precision. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(to_fp16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`MixedPrecision`](/callbacks.fp16.html#MixedPrecision) for details." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

mixup[source]

\n", "\n", "> mixup(`learn`:[`Learner`](/basic_train.html#Learner), `alpha`:`float`=`0.4`, `stack_x`:`bool`=`False`, `stack_y`:`bool`=`True`) → [`Learner`](/basic_train.html#Learner)\n", "\n", "Add mixup https://arxiv.org/abs/1710.09412 to `learn`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(mixup)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [`MixUpCallback`](/callbacks.mixup.html#MixUpCallback) for more details." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A last extension method comes from the module tta." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

TTA[source]

\n", "\n", "> TTA(`learn`:[`Learner`](/basic_train.html#Learner), `beta`:`float`=`0.4`, `scale`:`float`=`1.35`, `ds_type`:[`DatasetType`](/basic_data.html#DatasetType)=``, `with_loss`:`bool`=`False`) → `Tensors`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.TTA, full_name='TTA')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Applies Test Time Augmentation to `learn` on the dataset `ds_type`. We take the average of our regular predictions (with a weight `beta`) with the average of predictions obtained thourh augmented versions of the training set (with a weight `1-beta`). The transforms decided for the training set are applied with a few changes `scale` controls the scale for zoom (which isn't random), the cropping isn't random but we make sure to get the four corners of the image. Flipping isn't random but applied once on each of those corner images (so that makes 8 augmented versions total)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll show examples below using our MNIST sample." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class ShowGraph[source]

\n", "\n", "> ShowGraph(`learn`:[`Learner`](/basic_train.html#Learner)) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "Update a graph of learner stats and metrics after each epoch. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ShowGraph)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```python\n", "learn = create_cnn(data, models.resnet18, metrics=accuracy, callback_fns=ShowGraph)\n", "learn.fit(3)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![Training graph](imgs/train_graph.gif)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source]

\n", "\n", "> on_epoch_end(`n_epochs`:`int`, `last_metrics`:`MetricsList`, `kwargs`) → `bool`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ShowGraph.on_epoch_end, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we have `last_metrics`, plot them in `self.pbar`. Set the size of the graph with `n_epochs`." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class GradientClipping[source]

\n", "\n", "> GradientClipping(`learn`:[`Learner`](/basic_train.html#Learner), `clip`:`float`) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "To do gradient clipping during training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GradientClipping)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Clips gradient at a maximum absolute value of `clip` during training. For instance:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]'))), HTML(value…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:11\n", "epoch train loss valid loss accuracy\n", "0 0.086958 0.038721 0.989696 (00:11)\n", "\n" ] } ], "source": [ "learn = create_cnn(data, models.resnet18, metrics=accuracy,\n", " callback_fns=partial(GradientClipping, clip=0.1))\n", "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_backward_end[source]

\n", "\n", "> on_backward_end(`kwargs`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GradientClipping.on_backward_end, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Clip the gradients after they are computed but before the optimizer step." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class BnFreeze[source]

\n", "\n", "> BnFreeze(`learn`:[`Learner`](/basic_train.html#Learner)) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "Freeze moving average statistics in all non-trainable batchnorm layers. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(BnFreeze)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For batchnorm layers where `requires_grad==False`, you generally don't want to update their moving average statistics, in order to avoid the model's statistics getting out of sync with its pre-trained weights. You can add this callback to automate this freezing of statistics (internally, it calls `eval` on these layers)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=1), HTML(value='0.00% [0/1 00:00<00:00]'))), HTML(value…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:07\n", "epoch train loss valid loss accuracy\n", "0 0.079278 0.041832 0.985280 (00:07)\n", "\n" ] } ], "source": [ "learn = create_cnn(data, models.resnet18, metrics=accuracy, callback_fns=BnFreeze)\n", "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_begin[source]

\n", "\n", "> on_epoch_begin(`kwargs`:`Any`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(BnFreeze.on_epoch_begin, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set back the batchnorm layers on `eval` mode after the model has been set to [`train`](/train.html#train)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

one_cycle_scheduler[source]

\n", "\n", "> one_cycle_scheduler(`lr_max`:`float`, `kwargs`:`Any`) → [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(one_cycle_scheduler)" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Extensions to Learner that easily implement Callback", "title": "train" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }