{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Contribution from @fredguth, https://github.com/fredguth/fastai_playground" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.torch_core import *\n", "from fastai.callback import *\n", "from fastai.basic_train import Learner, LearnerCallback" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import * \n", "from fastai.vision import * " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# __all__ = ['TerminateOnNaN', 'EarlyStopping', 'SaveModel']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "___\n", "\n", "The callbacks bellow are based on Keras Callbacks of same name:\n", "https://github.com/keras-team/keras/blob/master/keras/callbacks.py\n", "\n", "___" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The Problem: Loss is Nan" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])\n", "learn.fit(2,1e4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The Solution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Callback bellow is *very* influenced by Keras Callback of same name." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TerminateOnNaN(Callback):\n", " \"A `LearnerCallback` that terminates training if loss is NaN.\"\n", " \n", " def __init__(self):\n", " self.stop = False\n", " \n", " def on_batch_end(self, last_loss, epoch, num_batch, **kwargs:Any)->None:\n", " if self.stop: return True #to skip validation after stopping during traning\n", " if torch.isnan(last_loss):\n", " print (f'Epoch/Batch ({epoch}/{num_batch}): Invalid loss, terminating training.')\n", " self.stop = True\n", " return True\n", " \n", " def on_epoch_end(self, **kwargs:Any)->None:\n", " return self.stop " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(2,1e4, callbacks=[TerminateOnNaN()])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The Problem: Metric does not improve" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(3,1e-42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Callback bellow is basically a simplified port of Keras Early Stopping callback to fastai/pytorch." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class TrackerCallback(LearnerCallback):\n", " \"A `LearnerCallback` that keeps track of the best value in `monitor`.\"\n", " monitor:str='val_loss'\n", " mode:str='auto'\n", " \n", " def __post_init__(self):\n", " if self.mode not in ['auto', 'min', 'max']:\n", " warn(f'{self.__name__} mode {self.mode} is invalid, falling back to \"auto\" mode.')\n", " self.mode = 'auto'\n", " mode_dict = {'min': np.less, 'max':np.greater}\n", " mode_dict['auto'] = np.less if 'loss' in self.monitor else np.greater\n", " self.operator = mode_dict[self.mode]\n", " \n", " def on_train_begin(self, **kwargs:Any)->None:\n", " self.best = float('inf') if self.operator == np.less else -float('inf')\n", " \n", " def get_monitor_value(self):\n", " values = {'trn_loss':self.learn.recorder.losses[-1:][0].cpu().numpy(),\n", " 'val_loss':self.learn.recorder.val_losses[-1:][0]}\n", " for i, name in enumerate(self.learn.recorder.names[3:]):\n", " values[name]=learn.recorder.metrics[-1:][0][i] \n", " if values.get(self.monitor) is None:\n", " warn(f'{self.__name__} conditioned on metric `{self.monitor}` which is not available. Available metrics are: {\", \".join(map(str, self.learn.recorder.names[1:]))}') \n", " return values.get(self.monitor)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class EarlyStopping(TrackerCallback):\n", " \"A `LearnerCallback` that terminates training when monitored quantity stops improving.\"\n", " min_delta:int=0\n", " patience:int=0\n", " \n", " def __post_init__(self):\n", " super().__post_init__()\n", " if self.operator == np.less: self.min_delta *= -1\n", " \n", " def on_train_begin(self, **kwargs:Any)->None:\n", " self.wait = 0\n", " super().on_train_begin(**kwargs)\n", "\n", " def on_epoch_end(self, epoch, **kwargs:Any)->None:\n", " current = self.get_monitor_value()\n", " if current is None: return\n", " if self.operator(current - self.min_delta, self.best):\n", " self.best,self.wait = current,0\n", " else:\n", " self.wait += 1\n", " if self.wait >= self.patience:\n", " print(f'Epoch {epoch}: early stopping')\n", " return True" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])\n", "learn.callback_fns.append(partial(EarlyStopping, monitor='accuracy', min_delta=0.01, patience=3))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(50,1e-42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### The Problem: best result is not in the last epoch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy] )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(5,1e-42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Best epoch is #1. But model is in #4." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class SaveModel(TrackerCallback):\n", " \"A `LearnerCallback` that saves the model when monitored quantity is best.\"\n", " every:str='improvement'\n", " name:str='bestmodel'\n", " def __post_init__(self):\n", " if self.every not in ['improvement', 'epoch']:\n", " warn(f'SaveModel every {every} is invalid, falling back to \"improvement\".')\n", " self.every = 'improvement'\n", " super().__post_init__()\n", " \n", " def on_epoch_end(self, epoch, **kwargs:Any)->None:\n", " if self.every==\"epoch\": learn.save(f'{self.name}_{epoch}')\n", " else: #every=\"improvement\"\n", " current = self.get_monitor_value()\n", " if current is not None and self.operator(current, self.best):\n", " self.best = current\n", " learn.save(f'{self.name}')\n", " \n", " def on_train_end(self, **kwargs):\n", " if self.every==\"improvement\": learn.load(f'{self.name}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])\n", "learn.callback_fns.append(partial(SaveModel, every='epoch'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(5,1e-42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!ls {path}/models/" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])\n", "learn.callback_fns.append(partial(SaveModel, monitor='accuracy'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(5,1e-2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!ls {path}/models/" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "validate(learn.model, learn.data.valid_dl, learn.loss_fn, metrics=[accuracy])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }