{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## TrainingPhase and General scheduler" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Creates a scheduler that lets you train a model with following different [`TrainingPhase`](/callbacks.general_sched.html#TrainingPhase)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.callbacks.general_sched import * \n", "from fastai.vision import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class TrainingPhase[source]

\n", "\n", "> TrainingPhase(**`length`**:`int`, **`lrs`**:`Floats`, **`moms`**:`Floats`, **`lr_anneal`**:`AnnealFunc`=***`None`***, **`mom_anneal`**:`AnnealFunc`=***`None`***)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainingPhase, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a phase for training a model during `length` iterations, following a schedule given by `lrs` and `lr_anneal`, `moms` and `mom_anneal`. More specifically, the phase will make the learning rate (or momentum) vary from the first value of `lrs` (or `moms`) to the second, following `lr_anneal` (or `mom_anneal`). If an annealing function is specified but `lrs` or `moms` is a float, it will decay to 0. If no annealing function is specified, the default is a linear annealing if `lrs` (or `moms`) is a tuple, a constant parameter if it's a float. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
Note: If you want to use discriminative learning rates, you can pass an numpy array of learning rate (or a tuple\n", "of them for start and stop).
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "jekyll_note(\"\"\"If you want to use discriminative learning rates, you can pass an numpy array of learning rate (or a tuple\n", "of them for start and stop).\"\"\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's make an example by using this to code [SGD with warm restarts](https://arxiv.org/abs/1608.03983)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def fit_sgd_warm(learn, n_cycles, lr, mom, cycle_len, cycle_mult):\n", " n = len(learn.data.train_dl)\n", " phases = [TrainingPhase(n * (cycle_len * cycle_mult**i), lr, mom, lr_anneal=annealing_cos) for i in range(n_cycles)]\n", " sched = GeneralScheduler(learn, phases)\n", " learn.callbacks.append(sched)\n", " if cycle_mult != 1:\n", " total_epochs = int(cycle_len * (1 - (cycle_mult)**n_cycles)/(1-cycle_mult)) \n", " else: total_epochs = n_cycles * cycle_len\n", " learn.fit(total_epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "Total time: 00:16

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracy
10.1852620.1643440.945044
20.1401570.1295740.954367
30.1247610.1235910.958292
40.1094660.1078760.964671
50.0996680.0916960.966143
60.0873450.0851870.970069
70.0858030.0848360.971050
\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)\n", "fit_sgd_warm(learn, 3, 1e-3, 0.9, 1, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "

" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class GeneralScheduler[source]

\n", "\n", "> GeneralScheduler(**`learn`**:[`Learner`](/basic_train.html#Learner), **`phases`**) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "Schedule multiple [`TrainingPhase`](/callbacks.general_sched.html#TrainingPhase) for a [`Learner`](/basic_train.html#Learner). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GeneralScheduler)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_batch_end[source]

\n", "\n", "> on_batch_end(**`train`**, **\\*\\*`kwargs`**:`Any`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GeneralScheduler.on_batch_end, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Takes a step in the current phase and prepare the hyperparameters for the next batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source]

\n", "\n", "> on_train_begin(**`n_epochs`**:`int`, **\\*\\*`kwargs`**:`Any`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GeneralScheduler.on_train_begin, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initiates the hyperparameters to the start values of the first phase. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Implementation of a flexible training API", "title": "callbacks.general_sched" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }