{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Learning Rate Finder" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "%matplotlib inline\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *\n", "from fastai.callbacks import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Learning rate finder plots lr vs loss relationship for a [`Learner`](/basic_train.html#Learner). The idea is to reduce the amount of guesswork on picking a good starting learning rate.\n", "\n", "**Overview:** \n", "1. First run lr_find `learn.lr_find()`\n", "2. Plot the learning rate vs loss `learn.recorder.plot()`\n", "3. Pick a learning rate before it diverges then start training\n", "\n", "**Technical Details:** (first [described]('https://arxiv.org/abs/1506.01186') by Leslie Smith) \n", ">Train [`Learner`](/basic_train.html#Learner) over a few iterations. Start with a very low `start_lr` and change it at each mini-batch until it reaches a very high `end_lr`. [`Recorder`](/basic_train.html#Recorder) will record the loss at each iteration. Plot those losses against the learning rate to find the optimal value before it diverges." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Choosing a good learning rate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a more intuitive explanation, please check out [Sylvain Gugger's post](https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])\n", "learn = simple_learner()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we run this command to launch the search:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
lr_find
[source][test]lr_find
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`start_lr`**:`Floats`=***`1e-07`***, **`end_lr`**:`Floats`=***`10`***, **`num_it`**:`int`=***`100`***, **`stop_div`**:`bool`=***`True`***, **`wd`**:`float`=***`None`***)\n",
"\n",
"epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.127434 | \n", "0.070243 | \n", "0.973013 | \n", "00:02 | \n", "
2 | \n", "0.050703 | \n", "0.039493 | \n", "0.984789 | \n", "00:02 | \n", "
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.727221 | \n", "0.693147 | \n", "0.495584 | \n", "00:02 | \n", "
2 | \n", "0.693826 | \n", "0.693147 | \n", "0.495584 | \n", "00:02 | \n", "
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.152897 | \n", "0.134366 | \n", "0.950932 | \n", "00:02 | \n", "
2 | \n", "0.120961 | \n", "0.117550 | \n", "0.960746 | \n", "00:02 | \n", "
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
1 | \n", "0.109475 | \n", "0.081607 | \n", "0.970559 | \n", "00:02 | \n", "
2 | \n", "0.070303 | \n", "0.050977 | \n", "0.982826 | \n", "00:02 | \n", "
class
LRFinder
[source][test]LRFinder
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`start_lr`**:`float`=***`1e-07`***, **`end_lr`**:`float`=***`10`***, **`num_it`**:`int`=***`100`***, **`stop_div`**:`bool`=***`True`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for LRFinder
. To contribute a test please refer to this guide and this discussion.
on_train_begin
[source][test]on_train_begin
(**`pbar`**, **\\*\\*`kwargs`**:`Any`)\n",
"\n",
"No tests found for on_train_begin
. To contribute a test please refer to this guide and this discussion.
on_batch_end
[source][test]on_batch_end
(**`iteration`**:`int`, **`smooth_loss`**:`TensorOrNumber`, **\\*\\*`kwargs`**:`Any`)\n",
"\n",
"No tests found for on_batch_end
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**\\*\\*`kwargs`**:`Any`)\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
on_train_end
[source][test]on_train_end
(**\\*\\*`kwargs`**:`Any`)\n",
"\n",
"No tests found for on_train_end
. To contribute a test please refer to this guide and this discussion.