{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Tracking Callbacks" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *\n", "from fastai.callbacks import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This module regroups the callbacks that track one of the metrics computed at the end of each epoch to take some decision about training. To show examples of use, we'll use our sample of MNIST and a simple cnn model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class TerminateOnNaNCallback[source]

\n", "\n", "> TerminateOnNaNCallback() :: [`Callback`](/callback.html#Callback)\n", "\n", "A [`Callback`](/callback.html#Callback) that terminates training if loss is NaN. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TerminateOnNaNCallback)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sometimes, training diverges and the loss goes to nan. In that case, there's no point continuing, so this callback stops the training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
1nannan0.49558400:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = simple_cnn((3,16,16,2))\n", "learn = Learner(data, model, metrics=[accuracy])\n", "learn.fit_one_cycle(1,1e4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using it prevents that situation to happen." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " 0.00% [0/2 00:00<00:00]\n", "
\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime

\n", "\n", "

\n", " \n", " \n", " Interrupted\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch/Batch (0/5): Invalid loss, terminating training.\n" ] } ], "source": [ "model = simple_cnn((3,16,16,2))\n", "learn = Learner(data, model, metrics=[accuracy], callbacks=[TerminateOnNaNCallback()])\n", "learn.fit(2,1e4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_batch_end[source]

\n", "\n", "> on_batch_end(**`last_loss`**, **`epoch`**, **`num_batch`**, **\\*\\*`kwargs`**:`Any`)\n", "\n", "Test if `last_loss` is NaN and interrupts training. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TerminateOnNaNCallback.on_batch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source]

\n", "\n", "> on_epoch_end(**\\*\\*`kwargs`**:`Any`)\n", "\n", "Stop the training if necessary. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TerminateOnNaNCallback.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class EarlyStoppingCallback[source]

\n", "\n", "> EarlyStoppingCallback(**`learn`**:[`Learner`](/basic_train.html#Learner), **`monitor`**:`str`=***`'val_loss'`***, **`mode`**:`str`=***`'auto'`***, **`min_delta`**:`int`=***`0`***, **`patience`**:`int`=***`0`***) :: [`TrackerCallback`](/callbacks.tracker.html#TrackerCallback)\n", "\n", "A [`TrackerCallback`](/callbacks.tracker.html#TrackerCallback) that terminates training when monitored quantity stops improving. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(EarlyStoppingCallback)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This callback tracks the quantity in `monitor` during the training of `learn`. `mode` can be forced to 'min' or 'max' but will automatically try to determine if the quantity should be the lowest possible (validation loss) or the highest possible (accuracy). Will stop training after `patience` epochs if the quantity hasn't improved by `min_delta`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " 8.00% [4/50 00:09<01:52]\n", "
\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
10.6966290.6962660.36261000:02
20.6963330.6962660.36261000:02
30.6963070.6962660.36261000:02
40.6964670.6962660.36261000:02

\n", "\n", "

\n", " \n", " \n", " 100.00% [32/32 00:00<00:00]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5: early stopping\n" ] } ], "source": [ "model = simple_cnn((3,16,16,2))\n", "learn = Learner(data, model, metrics=[accuracy], \n", " callback_fns=[partial(EarlyStoppingCallback, monitor='accuracy', min_delta=0.01, patience=3)])\n", "learn.fit(50,1e-42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source]

\n", "\n", "> on_train_begin(**\\*\\*`kwargs`**:`Any`)\n", "\n", "Initialize inner arguments. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(EarlyStoppingCallback.on_train_begin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source]

\n", "\n", "> on_epoch_end(**`epoch`**, **\\*\\*`kwargs`**:`Any`)\n", "\n", "Compare the value monitored to its best score and maybe stop training. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(EarlyStoppingCallback.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class SaveModelCallback[source]

\n", "\n", "> SaveModelCallback(**`learn`**:[`Learner`](/basic_train.html#Learner), **`monitor`**:`str`=***`'val_loss'`***, **`mode`**:`str`=***`'auto'`***, **`every`**:`str`=***`'improvement'`***, **`name`**:`str`=***`'bestmodel'`***) :: [`TrackerCallback`](/callbacks.tracker.html#TrackerCallback)\n", "\n", "A [`TrackerCallback`](/callbacks.tracker.html#TrackerCallback) that saves the model when monitored quantity is best. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(SaveModelCallback)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This callback tracks the quantity in `monitor` during the training of `learn`. `mode` can be forced to 'min' or 'max' but will automatically try to determine if the quantity should be the lowest possible (validation loss) or the highest possible (accuracy). Will save the model in `name` whenever determined by `every` ('improvement' or 'epoch'). Loads the best model at the end of training is `every='improvement'`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
10.6783380.6669260.65947000:02
20.5634760.5155980.90775300:02
30.3700790.3373530.93326800:02
40.2815640.2725600.93621200:02
50.2603850.2637200.93670300:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = simple_cnn((3,16,16,2))\n", "learn = Learner(data, model, metrics=[accuracy])\n", "learn.fit_one_cycle(5,1e-4, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy', name='model')])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Choosing `every='epoch'` saves an individual model at the end of each epoch." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "bestmodel_1.pth model_1.pth model_4.pth stage-1.pth\r\n", "bestmodel_2.pth model_2.pth model_5.pth tmp.pth\r\n", "bestmodel_3.pth model_3.pth one_epoch.pth trained_model.pth\r\n" ] } ], "source": [ "!ls ~/.fastai/data/mnist_sample/models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
10.2387110.2266840.93915600:02
20.1819800.1760780.94062800:02
30.1593140.1630880.94210000:02
40.1604530.1594230.94308100:02
50.1597170.1590170.94308100:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Better model found at epoch 1 with accuracy value: 0.9391560554504395.\n", "Better model found at epoch 2 with accuracy value: 0.9406280517578125.\n", "Better model found at epoch 3 with accuracy value: 0.9421001076698303.\n", "Better model found at epoch 4 with accuracy value: 0.9430814385414124.\n" ] } ], "source": [ "learn.fit_one_cycle(5,1e-4, callbacks=[SaveModelCallback(learn, every='improvement', monitor='accuracy', name='best')])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Choosing `every='improvement'` saves the single best model out of all epochs during training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "best.pth\t bestmodel_3.pth model_3.pth one_epoch.pth trained_model.pth\r\n", "bestmodel_1.pth model_1.pth\t model_4.pth stage-1.pth\r\n", "bestmodel_2.pth model_2.pth\t model_5.pth tmp.pth\r\n" ] } ], "source": [ "!ls ~/.fastai/data/mnist_sample/models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source]

\n", "\n", "> on_epoch_end(**`epoch`**, **\\*\\*`kwargs`**:`Any`)\n", "\n", "Compare the value monitored to its best score and maybe save the model. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(SaveModelCallback.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_end[source]

\n", "\n", "> on_train_end(**\\*\\*`kwargs`**)\n", "\n", "Load the best model. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(SaveModelCallback.on_train_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class ReduceLROnPlateauCallback[source]

\n", "\n", "> ReduceLROnPlateauCallback(**`learn`**:[`Learner`](/basic_train.html#Learner), **`monitor`**:`str`=***`'val_loss'`***, **`mode`**:`str`=***`'auto'`***, **`patience`**:`int`=***`0`***, **`factor`**:`float`=***`0.2`***, **`min_delta`**:`int`=***`0`***) :: [`TrackerCallback`](/callbacks.tracker.html#TrackerCallback)\n", "\n", "A [`TrackerCallback`](/callbacks.tracker.html#TrackerCallback) that reduces learning rate when a metric has stopped improving. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ReduceLROnPlateauCallback)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This callback tracks the quantity in `monitor` during the training of `learn`. `mode` can be forced to 'min' or 'max' but will automatically try to determine if the quantity should be the lowest possible (validation loss) or the highest possible (accuracy). Will reduce the learning rate by `factor` after `patience` epochs if the quantity hasn't improved by `min_delta`. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source]

\n", "\n", "> on_train_begin(**\\*\\*`kwargs`**:`Any`)\n", "\n", "Initialize inner arguments. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ReduceLROnPlateauCallback.on_train_begin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source]

\n", "\n", "> on_epoch_end(**`epoch`**, **\\*\\*`kwargs`**:`Any`)\n", "\n", "Compare the value monitored to its best and maybe reduce lr. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ReduceLROnPlateauCallback.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class TrackerCallback[source]

\n", "\n", "> TrackerCallback(**`learn`**:[`Learner`](/basic_train.html#Learner), **`monitor`**:`str`=***`'val_loss'`***, **`mode`**:`str`=***`'auto'`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "A [`LearnerCallback`](/basic_train.html#LearnerCallback) that keeps track of the best value in `monitor`. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrackerCallback)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

get_monitor_value[source]

\n", "\n", "> get_monitor_value()\n", "\n", "Pick the monitored value. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrackerCallback.get_monitor_value)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source]

\n", "\n", "> on_train_begin(**\\*\\*`kwargs`**:`Any`)\n", "\n", "Initializes the best value. \n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrackerCallback.on_train_begin)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Methods - Please document or move to the undocumented section" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Callbacks that take decisions depending on the evolution of metrics during training", "title": "callbacks.tracker" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }