{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## TrainingPhase and General scheduler" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Creates a scheduler that lets you train a model with following different [`TrainingPhase`](/callbacks.general_sched.html#TrainingPhase)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.callbacks.general_sched import * \n", "from fastai.vision import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class TrainingPhase[source][test]

\n", "\n", "> TrainingPhase(**`length`**:`int`)\n", "\n", "
×

No tests found for TrainingPhase. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Schedule hyper-parameters for a phase of `length` iterations. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainingPhase)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can then schedule any hyper-parameter you want by using the following method." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

schedule_hp[source][test]

\n", "\n", "> schedule_hp(**`name`**, **`vals`**, **`anneal`**=***`None`***)\n", "\n", "
×

No tests found for schedule_hp. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Adds a schedule for `name` between `vals` using `anneal`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainingPhase.schedule_hp)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The phase will make the hyper-parameter vary from the first value in `vals` to the second, following `anneal`. If an annealing function is specified but `vals` is a float, it will decay to 0. If no annealing function is specified, the default is a linear annealing for a tuple, a constant parameter if it's a float. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
Note: If you want to use discriminative values, you can pass an numpy array in `vals` (or a tuple\n", "of them for start and stop).
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "jekyll_note(\"\"\"If you want to use discriminative values, you can pass an numpy array in `vals` (or a tuple\n", "of them for start and stop).\"\"\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The basic hyper-parameters are named:\n", "- 'lr' for learning rate\n", "- 'mom' for momentum (or beta1 in Adam)\n", "- 'beta' for the beta2 in Adam or the alpha in RMSprop\n", "- 'wd' for weight decay\n", "\n", "You can also add any hyper-parameter that is in your optimizer (even if it's custom or a [`GeneralOptimizer`](/general_optimizer.html#GeneralOptimizer)), like 'eps' if you're using Adam. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's make an example by using this to code [SGD with warm restarts](https://arxiv.org/abs/1608.03983)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def fit_sgd_warm(learn, n_cycles, lr, mom, cycle_len, cycle_mult):\n", " n = len(learn.data.train_dl)\n", " phases = [(TrainingPhase(n * (cycle_len * cycle_mult**i))\n", " .schedule_hp('lr', lr, anneal=annealing_cos)\n", " .schedule_hp('mom', mom)) for i in range(n_cycles)]\n", " sched = GeneralScheduler(learn, phases)\n", " learn.callbacks.append(sched)\n", " if cycle_mult != 1:\n", " total_epochs = int(cycle_len * (1 - (cycle_mult)**n_cycles)/(1-cycle_mult)) \n", " else: total_epochs = n_cycles * cycle_len\n", " learn.fit(total_epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
00.1621460.1535320.94210000:02
10.1261120.1172670.96025500:02
20.1120450.1105860.96221800:02
30.0976030.0908380.96761500:02
40.0868830.0813750.97301300:02
50.0836730.0761600.97399400:02
60.0848350.0762110.97399400:02
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "learn = Learner(data, simple_cnn((3,16,16,2)), metrics=accuracy)\n", "fit_sgd_warm(learn, 3, 1e-3, 0.9, 1, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class GeneralScheduler[source][test]

\n", "\n", "> GeneralScheduler(**`learn`**:[`Learner`](/basic_train.html#Learner), **`phases`**:`Collection`\\[[`TrainingPhase`](/callbacks.general_sched.html#TrainingPhase)\\], **`start_epoch`**:`int`=***`None`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

No tests found for GeneralScheduler. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Schedule multiple [`TrainingPhase`](/callbacks.general_sched.html#TrainingPhase) for a [`Learner`](/basic_train.html#Learner). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GeneralScheduler)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_batch_end[source][test]

\n", "\n", "> on_batch_end(**`train`**, **\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_batch_end. To contribute a test please refer to this guide and this discussion.

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GeneralScheduler.on_batch_end, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Takes a step in the current phase and prepare the hyperparameters for the next batch." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source][test]

\n", "\n", "> on_train_begin(**`epoch`**:`int`, **\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GeneralScheduler.on_train_begin, doc_string=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initiates the hyperparameters to the start values of the first phase. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Implementation of a flexible training API", "title": "callbacks.general_sched" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.2" } }, "nbformat": 4, "nbformat_minor": 2 }