{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp callback.tracker" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.basics import *\n", "from fastai2.callback.progress import *\n", "from fastai2.callback.fp16 import MixedPrecision" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nbdev.showdoc import *\n", "from fastai2.test_utils import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Tracking callbacks\n", "\n", "> Callbacks that make decisions depending how a monitored metric/loss behaves" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TerminateOnNaNCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class TerminateOnNaNCallback(Callback):\n", " \"A `Callback` that terminates training if loss is NaN.\"\n", " run_before=Recorder\n", "\n", " def after_batch(self):\n", " \"Test if `last_loss` is NaN and interrupts training.\"\n", " if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = synth_learner()\n", "learn.fit(10, lr=100, cbs=TerminateOnNaNCallback())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "assert len(learn.recorder.losses) < 10 * len(learn.dls.train)\n", "for l in learn.recorder.losses:\n", " assert not torch.isinf(l) and not torch.isnan(l) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TrackerCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class TrackerCallback(Callback):\n", " \"A `Callback` that keeps track of the best value in `monitor`.\"\n", " remove_on_fetch,run_after = True,Recorder\n", "\n", " def __init__(self, monitor='valid_loss', comp=None, min_delta=0.):\n", " if comp is None: comp = np.less if 'loss' in monitor or 'error' in monitor else np.greater\n", " if comp == np.less: min_delta *= -1\n", " self.monitor,self.comp,self.min_delta = monitor,comp,min_delta\n", "\n", " def begin_fit(self):\n", " \"Prepare the monitored value\"\n", " self.run = not hasattr(self, \"lr_finder\") and not hasattr(self, \"gather_preds\")\n", " self.best = float('inf') if self.comp == np.less else -float('inf')\n", " assert self.monitor in self.recorder.metric_names[1:]\n", " self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)\n", "\n", " def after_epoch(self):\n", " \"Compare the last value to the best up to know\"\n", " val = self.recorder.values[-1][self.idx]\n", " if self.comp(val - self.min_delta, self.best): self.best,self.new_best = val,True\n", " else: self.new_best = False\n", "\n", " def after_fit(self): self.run=True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When implementing a `Callback` that has behavior that depends on the best value of a metric or loss, subclass this `Callback` and use its `best` (for best value so far) and `new_best` (there was a new best value this epoch) attributes. \n", "\n", "`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "class FakeRecords(Callback):\n", " run_after=Recorder\n", " run_before=TrackerCallback\n", " \n", " def __init__(self, monitor, values): self.monitor,self.values = monitor,values\n", " \n", " def begin_fit(self): self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)\n", " def after_epoch(self): self.recorder.values[-1][self.idx] = self.values[self.epoch]\n", " \n", "class TestTracker(Callback):\n", " run_after=TrackerCallback\n", " def begin_fit(self): self.bests,self.news = [],[]\n", " def after_epoch(self): \n", " self.bests.append(self.tracker.best)\n", " self.news.append(self.tracker.new_best)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "learn = synth_learner(n_trn=2, cbs=TestTracker())\n", "cbs=[TrackerCallback(monitor='valid_loss'), FakeRecords('valid_loss', [0.2,0.1])]\n", "with learn.no_logging(): learn.fit(2, cbs=cbs)\n", "test_eq(learn.test_tracker.bests, [0.2, 0.1])\n", "test_eq(learn.test_tracker.news, [True,True])\n", "\n", "#With a min_delta\n", "cbs=[TrackerCallback(monitor='valid_loss', min_delta=0.15), FakeRecords('valid_loss', [0.2,0.1])]\n", "with learn.no_logging(): learn.fit(2, cbs=cbs)\n", "test_eq(learn.test_tracker.bests, [0.2, 0.2])\n", "test_eq(learn.test_tracker.news, [True,False])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#By default metrics have to be bigger at each epoch.\n", "def tst_metric(out,targ): return F.mse_loss(out,targ)\n", "learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)\n", "cbs=[TrackerCallback(monitor='tst_metric'), FakeRecords('tst_metric', [0.2,0.1])]\n", "with learn.no_logging(): learn.fit(2, cbs=cbs)\n", "test_eq(learn.test_tracker.bests, [0.2, 0.2])\n", "test_eq(learn.test_tracker.news, [True,False])\n", "\n", "#This can be overwritten by passing `comp=np.less`.\n", "learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)\n", "cbs=[TrackerCallback(monitor='tst_metric', comp=np.less), FakeRecords('tst_metric', [0.2,0.1])]\n", "with learn.no_logging(): learn.fit(2, cbs=cbs)\n", "test_eq(learn.test_tracker.bests, [0.2, 0.1])\n", "test_eq(learn.test_tracker.news, [True,True])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#A tracker callback is not run during an lr_find\n", "from fastai2.callback.schedule import *\n", "learn = synth_learner(n_trn=2, cbs=TrackerCallback(monitor='tst_metric'), metrics=tst_metric)\n", "learn.lr_find(num_it=15, show_plot=False)\n", "assert not hasattr(learn, 'new_best')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## EarlyStoppingCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "@log_args\n", "class EarlyStoppingCallback(TrackerCallback):\n", " \"A `TrackerCallback` that terminates training when monitored quantity stops improving.\"\n", " def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1):\n", " super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)\n", " self.patience = patience\n", "\n", " def begin_fit(self): self.wait = 0; super().begin_fit()\n", " def after_epoch(self):\n", " \"Compare the value monitored to its best score and maybe stop training.\"\n", " super().after_epoch()\n", " if self.new_best: self.wait = 0\n", " else:\n", " self.wait += 1\n", " if self.wait >= self.patience:\n", " print(f'No improvement since epoch {self.epoch-self.wait}: early stopping')\n", " raise CancelFitException()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount. `patience` is the number of epochs you're willing to wait without improvement." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = synth_learner(n_trn=2, metrics=F.mse_loss)\n", "learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='mse_loss', min_delta=0.1, patience=2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.validate()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = synth_learner(n_trn=2)\n", "learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "test_eq(len(learn.recorder.values), 3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## SaveModelCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "@log_args\n", "class SaveModelCallback(TrackerCallback):\n", " \"A `TrackerCallback` that saves the model's best during training and loads it at the end.\"\n", " def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False, add_save=None, with_opt=False):\n", " super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)\n", " store_attr(self, 'fname,every_epoch,add_save,with_opt')\n", "\n", " def _save(self, name):\n", " self.learn.save(name, with_opt=self.with_opt)\n", " if self.add_save is not None:\n", " with self.add_save.open('wb') as f: self.learn.save(f, with_opt=self.with_opt)\n", "\n", " def after_epoch(self):\n", " \"Compare the value monitored to its best score and save if best.\"\n", " if self.every_epoch: self._save(f'{self.fname}_{self.epoch}')\n", " else: #every improvement\n", " super().after_epoch()\n", " if self.new_best:\n", " print(f'Better model found at epoch {self.epoch} with {self.monitor} value: {self.best}.')\n", " self._save(f'{self.fname}')\n", "\n", " def after_fit(self, **kwargs):\n", " \"Load the best model.\"\n", " if not self.every_epoch: self.learn.load(f'{self.fname}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount. Model will be saved in `learn.path/learn.model_dir/name.pth`, maybe `every_epoch` or at each improvement of the monitored quantity. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')\n", "learn.fit(n_epoch=2, cbs=SaveModelCallback())\n", "assert (Path.cwd()/'tmp/models/model.pth').exists()\n", "learn.fit(n_epoch=2, cbs=SaveModelCallback(every_epoch=True))\n", "for i in range(2): assert (Path.cwd()/f'tmp/models/model_{i}.pth').exists()\n", "shutil.rmtree(Path.cwd()/'tmp')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ReduceLROnPlateau" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "@log_args\n", "class ReduceLROnPlateau(TrackerCallback):\n", " \"A `TrackerCallback` that reduces learning rate when a metric has stopped improving.\"\n", " def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, factor=10., min_lr=0):\n", " super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)\n", " self.patience,self.factor,self.min_lr = patience,factor,min_lr\n", "\n", " def begin_fit(self): self.wait = 0; super().begin_fit()\n", " def after_epoch(self):\n", " \"Compare the value monitored to its best score and reduce LR by `factor` if no improvement.\"\n", " super().after_epoch()\n", " if self.new_best: self.wait = 0\n", " else:\n", " self.wait += 1\n", " if self.wait >= self.patience:\n", " old_lr = self.opt.hypers[-1]['lr']\n", " for h in self.opt.hypers: h['lr'] = max(h['lr'] / self.factor, self.min_lr)\n", " self.wait = 0\n", " if self.opt.hypers[-1][\"lr\"] < old_lr:\n", " print(f'Epoch {self.epoch}: reducing lr to {self.opt.hypers[-1][\"lr\"]}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = synth_learner(n_trn=2)\n", "learn.fit(n_epoch=4, lr=1e-7, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "test_eq(learn.opt.hypers[-1]['lr'], 1e-8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = synth_learner(n_trn=2)\n", "learn.fit(n_epoch=6, lr=5e-8, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2, min_lr=1e-8))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "test_eq(learn.opt.hypers[-1]['lr'], 1e-8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }