{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp callback.tracker" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.test import *\n", "from local.basics import *\n", "from local.callback.progress import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.notebook.showdoc import *\n", "from local.test_utils import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tracking callbacks\n", "\n", "> Callbacks that make decisions depending how a monitored metric/loss behaves" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ShortEpochCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class ShortEpochCallback(Callback):\n", " \"Fit just `pct` of an epoch, then stop\"\n", " def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid\n", " def after_batch(self):\n", " if self.iter/self.n_iter < self.pct: return\n", " if self.training: raise CancelTrainException\n", " if self.short_valid: raise CancelValidException" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
000:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = synth_learner()\n", "learn.fit(1, cbs=ShortEpochCallback())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
04.10348900:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = synth_learner()\n", "learn.fit(1, cbs=ShortEpochCallback(short_valid=False))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TerminateOnNaNCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class TerminateOnNaNCallback(Callback):\n", " \"A `Callback` that terminates training if loss is NaN.\"\n", " run_before=Recorder\n", "\n", " def after_batch(self):\n", " \"Test if `last_loss` is NaN and interrupts training.\"\n", " if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
02862784155043519244231962471774027776.00000000:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = synth_learner()\n", "learn.fit(10, lr=100, cbs=TerminateOnNaNCallback())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert len(learn.recorder.losses) < 10 * len(learn.dbunch.train_dl)\n", "for l in learn.recorder.losses:\n", " assert not torch.isinf(l) and not torch.isnan(l) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TrackerCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class TrackerCallback(Callback):\n", " \"A `Callback` that keeps track of the best value in `monitor`.\"\n", " run_after=Recorder\n", "\n", " def __init__(self, monitor='valid_loss', comp=None, min_delta=0.):\n", " if comp is None: comp = np.less if 'loss' in monitor else np.greater\n", " if comp == np.less: min_delta *= -1\n", " self.monitor,self.comp,self.min_delta = monitor,comp,min_delta\n", "\n", " def begin_fit(self):\n", " \"Prepare the monitored value\"\n", " self.run = not hasattr(self, \"lr_finder\") and not hasattr(self, \"gather_preds\")\n", " self.best = float('inf') if self.comp == np.less else -float('inf')\n", " assert self.monitor in self.recorder.metric_names[1:]\n", " self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)\n", "\n", " def after_epoch(self):\n", " \"Compare the last value to the best up to know\"\n", " val = self.recorder.values[-1][self.idx]\n", " if self.comp(val - self.min_delta, self.best): self.best,self.new_best = val,True\n", " else: self.new_best = False\n", " \n", " def after_fit(self): self.run=True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When implementing a `Callback` that has behavior that depends on the best value of a metric or loss, subclass this `Callback` and use its `best` (for best value so far) and `new_best` (there was a new best value this epoch) attributes. \n", "\n", "`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "class FakeRecords(Callback):\n", " run_after=Recorder\n", " run_before=TrackerCallback\n", " \n", " def __init__(self, monitor, values): self.monitor,self.values = monitor,values\n", " \n", " def begin_fit(self): self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)\n", " def after_epoch(self): self.recorder.values[-1][self.idx] = self.values[self.epoch]\n", " \n", "class TestTracker(Callback):\n", " run_after=TrackerCallback\n", " def begin_fit(self): self.bests,self.news = [],[]\n", " def after_epoch(self): \n", " self.bests.append(self.tracker.best)\n", " self.news.append(self.tracker.new_best)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "learn = synth_learner(n_trn=2, cbs=TestTracker())\n", "cbs=[TrackerCallback(monitor='valid_loss'), FakeRecords('valid_loss', [0.2,0.1])]\n", "with learn.no_logging(): learn.fit(2, cbs=cbs)\n", "test_eq(learn.test_tracker.bests, [0.2, 0.1])\n", "test_eq(learn.test_tracker.news, [True,True])\n", "\n", "#With a min_delta\n", "cbs=[TrackerCallback(monitor='valid_loss', min_delta=0.15), FakeRecords('valid_loss', [0.2,0.1])]\n", "with learn.no_logging(): learn.fit(2, cbs=cbs)\n", "test_eq(learn.test_tracker.bests, [0.2, 0.2])\n", "test_eq(learn.test_tracker.news, [True,False])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "#By default metrics have to be bigger at each epoch.\n", "def tst_metric(out,targ): return F.mse_loss(out,targ)\n", "learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)\n", "cbs=[TrackerCallback(monitor='tst_metric'), FakeRecords('tst_metric', [0.2,0.1])]\n", "with learn.no_logging(): learn.fit(2, cbs=cbs)\n", "test_eq(learn.test_tracker.bests, [0.2, 0.2])\n", "test_eq(learn.test_tracker.news, [True,False])\n", "\n", "#This can be overwritten by passing `comp=np.less`.\n", "learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)\n", "cbs=[TrackerCallback(monitor='tst_metric', comp=np.less), FakeRecords('tst_metric', [0.2,0.1])]\n", "with learn.no_logging(): learn.fit(2, cbs=cbs)\n", "test_eq(learn.test_tracker.bests, [0.2, 0.1])\n", "test_eq(learn.test_tracker.news, [True,True])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "#hide\n", "#A tracker callback is not run during an lr_find\n", "from local.callback.schedule import *\n", "learn = synth_learner(n_trn=2, cbs=TrackerCallback(monitor='tst_metric'), metrics=tst_metric)\n", "learn.lr_find(num_it=5, show_plot=False)\n", "assert not hasattr(learn, 'new_best')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## EarlyStoppingCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class EarlyStoppingCallback(TrackerCallback):\n", " \"A `TrackerCallback` that terminates training when monitored quantity stops improving.\"\n", " def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1):\n", " super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)\n", " self.patience = patience\n", "\n", " def begin_fit(self): self.wait = 0; super().begin_fit()\n", " def after_epoch(self):\n", " \"Compare the value monitored to its best score and maybe stop training.\"\n", " super().after_epoch()\n", " if self.new_best: self.wait = 0\n", " else:\n", " self.wait += 1\n", " if self.wait >= self.patience:\n", " print(f'No improvement since epoch {self.epoch-self.wait}: early stopping')\n", " raise CancelFitException()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount. `patience` is the number of epochs you're willing to wait without improvement." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
08.9695697.55027100:00
18.9640477.55025000:00
28.9639977.55022000:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "No improvement since epoch 0: early stopping\n" ] } ], "source": [ "learn = synth_learner(n_trn=2)\n", "learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "test_eq(len(learn.recorder.values), 3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## SaveModelCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class SaveModelCallback(TrackerCallback):\n", " \"A `TrackerCallback` that saves the model's best during training and loads it at the end.\"\n", " def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False, add_save=None, with_opt=False):\n", " super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)\n", " store_attr(self, 'fname,every_epoch,add_save,with_opt')\n", "\n", " def _save(self, name):\n", " self.learn.save(name, with_opt=self.with_opt)\n", " if self.add_save is not None:\n", " with self.add_save.open('wb') as f: self.learn.save(f, with_opt=self.with_opt)\n", "\n", " def after_epoch(self):\n", " \"Compare the value monitored to its best score and save if best.\"\n", " if self.every_epoch: self._save(f'{self.fname}_{self.epoch}')\n", " else: #every improvement\n", " super().after_epoch()\n", " if self.new_best: self._save(f'{self.fname}')\n", "\n", " def on_train_end(self, **kwargs):\n", " \"Load the best model.\"\n", " if not self.every_epoch: self.learn.load(f'{self.fname}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount. Model will be saved in `learn.path/learn.model_dir/name.pth`, maybe `every_epoch` or at each improvement of the monitored quantity. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
012.3522319.43710000:00
112.2476259.18876200:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
011.8022918.84779400:00
111.5624498.42846400:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')\n", "learn.fit(n_epoch=2, cbs=SaveModelCallback())\n", "assert (Path.cwd()/'tmp/models/model.pth').exists()\n", "learn.fit(n_epoch=2, cbs=SaveModelCallback(every_epoch=True))\n", "for i in range(2): assert (Path.cwd()/f'tmp/models/model_{i}.pth').exists()\n", "shutil.rmtree(Path.cwd()/'tmp')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ReduceLROnPlateau" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class ReduceLROnPlateau(TrackerCallback):\n", " \"A `TrackerCallback` that reduces learning rate when a metric has stopped improving.\"\n", " def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, factor=10.):\n", " super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)\n", " self.patience,self.factor = patience,factor\n", "\n", " def begin_fit(self): self.wait = 0; super().begin_fit()\n", " def after_epoch(self):\n", " \"Compare the value monitored to its best score and reduce LR by `factor` if no improvement.\"\n", " super().after_epoch()\n", " if self.new_best: self.wait = 0\n", " else:\n", " self.wait += 1\n", " if self.wait >= self.patience:\n", " for h in self.opt.hypers: h['lr'] /= self.factor\n", " self.wait = 0\n", " print(f'Epoch {self.epoch}: reducing lr to {self.opt.hypers[-1][\"lr\"]}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_losstime
013.37230411.86617900:00
113.37704311.86615400:00
213.39539511.86611800:00
313.40098811.86611400:00
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 2: reducing lr to 1e-08\n" ] } ], "source": [ "learn = synth_learner(n_trn=2)\n", "learn.fit(n_epoch=4, lr=1e-7, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "test_eq(learn.opt.hypers[-1]['lr'], 1e-8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_core_foundation.ipynb.\n", "Converted 01a_core_utils.ipynb.\n", "Converted 01b_core_dispatch.ipynb.\n", "Converted 01c_core_transform.ipynb.\n", "Converted 02_core_script.ipynb.\n", "Converted 03_torchcore.ipynb.\n", "Converted 03a_layers.ipynb.\n", "Converted 04_data_load.ipynb.\n", "Converted 05_data_core.ipynb.\n", "Converted 06_data_transforms.ipynb.\n", "Converted 07_data_block.ipynb.\n", "Converted 08_vision_core.ipynb.\n", "Converted 09_vision_augment.ipynb.\n", "Converted 09a_vision_data.ipynb.\n", "Converted 09b_vision_utils.ipynb.\n", "Converted 10_pets_tutorial.ipynb.\n", "Converted 11_vision_models_xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_learner.ipynb.\n", "Converted 13a_metrics.ipynb.\n", "Converted 14_callback_schedule.ipynb.\n", "Converted 14a_callback_data.ipynb.\n", "Converted 15_callback_hook.ipynb.\n", "Converted 15a_vision_models_unet.ipynb.\n", "Converted 16_callback_progress.ipynb.\n", "Converted 17_callback_tracker.ipynb.\n", "Converted 18_callback_fp16.ipynb.\n", "Converted 19_callback_mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision_learner.ipynb.\n", "Converted 22_tutorial_imagenette.ipynb.\n", "Converted 23_tutorial_transfer_learning.ipynb.\n", "Converted 30_text_core.ipynb.\n", "Converted 31_text_data.ipynb.\n", "Converted 32_text_models_awdlstm.ipynb.\n", "Converted 33_text_models_core.ipynb.\n", "Converted 34_callback_rnn.ipynb.\n", "Converted 35_tutorial_wikitext.ipynb.\n", "Converted 36_text_models_qrnn.ipynb.\n", "Converted 37_text_learner.ipynb.\n", "Converted 38_tutorial_ulmfit.ipynb.\n", "Converted 40_tabular_core.ipynb.\n", "Converted 41_tabular_model.ipynb.\n", "Converted 42_tabular_rapids.ipynb.\n", "Converted 50_data_block_examples.ipynb.\n", "Converted 60_medical_imaging.ipynb.\n", "Converted 65_medical_text.ipynb.\n", "Converted 70_callback_wandb.ipynb.\n", "Converted 71_callback_tensorboard.ipynb.\n", "Converted 90_notebook_core.ipynb.\n", "Converted 91_notebook_export.ipynb.\n", "Converted 92_notebook_showdoc.ipynb.\n", "Converted 93_notebook_export2html.ipynb.\n", "Converted 94_notebook_test.ipynb.\n", "Converted 95_index.ipynb.\n", "Converted 96_data_external.ipynb.\n", "Converted 97_utils_test.ipynb.\n", "Converted notebook2jekyll.ipynb.\n", "Converted xse_resnext.ipynb.\n" ] } ], "source": [ "#hide\n", "from local.notebook.export import notebook2script\n", "notebook2script(all_fs=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }