{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#default_exp callback.tracker"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from local.test import *\n",
"from local.basics import *\n",
"from local.callback.progress import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from local.notebook.showdoc import *\n",
"from local.test_utils import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tracking callbacks\n",
"\n",
"> Callbacks that make decisions depending how a monitored metric/loss behaves"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ShortEpochCallback -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class ShortEpochCallback(Callback):\n",
" \"Fit just `pct` of an epoch, then stop\"\n",
" def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid\n",
" def after_batch(self):\n",
" if self.iter/self.n_iter < self.pct: return\n",
" if self.training: raise CancelTrainException\n",
" if self.short_valid: raise CancelValidException"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = synth_learner()\n",
"learn.fit(1, cbs=ShortEpochCallback())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 4.103489 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = synth_learner()\n",
"learn.fit(1, cbs=ShortEpochCallback(short_valid=False))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## TerminateOnNaNCallback -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"class TerminateOnNaNCallback(Callback):\n",
" \"A `Callback` that terminates training if loss is NaN.\"\n",
" run_before=Recorder\n",
"\n",
" def after_batch(self):\n",
" \"Test if `last_loss` is NaN and interrupts training.\"\n",
" if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 2862784155043519244231962471774027776.000000 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = synth_learner()\n",
"learn.fit(10, lr=100, cbs=TerminateOnNaNCallback())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert len(learn.recorder.losses) < 10 * len(learn.dbunch.train_dl)\n",
"for l in learn.recorder.losses:\n",
" assert not torch.isinf(l) and not torch.isnan(l) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## TrackerCallback -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"class TrackerCallback(Callback):\n",
" \"A `Callback` that keeps track of the best value in `monitor`.\"\n",
" run_after=Recorder\n",
"\n",
" def __init__(self, monitor='valid_loss', comp=None, min_delta=0.):\n",
" if comp is None: comp = np.less if 'loss' in monitor else np.greater\n",
" if comp == np.less: min_delta *= -1\n",
" self.monitor,self.comp,self.min_delta = monitor,comp,min_delta\n",
"\n",
" def begin_fit(self):\n",
" \"Prepare the monitored value\"\n",
" self.run = not hasattr(self, \"lr_finder\") and not hasattr(self, \"gather_preds\")\n",
" self.best = float('inf') if self.comp == np.less else -float('inf')\n",
" assert self.monitor in self.recorder.metric_names[1:]\n",
" self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)\n",
"\n",
" def after_epoch(self):\n",
" \"Compare the last value to the best up to know\"\n",
" val = self.recorder.values[-1][self.idx]\n",
" if self.comp(val - self.min_delta, self.best): self.best,self.new_best = val,True\n",
" else: self.new_best = False\n",
" \n",
" def after_fit(self): self.run=True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When implementing a `Callback` that has behavior that depends on the best value of a metric or loss, subclass this `Callback` and use its `best` (for best value so far) and `new_best` (there was a new best value this epoch) attributes. \n",
"\n",
"`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"class FakeRecords(Callback):\n",
" run_after=Recorder\n",
" run_before=TrackerCallback\n",
" \n",
" def __init__(self, monitor, values): self.monitor,self.values = monitor,values\n",
" \n",
" def begin_fit(self): self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)\n",
" def after_epoch(self): self.recorder.values[-1][self.idx] = self.values[self.epoch]\n",
" \n",
"class TestTracker(Callback):\n",
" run_after=TrackerCallback\n",
" def begin_fit(self): self.bests,self.news = [],[]\n",
" def after_epoch(self): \n",
" self.bests.append(self.tracker.best)\n",
" self.news.append(self.tracker.new_best)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#hide\n",
"learn = synth_learner(n_trn=2, cbs=TestTracker())\n",
"cbs=[TrackerCallback(monitor='valid_loss'), FakeRecords('valid_loss', [0.2,0.1])]\n",
"with learn.no_logging(): learn.fit(2, cbs=cbs)\n",
"test_eq(learn.test_tracker.bests, [0.2, 0.1])\n",
"test_eq(learn.test_tracker.news, [True,True])\n",
"\n",
"#With a min_delta\n",
"cbs=[TrackerCallback(monitor='valid_loss', min_delta=0.15), FakeRecords('valid_loss', [0.2,0.1])]\n",
"with learn.no_logging(): learn.fit(2, cbs=cbs)\n",
"test_eq(learn.test_tracker.bests, [0.2, 0.2])\n",
"test_eq(learn.test_tracker.news, [True,False])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#hide\n",
"#By default metrics have to be bigger at each epoch.\n",
"def tst_metric(out,targ): return F.mse_loss(out,targ)\n",
"learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)\n",
"cbs=[TrackerCallback(monitor='tst_metric'), FakeRecords('tst_metric', [0.2,0.1])]\n",
"with learn.no_logging(): learn.fit(2, cbs=cbs)\n",
"test_eq(learn.test_tracker.bests, [0.2, 0.2])\n",
"test_eq(learn.test_tracker.news, [True,False])\n",
"\n",
"#This can be overwritten by passing `comp=np.less`.\n",
"learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)\n",
"cbs=[TrackerCallback(monitor='tst_metric', comp=np.less), FakeRecords('tst_metric', [0.2,0.1])]\n",
"with learn.no_logging(): learn.fit(2, cbs=cbs)\n",
"test_eq(learn.test_tracker.bests, [0.2, 0.1])\n",
"test_eq(learn.test_tracker.news, [True,True])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#hide\n",
"#A tracker callback is not run during an lr_find\n",
"from local.callback.schedule import *\n",
"learn = synth_learner(n_trn=2, cbs=TrackerCallback(monitor='tst_metric'), metrics=tst_metric)\n",
"learn.lr_find(num_it=5, show_plot=False)\n",
"assert not hasattr(learn, 'new_best')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## EarlyStoppingCallback -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"class EarlyStoppingCallback(TrackerCallback):\n",
" \"A `TrackerCallback` that terminates training when monitored quantity stops improving.\"\n",
" def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1):\n",
" super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)\n",
" self.patience = patience\n",
"\n",
" def begin_fit(self): self.wait = 0; super().begin_fit()\n",
" def after_epoch(self):\n",
" \"Compare the value monitored to its best score and maybe stop training.\"\n",
" super().after_epoch()\n",
" if self.new_best: self.wait = 0\n",
" else:\n",
" self.wait += 1\n",
" if self.wait >= self.patience:\n",
" print(f'No improvement since epoch {self.epoch-self.wait}: early stopping')\n",
" raise CancelFitException()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount. `patience` is the number of epochs you're willing to wait without improvement."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 8.969569 | \n",
" 7.550271 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 1 | \n",
" 8.964047 | \n",
" 7.550250 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 2 | \n",
" 8.963997 | \n",
" 7.550220 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"No improvement since epoch 0: early stopping\n"
]
}
],
"source": [
"learn = synth_learner(n_trn=2)\n",
"learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"test_eq(len(learn.recorder.values), 3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SaveModelCallback -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"class SaveModelCallback(TrackerCallback):\n",
" \"A `TrackerCallback` that saves the model's best during training and loads it at the end.\"\n",
" def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False, add_save=None, with_opt=False):\n",
" super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)\n",
" store_attr(self, 'fname,every_epoch,add_save,with_opt')\n",
"\n",
" def _save(self, name):\n",
" self.learn.save(name, with_opt=self.with_opt)\n",
" if self.add_save is not None:\n",
" with self.add_save.open('wb') as f: self.learn.save(f, with_opt=self.with_opt)\n",
"\n",
" def after_epoch(self):\n",
" \"Compare the value monitored to its best score and save if best.\"\n",
" if self.every_epoch: self._save(f'{self.fname}_{self.epoch}')\n",
" else: #every improvement\n",
" super().after_epoch()\n",
" if self.new_best: self._save(f'{self.fname}')\n",
"\n",
" def on_train_end(self, **kwargs):\n",
" \"Load the best model.\"\n",
" if not self.every_epoch: self.learn.load(f'{self.fname}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount. Model will be saved in `learn.path/learn.model_dir/name.pth`, maybe `every_epoch` or at each improvement of the monitored quantity. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 12.352231 | \n",
" 9.437100 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 1 | \n",
" 12.247625 | \n",
" 9.188762 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 11.802291 | \n",
" 8.847794 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 1 | \n",
" 11.562449 | \n",
" 8.428464 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')\n",
"learn.fit(n_epoch=2, cbs=SaveModelCallback())\n",
"assert (Path.cwd()/'tmp/models/model.pth').exists()\n",
"learn.fit(n_epoch=2, cbs=SaveModelCallback(every_epoch=True))\n",
"for i in range(2): assert (Path.cwd()/f'tmp/models/model_{i}.pth').exists()\n",
"shutil.rmtree(Path.cwd()/'tmp')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ReduceLROnPlateau"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"class ReduceLROnPlateau(TrackerCallback):\n",
" \"A `TrackerCallback` that reduces learning rate when a metric has stopped improving.\"\n",
" def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, factor=10.):\n",
" super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)\n",
" self.patience,self.factor = patience,factor\n",
"\n",
" def begin_fit(self): self.wait = 0; super().begin_fit()\n",
" def after_epoch(self):\n",
" \"Compare the value monitored to its best score and reduce LR by `factor` if no improvement.\"\n",
" super().after_epoch()\n",
" if self.new_best: self.wait = 0\n",
" else:\n",
" self.wait += 1\n",
" if self.wait >= self.patience:\n",
" for h in self.opt.hypers: h['lr'] /= self.factor\n",
" self.wait = 0\n",
" print(f'Epoch {self.epoch}: reducing lr to {self.opt.hypers[-1][\"lr\"]}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" \n",
" epoch | \n",
" train_loss | \n",
" valid_loss | \n",
" time | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 13.372304 | \n",
" 11.866179 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 1 | \n",
" 13.377043 | \n",
" 11.866154 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 2 | \n",
" 13.395395 | \n",
" 11.866118 | \n",
" 00:00 | \n",
"
\n",
" \n",
" 3 | \n",
" 13.400988 | \n",
" 11.866114 | \n",
" 00:00 | \n",
"
\n",
" \n",
"
"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2: reducing lr to 1e-08\n"
]
}
],
"source": [
"learn = synth_learner(n_trn=2)\n",
"learn.fit(n_epoch=4, lr=1e-7, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"test_eq(learn.opt.hypers[-1]['lr'], 1e-8)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Export -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Converted 00_test.ipynb.\n",
"Converted 01_core_foundation.ipynb.\n",
"Converted 01a_core_utils.ipynb.\n",
"Converted 01b_core_dispatch.ipynb.\n",
"Converted 01c_core_transform.ipynb.\n",
"Converted 02_core_script.ipynb.\n",
"Converted 03_torchcore.ipynb.\n",
"Converted 03a_layers.ipynb.\n",
"Converted 04_data_load.ipynb.\n",
"Converted 05_data_core.ipynb.\n",
"Converted 06_data_transforms.ipynb.\n",
"Converted 07_data_block.ipynb.\n",
"Converted 08_vision_core.ipynb.\n",
"Converted 09_vision_augment.ipynb.\n",
"Converted 09a_vision_data.ipynb.\n",
"Converted 09b_vision_utils.ipynb.\n",
"Converted 10_pets_tutorial.ipynb.\n",
"Converted 11_vision_models_xresnet.ipynb.\n",
"Converted 12_optimizer.ipynb.\n",
"Converted 13_learner.ipynb.\n",
"Converted 13a_metrics.ipynb.\n",
"Converted 14_callback_schedule.ipynb.\n",
"Converted 14a_callback_data.ipynb.\n",
"Converted 15_callback_hook.ipynb.\n",
"Converted 15a_vision_models_unet.ipynb.\n",
"Converted 16_callback_progress.ipynb.\n",
"Converted 17_callback_tracker.ipynb.\n",
"Converted 18_callback_fp16.ipynb.\n",
"Converted 19_callback_mixup.ipynb.\n",
"Converted 20_interpret.ipynb.\n",
"Converted 20a_distributed.ipynb.\n",
"Converted 21_vision_learner.ipynb.\n",
"Converted 22_tutorial_imagenette.ipynb.\n",
"Converted 23_tutorial_transfer_learning.ipynb.\n",
"Converted 30_text_core.ipynb.\n",
"Converted 31_text_data.ipynb.\n",
"Converted 32_text_models_awdlstm.ipynb.\n",
"Converted 33_text_models_core.ipynb.\n",
"Converted 34_callback_rnn.ipynb.\n",
"Converted 35_tutorial_wikitext.ipynb.\n",
"Converted 36_text_models_qrnn.ipynb.\n",
"Converted 37_text_learner.ipynb.\n",
"Converted 38_tutorial_ulmfit.ipynb.\n",
"Converted 40_tabular_core.ipynb.\n",
"Converted 41_tabular_model.ipynb.\n",
"Converted 42_tabular_rapids.ipynb.\n",
"Converted 50_data_block_examples.ipynb.\n",
"Converted 60_medical_imaging.ipynb.\n",
"Converted 65_medical_text.ipynb.\n",
"Converted 70_callback_wandb.ipynb.\n",
"Converted 71_callback_tensorboard.ipynb.\n",
"Converted 90_notebook_core.ipynb.\n",
"Converted 91_notebook_export.ipynb.\n",
"Converted 92_notebook_showdoc.ipynb.\n",
"Converted 93_notebook_export2html.ipynb.\n",
"Converted 94_notebook_test.ipynb.\n",
"Converted 95_index.ipynb.\n",
"Converted 96_data_external.ipynb.\n",
"Converted 97_utils_test.ipynb.\n",
"Converted notebook2jekyll.ipynb.\n",
"Converted xse_resnext.ipynb.\n"
]
}
],
"source": [
"#hide\n",
"from local.notebook.export import notebook2script\n",
"notebook2script(all_fs=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}