{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Tracking Callbacks" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *\n", "from fastai.callbacks import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This module regroups the callbacks that track one of the metrics computed at the end of each epoch to take some decision about training. To show examples of use, we'll use our sample of MNIST and a simple cnn model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class TerminateOnNaNCallback[source][test]

\n", "\n", "> TerminateOnNaNCallback() :: [`Callback`](/callback.html#Callback)\n", "\n", "
×

No tests found for TerminateOnNaNCallback. To contribute a test please refer to this guide and this discussion.

\n", "\n", "A [`Callback`](/callback.html#Callback) that terminates training if loss is NaN. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TerminateOnNaNCallback)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sometimes, training diverges and the loss goes to nan. In that case, there's no point continuing, so this callback stops the training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
1nannan0.49558400:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "model = simple_cnn((3,16,16,2))\n", "learn = Learner(data, model, metrics=[accuracy])\n", "learn.fit_one_cycle(1,1e4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using it prevents that situation to happen." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " 0.00% [0/2 00:00<00:00]\n", "
\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime

\n", "\n", "

\n", " \n", " \n", " Interrupted\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch/Batch (0/5): Invalid loss, terminating training.\n", "Epoch/Batch (0/6): Invalid loss, terminating training.\n" ] } ], "source": [ "model = simple_cnn((3,16,16,2))\n", "learn = Learner(data, model, metrics=[accuracy], callbacks=[TerminateOnNaNCallback()])\n", "learn.fit(2,1e4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_batch_end[source][test]

\n", "\n", "> on_batch_end(**`last_loss`**, **`epoch`**, **`num_batch`**, **\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_batch_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Test if `last_loss` is NaN and interrupts training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TerminateOnNaNCallback.on_batch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source][test]

\n", "\n", "> on_epoch_end(**\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Called at the end of an epoch. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TerminateOnNaNCallback.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class EarlyStoppingCallback[source][test]

\n", "\n", "> EarlyStoppingCallback(**`learn`**:[`Learner`](/basic_train.html#Learner), **`monitor`**:`str`=***`'valid_loss'`***, **`mode`**:`str`=***`'auto'`***, **`min_delta`**:`int`=***`0`***, **`patience`**:`int`=***`0`***) :: [`TrackerCallback`](/callbacks.tracker.html#TrackerCallback)\n", "\n", "
×

No tests found for EarlyStoppingCallback. To contribute a test please refer to this guide and this discussion.

\n", "\n", "A [`TrackerCallback`](/callbacks.tracker.html#TrackerCallback) that terminates training when monitored quantity stops improving. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(EarlyStoppingCallback)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This callback tracks the quantity in `monitor` during the training of `learn`. `mode` can be forced to 'min' or 'max' but will automatically try to determine if the quantity should be the lowest possible (validation loss) or the highest possible (accuracy). Will stop training after `patience` epochs if the quantity hasn't improved by `min_delta`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " 8.00% [4/50 00:09<01:51]\n", "
\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
10.6777290.6775140.56035300:02
20.6798200.6775140.56035300:02
30.6794340.6775140.56035300:02
40.6798930.6775140.56035300:02

\n", "\n", "

\n", " \n", " \n", " 100.00% [32/32 00:00<00:00]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 5: early stopping\n" ] }, { "ename": "AttributeError", "evalue": "'bool' object has no attribute 'items'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m learn = Learner(data, model, metrics=[accuracy], \n\u001b[1;32m 3\u001b[0m callback_fns=[partial(EarlyStoppingCallback, monitor='accuracy', min_delta=0.01, patience=3)])\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m50\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1e-42\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/fastai/fastai/basic_train.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, epochs, lr, wd, callbacks)\u001b[0m\n\u001b[1;32m 180\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdefaults\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextra_callbacks\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mcallbacks\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mdefaults\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextra_callbacks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m fit(epochs, self.model, self.loss_func, opt=self.opt, data=self.data, metrics=self.metrics,\n\u001b[0;32m--> 182\u001b[0;31m callbacks=self.callbacks+callbacks)\n\u001b[0m\u001b[1;32m 183\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 184\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcreate_opt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwd\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m->\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/fastai/fastai/basic_train.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(epochs, model, loss_func, opt, data, callbacks, metrics)\u001b[0m\n\u001b[1;32m 97\u001b[0m cb_handler=cb_handler, pbar=pbar)\n\u001b[1;32m 98\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mval_loss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_epoch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0mexception\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/fastai/fastai/callback.py\u001b[0m in \u001b[0;36mon_epoch_end\u001b[0;34m(self, val_loss)\u001b[0m\n\u001b[1;32m 287\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'last_metrics'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mval_loss\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mval_loss\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'epoch'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 289\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'epoch_end'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcall_mets\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_loss\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 290\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'stop_training'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 291\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/fastai/fastai/callback.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, cb_name, call_mets, **kwargs)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcall_mets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmet\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_and_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcb_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_and_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcb_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 225\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mset_dl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdl\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mDataLoader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/fastai/fastai/callback.py\u001b[0m in \u001b[0;36m_call_and_update\u001b[0;34m(self, cb, cb_name, **kwargs)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[0;34m\"Call `cb_name` on `cb` and update the inner state.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 214\u001b[0m \u001b[0mnew\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mifnone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf'on_{cb_name}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 215\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mnew\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 216\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mk\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mException\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"{k} isn't a valid key in the state of the callbacks.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mAttributeError\u001b[0m: 'bool' object has no attribute 'items'" ] } ], "source": [ "model = simple_cnn((3,16,16,2))\n", "learn = Learner(data, model, metrics=[accuracy], \n", " callback_fns=[partial(EarlyStoppingCallback, monitor='accuracy', min_delta=0.01, patience=3)])\n", "learn.fit(50,1e-42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source][test]

\n", "\n", "> on_train_begin(**\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Initialize inner arguments. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(EarlyStoppingCallback.on_train_begin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source][test]

\n", "\n", "> on_epoch_end(**`epoch`**, **\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Compare the value monitored to its best score and maybe stop training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(EarlyStoppingCallback.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class SaveModelCallback[source][test]

\n", "\n", "> SaveModelCallback(**`learn`**:[`Learner`](/basic_train.html#Learner), **`monitor`**:`str`=***`'valid_loss'`***, **`mode`**:`str`=***`'auto'`***, **`every`**:`str`=***`'improvement'`***, **`name`**:`str`=***`'bestmodel'`***) :: [`TrackerCallback`](/callbacks.tracker.html#TrackerCallback)\n", "\n", "
×

No tests found for SaveModelCallback. To contribute a test please refer to this guide and this discussion.

\n", "\n", "A [`TrackerCallback`](/callbacks.tracker.html#TrackerCallback) that saves the model when monitored quantity is best. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(SaveModelCallback)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This callback tracks the quantity in `monitor` during the training of `learn`. `mode` can be forced to 'min' or 'max' but will automatically try to determine if the quantity should be the lowest possible (validation loss) or the highest possible (accuracy). Will save the model in `name` whenever determined by `every` ('improvement' or 'epoch'). Loads the best model at the end of training is `every='improvement'`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = simple_cnn((3,16,16,2))\n", "learn = Learner(data, model, metrics=[accuracy])\n", "learn.fit_one_cycle(5,1e-4, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy', name='model')])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Choosing `every='epoch'` saves an individual model at the end of each epoch." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!ls ~/.fastai/data/mnist_sample/models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(5,1e-4, callbacks=[SaveModelCallback(learn, every='improvement', monitor='accuracy', name='best')])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Choosing `every='improvement'` saves the single best model out of all epochs during training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!ls ~/.fastai/data/mnist_sample/models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source][test]

\n", "\n", "> on_epoch_end(**`epoch`**:`int`, **\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Compare the value monitored to its best score and maybe save the model. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(SaveModelCallback.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_end[source][test]

\n", "\n", "> on_train_end(**\\*\\*`kwargs`**)\n", "\n", "
×

No tests found for on_train_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Load the best model. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(SaveModelCallback.on_train_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class ReduceLROnPlateauCallback[source][test]

\n", "\n", "> ReduceLROnPlateauCallback(**`learn`**:[`Learner`](/basic_train.html#Learner), **`monitor`**:`str`=***`'valid_loss'`***, **`mode`**:`str`=***`'auto'`***, **`patience`**:`int`=***`0`***, **`factor`**:`float`=***`0.2`***, **`min_delta`**:`int`=***`0`***) :: [`TrackerCallback`](/callbacks.tracker.html#TrackerCallback)\n", "\n", "
×

No tests found for ReduceLROnPlateauCallback. To contribute a test please refer to this guide and this discussion.

\n", "\n", "A [`TrackerCallback`](/callbacks.tracker.html#TrackerCallback) that reduces learning rate when a metric has stopped improving. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ReduceLROnPlateauCallback)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This callback tracks the quantity in `monitor` during the training of `learn`. `mode` can be forced to 'min' or 'max' but will automatically try to determine if the quantity should be the lowest possible (validation loss) or the highest possible (accuracy). Will reduce the learning rate by `factor` after `patience` epochs if the quantity hasn't improved by `min_delta`. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source][test]

\n", "\n", "> on_train_begin(**\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Initialize inner arguments. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ReduceLROnPlateauCallback.on_train_begin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source][test]

\n", "\n", "> on_epoch_end(**`epoch`**, **\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Compare the value monitored to its best and maybe reduce lr. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(ReduceLROnPlateauCallback.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class TrackerCallback[source][test]

\n", "\n", "> TrackerCallback(**`learn`**:[`Learner`](/basic_train.html#Learner), **`monitor`**:`str`=***`'valid_loss'`***, **`mode`**:`str`=***`'auto'`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for TrackerCallback. To contribute a test please refer to this guide and this discussion.

\n", "\n", "A [`LearnerCallback`](/basic_train.html#LearnerCallback) that keeps track of the best value in `monitor`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrackerCallback)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

get_monitor_value[source][test]

\n", "\n", "> get_monitor_value()\n", "\n", "
×

No tests found for get_monitor_value. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Pick the monitored value. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrackerCallback.get_monitor_value)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source][test]

\n", "\n", "> on_train_begin(**\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Initializes the best value. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrackerCallback.on_train_begin)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Methods - Please document or move to the undocumented section" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Callbacks that take decisions depending on the evolution of metrics during training", "title": "callbacks.tracker" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }