{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Training modules overview" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.basic_train import *\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai.callbacks import *" ] }, { "cell_type": "markdown", "metadata": { "hide_input": false }, "source": [ "The fastai library structures its training process around the [`Learner`](/basic_train.html#Learner) class, whose object binds together a PyTorch model, a dataset, an optimizer, and a loss function; the entire learner object then will allow us to launch training.\n", "\n", "[`basic_train`](/basic_train.html#basic_train) defines this [`Learner`](/basic_train.html#Learner) class, along with the wrapper around the PyTorch optimizer that the library uses. It defines the basic training loop that is used each time you call the [`fit`](/basic_train.html#fit) method (or one of its variants) in fastai. This training loop is very bare-bones and has very few lines of codes; you can customize it by supplying an optional [`Callback`](/callback.html#Callback) argument to the [`fit`](/basic_train.html#fit) method.\n", "\n", "[`callback`](/callback.html#callback) defines the [`Callback`](/callback.html#Callback) class and the [`CallbackHandler`](/callback.html#CallbackHandler) class that is responsible for the communication between the training loop and the [`Callback`](/callback.html#Callback)'s methods. The [`CallbackHandler`](/callback.html#CallbackHandler) maintains a state dictionary able to provide each [`Callback`](/callback.html#Callback) object all the information of the training loop it belongs to, putting any imaginable tweaks of the training loop within your reach.\n", "\n", "[`callbacks`](/callbacks.html#callbacks) implements each predefined [`Callback`](/callback.html#Callback) class of the fastai library in a separate module. Some modules deal with scheduling the hyperparameters, like [`callbacks.one_cycle`](/callbacks.one_cycle.html#callbacks.one_cycle), [`callbacks.lr_finder`](/callbacks.lr_finder.html#callbacks.lr_finder) and [`callback.general_sched`](/callbacks.general_sched.html#callbacks.general_sched). Others allow special kinds of training like [`callbacks.fp16`](/callbacks.fp16.html#callbacks.fp16) (mixed precision) and [`callbacks.rnn`](/callbacks.rnn.html#callbacks.rnn). The [`Recorder`](/basic_train.html#Recorder) and [`callbacks.hooks`](/callbacks.hooks.html#callbacks.hooks) are useful to save some internal data generated in the training loop.\n", "\n", "[`train`](/train.html#train) then uses these callbacks to implement useful helper functions. Lastly, [`metrics`](/metrics.html#metrics) contains all the functions and classes you might want to use to evaluate your training results; simpler metrics are implemented as functions while more complicated ones as subclasses of [`Callback`](/callback.html#Callback). For more details on implementing metrics as [`Callback`](/callback.html#Callback), please refer to [creating your own metrics](/metrics.html#Creating-your-own-metric)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Walk-through of key functionalities\n", "\n", "We'll do a quick overview of the key pieces of fastai's training modules. See the separate module docs for details on each. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup\n", "Import required [modules](/index.html#imports) and prepare [data](/basic_data.html#Get-your-data-ready-for-training): " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from fastai.vision import *\n", "\n", "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*`URLs.MNIST_SAMPLE` is a small subset of the classic MNIST dataset containing the images of just 3's and 7's for the purpose of demo and documentation here. Common [`datasets`](/datasets.html#datasets) can be downloaded with [`untar_data`](/datasets.html#untar_data) - which we will use to create an [`ImageDataBunch`](/vision.data.html#ImageDataBunch) object*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Basic training with [`Learner`](/basic_train.html#Learner)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can create a minimal CNN using [`simple_cnn`](/layers.html#simple_cnn) (see [`models`](/vision.models.html#vision.models) for details on creating models):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = simple_cnn((3,16,16,2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The [`Learner`](/basic_train.html#Learner) class plays a central role in training models; when you create a [`Learner`](/basic_train.html#Learner) you need to specify at the very minimum the [`data`](/vision.data.html#vision.data) and `model` to use." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(data, model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These are enough to create a [`Learner`](/basic_train.html#Learner) object and then use it to train a model using its [`fit`](/basic_train.html#fit) method. If you have a CUDA-enabled GPU, it will be used automatically. To call the [`fit`](/basic_train.html#fit) method, you have to at least specify how many epochs to train for." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:03

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_loss
10.1249810.097195
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Viewing metrics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To see how our training is going, we can request that it reports various kinds of [`metrics`](/metrics.html#metrics) after each epoch. You can pass it to the constructor, or set it later. Note that metrics are always calculated on the validation set." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:02

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.0815630.0627980.976938
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.metrics=[accuracy]\n", "learn.fit(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Extending training with callbacks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can use [`callback`](/callback.html#callback)s to modify training in almost any way you can imagine. For instance, we've provided a callback to implement Leslie Smith's 1cycle training method." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:02

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.0559550.0454690.984298
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "cb = OneCycleScheduler(learn, lr_max=0.01)\n", "learn.fit(1, callbacks=cb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The [`Recorder`](/basic_train.html#Recorder) callback is automatically added for you, and you can use it to see what happened in your training, e.g.:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "

" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr(show_moms=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Extending [`Learner`](/basic_train.html#Learner) with [`train`](/train.html#train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Many of the callbacks can be used more easily by taking advantage of the [`Learner`](/basic_train.html#Learner) extensions in [`train`](/train.html#train). For instance, instead of creating OneCycleScheduler manually as above, you can simply call [`Learner.fit_one_cycle`](/train.html#fit_one_cycle):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:03

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.0405350.0350620.986752
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Applications" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that if you're training a model for one of our supported *applications*, there's a lot of help available to you in the application modules:\n", "\n", "- [`vision`](/vision.html#vision)\n", "- [`text`](/text.html#text)\n", "- [`tabular`](/tabular.html#tabular)\n", "- [`collab`](/collab.html#collab)\n", "\n", "For instance, let's use [`cnn_learner`](/vision.learner.html#cnn_learner) (from [`vision`](/vision.html#vision)) to quickly fine-tune a pre-trained Imagenet model for MNIST (not a very practical approach, of course, since MNIST is handwriting and our model is pre-trained on photos!)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 02:06

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.1636590.1127670.958783
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn = cnn_learner(data, models.resnet18, metrics=accuracy)\n", "learn.fit_one_cycle(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Open This Notebook\n", "\n", "" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Overview of fastai training modules, including Learner, metrics, and callbacks", "title": "training" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 2 }