In [None]:
#default_exp callback.tracker

In [None]:
#export
from local.test import *
from local.basics import *
from local.callback.progress import *

In [None]:
from local.notebook.showdoc import *
from local.test_utils import *

# Tracking callbacks

> Callbacks that make decisions depending how a monitored metric/loss behaves

## ShortEpochCallback -

In [None]:
#export
class ShortEpochCallback(Callback):
    "Fit just `pct` of an epoch, then stop"
    def __init__(self,pct=0.01,short_valid=True): self.pct,self.short_valid = pct,short_valid
    def after_batch(self):
        if self.iter/self.n_iter < self.pct: return
        if self.training:    raise CancelTrainException
        if self.short_valid: raise CancelValidException

In [None]:
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback())

epoch,train_loss,valid_loss,time
0,00:00,,


In [None]:
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback(short_valid=False))

epoch,train_loss,valid_loss,time
0,4.103489,00:00,


## TerminateOnNaNCallback -

In [None]:
# export
class TerminateOnNaNCallback(Callback):
    "A `Callback` that terminates training if loss is NaN."
    run_before=Recorder

    def after_batch(self):
        "Test if `last_loss` is NaN and interrupts training."
        if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException

In [None]:
learn = synth_learner()
learn.fit(10, lr=100, cbs=TerminateOnNaNCallback())

epoch,train_loss,valid_loss,time
0,2.862784155043519e+36,00:00,


In [None]:
assert len(learn.recorder.losses) < 10 * len(learn.dbunch.train_dl)
for l in learn.recorder.losses:
    assert not torch.isinf(l) and not torch.isnan(l) 

## TrackerCallback -

In [None]:
# export
class TrackerCallback(Callback):
    "A `Callback` that keeps track of the best value in `monitor`."
    run_after=Recorder

    def __init__(self, monitor='valid_loss', comp=None, min_delta=0.):
        if comp is None: comp = np.less if 'loss' in monitor else np.greater
        if comp == np.less: min_delta *= -1
        self.monitor,self.comp,self.min_delta = monitor,comp,min_delta

    def begin_fit(self):
        "Prepare the monitored value"
        self.run = not hasattr(self, "lr_finder") and not hasattr(self, "gather_preds")
        self.best = float('inf') if self.comp == np.less else -float('inf')
        assert self.monitor in self.recorder.metric_names[1:]
        self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)

    def after_epoch(self):
        "Compare the last value to the best up to know"
        val = self.recorder.values[-1][self.idx]
        if self.comp(val - self.min_delta, self.best): self.best,self.new_best = val,True
        else: self.new_best = False
    
    def after_fit(self): self.run=True

When implementing a `Callback` that has behavior that depends on the best value of a metric or loss, subclass this `Callback` and use its `best` (for best value so far) and `new_best` (there was a new best value this epoch) attributes. 

`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount.

In [None]:
#hide
class FakeRecords(Callback):
    run_after=Recorder
    run_before=TrackerCallback
    
    def __init__(self, monitor, values): self.monitor,self.values = monitor,values
        
    def begin_fit(self):   self.idx = list(self.recorder.metric_names[1:]).index(self.monitor)
    def after_epoch(self): self.recorder.values[-1][self.idx] = self.values[self.epoch]
        
class TestTracker(Callback):
    run_after=TrackerCallback
    def begin_fit(self): self.bests,self.news = [],[]
    def after_epoch(self): 
        self.bests.append(self.tracker.best)
        self.news.append(self.tracker.new_best)

In [None]:
#hide
learn = synth_learner(n_trn=2, cbs=TestTracker())
cbs=[TrackerCallback(monitor='valid_loss'), FakeRecords('valid_loss', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.1])
test_eq(learn.test_tracker.news,  [True,True])

#With a min_delta
cbs=[TrackerCallback(monitor='valid_loss', min_delta=0.15), FakeRecords('valid_loss', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.2])
test_eq(learn.test_tracker.news,  [True,False])

In [None]:
#hide
#By default metrics have to be bigger at each epoch.
def tst_metric(out,targ): return F.mse_loss(out,targ)
learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)
cbs=[TrackerCallback(monitor='tst_metric'), FakeRecords('tst_metric', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.2])
test_eq(learn.test_tracker.news,  [True,False])

#This can be overwritten by passing `comp=np.less`.
learn = synth_learner(n_trn=2, cbs=TestTracker(), metrics=tst_metric)
cbs=[TrackerCallback(monitor='tst_metric', comp=np.less), FakeRecords('tst_metric', [0.2,0.1])]
with learn.no_logging(): learn.fit(2, cbs=cbs)
test_eq(learn.test_tracker.bests, [0.2, 0.1])
test_eq(learn.test_tracker.news,  [True,True])

In [None]:
#hide
#A tracker callback is not run during an lr_find
from local.callback.schedule import *
learn = synth_learner(n_trn=2, cbs=TrackerCallback(monitor='tst_metric'), metrics=tst_metric)
learn.lr_find(num_it=5, show_plot=False)
assert not hasattr(learn, 'new_best')

## EarlyStoppingCallback -

In [None]:
# export
class EarlyStoppingCallback(TrackerCallback):
    "A `TrackerCallback` that terminates training when monitored quantity stops improving."
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)
        self.patience = patience

    def begin_fit(self): self.wait = 0; super().begin_fit()
    def after_epoch(self):
        "Compare the value monitored to its best score and maybe stop training."
        super().after_epoch()
        if self.new_best: self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                print(f'No improvement since epoch {self.epoch-self.wait}: early stopping')
                raise CancelFitException()

`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount. `patience` is the number of epochs you're willing to wait without improvement.

In [None]:
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2))

epoch,train_loss,valid_loss,time
0,8.969569,7.550271,00:00
1,8.964047,7.55025,00:00
2,8.963997,7.55022,00:00


No improvement since epoch 0: early stopping


In [None]:
#hide
test_eq(len(learn.recorder.values), 3)

## SaveModelCallback -

In [None]:
# export
class SaveModelCallback(TrackerCallback):
    "A `TrackerCallback` that saves the model's best during training and loads it at the end."
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., fname='model', every_epoch=False, add_save=None, with_opt=False):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)
        store_attr(self, 'fname,every_epoch,add_save,with_opt')

    def _save(self, name):
        self.learn.save(name, with_opt=self.with_opt)
        if self.add_save is not None:
            with self.add_save.open('wb') as f: self.learn.save(f, with_opt=self.with_opt)

    def after_epoch(self):
        "Compare the value monitored to its best score and save if best."
        if self.every_epoch: self._save(f'{self.fname}_{self.epoch}')
        else: #every improvement
            super().after_epoch()
            if self.new_best: self._save(f'{self.fname}')

    def on_train_end(self, **kwargs):
        "Load the best model."
        if not self.every_epoch: self.learn.load(f'{self.fname}')

`comp` is the comparison operator used to determine if a value is best than another (defaults to `np.less` if 'loss' is in the name passed in `monitor`, `np.greater` otherwise) and `min_delta` is an optional float that requires a new value to go over the current best (depending on `comp`) by at least that amount. Model will be saved in `learn.path/learn.model_dir/name.pth`, maybe `every_epoch` or at each improvement of the monitored quantity. 

In [None]:
learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')
learn.fit(n_epoch=2, cbs=SaveModelCallback())
assert (Path.cwd()/'tmp/models/model.pth').exists()
learn.fit(n_epoch=2, cbs=SaveModelCallback(every_epoch=True))
for i in range(2): assert (Path.cwd()/f'tmp/models/model_{i}.pth').exists()
shutil.rmtree(Path.cwd()/'tmp')

epoch,train_loss,valid_loss,time
0,12.352231,9.4371,00:00
1,12.247625,9.188762,00:00


epoch,train_loss,valid_loss,time
0,11.802291,8.847794,00:00
1,11.562449,8.428464,00:00


## ReduceLROnPlateau

In [None]:
# export
class ReduceLROnPlateau(TrackerCallback):
    "A `TrackerCallback` that reduces learning rate when a metric has stopped improving."
    def __init__(self, monitor='valid_loss', comp=None, min_delta=0., patience=1, factor=10.):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta)
        self.patience,self.factor = patience,factor

    def begin_fit(self): self.wait = 0; super().begin_fit()
    def after_epoch(self):
        "Compare the value monitored to its best score and reduce LR by `factor` if no improvement."
        super().after_epoch()
        if self.new_best: self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                for h in self.opt.hypers: h['lr'] /= self.factor
                self.wait = 0
                print(f'Epoch {self.epoch}: reducing lr to {self.opt.hypers[-1]["lr"]}')

In [None]:
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=4, lr=1e-7, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2))

epoch,train_loss,valid_loss,time
0,13.372304,11.866179,00:00
1,13.377043,11.866154,00:00
2,13.395395,11.866118,00:00
3,13.400988,11.866114,00:00


Epoch 2: reducing lr to 1e-08


In [None]:
#hide
test_eq(learn.opt.hypers[-1]['lr'], 1e-8)

## Export -

In [None]:
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core_foundation.ipynb.
Converted 01a_core_utils.ipynb.
Converted 01b_core_dispatch.ipynb.
Converted 01c_core_transform.ipynb.
Converted 02_core_script.ipynb.
Converted 03_torchcore.ipynb.
Converted 03a_layers.ipynb.
Converted 04_data_load.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_transforms.ipynb.
Converted 07_data_block.ipynb.
Converted 08_vision_core.ipynb.
Converted 09_vision_augment.ipynb.
Converted 09a_vision_data.ipynb.
Converted 09b_vision_utils.ipynb.
Converted 10_pets_tutorial.ipynb.
Converted 11_vision_models_xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 13a_metrics.ipynb.
Converted 14_callback_schedule.ipynb.
Converted 14a_callback_data.ipynb.
Converted 15_callback_hook.ipynb.
Converted 15a_vision_models_unet.ipynb.
Converted 16_callback_progress.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 18_callback_fp16.ipynb.
Converted 19_callback_mixup.ipynb.
Converted 20_interpret.ipynb.
C