{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# TrainingPhase and General scheduler" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Creates a scheduler that lets you train a model with following different [`TrainingPhase`](/callbacks.general_sched.html#TrainingPhase)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.callbacks.general_sched import * \n", "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class TrainingPhase[source]

\n", "\n", "> TrainingPhase(`length`:`int`, `lrs`:`Floats`, `moms`:`Floats`, `lr_anneal`:`AnnealFunc`=`None`, `mom_anneal`:`AnnealFunc`=`None`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainingPhase, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create a phase for training a model during `length` iterations, following a schedule given by `lrs` and `lr_anneal`, `moms` and `mom_anneal`. More specifically, the phase will make the learning rate (or momentum) vary from the first value of `lrs` (or `moms`) to the second, following `lr_anneal` (or `mom_anneal`). If an annealing function is speficied but `lrs` or `moms` is a float, it will decay to 0. If no annealing function is specified, the default is a linear annealing if `lrs` (or `moms`) is a tuple, a constant parameter if it's a float. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class GeneralScheduler[source]

\n", "\n", "> GeneralScheduler(`learn`:[`Learner`](/basic_train.html#Learner), `phases`:`Collection`\\[[`TrainingPhase`](/callbacks.general_sched.html#TrainingPhase)\\]) :: [`Callback`](/callback.html#Callback)\n", "\n", "Schedule multiple [`TrainingPhase`](/callbacks.general_sched.html#TrainingPhase) for a [`Learner`](/basic_train.html#Learner). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GeneralScheduler)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_batch_end[source]

\n", "\n", "> on_batch_end(`train`, `kwargs`:`Any`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GeneralScheduler.on_batch_end, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Takes a step in the current phase and prepare the hyperparameters for the next batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source]

\n", "\n", "> on_train_begin(`n_epochs`:`int`, `kwargs`:`Any`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GeneralScheduler.on_train_begin, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initiates the hyperparameters to the start values of the first phase. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's make an example by using this to code [SGD with warm restarts](https://arxiv.org/abs/1608.03983)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def fit_sgd_warm(learn, n_cycles, lr, mom, cycle_len, cycle_mult):\n", " n = len(learn.data.train_dl)\n", " phases = [TrainingPhase(n * (cycle_len * cycle_mult**i), lr, mom, lr_anneal=annealing_cos) for i in range(n_cycles)]\n", " sched = GeneralScheduler(learn, phases)\n", " learn.callbacks.append(sched)\n", " if cycle_mult != 1:\n", " total_epochs = int(cycle_len * (1 - (cycle_mult)**n_cycles)/(1-cycle_mult)) \n", " else: total_epochs = n_cycles * cycle_len\n", " learn.fit(total_epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(HBox(children=(IntProgress(value=0, max=7), HTML(value='0.00% [0/7 00:00<00:00]'))), HTML(value…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:16\n", "epoch train loss valid loss\n", "0 0.203685 0.176289 (00:02)\n", "1 0.139156 0.147694 (00:02)\n", "2 0.132314 0.131610 (00:02)\n", "3 0.118946 0.118343 (00:02)\n", "4 0.116849 0.105648 (00:02)\n", "5 0.105146 0.105442 (00:02)\n", "6 0.099159 0.102690 (00:02)\n", "\n" ] } ], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "learn = Learner(data, simple_cnn((3,16,16,2)))\n", "fit_sgd_warm(learn, 3, 1e-3, 0.9, 1, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Implementation of a flexible training API", "title": "callbacks.general_sched" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }