{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from exp.nb_04 import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initial setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 9 video](https://course19.fast.ai/videos/?lesson=9&t=7013)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x_train,y_train,x_valid,y_valid = get_data()\n", "train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)\n", "nh,bs = 50,512\n", "c = y_train.max().item()+1\n", "loss_func = F.cross_entropy" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def create_learner(model_func, loss_func, data):\n", " return Learner(*model_func(data), loss_func, data)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train: [0.6664653125, tensor(0.8075)]\n", "valid: [0.302250390625, tensor(0.9146)]\n", "train: [0.291615546875, tensor(0.9162)]\n", "valid: [0.2376760986328125, tensor(0.9324)]\n", "train: [0.23417873046875, tensor(0.9323)]\n", "valid: [0.2117640869140625, tensor(0.9397)]\n" ] } ], "source": [ "learn = create_learner(get_model, loss_func, data)\n", "run = Runner([AvgStatsCallback([accuracy])])\n", "\n", "run.fit(3, learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train: [0.762880078125, tensor(0.7988)]\n", "valid: [0.36490234375, tensor(0.8969)]\n", "train: [0.3508654296875, tensor(0.9002)]\n", "valid: [0.30942646484375, tensor(0.9107)]\n", "train: [0.30202353515625, tensor(0.9126)]\n", "valid: [0.26613701171875, tensor(0.9218)]\n" ] } ], "source": [ "learn = create_learner(partial(get_model, lr=0.3), loss_func, data)\n", "run = Runner([AvgStatsCallback([accuracy])])\n", "\n", "run.fit(3, learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def get_model_func(lr=0.5): return partial(get_model, lr=lr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Annealing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define two new callbacks: the Recorder to save track of the loss and our scheduled learning rate, and a ParamScheduler that can schedule any hyperparameter as long as it's registered in the state_dict of the optimizer. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 9 video](https://course19.fast.ai/videos/?lesson=9&t=7202)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Recorder(Callback):\n", " def begin_fit(self): self.lrs,self.losses = [],[]\n", "\n", " def after_batch(self):\n", " if not self.in_train: return\n", " self.lrs.append(self.opt.param_groups[-1]['lr'])\n", " self.losses.append(self.loss.detach().cpu()) \n", "\n", " def plot_lr (self): plt.plot(self.lrs)\n", " def plot_loss(self): plt.plot(self.losses)\n", "\n", "class ParamScheduler(Callback):\n", " _order=1\n", " def __init__(self, pname, sched_func): self.pname,self.sched_func = pname,sched_func\n", "\n", " def set_param(self):\n", " for pg in self.opt.param_groups:\n", " pg[self.pname] = self.sched_func(self.n_epochs/self.epochs)\n", " \n", " def begin_batch(self): \n", " if self.in_train: self.set_param()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's start with a simple linear schedule going from start to end. It returns a function that takes a `pos` argument (going from 0 to 1) such that this function goes from `start` (at `pos=0`) to `end` (at `pos=1`) in a linear fashion." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 9 video](https://course19.fast.ai/videos/?lesson=9&t=7431)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def sched_lin(start, end):\n", " def _inner(start, end, pos): return start + pos*(end-start)\n", " return partial(_inner, start, end)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can refactor this with a decorator." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 9 video](https://course19.fast.ai/videos/?lesson=9&t=7526)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def annealer(f):\n", " def _inner(start, end): return partial(f, start, end)\n", " return _inner\n", "\n", "@annealer\n", "def sched_lin(start, end, pos): return start + pos*(end-start)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# shift-tab works too, in Jupyter!\n", "# sched_lin()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.3" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "f = sched_lin(1,2)\n", "f(0.3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And here are other scheduler functions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@annealer\n", "def sched_cos(start, end, pos): return start + (1 + math.cos(math.pi*(1-pos))) * (end-start) / 2\n", "@annealer\n", "def sched_no(start, end, pos): return start\n", "@annealer\n", "def sched_exp(start, end, pos): return start * (end/start) ** pos\n", "\n", "def cos_1cycle_anneal(start, high, end):\n", " return [sched_cos(start, high), sched_cos(high, end)]\n", "\n", "#This monkey-patch is there to be able to plot tensors\n", "torch.Tensor.ndim = property(lambda x: len(x.shape))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Jump_to lesson 9 video](https://course19.fast.ai/videos/?lesson=9&t=7730)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "annealings = \"NO LINEAR COS EXP\".split()\n", "\n", "a = torch.arange(0, 100)\n", "p = torch.linspace(0.01,1,100)\n", "\n", "fns = [sched_no, sched_lin, sched_cos, sched_exp]\n", "for fn, t in zip(fns, annealings):\n", " f = fn(2, 1e-2)\n", " plt.plot(a, [f(o) for o in p], label=t)\n", "plt.legend();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In practice, we'll often want to combine different schedulers, the following function does that: it uses `scheds[i]` for `pcts[i]` of the training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def combine_scheds(pcts, scheds):\n", " assert sum(pcts) == 1.\n", " pcts = tensor([0] + listify(pcts))\n", " assert torch.all(pcts >= 0)\n", " pcts = torch.cumsum(pcts, 0)\n", " def _inner(pos):\n", " idx = (pos >= pcts).nonzero().max()\n", " if idx == 2: idx = 1\n", " actual_pos = (pos-pcts[idx]) / (pcts[idx+1]-pcts[idx])\n", " return scheds[idx](actual_pos)\n", " return _inner" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is an example: use 30% of the budget to go from 0.3 to 0.6 following a cosine, then the last 70% of the budget to go from 0.6 to 0.2, still following a cosine." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sched = combine_scheds([0.3, 0.7], [sched_cos(0.3, 0.6), sched_cos(0.6, 0.2)]) " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(a, [sched(o) for o in p])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use it for training quite easily..." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cbfs = [Recorder,\n", " partial(AvgStatsCallback,accuracy),\n", " partial(ParamScheduler, 'lr', sched)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = create_learner(get_model_func(0.3), loss_func, data)\n", "run = Runner(cb_funcs=cbfs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "train: [0.788337578125, tensor(0.7860)]\n", "valid: [0.3378462158203125, tensor(0.8999)]\n", "train: [0.29788216796875, tensor(0.9127)]\n", "valid: [0.2410066162109375, tensor(0.9319)]\n", "train: [0.23813974609375, tensor(0.9312)]\n", "valid: [0.2132032958984375, tensor(0.9400)]\n" ] } ], "source": [ "run.fit(3, learn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "... then check with our recorder if the learning rate followed the right schedule." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "run.recorder.plot_lr()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "run.recorder.plot_loss()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 05_anneal.ipynb to nb_05.py\r\n" ] } ], "source": [ "!./notebook2script.py 05_anneal.ipynb" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }