{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## List of callbacks" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *\n", "from fastai.text import *\n", "from fastai.callbacks import * \n", "from fastai.basic_train import * \n", "from fastai.train import * \n", "from fastai import callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "fastai's training loop is highly extensible, with a rich *callback* system. See the [`callback`](/callback.html#callback) docs if you're interested in writing your own callback. See below for a list of callbacks that are provided with fastai, grouped by the module they're defined in.\n", "\n", "Every callback that is passed to [`Learner`](/basic_train.html#Learner) with the `callback_fns` parameter will be automatically stored as an attribute. The attribute name is snake-cased, so for instance [`ActivationStats`](/callbacks.hooks.html#ActivationStats) will appear as `learn.activation_stats` (assuming your object is named `learn`)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`Callback`](/callback.html#Callback)\n", "\n", "This sub-package contains more sophisticated callbacks that each are in their own module. They are (click the link for more details):" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`LRFinder`](/callbacks.lr_finder.html#LRFinder)\n", "\n", "Use Leslie Smith's [learning rate finder](https://www.jeremyjordan.me/nn-learning-rate/) to find a good learning rate for training your model. Let's see an example of use on the MNIST dataset with a simple CNN." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])\n", "learn = simple_learner()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The fastai librairy already has a Learner method called [`lr_find`](/train.html#lr_find) that uses [`LRFinder`](/callbacks.lr_finder.html#LRFinder) to plot the loss as a function of the learning rate" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, a learning rate around 2e-2 seems like the right fit." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 2e-2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)\n", "\n", "Train with Leslie Smith's [1cycle annealing](https://sgugger.github.io/the-1cycle-policy.html) method. Let's train our simple learner using the one cycle policy." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.1191910.0711950.97252200:02
10.0574190.0427370.98429800:02
20.0317920.0282590.98773300:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(3, lr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The learning rate and the momentum were changed during the epochs as follows (more info on the [dedicated documentation page](https://docs.fast.ai/callbacks.one_cycle.html))." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr(show_moms=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`MixUpCallback`](/callbacks.mixup.html#MixUpCallback)\n", "\n", "Data augmentation using the method from [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412). It is very simple to add mixup in fastai :" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy]).mixup()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`CSVLogger`](/callbacks.csv_logger.html#CSVLogger)\n", "\n", "Log the results of training in a csv file. Simply pass the CSVLogger callback to the Learner." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy, error_rate], callback_fns=[CSVLogger])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyerror_ratetime
00.1190830.1080340.9592740.04072600:02
10.0781560.0712080.9730130.02698700:02
20.0569850.0458350.9847890.01521100:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can then read the csv." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyerror_ratetime
000.1190830.1080340.9592740.040726NaN
110.0781560.0712080.9730130.026987NaN
220.0569850.0458350.9847890.015211NaN
\n", "
" ], "text/plain": [ " epoch train_loss valid_loss accuracy error_rate time\n", "0 0 0.119083 0.108034 0.959274 0.040726 NaN\n", "1 1 0.078156 0.071208 0.973013 0.026987 NaN\n", "2 2 0.056985 0.045835 0.984789 0.015211 NaN" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.csv_logger.read_logged_file()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`GeneralScheduler`](/callbacks.general_sched.html#GeneralScheduler)\n", "\n", "Create your own multi-stage annealing schemes with a convenient API. To illustrate, let's implement a 2 phase schedule." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def fit_odd_shedule(learn, lr):\n", " n = len(learn.data.train_dl)\n", " phases = [TrainingPhase(n).schedule_hp('lr', lr, anneal=annealing_cos), \n", " TrainingPhase(n*2).schedule_hp('lr', lr, anneal=annealing_poly(2))]\n", " sched = GeneralScheduler(learn, phases)\n", " learn.callbacks.append(sched)\n", " total_epochs = 3\n", " learn.fit(total_epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.1719620.1547160.94749800:02
10.1337200.1322490.95780200:02
20.1329280.1299270.95780200:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)\n", "fit_odd_shedule(learn, 1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`MixedPrecision`](/callbacks.fp16.html#MixedPrecision)\n", "\n", "Use fp16 to [take advantage of tensor cores](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) on recent NVIDIA GPUs for a 200% or more speedup." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`HookCallback`](/callbacks.hooks.html#HookCallback)\n", "\n", "Convenient wrapper for registering and automatically deregistering [PyTorch hooks](https://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html#forward-and-backward-function-hooks). Also contains pre-defined hook callback: [`ActivationStats`](/callbacks.hooks.html#ActivationStats)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`RNNTrainer`](/callbacks.rnn.html#RNNTrainer)\n", "\n", "Callback taking care of all the tweaks to train an RNN." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`TerminateOnNaNCallback`](/callbacks.tracker.html#TerminateOnNaNCallback)\n", "\n", "Stop training if the loss reaches NaN." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`EarlyStoppingCallback`](/callbacks.tracker.html#EarlyStoppingCallback)\n", "\n", "Stop training if a given metric/validation loss doesn't improve." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`SaveModelCallback`](/callbacks.tracker.html#SaveModelCallback)\n", "\n", "Save the model at every epoch, or the best model for a given metric/validation loss." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.6652440.6425820.81648700:02
10.5084920.4719500.93768400:02
20.4382860.4353770.94111900:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)\n", "learn.fit_one_cycle(3,1e-4, callbacks=[SaveModelCallback(learn, every='epoch', monitor='accuracy')])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "best.pth\t bestmodel_2.pth model_1.pth model_4.pth stage-1.pth\r\n", "bestmodel_0.pth bestmodel_3.pth model_2.pth model_5.pth tmp.pth\r\n", "bestmodel_1.pth model_0.pth\t model_3.pth one_epoch.pth trained_model.pth\r\n" ] } ], "source": [ "!ls ~/.fastai/data/mnist_sample/models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`ReduceLROnPlateauCallback`](/callbacks.tracker.html#ReduceLROnPlateauCallback)\n", "\n", "Reduce the learning rate each time a given metric/validation loss doesn't improve by a certain factor." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`PeakMemMetric`](/callbacks.mem.html#PeakMemMetric)\n", "\n", "GPU and general RAM profiling callback" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`StopAfterNBatches`](/callbacks.misc.html#StopAfterNBatches)\n", "\n", "Stop training after n batches of the first epoch." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`train`](/train.html#train) and [`basic_train`](/basic_train.html#basic_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`Recorder`](/basic_train.html#Recorder)\n", "\n", "Track per-batch and per-epoch smoothed losses and metrics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`ShowGraph`](/train.html#ShowGraph)\n", "\n", "Dynamically display a learning chart during training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`BnFreeze`](/train.html#BnFreeze)\n", "\n", "Freeze batchnorm layer moving average statistics for non-trainable layers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`GradientClipping`](/train.html#GradientClipping)\n", "\n", "Clips gradient during training." ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Callbacks implemented in the fastai library", "title": "callbacks" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }