{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# The TrainPhase API" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*This notebook was prepared by Sylvain Gugger - many thanks!*\n", "\n", "Here we show how to use a new API in the fastai library, that allows you all the flexibility you might want while training your model.\n", "\n", "All the examples will run on cifar10, so be sure to change the path to a directory that contains this dataset, with the usual hierarchy (a train and a valid folder, each of them containing ten subdirectories for each class)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.conv_learner import *\n", "PATH = Path(\"../data/cifar10/\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n", "stats = (np.array([ 0.4914 , 0.48216, 0.44653]), np.array([ 0.24703, 0.24349, 0.26159]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This will allow us to grab data for a given image size and batch size." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data(sz,bs):\n", " tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)\n", " return ImageClassifierData.from_paths(PATH, val_name='test', tfms=tfms, bs=bs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "size = 32\n", "batch_size = 64" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = get_data(size,batch_size)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's create a very simple model that we'll train: a neural net with a hidden layer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def SimpleNet(layers):\n", " list_layers=[Flatten()]\n", " for i in range(len(layers)-1):\n", " list_layers.append(nn.Linear(layers[i], layers[i + 1]))\n", " if i < len(layers)-2: list_layers.append(nn.ReLU(inplace=True))\n", " else: list_layers.append(nn.LogSoftmax(dim=0))\n", " return nn.Sequential(*list_layers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = ConvLearner.from_model_data(SimpleNet([32*32*3, 40,10]), data)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can use our learner object to give examples of traning.\n", "\n", "With the new API, you don't use a pre-implemented training schedule but you can design your own with object called TrainingPhase. A training phase is just a class that will record all the parameters you want to apply during this part of the training loop, specifically:\n", "- a number of epochs for which these settings will be valid (can be a float)\n", "- an optimizer function (SGD, RMSProp, Adam...)\n", "- a learning rate (or array of lrs) or a range of learning rates (or array of lrs) if you want to change the lr.\n", "- a learning rate decay method (that will explain how you want to change the lr)\n", "- a momentum (which will beta1 if you're using Adam), or a range of momentums if you want to change it\n", "- a momentum decay method (that will explain how you want to change the momentum, if applicable)\n", "- optionally a weight decay (or array of wds)\n", "- optionally a beta parameter (which is the RMSProp alpha or the Adam beta2, if you want another vlaue than default)\n", "\n", "By combining those blocks as you wish, you can implement pretty much any method of training you could think of." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Basic lr decay" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's begin with something basic and say you want to train with SGD and momentum, with a learning rate of 1e-2 for 1 epoch then 1e-3 for two epochs. We'll just create a list of two phases for this." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "phases = [TrainingPhase(epochs=1, opt_fn=optim.SGD, lr = 1e-2), TrainingPhase(epochs=2, opt_fn=optim.SGD, lr = 1e-3)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that we didn't set the momentum parameter because it will default to 0.9. If you don't want any momentum, you'll have to put it to 0.\n", "\n", "Now that we have created this list of phases, we just have to call fit_opt_sched." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "42579b93c5cf45ada3112d3406e991db", "version_major": 2, "version_minor": 0 }, "text/html": [ "
Failed to display Jupyter Widget of type HBox
.
\n", " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", " that the widgets JavaScript is still loading. If this message persists, it\n", " likely means that the widgets JavaScript library is either not installed or\n", " not enabled. See the Jupyter\n", " Widgets Documentation for setup instructions.\n", "
\n", "\n", " If you're reading this message in another frontend (for example, a static\n", " rendering on GitHub or NBViewer),\n", " it may mean that your frontend doesn't currently support widgets.\n", "
\n" ], "text/plain": [ "HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "epoch trn_loss val_loss accuracy \n", " 0 3.406514 4.553418 0.1387 \n", " 1 3.396738 4.559127 0.1383 \n", " 2 3.421531 4.541069 0.138 \n" ] }, { "data": { "text/plain": [ "[array([4.54107]), 0.138]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.fit_opt_sched(phases)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we want to see what we did, we can use learn.sched.plot_lr()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "