{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Learning Rate Finder" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "%matplotlib inline\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Learning rate finder plots lr vs loss relationship for a [`Learner`](/basic_train.html#Learner). The idea is to reduce the amount of guesswork on picking a good starting learning rate.\n", "\n", "**Overview:** \n", "1. First run lr_find `learn.lr_find()`\n", "2. Plot the learning rate vs loss `learn.recorder.plot()`\n", "3. Pick a learning rate before it diverges then start training\n", "\n", "**Technical Details:** (first [described]('https://arxiv.org/abs/1506.01186') by Leslie Smith) \n", ">Train [`Learner`](/basic_train.html#Learner) over a few iterations. Start with a very low `start_lr` and change it at each mini-batch until it reaches a very high `end_lr`. [`Recorder`](/basic_train.html#Recorder) will record the loss at each iteration. Plot those losses against the learning rate to find the optimal value before it diverges." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Choosing a good learning rate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a more intuitive explanation, please check out [Sylvain Gugger's post](https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])\n", "learn = simple_learner()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we run this command to launch the search:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
lr_find[source]lr_find(`learn`:[`Learner`](/basic_train.html#Learner), `start_lr`:`Floats`=`1e-07`, `end_lr`:`Floats`=`10`, `num_it`:`int`=`100`, `stop_div`:`bool`=`True`, `kwargs`:`Any`)\n",
"\n",
"Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss explodes. "
],
"text/plain": [
"class LRFinder[source]LRFinder(`learn`:[`Learner`](/basic_train.html#Learner), `start_lr`:`float`=`1e-07`, `end_lr`:`float`=`10`, `num_it`:`int`=`100`, `stop_div`:`bool`=`True`) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"Causes `learn` to go on a mock training from `start_lr` to `end_lr` for `num_it` iterations. Training is interrupted if the loss diverges. Weights changes are reverted after run complete. "
],
"text/plain": [
"on_train_end[source]on_train_end(`kwargs`:`Any`)\n",
"\n",
"Cleanup learn model weights disturbed during LRFind exploration. "
],
"text/plain": [
"on_batch_end[source]on_batch_end(`iteration`:`int`, `smooth_loss`:`TensorOrNumber`, `kwargs`:`Any`)\n",
"\n",
"Determine if loss has runaway and we should stop. "
],
"text/plain": [
"on_train_begin[source]on_train_begin(`pbar`, `kwargs`:`Any`)\n",
"\n",
"Initialize optimizer and learner hyperparameters. "
],
"text/plain": [
"on_epoch_end[source]on_epoch_end(`kwargs`:`Any`)\n",
"\n",
"Tell Learner if we need to stop. "
],
"text/plain": [
"