{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#default_exp callback.training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from fastai2.basics import *\n",
"from fastai2.callback.progress import *\n",
"from fastai2.callback.fp16 import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from nbdev.showdoc import *\n",
"from fastai2.test_utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tracking callbacks\n",
"\n",
"> Callbacks that make decisions depending how a monitored metric/loss behaves"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ShortEpochCallback -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"@log_args\n",
"class ShortEpochCallback(Callback):\n",
" \"Fit just `pct` of an epoch, then stop\"\n",
" def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid\n",
" def after_batch(self):\n",
" if self.iter/self.n_iter < self.pct: return\n",
" if self.training: raise CancelTrainException\n",
" if self.short_valid: raise CancelValidException"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = synth_learner()\n",
"learn.fit(1, cbs=ShortEpochCallback())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 12.395771 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = synth_learner()\n",
"learn.fit(1, cbs=ShortEpochCallback(short_valid=False))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## GradientAccumulation -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"@log_args\n",
"class GradientAccumulation(Callback):\n",
" \"Accumulate gradients before updating weights\"\n",
" toward_end,run_before=True,MixedPrecision\n",
"\n",
" def __init__(self, n_acc=32): store_attr(self, 'n_acc')\n",
" def before_fit(self): self.count=0\n",
"\n",
" def after_backward(self):\n",
" self.count += find_bs(self.learn.yb)\n",
" if self.count < self.n_acc: raise CancelBatchException() #skip weight update\n",
" else: self.count=0\n",
"\n",
" _docs = dict(before_fit=\"Set counter to 0\",\n",
" after_backward=\"Skip weight update if we have not seen enough items\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 10.566907 | \n",
" 3.633753 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 1 | \n",
" 5.525984 | \n",
" 0.397483 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.476599 | \n",
" 0.397483 | \n",
" 00:00 | \n",
"
\n",
" \n",
" | 1 | \n",
" 0.478213 | \n",
" 0.397483 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = synth_learner()\n",
"\n",
"learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=2*learn.dls.bs))\n",
"# ensure train_loss decreased\n",
"assert learn.recorder.values[-1][0] < learn.recorder.values[0][0]\n",
"\n",
"learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=1e6))\n",
"# ensure valid_loss didn't change (same weights)\n",
"assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## BnFreeze"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)\n",
"\n",
"def set_bn_eval(m:nn.Module, use_eval=True)->None:\n",
" \"Set bn layers in eval mode for all recursive children of `m`.\"\n",
" for l in m.children():\n",
" if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:\n",
" if use_eval: l.eval()\n",
" else: l.train()\n",
" set_bn_eval(l)\n",
"\n",
"class BnFreeze(Callback):\n",
" \"Freeze moving average statistics in all non-trainable batchnorm layers.\"\n",
" def before_epoch(self):\n",
" set_bn_eval(self.model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`BnFreeze` is useful when you'd like to train two separate models that have a common feature extractor / body. The only part of the model that's different is the head that you attach for transfer learning.
\n",
"\n",
"`Learner.freeze()` doesn't suffice here as the `BatchNorm` layers are trainable by default, and running mean and sdev of batches are tracked. For feature extractors to fully match, you need to set `train_bn=False` and these stats need to be frozen as well, which is precisely the function of `BnFreeze`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#slow\n",
"from fastai2.vision.all import *\n",
"\n",
"path = untar_data(URLs.MNIST_TINY)\n",
"dls = ImageDataLoaders.from_folder(path, valid_pct=0.2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We first demonstrate the mismatch of the running stats when using only `train_bn=False`, by creating a `Learner`...:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#slow\n",
"learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...and grab the first `BatchNorm` layer, and store its running mean: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#slow\n",
"m = learn1.model[0][1].running_mean.clone()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see that now that running mean has changed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 1.058304 | \n",
" 0.713414 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#slow\n",
"learn1.fit(1, lr=0.02)\n",
"test_ne(learn1.model[0][1].running_mean, m)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we use the `BnFreeze` callback, the running statistics will not be changed during training. This is often important for getting good results from transfer learning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" | epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" | 0 | \n",
" 0.540841 | \n",
" 0.432421 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#slow\n",
"learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)\n",
"m = learn1.model[0][1].running_mean.clone()\n",
"learn1.fit(1, lr=0.02)\n",
"test_eq(learn1.model[0][1].running_mean, m)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Export -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converted 00_torch_core.ipynb.\n",
"Converted 01_layers.ipynb.\n",
"Converted 02_data.load.ipynb.\n",
"Converted 03_data.core.ipynb.\n",
"Converted 04_data.external.ipynb.\n",
"Converted 05_data.transforms.ipynb.\n",
"Converted 06_data.block.ipynb.\n",
"Converted 07_vision.core.ipynb.\n",
"Converted 08_vision.data.ipynb.\n",
"Converted 09_vision.augment.ipynb.\n",
"Converted 09b_vision.utils.ipynb.\n",
"Converted 09c_vision.widgets.ipynb.\n",
"Converted 10_tutorial.pets.ipynb.\n",
"Converted 11_vision.models.xresnet.ipynb.\n",
"Converted 12_optimizer.ipynb.\n",
"Converted 13_callback.core.ipynb.\n",
"Converted 13a_learner.ipynb.\n",
"Converted 13b_metrics.ipynb.\n",
"Converted 14_callback.schedule.ipynb.\n",
"Converted 14a_callback.data.ipynb.\n",
"Converted 15_callback.hook.ipynb.\n",
"Converted 15a_vision.models.unet.ipynb.\n",
"Converted 16_callback.progress.ipynb.\n",
"Converted 17_callback.tracker.ipynb.\n",
"Converted 18_callback.fp16.ipynb.\n",
"Converted 18a_callback.training.ipynb.\n",
"Converted 19_callback.mixup.ipynb.\n",
"Converted 20_interpret.ipynb.\n",
"Converted 20a_distributed.ipynb.\n",
"Converted 21_vision.learner.ipynb.\n",
"Converted 22_tutorial.imagenette.ipynb.\n",
"Converted 23_tutorial.vision.ipynb.\n",
"Converted 24_tutorial.siamese.ipynb.\n",
"Converted 24_vision.gan.ipynb.\n",
"Converted 30_text.core.ipynb.\n",
"Converted 31_text.data.ipynb.\n",
"Converted 32_text.models.awdlstm.ipynb.\n",
"Converted 33_text.models.core.ipynb.\n",
"Converted 34_callback.rnn.ipynb.\n",
"Converted 35_tutorial.wikitext.ipynb.\n",
"Converted 36_text.models.qrnn.ipynb.\n",
"Converted 37_text.learner.ipynb.\n",
"Converted 38_tutorial.text.ipynb.\n",
"Converted 39_tutorial.transformers.ipynb.\n",
"Converted 40_tabular.core.ipynb.\n",
"Converted 41_tabular.data.ipynb.\n",
"Converted 42_tabular.model.ipynb.\n",
"Converted 43_tabular.learner.ipynb.\n",
"Converted 44_tutorial.tabular.ipynb.\n",
"Converted 45_collab.ipynb.\n",
"Converted 46_tutorial.collab.ipynb.\n",
"Converted 50_tutorial.datablock.ipynb.\n",
"Converted 60_medical.imaging.ipynb.\n",
"Converted 61_tutorial.medical_imaging.ipynb.\n",
"Converted 65_medical.text.ipynb.\n",
"Converted 70_callback.wandb.ipynb.\n",
"Converted 71_callback.tensorboard.ipynb.\n",
"Converted 72_callback.neptune.ipynb.\n",
"Converted 73_callback.captum.ipynb.\n",
"Converted 74_callback.cutmix.ipynb.\n",
"Converted 97_test_utils.ipynb.\n",
"Converted 99_pytorch_doc.ipynb.\n",
"Converted index.ipynb.\n",
"Converted tutorial.ipynb.\n"
]
}
],
"source": [
"#hide\n",
"from nbdev.export import notebook2script\n",
"notebook2script()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}