{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp callback.training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.basics import *\n", "from fastai2.callback.progress import *\n", "from fastai2.callback.fp16 import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *\n", "from fastai2.test_utils import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tracking callbacks\n", "\n", "> Callbacks that make decisions depending how a monitored metric/loss behaves" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ShortEpochCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@log_args\n", "class ShortEpochCallback(Callback):\n", " \"Fit just `pct` of an epoch, then stop\"\n", " def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid\n", " def after_batch(self):\n", " if self.iter/self.n_iter < self.pct: return\n", " if self.training: raise CancelTrainException\n", " if self.short_valid: raise CancelValidException" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
000:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = synth_learner()\n", "learn.fit(1, cbs=ShortEpochCallback())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
012.39577100:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = synth_learner()\n", "learn.fit(1, cbs=ShortEpochCallback(short_valid=False))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GradientAccumulation -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "@log_args\n", "class GradientAccumulation(Callback):\n", " \"Accumulate gradients before updating weights\"\n", " toward_end,run_before=True,MixedPrecision\n", "\n", " def __init__(self, n_acc=32): store_attr(self, 'n_acc')\n", " def before_fit(self): self.count=0\n", "\n", " def after_backward(self):\n", " self.count += find_bs(self.learn.yb)\n", " if self.count < self.n_acc: raise CancelBatchException() #skip weight update\n", " else: self.count=0\n", "\n", " _docs = dict(before_fit=\"Set counter to 0\",\n", " after_backward=\"Skip weight update if we have not seen enough items\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
010.5669073.63375300:00
15.5259840.39748300:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.4765990.39748300:00
10.4782130.39748300:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = synth_learner()\n", "\n", "learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=2*learn.dls.bs))\n", "# ensure train_loss decreased\n", "assert learn.recorder.values[-1][0] < learn.recorder.values[0][0]\n", "\n", "learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=1e6))\n", "# ensure valid_loss didn't change (same weights)\n", "assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## BnFreeze" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)\n", "\n", "def set_bn_eval(m:nn.Module, use_eval=True)->None:\n", " \"Set bn layers in eval mode for all recursive children of `m`.\"\n", " for l in m.children():\n", " if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:\n", " if use_eval: l.eval()\n", " else: l.train()\n", " set_bn_eval(l)\n", "\n", "class BnFreeze(Callback):\n", " \"Freeze moving average statistics in all non-trainable batchnorm layers.\"\n", " def before_epoch(self):\n", " set_bn_eval(self.model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`BnFreeze` is useful when you'd like to train two separate models that have a common feature extractor / body. The only part of the model that's different is the head that you attach for transfer learning.
\n", "\n", "`Learner.freeze()` doesn't suffice here as the `BatchNorm` layers are trainable by default, and running mean and sdev of batches are tracked. For feature extractors to fully match, you need to set `train_bn=False` and these stats need to be frozen as well, which is precisely the function of `BnFreeze`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#slow\n", "from fastai2.vision.all import *\n", "\n", "path = untar_data(URLs.MNIST_TINY)\n", "dls = ImageDataLoaders.from_folder(path, valid_pct=0.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We first demonstrate the mismatch of the running stats when using only `train_bn=False`, by creating a `Learner`...:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#slow\n", "learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...and grab the first `BatchNorm` layer, and store its running mean: " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#slow\n", "m = learn1.model[0][1].running_mean.clone()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see that now that running mean has changed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
01.0583040.71341400:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#slow\n", "learn1.fit(1, lr=0.02)\n", "test_ne(learn1.model[0][1].running_mean, m)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When we use the `BnFreeze` callback, the running statistics will not be changed during training. This is often important for getting good results from transfer learning." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.5408410.43242100:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#slow\n", "learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)\n", "m = learn1.model[0][1].running_mean.clone()\n", "learn1.fit(1, lr=0.02)\n", "test_eq(learn1.model[0][1].running_mean, m)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_torch_core.ipynb.\n", "Converted 01_layers.ipynb.\n", "Converted 02_data.load.ipynb.\n", "Converted 03_data.core.ipynb.\n", "Converted 04_data.external.ipynb.\n", "Converted 05_data.transforms.ipynb.\n", "Converted 06_data.block.ipynb.\n", "Converted 07_vision.core.ipynb.\n", "Converted 08_vision.data.ipynb.\n", "Converted 09_vision.augment.ipynb.\n", "Converted 09b_vision.utils.ipynb.\n", "Converted 09c_vision.widgets.ipynb.\n", "Converted 10_tutorial.pets.ipynb.\n", "Converted 11_vision.models.xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_callback.core.ipynb.\n", "Converted 13a_learner.ipynb.\n", "Converted 13b_metrics.ipynb.\n", "Converted 14_callback.schedule.ipynb.\n", "Converted 14a_callback.data.ipynb.\n", "Converted 15_callback.hook.ipynb.\n", "Converted 15a_vision.models.unet.ipynb.\n", "Converted 16_callback.progress.ipynb.\n", "Converted 17_callback.tracker.ipynb.\n", "Converted 18_callback.fp16.ipynb.\n", "Converted 18a_callback.training.ipynb.\n", "Converted 19_callback.mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision.learner.ipynb.\n", "Converted 22_tutorial.imagenette.ipynb.\n", "Converted 23_tutorial.vision.ipynb.\n", "Converted 24_tutorial.siamese.ipynb.\n", "Converted 24_vision.gan.ipynb.\n", "Converted 30_text.core.ipynb.\n", "Converted 31_text.data.ipynb.\n", "Converted 32_text.models.awdlstm.ipynb.\n", "Converted 33_text.models.core.ipynb.\n", "Converted 34_callback.rnn.ipynb.\n", "Converted 35_tutorial.wikitext.ipynb.\n", "Converted 36_text.models.qrnn.ipynb.\n", "Converted 37_text.learner.ipynb.\n", "Converted 38_tutorial.text.ipynb.\n", "Converted 39_tutorial.transformers.ipynb.\n", "Converted 40_tabular.core.ipynb.\n", "Converted 41_tabular.data.ipynb.\n", "Converted 42_tabular.model.ipynb.\n", "Converted 43_tabular.learner.ipynb.\n", "Converted 44_tutorial.tabular.ipynb.\n", "Converted 45_collab.ipynb.\n", "Converted 46_tutorial.collab.ipynb.\n", "Converted 50_tutorial.datablock.ipynb.\n", "Converted 60_medical.imaging.ipynb.\n", "Converted 61_tutorial.medical_imaging.ipynb.\n", "Converted 65_medical.text.ipynb.\n", "Converted 70_callback.wandb.ipynb.\n", "Converted 71_callback.tensorboard.ipynb.\n", "Converted 72_callback.neptune.ipynb.\n", "Converted 73_callback.captum.ipynb.\n", "Converted 74_callback.cutmix.ipynb.\n", "Converted 97_test_utils.ipynb.\n", "Converted 99_pytorch_doc.ipynb.\n", "Converted index.ipynb.\n", "Converted tutorial.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }