{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#skip\n", "! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp callback.training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai.basics import *\n", "from fastai.callback.progress import *\n", "from fastai.callback.fp16 import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *\n", "from fastai.test_utils import *\n", "from fastai.vision.all import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training callbacks\n", "\n", "> Various callbacks to customize training behavior" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ShortEpochCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ShortEpochCallback(Callback):\n", " \"Fit just `pct` of an epoch, then stop\"\n", " def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid\n", " def after_batch(self):\n", " if self.iter/self.n_iter < self.pct: return\n", " if self.training: raise CancelTrainException\n", " if self.short_valid: raise CancelValidException" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
000:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = synth_learner()\n", "learn.fit(1, cbs=ShortEpochCallback())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
014.86797500:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = synth_learner()\n", "learn.fit(1, cbs=ShortEpochCallback(short_valid=False))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GradientAccumulation -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class GradientAccumulation(Callback):\n", " \"Accumulate gradients before updating weights\"\n", " order,run_valid = MixedPrecision.order-4,False\n", " def __init__(self, n_acc=32): store_attr()\n", " def before_fit(self): self.count=0\n", " def after_loss(self): self.learn.loss_grad /= self.n_acc/find_bs(self.learn.yb)\n", " def before_step(self):\n", " \"Skip weight update if we have not seen enough items\"\n", " self.learn.loss_grad *= self.n_acc/find_bs(self.learn.yb) # log correct loss\n", " self.count += find_bs(self.learn.yb)\n", " if self.count\n", " \n", " \n", " epoch\n", " train_loss\n", " valid_loss\n", " time\n", " \n", " \n", " \n", " \n", " 0\n", " 0.834062\n", " 0.295950\n", " 00:00\n", " \n", " \n", "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.8245500.29595000:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "class GetGrads(Callback):\n", " run_valid,order = False,GradientAccumulation.order+1\n", " def before_step(self): self.grads=to_detach(L([p.grad.clone() for p in self.model.parameters()]))\n", "\n", "def _test_acc(bs,n,cbs=None,cuda=False):\n", " with no_random(99): \n", " db=synth_dbunch(bs=bs,n_train=n,n_valid=n,cuda=cuda)\n", " learn = synth_learner(data=db,cbs=[GetGrads]+L(cbs))\n", " learn.fit(1, lr=0.01)\n", " train,valid = learn.recorder.values[-1]\n", " return train,valid,learn.get_grads.grads\n", "\n", "acc_cb = GradientAccumulation(n_acc=8)\n", "\n", "train1,valid1,grads1 = _test_acc(8,1)\n", "train2,valid2,grads2 = _test_acc(1,8,acc_cb)\n", "\n", "#grads should be same, valid loss same, train loss different\n", "test_close(grads2,grads1)\n", "test_close(valid2, valid1)\n", "test_ne(train2, train1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.8340620.29595000:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.8245500.29595000:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "#cuda\n", "fp16_cb = MixedPrecision(init_scale=1024)\n", "train1,valid1,grads1 = _test_acc(8,1, fp16_cb, cuda=True)\n", "train2,valid2,grads2 = _test_acc(1,8, [acc_cb,fp16_cb], cuda=True)\n", "test_close(grads2,grads1, eps=0.01)\n", "test_close(valid2, valid1)\n", "test_ne(train2, train1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When the number of steps per accumulation is higher than the number of batches, the parameters (and therefore validation loss) don't change at all:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
010.94116810.28042800:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = synth_learner()\n", "learn.fit(1, lr=0.01, cbs=GradientAccumulation(n_acc=1000))\n", "# ensure valid_loss didn't change\n", "assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GradientClip -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class GradientClip(Callback):\n", " \"Clip norm of gradients\"\n", " order=MixedPrecision.order+1\n", " def __init__(self,max_norm:float=1., norm_type:float=2.0): store_attr()\n", " def before_step(self): nn.utils.clip_grad_norm_(self.parameters(), self.max_norm, self.norm_type)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Normally if we use a learning rate that is too high, our training will diverge. This even happens if we use mixed precision training, which avoid infinities by using dynamic loss scaling, but still diverges:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fp16 = MixedPrecision()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
038.21416925.26901200:00
1377.146088890.01178000:00
2839.3919079965.71289100:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "set_seed(99)\n", "learn = synth_learner(lr=1.1, cuda=True)\n", "learn.fit(3, cbs=fp16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By adding the `GradientClip` callback, the gradient `norm_type` (default:2) norm is clipped to at most `max_norm` (default:1) using `nn.utils.clip_grad_norm_`, which can avoid loss divergence:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
02.0394272.37218300:00
11.4024240.30072400:00
21.0135510.33266800:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "set_seed(99)\n", "learn = synth_learner(lr=1.1, cuda=True)\n", "learn.fit(3, cbs=[GradientClip,fp16])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## BnFreeze" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)\n", "\n", "def set_bn_eval(m:nn.Module, use_eval=True)->None:\n", " \"Set bn layers in eval mode for all recursive children of `m`.\"\n", " for l in m.children():\n", " if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:\n", " if use_eval: l.eval()\n", " else: l.train()\n", " set_bn_eval(l)\n", "\n", "class BnFreeze(Callback):\n", " run_after=TrainEvalCallback\n", " \"Freeze moving average statistics in all non-trainable batchnorm layers.\"\n", " def before_train(self):\n", " set_bn_eval(self.model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`BnFreeze` is useful when you'd like to train two separate models that have a common feature extractor / body. The only part of the model that's different is the head that you attach for transfer learning.
\n", "\n", "`Learner.freeze()` doesn't suffice here as the `BatchNorm` layers are trainable by default, and running mean and std of batches are tracked. For feature extractors to fully match, you need to set `train_bn=False` and these stats need to be frozen as well, which is precisely the function of `BnFreeze`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#slow\n", "path = untar_data(URLs.MNIST_TINY)\n", "dls = ImageDataLoaders.from_folder(path, valid_pct=0.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We first demonstrate the mismatch of the running stats when using only `train_bn=False`, by creating a `Learner`...:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#slow\n", "learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "...and grab the first `BatchNorm` layer, and store its running mean: " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#slow\n", "m = learn1.model[0][1].running_mean.clone()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can see that now that running mean has changed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
01.1527010.46889200:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#slow\n", "learn1.fit(1, lr=0.02)\n", "test_ne(to_detach(learn1.model[0][1].running_mean), m)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When we use the `BnFreeze` callback, the running statistics will not be changed during training. This is often important for getting good results from transfer learning." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
00.4886340.27768300:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#slow\n", "learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)\n", "m = learn1.model[0][1].running_mean.detach().clone()\n", "learn1.fit(1, lr=0.02)\n", "test_eq(to_detach(learn1.model[0][1].running_mean), m)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_torch_core.ipynb.\n", "Converted 01_layers.ipynb.\n", "Converted 01a_losses.ipynb.\n", "Converted 02_data.load.ipynb.\n", "Converted 03_data.core.ipynb.\n", "Converted 04_data.external.ipynb.\n", "Converted 05_data.transforms.ipynb.\n", "Converted 06_data.block.ipynb.\n", "Converted 07_vision.core.ipynb.\n", "Converted 08_vision.data.ipynb.\n", "Converted 09_vision.augment.ipynb.\n", "Converted 09b_vision.utils.ipynb.\n", "Converted 09c_vision.widgets.ipynb.\n", "Converted 10_tutorial.pets.ipynb.\n", "Converted 10b_tutorial.albumentations.ipynb.\n", "Converted 11_vision.models.xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_callback.core.ipynb.\n", "Converted 13a_learner.ipynb.\n", "Converted 13b_metrics.ipynb.\n", "Converted 14_callback.schedule.ipynb.\n", "Converted 14a_callback.data.ipynb.\n", "Converted 15_callback.hook.ipynb.\n", "Converted 15a_vision.models.unet.ipynb.\n", "Converted 16_callback.progress.ipynb.\n", "Converted 17_callback.tracker.ipynb.\n", "Converted 18_callback.fp16.ipynb.\n", "Converted 18a_callback.training.ipynb.\n", "Converted 18b_callback.preds.ipynb.\n", "Converted 19_callback.mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision.learner.ipynb.\n", "Converted 22_tutorial.imagenette.ipynb.\n", "Converted 23_tutorial.vision.ipynb.\n", "Converted 24_tutorial.siamese.ipynb.\n", "Converted 24_vision.gan.ipynb.\n", "Converted 30_text.core.ipynb.\n", "Converted 31_text.data.ipynb.\n", "Converted 32_text.models.awdlstm.ipynb.\n", "Converted 33_text.models.core.ipynb.\n", "Converted 34_callback.rnn.ipynb.\n", "Converted 35_tutorial.wikitext.ipynb.\n", "Converted 36_text.models.qrnn.ipynb.\n", "Converted 37_text.learner.ipynb.\n", "Converted 38_tutorial.text.ipynb.\n", "Converted 39_tutorial.transformers.ipynb.\n", "Converted 40_tabular.core.ipynb.\n", "Converted 41_tabular.data.ipynb.\n", "Converted 42_tabular.model.ipynb.\n", "Converted 43_tabular.learner.ipynb.\n", "Converted 44_tutorial.tabular.ipynb.\n", "Converted 45_collab.ipynb.\n", "Converted 46_tutorial.collab.ipynb.\n", "Converted 50_tutorial.datablock.ipynb.\n", "Converted 60_medical.imaging.ipynb.\n", "Converted 61_tutorial.medical_imaging.ipynb.\n", "Converted 65_medical.text.ipynb.\n", "Converted 70_callback.wandb.ipynb.\n", "Converted 71_callback.tensorboard.ipynb.\n", "Converted 72_callback.neptune.ipynb.\n", "Converted 73_callback.captum.ipynb.\n", "Converted 97_test_utils.ipynb.\n", "Converted 99_pytorch_doc.ipynb.\n", "Converted dev-setup.ipynb.\n", "Converted index.ipynb.\n", "Converted quick_start.ipynb.\n", "Converted tutorial.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }