In [None]:
#hide
#skip
! [ -e /content ] && pip install -Uqq fastai  # upgrade fastai on colab

In [None]:
#default_exp callback.tracker

In [None]:
#export
from fastai.basics import *
from fastai.callback.progress import *
from fastai.callback.fp16 import MixedPrecision

In [None]:
#hide
from nbdev.showdoc import *
from fastai.test_utils import *

# Tracking callbacks

> Callbacks that make decisions depending how a monitored metric/loss behaves

## TerminateOnNaNCallback -

In [None]:
# export
class TerminateOnNaNCallback(Callback):
    "A `Callback` that terminates training if loss is NaN."
    order=-9
    def after_batch(self):
        "Test if `last_loss` is NaN and interrupts training."
        if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException

In [None]:
learn = synth_learner()
learn.fit(10, lr=100, cbs=TerminateOnNaNCallback())

epoch,train_loss,valid_loss,time


In [None]:
assert len(learn.recorder.losses) < 10 * len(learn.dls.train)
for l in learn.recorder.losses:
    assert not torch.isinf(l) and not torch.isnan(l) 

## TrackerCallback -

In [None]:
# export
class TrackerCallback(Callback):
    "A `Callback` that keeps track of the best value in `monitor`."
    order,remove_on_fetch = 60,True
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., reset_on_fit=True):
        if comp is None: comp = np.less if 'loss' in monitor or 'error' in monitor else np.greater
        if comp == np.less: min_delta *= -1
        self.monitor,self.comp,self.min_delta,self.reset_on_fit,self.best= monitor,comp,min_delta,reset_on_fit,None

    def before_fit(self):
        "Prepare the monitored value"
        self.run = not hasattr(self, "lr_finder") and not hasattr(self, "gather_preds")
        if self.reset_on_fit or self.best is None: self.best = float('inf') if self.comp == np.less else -float('inf')
        assert self.monitor in self.recorder.metric_names[1:]
        self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)

    def after_epoch(self):
        "Compare the last value to the best up to now"
        val = self.recorder.values[-1][self.idx]
        if self.comp(val - self.min_delta, self.best): self.best,self.new_best = val,True
        else: self.new_best = False

    def after_fit(self): self.run=True

When implementing a `Callback` that has behavior that depends on the best value of a metric or loss, subclass this `Callback` and use its `best` (for best value so far) and `new_best` (there was a new best value this epoch) attributes. If you want to maintain `best` over subsequent calls to `fit` (e.g., `Learner.fit_one_cycle`), set `reset_on_fit` = True.

`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount.

In [None]:
#hide
class FakeRecords(Callback):
    order=51
    def __init__(self, monitor, values): self.monitor,self.values = monitor,values
        
    def before_fit(self):   self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)
    def after_epoch(self): self.recorder.values[-1][self.idx] = self.values[self.epoch]
        
class TestTracker(Callback):
    order=61
    def before_fit(self): self.bests,self.news = [],[]
    def after_epoch(self): 
        self.bests.append(self.tracker.best)
        self.news.append(self.tracker.new_best)

In [None]:
#hide
learn = synth_learner(n_trn=2, cbs=TestTracker())
cbs=[TrackerCallback(monitor='valid_loss'), FakeRecords('valid_loss', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.1])
test_eq(learn.test_tracker.news,  [True,True])

#With a min_delta
cbs=[TrackerCallback(monitor='valid_loss', min_delta=0.15), FakeRecords('valid_loss', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.2])
test_eq(learn.test_tracker.news,  [True,False])

In [None]:
#hide
#By default metrics have to be bigger at each epoch.
def tst_metric(out,targ): return F.mse_loss(out,targ)
learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)
cbs=[TrackerCallback(monitor='tst_metric'), FakeRecords('tst_metric', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.2])
test_eq(learn.test_tracker.news,  [True,False])

#This can be overwritten by passing `comp=np.less`.
learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)
cbs=[TrackerCallback(monitor='tst_metric', comp=np.less), FakeRecords('tst_metric', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.1])
test_eq(learn.test_tracker.news,  [True,True])

In [None]:
#hide
#Setting reset_on_fit=True will maintain the "best" value over subsequent calls to fit
learn = synth_learner(n_val=2, cbs=TrackerCallback(monitor='tst_metric', reset_on_fit=False), metrics=tst_metric)
tracker_cb = learn.cbs.filter(lambda cb: isinstance(cb, TrackerCallback))[0]
with learn.no_logging(): learn.fit(1)
first_best = tracker_cb.best
with learn.no_logging(): learn.fit(1)
test_eq(tracker_cb.best, first_best)

In [None]:
#hide
#A tracker callback is not run during an lr_find
from fastai.callback.schedule import *
learn = synth_learner(n_trn=2, cbs=TrackerCallback(monitor='tst_metric'), metrics=tst_metric)
learn.lr_find(num_it=15, show_plot=False)
assert not hasattr(learn, 'new_best')

## EarlyStoppingCallback -

In [None]:
# export
class EarlyStoppingCallback(TrackerCallback):
    "A `TrackerCallback` that terminates training when monitored quantity stops improving."
    order=TrackerCallback.order+3
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, reset_on_fit=True):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
        self.patience = patience

    def before_fit(self): self.wait = 0; super().before_fit()
    def after_epoch(self):
        "Compare the value monitored to its best score and maybe stop training."
        super().after_epoch()
        if self.new_best: self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                print(f'No improvement since epoch {self.epoch-self.wait}: early stopping')
                raise CancelFitException()

`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount. `patience` is the number of epochs you're willing to wait without improvement.

In [None]:
learn = synth_learner(n_trn=2, metrics=F.mse_loss)
learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='mse_loss', min_delta=0.1, patience=2))

epoch,train_loss,valid_loss,mse_loss,time
0,10.651194,14.263412,14.263412,00:00
1,10.655529,14.263385,14.263385,00:00
2,10.675529,14.263347,14.263347,00:00


No improvement since epoch 0: early stopping


In [None]:
learn.validate()

(#2) [14.263346672058105,14.263346672058105]

In [None]:
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2))

epoch,train_loss,valid_loss,time
0,26.303347,31.155645,00:00
1,26.319504,31.155575,00:00
2,26.335766,31.155474,00:00


No improvement since epoch 0: early stopping


In [None]:
#hide
test_eq(len(learn.recorder.values), 3)

## SaveModelCallback -

In [None]:
#export
class SaveModelCallback(TrackerCallback):
    "A `TrackerCallback` that saves the model's best during training and loads it at the end."
    _only_train_loop,order = True,TrackerCallback.order+1
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False, at_end=False,
                 with_opt=False, reset_on_fit=True):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
        assert not (every_epoch and at_end), "every_epoch and at_end cannot both be set to True"
        # keep track of file path for loggers
        self.last_saved_path = None
        store_attr('fname,every_epoch,at_end,with_opt')

    def _save(self, name): self.last_saved_path = self.learn.save(name, with_opt=self.with_opt)

    def after_epoch(self):
        "Compare the value monitored to its best score and save if best."
        if self.every_epoch: self._save(f'{self.fname}_{self.epoch}')
        else: #every improvement
            super().after_epoch()
            if self.new_best:
                print(f'Better model found at epoch {self.epoch} with {self.monitor} value: {self.best}.')
                self._save(f'{self.fname}')

    def after_fit(self, **kwargs):
        "Load the best model."
        if self.at_end: self._save(f'{self.fname}')
        elif not self.every_epoch: self.learn.load(f'{self.fname}', with_opt=self.with_opt)

`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount. Model will be saved in `learn.path/learn.model_dir/name.pth`, maybe `every_epoch` or at each improvement of the monitored quantity. 

In [None]:
learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')
learn.fit(n_epoch=2, cbs=SaveModelCallback())
assert (Path.cwd()/'tmp/models/model.pth').exists()
learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')
learn.fit(n_epoch=2, cbs=SaveModelCallback(fname='end',at_end=True))
assert (Path.cwd()/'tmp/models/end.pth').exists()
learn.fit(n_epoch=2, cbs=SaveModelCallback(every_epoch=True))
for i in range(2): assert (Path.cwd()/f'tmp/models/model_{i}.pth').exists()
shutil.rmtree(Path.cwd()/'tmp')

epoch,train_loss,valid_loss,time
0,14.472381,14.357326,00:00
1,14.362669,14.045964,00:00


Better model found at epoch 0 with valid_loss value: 14.357325553894043.
Better model found at epoch 1 with valid_loss value: 14.045964241027832.


epoch,train_loss,valid_loss,time
0,10.07456,11.895357,00:00
1,9.999896,11.651211,00:00


Better model found at epoch 0 with valid_loss value: 11.895357131958008.
Better model found at epoch 1 with valid_loss value: 11.65121078491211.


epoch,train_loss,valid_loss,time
0,9.702823,11.311175,00:00
1,9.553051,10.910662,00:00


## ReduceLROnPlateau

In [None]:
# export
class ReduceLROnPlateau(TrackerCallback):
    "A `TrackerCallback` that reduces learning rate when a metric has stopped improving."
    order=TrackerCallback.order+2
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, factor=10., min_lr=0, reset_on_fit=True):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
        self.patience,self.factor,self.min_lr = patience,factor,min_lr

    def before_fit(self): self.wait = 0; super().before_fit()
    def after_epoch(self):
        "Compare the value monitored to its best score and reduce LR by `factor` if no improvement."
        super().after_epoch()
        if self.new_best: self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                old_lr = self.opt.hypers[-1]['lr']
                for h in self.opt.hypers: h['lr'] = max(h['lr'] / self.factor, self.min_lr)
                self.wait = 0
                if self.opt.hypers[-1]["lr"] < old_lr:
                    print(f'Epoch {self.epoch}: reducing lr to {self.opt.hypers[-1]["lr"]}')

In [None]:
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=4, lr=1e-7, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2))

epoch,train_loss,valid_loss,time
0,6.122743,7.348515,00:00
1,6.119377,7.348499,00:00
2,6.12579,7.348477,00:00
3,6.131386,7.348475,00:00


Epoch 2: reducing lr to 1e-08


In [None]:
#hide
test_eq(learn.opt.hypers[-1]['lr'], 1e-8)

In [None]:
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=6, lr=5e-8, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2, min_lr=1e-8))

epoch,train_loss,valid_loss,time
0,16.747515,15.265999,00:00
1,16.725756,15.265974,00:00
2,16.735016,15.265943,00:00
3,16.73336,15.265934,00:00
4,16.733513,15.265925,00:00
5,16.730352,15.265915,00:00


Epoch 2: reducing lr to 1e-08


In [None]:
#hide
test_eq(learn.opt.hypers[-1]['lr'], 1e-8)

Each of these three derived `TrackerCallback`s (`SaveModelCallback`, `ReduceLROnPlateu`, and `EarlyStoppingCallback`) all have an adjusted order so they can each run with each other without interference. That order is as follows:

> Note: in parenthesis is the actual `Callback` order number

1. `TrackerCallback` (60)
2. `SaveModelCallback` (61)
3. `ReduceLrOnPlateu` (62)
4. `EarlyStoppingCallback` (63)

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_torch_core.ipynb.
Converted 01_layers.ipynb.
Converted 01a_losses.ipynb.
Converted 02_data.load.ipynb.
Converted 03_data.core.ipynb.
Converted 04_data.external.ipynb.
Converted 05_data.transforms.ipynb.
Converted 06_data.block.ipynb.
Converted 07_vision.core.ipynb.
Converted 08_vision.data.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 09c_vision.widgets.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 10b_tutorial.albumentations.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_callback.core.ipynb.
Converted 13a_learner.ipynb.
Converted 13b_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 18a_callback.training.ipynb.
Converted 18b_callback.preds.ipynb.
Converted 