{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"#skip\n",
"! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#default_exp callback.training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from fastai.basics import *\n",
"from fastai.callback.progress import *\n",
"from fastai.callback.fp16 import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from nbdev.showdoc import *\n",
"from fastai.test_utils import *\n",
"from fastai.vision.all import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training callbacks\n",
"\n",
"> Various callbacks to customize training behavior"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ShortEpochCallback -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class ShortEpochCallback(Callback):\n",
" \"Fit just `pct` of an epoch, then stop\"\n",
" def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid\n",
" def after_batch(self):\n",
" if self.iter/self.n_iter < self.pct: return\n",
" if self.training: raise CancelTrainException\n",
" if self.short_valid: raise CancelValidException"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = synth_learner()\n",
"learn.fit(1, cbs=ShortEpochCallback())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 14.867975 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = synth_learner()\n",
"learn.fit(1, cbs=ShortEpochCallback(short_valid=False))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## GradientAccumulation -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"class GradientAccumulation(Callback):\n",
" \"Accumulate gradients before updating weights\"\n",
" order,run_valid = MixedPrecision.order-4,False\n",
" def __init__(self, n_acc=32): store_attr()\n",
" def before_fit(self): self.count=0\n",
" def after_loss(self): self.learn.loss_grad /= self.n_acc/find_bs(self.learn.yb)\n",
" def before_step(self):\n",
" \"Skip weight update if we have not seen enough items\"\n",
" self.learn.loss_grad *= self.n_acc/find_bs(self.learn.yb) # log correct loss\n",
" self.count += find_bs(self.learn.yb)\n",
" if self.count\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.834062 | \n",
" 0.295950 | \n",
" 00:00 | \n",
"
\n",
" \n",
""
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.824550 | \n",
" 0.295950 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#hide\n",
"class GetGrads(Callback):\n",
" run_valid,order = False,GradientAccumulation.order+1\n",
" def before_step(self): self.grads=to_detach(L([p.grad.clone() for p in self.model.parameters()]))\n",
"\n",
"def _test_acc(bs,n,cbs=None,cuda=False):\n",
" with no_random(99): \n",
" db=synth_dbunch(bs=bs,n_train=n,n_valid=n,cuda=cuda)\n",
" learn = synth_learner(data=db,cbs=[GetGrads]+L(cbs))\n",
" learn.fit(1, lr=0.01)\n",
" train,valid = learn.recorder.values[-1]\n",
" return train,valid,learn.get_grads.grads\n",
"\n",
"acc_cb = GradientAccumulation(n_acc=8)\n",
"\n",
"train1,valid1,grads1 = _test_acc(8,1)\n",
"train2,valid2,grads2 = _test_acc(1,8,acc_cb)\n",
"\n",
"#grads should be same, valid loss same, train loss different\n",
"test_close(grads2,grads1)\n",
"test_close(valid2, valid1)\n",
"test_ne(train2, train1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.834062 | \n",
" 0.295950 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.824550 | \n",
" 0.295950 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#hide\n",
"#cuda\n",
"fp16_cb = MixedPrecision(init_scale=1024)\n",
"train1,valid1,grads1 = _test_acc(8,1, fp16_cb, cuda=True)\n",
"train2,valid2,grads2 = _test_acc(1,8, [acc_cb,fp16_cb], cuda=True)\n",
"test_close(grads2,grads1, eps=0.01)\n",
"test_close(valid2, valid1)\n",
"test_ne(train2, train1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When the number of steps per accumulation is higher than the number of batches, the parameters (and therefore validation loss) don't change at all:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 10.941168 | \n",
" 10.280428 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = synth_learner()\n",
"learn.fit(1, lr=0.01, cbs=GradientAccumulation(n_acc=1000))\n",
"# ensure valid_loss didn't change\n",
"assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## GradientClip -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"class GradientClip(Callback):\n",
" \"Clip norm of gradients\"\n",
" order=MixedPrecision.order+1\n",
" def __init__(self,max_norm:float=1., norm_type:float=2.0): store_attr()\n",
" def before_step(self): nn.utils.clip_grad_norm_(self.parameters(), self.max_norm, self.norm_type)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Normally if we use a learning rate that is too high, our training will diverge. This even happens if we use mixed precision training, which avoid infinities by using dynamic loss scaling, but still diverges:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fp16 = MixedPrecision()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 38.214169 | \n",
" 25.269012 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 1 | \n",
" 377.146088 | \n",
" 890.011780 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 2 | \n",
" 839.391907 | \n",
" 9965.712891 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"set_seed(99)\n",
"learn = synth_learner(lr=1.1, cuda=True)\n",
"learn.fit(3, cbs=fp16)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By adding the `GradientClip` callback, the gradient `norm_type` (default:2) norm is clipped to at most `max_norm` (default:1) using `nn.utils.clip_grad_norm_`, which can avoid loss divergence:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 2.039427 | \n",
" 2.372183 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 1 | \n",
" 1.402424 | \n",
" 0.300724 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 2 | \n",
" 1.013551 | \n",
" 0.332668 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"set_seed(99)\n",
"learn = synth_learner(lr=1.1, cuda=True)\n",
"learn.fit(3, cbs=[GradientClip,fp16])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## BnFreeze"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)\n",
"\n",
"def set_bn_eval(m:nn.Module, use_eval=True)->None:\n",
" \"Set bn layers in eval mode for all recursive children of `m`.\"\n",
" for l in m.children():\n",
" if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:\n",
" if use_eval: l.eval()\n",
" else: l.train()\n",
" set_bn_eval(l)\n",
"\n",
"class BnFreeze(Callback):\n",
" run_after=TrainEvalCallback\n",
" \"Freeze moving average statistics in all non-trainable batchnorm layers.\"\n",
" def before_train(self):\n",
" set_bn_eval(self.model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`BnFreeze` is useful when you'd like to train two separate models that have a common feature extractor / body. The only part of the model that's different is the head that you attach for transfer learning.
\n",
"\n",
"`Learner.freeze()` doesn't suffice here as the `BatchNorm` layers are trainable by default, and running mean and std of batches are tracked. For feature extractors to fully match, you need to set `train_bn=False` and these stats need to be frozen as well, which is precisely the function of `BnFreeze`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#slow\n",
"path = untar_data(URLs.MNIST_TINY)\n",
"dls = ImageDataLoaders.from_folder(path, valid_pct=0.2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We first demonstrate the mismatch of the running stats when using only `train_bn=False`, by creating a `Learner`...:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#slow\n",
"learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"...and grab the first `BatchNorm` layer, and store its running mean: "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#slow\n",
"m = learn1.model[0][1].running_mean.clone()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can see that now that running mean has changed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 1.152701 | \n",
" 0.468892 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#slow\n",
"learn1.fit(1, lr=0.02)\n",
"test_ne(to_detach(learn1.model[0][1].running_mean), m)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we use the `BnFreeze` callback, the running statistics will not be changed during training. This is often important for getting good results from transfer learning."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 0.488634 | \n",
" 0.277683 | \n",
" 00:02 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#slow\n",
"learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)\n",
"m = learn1.model[0][1].running_mean.detach().clone()\n",
"learn1.fit(1, lr=0.02)\n",
"test_eq(to_detach(learn1.model[0][1].running_mean), m)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Export -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converted 00_torch_core.ipynb.\n",
"Converted 01_layers.ipynb.\n",
"Converted 01a_losses.ipynb.\n",
"Converted 02_data.load.ipynb.\n",
"Converted 03_data.core.ipynb.\n",
"Converted 04_data.external.ipynb.\n",
"Converted 05_data.transforms.ipynb.\n",
"Converted 06_data.block.ipynb.\n",
"Converted 07_vision.core.ipynb.\n",
"Converted 08_vision.data.ipynb.\n",
"Converted 09_vision.augment.ipynb.\n",
"Converted 09b_vision.utils.ipynb.\n",
"Converted 09c_vision.widgets.ipynb.\n",
"Converted 10_tutorial.pets.ipynb.\n",
"Converted 10b_tutorial.albumentations.ipynb.\n",
"Converted 11_vision.models.xresnet.ipynb.\n",
"Converted 12_optimizer.ipynb.\n",
"Converted 13_callback.core.ipynb.\n",
"Converted 13a_learner.ipynb.\n",
"Converted 13b_metrics.ipynb.\n",
"Converted 14_callback.schedule.ipynb.\n",
"Converted 14a_callback.data.ipynb.\n",
"Converted 15_callback.hook.ipynb.\n",
"Converted 15a_vision.models.unet.ipynb.\n",
"Converted 16_callback.progress.ipynb.\n",
"Converted 17_callback.tracker.ipynb.\n",
"Converted 18_callback.fp16.ipynb.\n",
"Converted 18a_callback.training.ipynb.\n",
"Converted 18b_callback.preds.ipynb.\n",
"Converted 19_callback.mixup.ipynb.\n",
"Converted 20_interpret.ipynb.\n",
"Converted 20a_distributed.ipynb.\n",
"Converted 21_vision.learner.ipynb.\n",
"Converted 22_tutorial.imagenette.ipynb.\n",
"Converted 23_tutorial.vision.ipynb.\n",
"Converted 24_tutorial.siamese.ipynb.\n",
"Converted 24_vision.gan.ipynb.\n",
"Converted 30_text.core.ipynb.\n",
"Converted 31_text.data.ipynb.\n",
"Converted 32_text.models.awdlstm.ipynb.\n",
"Converted 33_text.models.core.ipynb.\n",
"Converted 34_callback.rnn.ipynb.\n",
"Converted 35_tutorial.wikitext.ipynb.\n",
"Converted 36_text.models.qrnn.ipynb.\n",
"Converted 37_text.learner.ipynb.\n",
"Converted 38_tutorial.text.ipynb.\n",
"Converted 39_tutorial.transformers.ipynb.\n",
"Converted 40_tabular.core.ipynb.\n",
"Converted 41_tabular.data.ipynb.\n",
"Converted 42_tabular.model.ipynb.\n",
"Converted 43_tabular.learner.ipynb.\n",
"Converted 44_tutorial.tabular.ipynb.\n",
"Converted 45_collab.ipynb.\n",
"Converted 46_tutorial.collab.ipynb.\n",
"Converted 50_tutorial.datablock.ipynb.\n",
"Converted 60_medical.imaging.ipynb.\n",
"Converted 61_tutorial.medical_imaging.ipynb.\n",
"Converted 65_medical.text.ipynb.\n",
"Converted 70_callback.wandb.ipynb.\n",
"Converted 71_callback.tensorboard.ipynb.\n",
"Converted 72_callback.neptune.ipynb.\n",
"Converted 73_callback.captum.ipynb.\n",
"Converted 97_test_utils.ipynb.\n",
"Converted 99_pytorch_doc.ipynb.\n",
"Converted dev-setup.ipynb.\n",
"Converted index.ipynb.\n",
"Converted quick_start.ipynb.\n",
"Converted tutorial.ipynb.\n"
]
}
],
"source": [
"#hide\n",
"from nbdev.export import notebook2script\n",
"notebook2script()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"jupytext": {
"split_at_heading": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}