{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Learning Rate Finder" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "%matplotlib inline\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Learning rate finder plots lr vs loss relationship for a [`Learner`](/basic_train.html#Learner). The idea is to reduce the amount of guesswork on picking a good starting learning rate.\n", "\n", "**Overview:** \n", "1. First run lr_find `learn.lr_find()`\n", "2. Plot the learning rate vs loss `learn.recorder.plot()`\n", "3. Pick a learning rate before it diverges then start training\n", "\n", "**Technical Details:** (first [described]('https://arxiv.org/abs/1506.01186') by Leslie Smith) \n", ">Train [`Learner`](/basic_train.html#Learner) over a few iterations. Start with a very low `start_lr` and change it at each mini-batch until it reaches a very high `end_lr`. [`Recorder`](/basic_train.html#Recorder) will record the loss at each iteration. Plot those losses against the learning rate to find the optimal value before it diverges." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Choosing a good learning rate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a more intuitive explanation, please check out [Sylvain Gugger's post](https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])\n", "learn = simple_learner()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we run this command to launch the search:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
lr_find
[source]lr_find
(`learn`:[`Learner`](/basic_train.html#Learner), `start_lr`:`Floats`=`1e-07`, `end_lr`:`Floats`=`10`, `num_it`:`int`=`100`, `stop_div`:`bool`=`True`, `kwargs`:`Any`)\n",
"\n",
"Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss explodes. "
],
"text/plain": [
"