{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## List of callbacks" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *\n", "from fastai.text import *\n", "from fastai.callbacks import * \n", "from fastai.basic_train import * \n", "from fastai.train import * \n", "from fastai import callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "fastai's training loop is highly extensible, with a rich *callback* system. See the [`callback`](/callback.html#callback) docs if you're interested in writing your own callback. See below for a list of callbacks that are provided with fastai, grouped by the module they're defined in.\n", "\n", "Every callback that is passed to [`Learner`](/basic_train.html#Learner) with the `callback_fns` parameter will be automatically stored as an attribute. The attribute name is snake-cased, so for instance [`ActivationStats`](/callbacks.hooks.html#ActivationStats) will appear as `learn.activation_stats` (assuming your object is named `learn`)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`Callback`](/callback.html#Callback)\n", "\n", "This sub-package contains more sophisticated callbacks that each are in their own module. They are (click the link for more details):" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`LRFinder`](/callbacks.lr_finder.html#LRFinder)\n", "\n", "Use Leslie Smith's [learning rate finder](https://www.jeremyjordan.me/nn-learning-rate/) to find a good learning rate for training your model. Let's see an example of use on the MNIST dataset with a simple CNN." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])\n", "learn = simple_learner()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The fastai librairy already has a Learner method called [`lr_find`](/train.html#lr_find) that uses [`LRFinder`](/callbacks.lr_finder.html#LRFinder) to plot the loss as a function of the learning rate" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, a learning rate around 2e-2 seems like the right fit." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lr = 2e-2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler)\n", "\n", "Train with Leslie Smith's [1cycle annealing](https://sgugger.github.io/the-1cycle-policy.html) method. Let's train our simple learner using the one cycle policy." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:08

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.1112050.0564600.979882
20.0406320.0236500.987733
30.0212170.0200440.991659
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(3, lr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The learning rate and the momentum were changed during the epochs as follows (more info on the [dedicated documentation page](https://docs.fast.ai/callbacks.one_cycle.html))." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "

" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr(show_moms=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`MixUpCallback`](/callbacks.mixup.html#MixUpCallback)\n", "\n", "Data augmentation using the method from [mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412). It is very simple to add mixup in fastai :" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy]).mixup()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`CSVLogger`](/callbacks.csv_logger.html#CSVLogger)\n", "\n", "Log the results of training in a csv file. Simply pass the CSVLogger callback to the Learner." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, simple_cnn((3, 16, 16, 2)), metrics=[accuracy, error_rate], callback_fns=[CSVLogger])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:08

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyerror_rate
10.1253260.1034730.9636900.036310
20.0773920.0592230.9779200.022080
30.0657560.0810310.9695780.030422
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can then read the csv." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "

\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracyerror_rate
010.1253260.1034730.9636900.036310
120.0773920.0592230.9779200.022080
230.0657560.0810310.9695780.030422
\n", "
" ], "text/plain": [ " epoch train_loss valid_loss accuracy error_rate\n", "0 1 0.125326 0.103473 0.963690 0.036310\n", "1 2 0.077392 0.059223 0.977920 0.022080\n", "2 3 0.065756 0.081031 0.969578 0.030422" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.csv_logger.read_logged_file()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`GeneralScheduler`](/callbacks.general_sched.html#GeneralScheduler)\n", "\n", "Create your own multi-stage annealing schemes with a convenient API. To illustrate, let's implement a 2 phase schedule." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def fit_odd_shedule(learn, lr, mom):\n", " n = len(learn.data.train_dl)\n", " phases = [TrainingPhase(n, lr, mom, lr_anneal=annealing_cos), TrainingPhase(n*2, lr, mom, lr_anneal=annealing_poly(2))]\n", " sched = GeneralScheduler(learn, phases)\n", " learn.callbacks.append(sched)\n", " total_epochs = 3\n", " learn.fit(total_epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:08

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.1786480.1617280.944553
20.1427390.1326200.957802
30.1352390.1291830.960255
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)\n", "fit_odd_shedule(learn, 1e-3, 0.9)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "

" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`MixedPrecision`](/callbacks.fp16.html#MixedPrecision)\n", "\n", "Use fp16 to [take advantage of tensor cores](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) on recent NVIDIA GPUs for a 200% or more speedup." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`HookCallback`](/callbacks.hooks.html#HookCallback)\n", "\n", "Convenient wrapper for registering and automatically deregistering [PyTorch hooks](https://pytorch.org/tutorials/beginner/former_torchies/nn_tutorial.html#forward-and-backward-function-hooks). Also contains pre-defined hook callback: [`ActivationStats`](/callbacks.hooks.html#ActivationStats)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`RNNTrainer`](/callbacks.rnn.html#RNNTrainer)\n", "\n", "Callback taking care of all the tweaks to train an RNN." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`TerminateOnNaNCallback`](/callbacks.tracker.html#TerminateOnNaNCallback)\n", "\n", "Stop training if the loss reaches NaN." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`EarlyStoppingCallback`](/callbacks.tracker.html#EarlyStoppingCallback)\n", "\n", "Stop training if a given metric/validation loss doesn't improve." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`SaveModelCallback`](/callbacks.tracker.html#SaveModelCallback)\n", "\n", "Save the model at every epoch, or the best model for a given metric/validation loss." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`ReduceLROnPlateauCallback`](/callbacks.tracker.html#ReduceLROnPlateauCallback)\n", "\n", "Reduce the learning rate each time a given metric/validation loss doesn't improve by a certain factor." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`PeakMemMetric`](/callbacks.mem.html#PeakMemMetric)\n", "\n", "GPU and general RAM profiling callback" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## [`train`](/train.html#train) and [`basic_train`](/basic_train.html#basic_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`Recorder`](/basic_train.html#Recorder)\n", "\n", "Track per-batch and per-epoch smoothed losses and metrics." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`ShowGraph`](/train.html#ShowGraph)\n", "\n", "Dynamically display a learning chart during training." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`BnFreeze`](/train.html#BnFreeze)\n", "\n", "Freeze batchnorm layer moving average statistics for non-trainable layers." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### [`GradientClipping`](/train.html#GradientClipping)\n", "\n", "Clips gradient during training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Callbacks implemented in the fastai library", "title": "callbacks" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }