{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Learning Rate Finder" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "%matplotlib inline\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Learning rate finder plots lr vs loss relationship for a [`Learner`](/basic_train.html#Learner). The idea is to reduce the amount of guesswork on picking a good starting learning rate.\n", "\n", "**Overview:** \n", "1. First run lr_find `learn.lr_find()`\n", "2. Plot the learning rate vs loss `learn.recorder.plot()`\n", "3. Pick a learning rate before it diverges then start training\n", "\n", "**Technical Details:** (first [described]('https://arxiv.org/abs/1506.01186') by Leslie Smith) \n", ">Train [`Learner`](/basic_train.html#Learner) over a few iterations. Start with a very low `start_lr` and change it at each mini-batch until it reaches a very high `end_lr`. [`Recorder`](/basic_train.html#Recorder) will record the loss at each iteration. Plot those losses against the learning rate to find the optimal value before it diverges." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Choosing a good learning rate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For a more intuitive explanation, please check out [Sylvain Gugger's post](https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "def simple_learner(): return Learner(data, simple_cnn((3,16,16,2)), metrics=[accuracy])\n", "learn = simple_learner()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we run this command to launch the search:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

lr_find[source]

\n", "\n", "> lr_find(`learn`:[`Learner`](/basic_train.html#Learner), `start_lr`:`Floats`=`1e-07`, `end_lr`:`Floats`=`10`, `num_it`:`int`=`100`, `stop_div`:`bool`=`True`, `kwargs`:`Any`)\n", "\n", "Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss explodes. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.lr_find)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LR Finder complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn.lr_find(stop_div=False, num_it=200)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we plot the loss versus the learning rates. We're interested in finding a good order of magnitude of learning rate, so we plot with a log scale. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, we choose a value that is approximately in the middle of the sharpest downward slope. In this case, training with 3e-2 looks like it should work well:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:03\n", "epoch train_loss valid_loss accuracy\n", "1 0.070224 0.039051 0.986752 (00:01)\n", "2 0.038105 0.043696 0.985280 (00:01)\n", "\n" ] } ], "source": [ "simple_learner().fit(2, 3e-2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Don't just pick the minimum value from the plot!:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:03\n", "epoch train_loss valid_loss accuracy\n", "1 0.724437 0.693147 0.495584 (00:01)\n", "2 0.693758 0.693147 0.495584 (00:01)\n", "\n" ] } ], "source": [ "learn = simple_learner()\n", "simple_learner().fit(2, 1e-0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Picking a value before the downward slope results in slow training:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total time: 00:03\n", "epoch train_loss valid_loss accuracy\n", "1 0.184354 0.168152 0.940137 (00:01)\n", "2 0.146272 0.143661 0.946516 (00:01)\n", "\n" ] } ], "source": [ "learn = simple_learner()\n", "simple_learner().fit(2, 1e-3)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class LRFinder[source]

\n", "\n", "> LRFinder(`learn`:[`Learner`](/basic_train.html#Learner), `start_lr`:`float`=`1e-07`, `end_lr`:`float`=`10`, `num_it`:`int`=`100`, `stop_div`:`bool`=`True`) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "Causes `learn` to go on a mock training from `start_lr` to `end_lr` for `num_it` iterations. Training is interrupted if the loss diverges. Weights changes are reverted after run complete. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(LRFinder)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_end[source]

\n", "\n", "> on_train_end(`kwargs`:`Any`)\n", "\n", "Cleanup learn model weights disturbed during LRFind exploration. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(LRFinder.on_train_end)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_batch_end[source]

\n", "\n", "> on_batch_end(`iteration`:`int`, `smooth_loss`:`TensorOrNumber`, `kwargs`:`Any`)\n", "\n", "Determine if loss has runaway and we should stop. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(LRFinder.on_batch_end)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source]

\n", "\n", "> on_train_begin(`pbar`, `kwargs`:`Any`)\n", "\n", "Initialize optimizer and learner hyperparameters. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(LRFinder.on_train_begin)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source]

\n", "\n", "> on_epoch_end(`kwargs`:`Any`)\n", "\n", "Tell Learner if we need to stop. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(LRFinder.on_epoch_end)" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Implementation of the LR Range test from Leslie Smith", "title": "callbacks.lr_finder" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.6" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }