{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#skip\n", "! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp optimizer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai.torch_basics import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Optimizers\n", "\n", "> Define the general fastai optimizer and the variants" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `_BaseOptimizer` -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class _BaseOptimizer():\n", " \"Common functionality between `Optimizer` and `OptimWrapper`\"\n", " def all_params(self, n=slice(None), with_grad=False):\n", " res = L((p,pg,self.state[p],hyper) for pg,hyper in zip(self.param_lists[n],self.hypers[n]) for p in pg)\n", " return L(o for o in res if hasattr(o[0], 'grad') and o[0].grad is not None) if with_grad else res\n", "\n", " def _set_require_grad(self, rg, p,pg,state,h): p.requires_grad_(rg or state.get('force_train', False))\n", " def freeze_to(self, n):\n", " self.frozen_idx = n if n >= 0 else len(self.param_lists) + n\n", " if self.frozen_idx >= len(self.param_lists):\n", " warn(f\"Freezing {self.frozen_idx} groups; model has {len(self.param_lists)}; whole model is frozen.\")\n", " for o in self.all_params(slice(n, None)): self._set_require_grad(True, *o)\n", " for o in self.all_params(slice(None, n)): self._set_require_grad(False, *o)\n", "\n", " def freeze(self):\n", " assert(len(self.param_lists)>1)\n", " self.freeze_to(-1)\n", "\n", " def set_freeze(self, n, rg, ignore_force_train=False):\n", " for p in self.param_lists[n]: p.requires_grad_(rg or (state.get('force_train', False) and not ignore_force_train))\n", "\n", " def unfreeze(self): self.freeze_to(0)\n", " def set_hypers(self, **kwargs): L(kwargs.items()).starmap(self.set_hyper)\n", " def _set_hyper(self, k, v):\n", " for v_,h in zip(v, self.hypers): h[k] = v_\n", "\n", " def set_hyper(self, k, v):\n", " if isinstance(v, slice):\n", " if v.start: v = even_mults(v.start, v.stop, len(self.param_lists))\n", " else: v = [v.stop/10]*(len(self.param_lists)-1) + [v.stop]\n", " v = L(v, use_list=None)\n", " if len(v)==1: v = v*len(self.param_lists)\n", " assert len(v) == len(self.hypers), f\"Trying to set {len(v)} values for {k} but there are {len(self.param_lists)} parameter groups.\"\n", " self._set_hyper(k, v)\n", "\n", " @property\n", " def param_groups(self): return [{**{'params': pg}, **hp} for pg,hp in zip(self.param_lists, self.hypers)]\n", " @param_groups.setter\n", " def param_groups(self, v):\n", " for pg,v_ in zip(self.param_lists,v): pg = v_['params']\n", " for hyper,v_ in zip(self.hypers,v):\n", " for k,t in v_.items():\n", " if k != 'params': hyper[k] = t" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "add_docs(_BaseOptimizer, \n", " all_params=\"List of param_groups, parameters, and hypers\",\n", " freeze_to=\"Freeze parameter groups up to `n`\",\n", " freeze=\"Freeze up to last parameter group\",\n", " set_freeze=\"Set `rg` for parameter group `n` only\",\n", " unfreeze=\"Unfreeze the entire model\",\n", " set_hypers=\"`set_hyper` for all `kwargs`\",\n", " set_hyper=\"Set the value(s) in `v` for hyper-parameter `k`\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _update(state, new=None):\n", " if new is None: return state\n", " if isinstance(new, dict): state.update(new)\n", " return state" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `Optimizer` -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Optimizer(_BaseOptimizer):\n", " \"Base optimizer class for the fastai library, updating `params` with `cbs`\"\n", " _keep_on_clear = ['force_train', 'do_wd']\n", " def __init__(self, params, cbs, train_bn=True, **defaults):\n", " params = L(params)\n", " self.cbs,self.state,self.train_bn = L(cbs),defaultdict(dict),train_bn\n", " defaults = merge(*self.cbs.attrgot('defaults'), defaults)\n", " self.param_lists = L(L(p) for p in params) if isinstance(params[0], (L,list)) else L([params])\n", " self.hypers = L({} for _ in range_of(self.param_lists))\n", " self.set_hypers(**defaults)\n", " self.frozen_idx = 0\n", "\n", " def zero_grad(self):\n", " for p,*_ in self.all_params(with_grad=True):\n", " p.grad.detach_()\n", " p.grad.zero_()\n", "\n", " def step(self):\n", " for p,pg,state,hyper in self.all_params(with_grad=True):\n", " for cb in self.cbs: state = _update(state, cb(p, **{**state, **hyper}))\n", " self.state[p] = state\n", "\n", " def clear_state(self):\n", " for p,pg,state,hyper in self.all_params():\n", " self.state[p] = {k: state[k] for k in self._keep_on_clear if k in state}\n", "\n", " def state_dict(self):\n", " state = [self.state[p] for p,*_ in self.all_params()]\n", " return {'state': state, 'hypers': self.hypers}\n", "\n", " def load_state_dict(self, sd):\n", " assert len(sd[\"hypers\"]) == len(self.param_lists)\n", " assert len(sd[\"state\"]) == sum([len(pg) for pg in self.param_lists])\n", " self.hypers = sd['hypers']\n", " self.state = {p: s for p,s in zip(self.all_params().itemgot(0), sd['state'])}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "add_docs(Optimizer, \n", " zero_grad=\"Standard PyTorch API: Zero all the grad attributes of the parameters\",\n", " step=\"Standard PyTorch API: Update the stats and execute the steppers in on all parameters that have a grad\",\n", " state_dict=\"Return the state of the optimizer in a dictionary\",\n", " load_state_dict=\"Load the content of `sd`\",\n", " clear_state=\"Reset the state of the optimizer\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Initializing an Optimizer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`params` will be used to create the `param_groups` of the optimizer. If it's a collection (or a generator) of parameters, it will be a `L` containing one `L` with all the parameters. To define multiple parameter groups `params` should be passed as a collection (or a generator) of `L`s.\n", "\n", "> Note: In PyTorch, model.parameters() returns a generator with all the parameters, that you can directly pass to Optimizer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "opt = Optimizer([1,2,3], noop)\n", "test_eq(opt.param_lists, [[1,2,3]])\n", "opt = Optimizer(range(3), noop)\n", "test_eq(opt.param_lists, [[0,1,2]])\n", "opt = Optimizer([[1,2],[3]], noop)\n", "test_eq(opt.param_lists, [[1,2],[3]])\n", "opt = Optimizer(([o,o+1] for o in range(0,4,2)), noop)\n", "test_eq(opt.param_lists, [[0,1],[2,3]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cbs` is a list of functions that will be composed when applying the step. For instance, you can compose a function making the SGD step, with another one applying weight decay. Additionally, each `cb` can have a `defaults` attribute that contains hyper-parameters and their default value. Those are all gathered at initialization, and new values can be passed to override those defaults with the `defaults` kwargs. The steppers will be called by `Optimizer.step` (which is the standard PyTorch name), and gradients can be cleared with `Optimizer.zero_grad` (also a standard PyTorch name).\n", "\n", "Once the defaults have all been pulled off, they are copied as many times as there are `param_groups` and stored in `hypers`. To apply different hyper-parameters to different groups (differential learning rates, or no weight decay for certain layers for instance), you will need to adjust those values after the init. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def tst_arg(p, lr=0, **kwargs): return p\n", "tst_arg.defaults = dict(lr=1e-2)\n", "\n", "def tst_arg2(p, lr2=0, **kwargs): return p\n", "tst_arg2.defaults = dict(lr2=1e-3)\n", "\n", "def tst_arg3(p, mom=0, **kwargs): return p\n", "tst_arg3.defaults = dict(mom=0.9)\n", "\n", "def tst_arg4(p, **kwargs): return p\n", "\n", "opt = Optimizer([1,2,3], [tst_arg,tst_arg2, tst_arg3])\n", "test_eq(opt.hypers, [{'lr2': 1e-3, 'mom': 0.9, 'lr': 1e-2}])\n", "opt = Optimizer([1,2,3], tst_arg, lr=0.1)\n", "test_eq(opt.hypers, [{'lr': 0.1}])\n", "opt = Optimizer([[1,2],[3]], tst_arg)\n", "test_eq(opt.hypers, [{'lr': 1e-2}, {'lr': 1e-2}])\n", "opt = Optimizer([[1,2],[3]], tst_arg, lr=0.1)\n", "test_eq(opt.hypers, [{'lr': 0.1}, {'lr': 0.1}])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For each hyper-parameter, you can pass a slice or a collection to set them, if there are multiple parameter groups. A slice will be converted to a log-uniform collection from its beginning to its end, or if it only has an end `e`, to a collection of as many values as there are parameter groups that are `...,e/10,e/10,e`.\n", "\n", "Setting an hyper-parameter with a collection that has a different number of elements than the optimizer has parameter groups will raise an error." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "opt = Optimizer([[1,2],[3]], tst_arg, lr=[0.1,0.2])\n", "test_eq(opt.hypers, [{'lr': 0.1}, {'lr': 0.2}])\n", "opt = Optimizer([[1,2],[3],[4]], tst_arg, lr=slice(1e-2))\n", "test_eq(opt.hypers, [{'lr': 1e-3}, {'lr': 1e-3}, {'lr': 1e-2}])\n", "opt = Optimizer([[1,2],[3],[4]], tst_arg, lr=slice(1e-4,1e-2))\n", "test_eq(opt.hypers, [{'lr': 1e-4}, {'lr': 1e-3}, {'lr': 1e-2}])\n", "test_eq(opt.param_groups, [{'params': [1,2], 'lr': 1e-4}, {'params': [3], 'lr': 1e-3}, {'params': [4], 'lr': 1e-2}])\n", "test_fail(lambda: Optimizer([[1,2],[3],[4]], tst_arg, lr=np.array([0.1,0.2])))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Basic steppers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To be able to give examples of optimizer steps, we will need some steppers, like the following:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def sgd_step(p, lr, **kwargs):\n", " p.data.add_(p.grad.data, alpha=-lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def tst_param(val, grad=None):\n", " \"Create a tensor with `val` and a gradient of `grad` for testing\"\n", " res = tensor([val]).float()\n", " res.grad = tensor([val/10 if grad is None else grad]).float()\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p = tst_param(1., 0.1)\n", "sgd_step(p, 1.)\n", "test_eq(p, tensor([0.9]))\n", "test_eq(p.grad, tensor([0.1]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def weight_decay(p, lr, wd, do_wd=True, **kwargs):\n", " \"Weight decay as decaying `p` with `lr*wd`\"\n", " if do_wd and wd!=0: p.data.mul_(1 - lr*wd)\n", "\n", "weight_decay.defaults = dict(wd=0.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p = tst_param(1., 0.1)\n", "weight_decay(p, 1., 0.1)\n", "test_eq(p, tensor([0.9]))\n", "test_eq(p.grad, tensor([0.1]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def l2_reg(p, lr, wd, do_wd=True, **kwargs):\n", " \"L2 regularization as adding `wd*p` to `p.grad`\"\n", " if do_wd and wd!=0: p.grad.data.add_(p.data, alpha=wd)\n", "\n", "l2_reg.defaults = dict(wd=0.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p = tst_param(1., 0.1)\n", "l2_reg(p, 1., 0.1)\n", "test_eq(p, tensor([1.]))\n", "test_eq(p.grad, tensor([0.2]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Warning: Weight decay and L2 regularization is the same thing for basic SGD, but for more complex optimizers, they are very different." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Making the step" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Optimizer.step[source]

\n", "\n", "> Optimizer.step()\n", "\n", "Standard PyTorch API: Update the stats and execute the steppers in on all parameters that have a grad" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Optimizer.step)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This method will loop over all param groups, then all parameters for which `grad` is not None and call each function in `stepper`, passing it the parameter `p` with the hyper-parameters in the corresponding dict in `hypers`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test basic step\n", "r = L.range(4)\n", "def tst_params(): return r.map(tst_param)\n", "\n", "params = tst_params()\n", "opt = Optimizer(params, sgd_step, lr=0.1)\n", "opt.step()\n", "test_close([p.item() for p in params], r.map(mul(0.99)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test two steps\n", "params = tst_params()\n", "opt = Optimizer(params, [weight_decay, sgd_step], lr=0.1, wd=0.1)\n", "opt.step()\n", "test_close([p.item() for p in params], r.map(mul(0.98)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test None gradients are ignored\n", "params = tst_params()\n", "opt = Optimizer(params, sgd_step, lr=0.1)\n", "params[-1].grad = None\n", "opt.step()\n", "test_close([p.item() for p in params], [0., 0.99, 1.98, 3.])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test discriminative lrs\n", "params = tst_params()\n", "opt = Optimizer([params[:2], params[2:]], sgd_step, lr=0.1)\n", "opt.hypers[0]['lr'] = 0.01\n", "opt.step()\n", "test_close([p.item() for p in params], [0., 0.999, 1.98, 2.97])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Optimizer.zero_grad[source]

\n", "\n", "> Optimizer.zero_grad()\n", "\n", "Standard PyTorch API: Zero all the grad attributes of the parameters" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Optimizer.zero_grad)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = tst_params()\n", "opt = Optimizer(params, [weight_decay, sgd_step], lr=0.1, wd=0.1)\n", "opt.zero_grad()\n", "[test_eq(p.grad, tensor([0.])) for p in params];" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some of the `Optimizer` `cbs` can be functions updating the state associated with a parameter. That state can then be used by any stepper. The best example is a momentum calculation." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def tst_stat(p, **kwargs): \n", " s = kwargs.get('sum', torch.zeros_like(p)) + p.data\n", " return {'sum': s}\n", "tst_stat.defaults = {'mom': 0.9}\n", "\n", "#Test Optimizer init\n", "opt = Optimizer([1,2,3], tst_stat)\n", "test_eq(opt.hypers, [{'mom': 0.9}])\n", "opt = Optimizer([1,2,3], tst_stat, mom=0.99)\n", "test_eq(opt.hypers, [{'mom': 0.99}])\n", "\n", "#Test stat\n", "x = torch.randn(4,5)\n", "state = tst_stat(x)\n", "assert 'sum' in state\n", "test_eq(x, state['sum'])\n", "state = tst_stat(x, **state)\n", "test_eq(state['sum'], 2*x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Statistics" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def average_grad(p, mom, dampening=False, grad_avg=None, **kwargs):\n", " \"Keeps track of the avg grads of `p` in `state` with `mom`.\"\n", " if grad_avg is None: grad_avg = torch.zeros_like(p.grad.data)\n", " damp = 1-mom if dampening else 1.\n", " grad_avg.mul_(mom).add_(p.grad.data, alpha=damp)\n", " return {'grad_avg': grad_avg}\n", "\n", "average_grad.defaults = dict(mom=0.9)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`dampening=False` gives the classical formula for momentum in SGD: \n", "```\n", "new_val = old_val * mom + grad\n", "```\n", "whereas `dampening=True` makes it an exponential moving average:\n", "```\n", "new_val = old_val * mom + grad * (1-mom)\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p = tst_param([1,2,3], [4,5,6])\n", "state = {}\n", "state = average_grad(p, mom=0.9, **state)\n", "test_eq(state['grad_avg'], p.grad)\n", "state = average_grad(p, mom=0.9, **state)\n", "test_eq(state['grad_avg'], p.grad * 1.9)\n", "\n", "#Test dampening\n", "state = {}\n", "state = average_grad(p, mom=0.9, dampening=True, **state)\n", "test_eq(state['grad_avg'], 0.1*p.grad)\n", "state = average_grad(p, mom=0.9, dampening=True, **state)\n", "test_close(state['grad_avg'], (0.1*0.9+0.1)*p.grad)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def average_sqr_grad(p, sqr_mom, dampening=True, sqr_avg=None, **kwargs):\n", " if sqr_avg is None: sqr_avg = torch.zeros_like(p.grad.data)\n", " damp = 1-sqr_mom if dampening else 1.\n", " sqr_avg.mul_(sqr_mom).addcmul_(p.grad.data, p.grad.data, value=damp)\n", " return {'sqr_avg': sqr_avg}\n", "\n", "average_sqr_grad.defaults = dict(sqr_mom=0.99)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`dampening=False` gives the classical formula for momentum in SGD: \n", "```\n", "new_val = old_val * mom + grad**2\n", "```\n", "whereas `dampening=True` makes it an exponential moving average:\n", "```\n", "new_val = old_val * mom + (grad**2) * (1-mom)\n", "```" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p = tst_param([1,2,3], [4,5,6])\n", "state = {}\n", "state = average_sqr_grad(p, sqr_mom=0.99, dampening=False, **state)\n", "test_eq(state['sqr_avg'], p.grad.pow(2))\n", "state = average_sqr_grad(p, sqr_mom=0.99, dampening=False, **state)\n", "test_eq(state['sqr_avg'], p.grad.pow(2) * 1.99)\n", "\n", "#Test dampening\n", "state = {}\n", "state = average_sqr_grad(p, sqr_mom=0.99, **state)\n", "test_close(state['sqr_avg'], 0.01*p.grad.pow(2))\n", "state = average_sqr_grad(p, sqr_mom=0.99, **state)\n", "test_close(state['sqr_avg'], (0.01*0.99+0.01)*p.grad.pow(2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Freezing part of the model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Optimizer.freeze[source]

\n", "\n", "> Optimizer.freeze()\n", "\n", "Freeze up to last parameter group" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Optimizer.freeze, name=\"Optimizer.freeze\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Optimizer.freeze_to[source]

\n", "\n", "> Optimizer.freeze_to(**`n`**)\n", "\n", "Freeze parameter groups up to `n`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Optimizer.freeze_to, name=\"Optimizer.freeze_to\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Optimizer.unfreeze[source]

\n", "\n", "> Optimizer.unfreeze()\n", "\n", "Unfreeze the entire model" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Optimizer.unfreeze, name=\"Optimizer.unfreeze\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Freezing the first layer\n", "params = [tst_params(), tst_params(), tst_params()]\n", "opt = Optimizer(params, sgd_step, lr=0.1)\n", "opt.freeze_to(1)\n", "req_grad = Self.requires_grad()\n", "test_eq(L(params[0]).map(req_grad), [False]*4)\n", "for i in {1,2}: test_eq(L(params[i]).map(req_grad), [True]*4)\n", " \n", "#Unfreezing\n", "opt.unfreeze()\n", "for i in range(2): test_eq(L(params[i]).map(req_grad), [True]*4)\n", "\n", "#TODO: test warning\n", "# opt.freeze_to(3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Parameters such as batchnorm weights/bias can be marked to always be in training mode, just put `force_train=true` in their state." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = [tst_params(), tst_params(), tst_params()]\n", "opt = Optimizer(params, sgd_step, lr=0.1)\n", "for p in L(params[1])[[1,3]]: opt.state[p] = {'force_train': True}\n", "opt.freeze()\n", "test_eq(L(params[0]).map(req_grad), [False]*4)\n", "test_eq(L(params[1]).map(req_grad), [False, True, False, True])\n", "test_eq(L(params[2]).map(req_grad), [True]*4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Serializing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Optimizer.state_dict[source]

\n", "\n", "> Optimizer.state_dict()\n", "\n", "Return the state of the optimizer in a dictionary" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Optimizer.state_dict)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Optimizer.load_state_dict[source]

\n", "\n", "> Optimizer.load_state_dict(**`sd`**)\n", "\n", "Load the content of `sd`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Optimizer.load_state_dict)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p = tst_param([1,2,3], [4,5,6])\n", "opt = Optimizer(p, average_grad)\n", "opt.step()\n", "test_eq(opt.state[p]['grad_avg'], tensor([[4., 5., 6.]]))\n", "\n", "sd = opt.state_dict()\n", "p1 = tst_param([10,20,30], [40,50,60])\n", "opt = Optimizer(p1, average_grad, mom=0.99)\n", "test_eq(opt.hypers[0]['mom'], 0.99)\n", "test_eq(opt.state, {})\n", "\n", "opt.load_state_dict(sd)\n", "test_eq(opt.hypers[0]['mom'], 0.9)\n", "test_eq(opt.state[p1]['grad_avg'], tensor([[4., 5., 6.]]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Optimizer.clear_state[source]

\n", "\n", "> Optimizer.clear_state()\n", "\n", "Reset the state of the optimizer" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Optimizer.clear_state)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p = tst_param([1,2,3], [4,5,6])\n", "opt = Optimizer(p, average_grad)\n", "opt.state[p] = {'force_train': True}\n", "opt.step()\n", "test_eq(opt.state[p]['grad_avg'], tensor([[4., 5., 6.]]))\n", "\n", "opt.clear_state()\n", "test_eq(opt.state[p], {'force_train': True})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Optimizers" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SGD with momentum" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def momentum_step(p, lr, grad_avg, **kwargs):\n", " \"Step for SGD with momentum with `lr`\"\n", " p.data.add_(grad_avg, alpha=-lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def SGD(params, lr, mom=0., wd=0., decouple_wd=True):\n", " \"A `Optimizer` for SGD with `lr` and `mom` and `params`\"\n", " cbs = [weight_decay] if decouple_wd else [l2_reg]\n", " if mom != 0: cbs.append(average_grad)\n", " cbs.append(sgd_step if mom==0 else momentum_step)\n", " return Optimizer(params, cbs, lr=lr, mom=mom, wd=wd)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Vanilla SGD\n", "params = tst_params()\n", "opt = SGD(params, lr=0.1)\n", "opt.step()\n", "test_close([p.item() for p in params], [i*0.99 for i in range(4)])\n", "opt.step()\n", "[p.item() for p in params]\n", "test_close([p.item() for p in params], [i*0.98 for i in range(4)])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#SGD with momentum\n", "params = tst_params()\n", "opt = SGD(params, lr=0.1, mom=0.9)\n", "assert isinstance(opt, Optimizer)\n", "opt.step()\n", "test_close([p.item() for p in params], [i*0.99 for i in range(4)])\n", "opt.step()\n", "[p.item() for p in params]\n", "test_close([p.item() for p in params], [i*(1 - 0.1 * (0.1 + 0.1*1.9)) for i in range(4)])\n", "for i,p in enumerate(params): test_close(opt.state[p]['grad_avg'].item(), i*0.19)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Test weight decay, notice how we can see that L2 regularization is different from weight decay even for simple SGD with momentum." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = tst_params()\n", "#Weight decay\n", "opt = SGD(params, lr=0.1, mom=0.9, wd=0.1)\n", "opt.step()\n", "test_close([p.item() for p in params], [i*0.98 for i in range(4)])\n", "#L2 reg\n", "opt = SGD(params, lr=0.1, mom=0.9, wd=0.1, decouple_wd=False)\n", "opt.step()\n", "#TODO: fix cause this formula was wrong\n", "#test_close([p.item() for p in params], [i*0.97 for i in range(4)])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### RMSProp" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def rms_prop_step(p, lr, sqr_avg, eps, grad_avg=None, **kwargs):\n", " \"Step for SGD with momentum with `lr`\"\n", " denom = sqr_avg.sqrt().add_(eps)\n", " p.data.addcdiv_((grad_avg if grad_avg is not None else p.grad), denom, value=-lr)\n", "\n", "rms_prop_step.defaults = dict(eps=1e-8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def RMSProp(params, lr, sqr_mom=0.99, mom=0., wd=0., decouple_wd=True):\n", " \"A `Optimizer` for RMSProp with `lr`, `sqr_mom`, `mom` and `params`\"\n", " cbs = [weight_decay] if decouple_wd else [l2_reg]\n", " cbs += ([average_sqr_grad] if mom==0. else [average_grad, average_sqr_grad])\n", " cbs.append(rms_prop_step)\n", " return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, wd=wd)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "RMSProp was introduced by Geoffrey Hinton in his [course](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf). What is named `sqr_mom` here is the `alpha` in the course. Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Without momentum\n", "params = tst_param([1,2,3], [0.1,0.2,0.3])\n", "opt = RMSProp(params, lr=0.1)\n", "opt.step()\n", "test_close(params[0], tensor([0.,1.,2.]))\n", "opt.step()\n", "step = - 0.1 * 0.1 / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)\n", "test_close(params[0], tensor([step, 1+step, 2+step]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#With momentum\n", "params = tst_param([1,2,3], [0.1,0.2,0.3])\n", "opt = RMSProp(params, lr=0.1, mom=0.9)\n", "opt.step()\n", "test_close(params[0], tensor([0.,1.,2.]))\n", "opt.step()\n", "step = - 0.1 * (0.1 + 0.9*0.1) / (math.sqrt((0.01*0.99+0.01) * 0.1**2) + 1e-8)\n", "test_close(params[0], tensor([step, 1+step, 2+step]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Adam" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def step_stat(p, step=0, **kwargs):\n", " \"Register the number of steps done in `state` for `p`\"\n", " step += 1\n", " return {'step' : step}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "p = tst_param(1,0.1)\n", "state = {}\n", "state = step_stat(p, **state)\n", "test_eq(state['step'], 1)\n", "for _ in range(5): state = step_stat(p, **state)\n", "test_eq(state['step'], 6)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def debias(mom, damp, step): return damp * (1 - mom**step) / (1-mom)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def adam_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, **kwargs):\n", " \"Step for Adam with `lr` on `p`\"\n", " debias1 = debias(mom, 1-mom, step)\n", " debias2 = debias(sqr_mom, 1-sqr_mom, step)\n", " p.data.addcdiv_(grad_avg, (sqr_avg/debias2).sqrt() + eps, value = -lr / debias1)\n", " return p\n", "\n", "adam_step._defaults = dict(eps=1e-5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0.01, decouple_wd=True):\n", " \"A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`\"\n", " cbs = [weight_decay] if decouple_wd else [l2_reg]\n", " cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, adam_step]\n", " return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adam was introduced by Diederik P. Kingma and Jimmy Ba in [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980). For consistency across optimizers, we renamed `beta1` and `beta2` in the paper to `mom` and `sqr_mom`. Note that our defaults also differ from the paper (0.99 for `sqr_mom` or `beta2`, 1e-5 for `eps`). Those values seem to be better from our experiments in a wide range of situations.\n", "\n", "Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients).\n", "\n", "> Note: Don't forget that `eps` is an hyper-parameter you can change. Some models won't train without a very high `eps` like 0.1 (intuitively, the higher `eps` is, the closer we are to normal SGD). The usual default of 1e-8 is often too extreme in the sense we don't manage to get as good results as with SGD. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = tst_param([1,2,3], [0.1,0.2,0.3])\n", "opt = Adam(params, lr=0.1, wd=0)\n", "opt.step()\n", "step = -0.1 * 0.1 / (math.sqrt(0.1**2) + 1e-8)\n", "test_close(params[0], tensor([1+step, 2+step, 3+step]))\n", "opt.step()\n", "test_close(params[0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### RAdam" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "RAdam (for rectified Adam) was introduced by Zhang et al. in [On the Variance of the Adaptive Learning Rate and Beyond](https://arxiv.org/abs/1907.08610) to slightly modify the Adam optimizer to be more stable at the beginning of training (and thus not require a long warmup). They use an estimate of the variance of the moving average of the squared gradients (the term in the denominator of traditional Adam) and rescale this moving average by this term before performing the update.\n", "\n", "This version also incorporates [SAdam](https://arxiv.org/abs/1908.00700); set `beta` to enable this (definition same as in the paper)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def radam_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, beta, **kwargs):\n", " \"Step for RAdam with `lr` on `p`\"\n", " debias1 = debias(mom, 1-mom, step)\n", " debias2 = debias(sqr_mom, 1-sqr_mom, step)\n", " r_inf = 2/(1-sqr_mom) - 1\n", " r = r_inf - 2*step*sqr_mom**step/(1-sqr_mom**step)\n", " if r > 5:\n", " v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))\n", " denom = (sqr_avg/debias2).sqrt()\n", " if eps: denom += eps\n", " if beta: denom = F.softplus(denom, beta)\n", " p.data.addcdiv_(grad_avg, denom, value = -lr*v / debias1)\n", " else: p.data.add_(grad_avg, alpha=-lr / debias1)\n", " return p\n", "\n", "radam_step._defaults = dict(eps=1e-5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def RAdam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., beta=0., decouple_wd=True):\n", " \"A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`\"\n", " cbs = [weight_decay] if decouple_wd else [l2_reg]\n", " cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, radam_step]\n", " return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd, beta=beta)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the effective correction reported to the adam step for 500 iterations in RAdam. We can see how it goes from 0 to 1, mimicking the effect of a warm-up." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/d3fzzAAAACXBIWXMAAAsTAAALEwEAmpwYAAAev0lEQVR4nO3deXhV1b3/8feXjGSETEwJhCGMCioBHKriCA5V29pWrdrai1xtaX8db21729v+eu/Twd62t/dWrT+l1utAtWqLFqu2zlqZZIwQCARICCEJIfN8zvr9kQPGGCDACTtnn8/rec6TPeWc7zo8fFisvfbe5pxDREQi3xCvCxARkfBQoIuI+IQCXUTEJxToIiI+oUAXEfGJWK8+OCsry+Xn53v18SIiEWnt2rU1zrnsvvZ5Fuj5+fmsWbPGq48XEYlIZrb7SPs05CIi4hPHDHQzW2pmVWa2+Qj7zcx+bWYlZrbRzM4Kf5kiInIs/emhPwQsPMr+K4CC0GsxcO/JlyUiIsfrmIHunHsdqD3KIdcCD7tu7wDDzGxUuAoUEZH+CccY+higrMd6eWibiIicQuEIdOtjW593/DKzxWa2xszWVFdXh+GjRUTkkHAEejmQ12M9F6jo60Dn3P3OuULnXGF2dp/TKEVE5ASFYx76cmCJmS0D5gH1zrl9YXhfEZGIEgg6mju6aGrroqm9i8bQz+71zsPrs8cN5/yC8HdqjxnoZvY4MB/IMrNy4N+AOADn3H3ACuBKoARoAW4Le5UiIqdAZyBIfWvnB14Nh5Zb3t/W1P5+YDe2dR4O7eaOQL8+5875E70JdOfcjcfY74Avhq0iEZGT5Jyjoa2Lg80dHGjuoLa5o8dyO3UtfYf2sQI5KT6G9KFxpCbGkpIQS9rQOMYMG0pKQiwpoW2H9n1wPe7wenJ8DLExA3NNp2eX/ouIHI+2zgBVDe1UNbZR3dhOTVM7tc2d1Da3Hw7tQ6+DLR10Bvp+GltC7BCGJ8WTPjSO9KFx5A5PYtiYuMPrPV9pvdbjYwf3xfUKdBHxjHOOupZOqhrbqW7sDuv3l9upamijuqmd6oZ2Gtu7+nyPtMRYMlMSyEiOJy8jiVm5w8hIiSczOZ7hSfGHlzNCr6R4/8aef1smIp5raOtkX10bFfWt7KtrY199KxWhn/vqu3+2dQY/9HtJ8THkpCaQnZrAtJFpXFDQvZyTmkBOWiLZKQlkpXYHdtwADV9EIgW6iJwQ5xy1zR3sqW2h7GArZbUtlB9sYW9dG/vqugO7qVeveojBiLRERqUnMn10GpdOy2Fk+tDuoD4U1qkJpCQomk6EvjUROaK2zgBltS3doV3bwp7aVsoOdi+X1bZ86CRiZnI8o4cNZXxWMudNymJUeiKjhw1l9LBERoWCe6BOCIoCXSTqBYKOirpWdtY0s7O6idKaZnZWN1Na08zeutYPHDs0Loa8jKGMzUji7AmZjM1IYmxGEnkZSeRlDPX1+HQk0LcvEiVaOrooqWpi2/4mdlQ3UVrdzM6aJnYdaKGj6/1x7NSEWMZnJzMnfzifysojPyuJ3OHdwZ2VEo9ZX3f7kMFAgS7iM+1dAXZWN7NtfyPb9jdSXNnEtv2NlB1swYVm8sUOMcZmJjEhK4X5U3KYkJXM+KxkJmSnKLQjmAJdJEI556hqbKeoop6ivQ1sqWyguLKRXQdaCAS7kzt2iDE+K5nTc9O5fnYuk0ekUDAilXEZSRrL9iEFukgECAYde2pbKKpooKiins0VDbxXUU9NU8fhY8ZlJjFlRCpXnDaKySNTmTIilfFZyYP+YhgJHwW6yCDjnGNvXSsbyupZX3aQDeX1vFfRcHgKYOwQo2BEKvOn5DBjdBozRqczbVQqqYlxHlcuXlOgi3issa2TjeX1rC+rY92eOtaX1VHT1A5AfOwQZoxO42Nnjjkc3pNHppAQG+Nx1TIYKdBFTiHnHOUHW1lVWsvqXbWs3X2QkuqmwycrJ2Qlc0FBFmeMHcYZecOYOjJNQybSbwp0kQEUDDpKqptYWVrL6lCI76tvAyA1MZbZ44Zz1cxRnDl2OLNy0xmWFO9xxRLJFOgiYRQMOor3N/JWSQ3v7Kxlze5a6lo6AchJTWDO+Azm5mcwJz+DKSNTiRmi6YESPgp0kZNUfrCFt0pqeLPkAG+X1HCguXvmybjMJC6bNuJwiI/LTNL8bhlQCnSR41Tf2snbJTW8WVLDWyU17DrQAkB2agIXTM7m3ImZnDcpi9HDhnpcqUQbBbrIMTjn2La/iZe3VvFKcRVrdx8kEHQkx8dw9oRMbj0nn48UZFGQk6IeuHhKgS7Sh9aOAG/vqOHlrVW8Wlx9+CZV00elcceFE5g/JYcz8obpXtwyqCjQRUJqmzv425b9vLC5kjdKaujoCpIUH8NHJmXxpYsnMX9KDiPTE70uU+SIFOgS1fbVt/Ji0X5eKKpkZWktgaBjzLChfGbeWC6ZOoI544frIh6JGAp0iTp7DrSwYvM+/rq5kvVldQAU5KRw54UTWXjaSGaMTtNYuEQkBbpEhaqGNp7buI/lGyoOh/jM3HS+uWAKC2aMZFJOircFioSBAl18q76lk78WdYf4P3YcIOi6T2redcVUrp45itzhSV6XKBJWCnTxlc5AkFeLq3lyTRmvFlfTEQiSn5nEkosmcc0Zo5mUk+p1iSIDRoEuvrBtfyNPrinjmXV7qWnqICslgZvPHse1Z4xmZm66xsQlKijQJWLVt3by7IYKnlxbzoayOmKHGBdPzeFThXlcOCVbc8Ql6ijQJeJsKKvjf9/ZzbMbKmjvCjJ1ZCr/etU0rjtzDFkpCV6XJ+IZBbpEhNaOAM9uqOB/39nNpr31JMfHcP3sXD49J4/Tx2hIRQQU6DLI7axu4tGVe3hyTRkNbV1MHpHCj66dwXVnjtEj10R6UaDLoOOc443tNTzwZimvb6smdoix8LSR3HL2OOaOz1BvXOQIFOgyaLR3BVi+voIH3yxla2Uj2akJfO2yydwwN4+cVN1DReRYFOjiuYPNHTy2ag8Pvb2L6sZ2poxI5e7rZ3LNGaN1HxWR46BAF89U1rfx29d3sGxVGa2dAc4vyOI/PzmL8wuyNKwicgL6FehmthD4LyAGeMA595Ne+9OBR4Cxoff8uXPud2GuVXyirLaFe1/bwR/XlBNwjuvOGMPtF4xn6sg0r0sTiWjHDHQziwF+A1wGlAOrzWy5c+69Hod9EXjPOfdRM8sGis3sUedcx4BULRFpZ3UT97y6g2fW7SXGjE8W5nLHhRPJy9A9VUTCoT899LlAiXNuJ4CZLQOuBXoGugNSrfv/ySlALdAV5lolQpVUNfHrv2/nuY0VxMUM4dZzxrH4ggmMStczN0XCqT+BPgYo67FeDszrdcz/AMuBCiAV+LRzLtj7jcxsMbAYYOzYsSdSr0SQ8oMt/NfftvPUu+UkxsVw+wUTWPSRCWSn6mpOkYHQn0Dv6+yU67W+AFgPXAxMBF4yszeccw0f+CXn7gfuBygsLOz9HuITVY1t3PPKDh5duRsz4/PnjefO+RPJ1GX5IgOqP4FeDuT1WM+luyfe023AT5xzDigxs1JgKrAqLFVKRKhv7eS3r+3gd2/toiMQ5FOFeXz5kkkaWhE5RfoT6KuBAjMbD+wFbgBu6nXMHuAS4A0zGwFMAXaGs1AZvDoDQR5buYdf/W0bda2dXDNrNF+9dDL5WclelyYSVY4Z6M65LjNbArxA97TFpc65IjO7I7T/PuBHwENmtonuIZpvOedqBrBuGQScc7y8tYr/WLGFndXNnDsxk+9eNY0Zo9O9Lk0kKvVrHrpzbgWwote2+3osVwCXh7c0Gczeq2jgP1a8x1slB5iQlcwDtxZyybQcXRAk4iFdKSrH5WBzBz97oZhlq/eQPjSOH3x0Op85e5weJiEyCCjQpV+CQceTa8v4yfNbaWjr4nPn5vOVSyaTnqRb2IoMFgp0Oaaiinq+96fNvLunjjn5w/m/157GtFG6TF9ksFGgyxE1tHXyixe38fA/djE8KZ6ff3IWnzhrjMbJRQYpBbr06cWiSv71T5upbmrn5nnj+MblUzS8IjLIKdDlA2qa2vnB8iKe27iPqSNTeeCzhczMHeZ1WSLSDwp0AbrnlP95fQU/fLaI5vYAX79sMnfMn6jZKyIRRIEu7G9o49tPb+LlrVWcOXYYP/vETApGpHpdlogcJwV6lPvLxn1855lNtHcF+N7V0/ncufnEDNFJT5FIpECPUvWtnfxgeRHPrNvLrLxh/PJTs5iQneJ1WSJyEhToUejtHTV844kN7G9s56uXTuaLF00kVmPlIhFPgR5FOrqC/PzFYv7fGzvJz0zmqTvP5Yy8YV6XJSJhokCPEmW1LSx5fB0byur4zLyxfPeqaSTF649fxE/0NzoKvFBUyTef3IBzcO9nzuKK00d5XZKIDAAFuo91dAX58fNb+N1bu5iZm87/3HgWYzOTvC5LRAaIAt2nympb+OJj77KxvJ7bzsvnriumkhAb43VZIjKAFOg+9Ob2GpY8/i6BoOO3t8xmwYyRXpckIqeAAt1HnHM88EYpP35+C5NyUrj/lkI911MkiijQfaK1I8C3ntrI8g0VXHn6SO6+fhbJCfrjFYkm+hvvA+UHW7j94bVsrWzgmwum8IX5E3XPcpEopECPcOv2HOT2h9fQ3hVk6efmcNGUHK9LEhGPKNAj2F827uNrT6xnRFoiyxbPYVKO7sUiEs0U6BHIOcc9r+7g7heKmT1uOPffMpvMlASvyxIRjynQI0xnIMh3nt7Ek2vLuWbWaH52/UwS4zS/XEQU6BGlpaOLOx95l9e2VfPlSwr46qUFOvkpIocp0CPEweYObntoNRvL6/jJx0/nhrljvS5JRAYZBXoEqKhr5dalq9hT28K9N+vKTxHpmwJ9kCupauSWB1fR1NbFw5+fy9kTMr0uSUQGKQX6ILapvJ5bl64kZsgQlv3z2cwYne51SSIyiCnQB6l1ew5y69JVpCXG8djt8xiXqXuyiMjRKdAHobW7a/ns0tVkJMfz+OKzGTNsqNcliUgE0JOBB5mVOw9w64OryE5N4A//rDAXkf5ToA8ib5fU8LnfrWZkeiLLFp/NqHSFuYj0nwJ9kPjHjgPc9tBq8jKGsmzxOYxIS/S6JBGJMP0KdDNbaGbFZlZiZncd4Zj5ZrbezIrM7LXwlulv6/YcZNHvV5OXkcRjt59NdqruyyIix++YJ0XNLAb4DXAZUA6sNrPlzrn3ehwzDLgHWOic22NmuodrP71X0cBnl64iKzWBRxfNI0s32RKRE9SfHvpcoMQ5t9M51wEsA67tdcxNwNPOuT0Azrmq8JbpTyVVTdzy4EpSEmJ5dNE8DbOIyEnpT6CPAcp6rJeHtvU0GRhuZq+a2Vozu7WvNzKzxWa2xszWVFdXn1jFPlFW28LND6zEzHhk0Txyhyd5XZKIRLj+BHpft/NzvdZjgdnAVcAC4HtmNvlDv+Tc/c65QudcYXZ29nEX6xe1zR3cunQVrZ0BHlk0lwnZejCFiJy8/lxYVA7k9VjPBSr6OKbGOdcMNJvZ68AsYFtYqvSRlo4uPv/QairqWnl00TymjkzzuiQR8Yn+9NBXAwVmNt7M4oEbgOW9jvkzcL6ZxZpZEjAP2BLeUiNfVyDIlx5bx8byOn5945kU5md4XZKI+Mgxe+jOuS4zWwK8AMQAS51zRWZ2R2j/fc65LWb2V2AjEAQecM5tHsjCI41zju/9eTN/31rFv193mm6BKyJh1697uTjnVgArem27r9f63cDd4SvNX/775RIeX1XGkosmcfPZ47wuR0R8SFeKngLPbazgFy9t4+NnjuHrl3/oXLGISFgo0AfYpvJ6vvHkBgrHDefHnzhdzwAVkQGjQB9A+xvaWPTwajKTE7jvltkkxMZ4XZKI+Jjuhz5A2joDLH54DY1tXTx157m6pF9EBpwCfQA45/iXP25k4956fnvzbKaN0lxzERl4GnIZAA++WcryDRV84/IpXK7piSJyiijQw2xVaS0/fn4rl08fwRfmT/S6HBGJIgr0MKpqbGPJY++SN3woP//ULM1oEZFTSmPoYXLosv6Gtk5+//m5pCXGeV2SiEQZBXqY/PzFbawsreWXn56lk6Ai4gkNuYTB69uque+1Hdw0bywfOzPX63JEJEop0E9STVM7X3tiA5NHpPD9q6d7XY6IRDENuZyEYNDxjSc30NDWySOL5pIYpytBRcQ76qGfhN+9vYtXi6v516um6UEVIuI5BfoJKqqo56fPb+XSaSO4RbfDFZFBQIF+Ajq6gnz9iQ2kJ8Xxs+tnar65iAwKGkM/Ab/++3a2VjbywK2FZCTHe12OiAigHvpx21BWx72v7eATZ+Vy6fQRXpcjInKYAv04tHUG+MaTG8hOSeD7H9UURREZXDTkchx+9bftbK9q4qHb5pA+VJf2i8jgoh56P20sr+P+13dww5w85k/J8bocEZEPUaD3Q1cgyHee2URmSgLfuWqa1+WIiPRJgd4PD/9jN5v3NvD9q6frLooiMmgp0I+hsr6N/3yxmAsmZ3P1zFFelyMickQK9GP44bNFdAUd/37tabqASEQGNQX6Uby8dT/Pb67ky5cUMDYzyetyRESOSoF+BG2dAf5teRGTclK4/fwJXpcjInJMmod+BA++WUpZbSuPLZpHfKz+3RORwU9J1YeqhjbueaWEy6aP4NxJWV6XIyLSLwr0Ptz9QjEdgSDfvVJzzkUkcijQe9lUXs8f3y3ntvPGk5+V7HU5IiL9pkDvwTnHj557j4ykeJZcPMnrckREjosCvYeXt1axalctX7lssq4IFZGI069AN7OFZlZsZiVmdtdRjptjZgEzuz58JZ4awaDj7heKGZeZxA1z8rwuR0TkuB0z0M0sBvgNcAUwHbjRzD50M/DQcT8FXgh3kafCsxsr2FrZyNcum0xcjP7jIiKRpz/JNRcocc7tdM51AMuAa/s47kvAU0BVGOs7JToDQX7x0jamjUrjozNHe12OiMgJ6U+gjwHKeqyXh7YdZmZjgI8B9x3tjcxssZmtMbM11dXVx1vrgPnD6jJ2H2jhmwsmM2SI7tciIpGpP4HeV8K5Xuu/Ar7lnAsc7Y2cc/c75wqdc4XZ2dn9LHFgtXUG+O+Xt1M4bjgX6cEVIhLB+nPpfznQ8yxhLlDR65hCYFnoboRZwJVm1uWc+1M4ihxIT64pY39DO7/89Bm6m6KIRLT+BPpqoMDMxgN7gRuAm3oe4Jwbf2jZzB4CnouEMO8MBLnvtZ3MHjeccyZkel2OiMhJOeaQi3OuC1hC9+yVLcATzrkiM7vDzO4Y6AIH0jPr9rK3rpUlF09S71xEIl6/7rbonFsBrOi1rc8ToM65z518WQMvEHTc++oOThuTxvzJg2M8X0TkZETthOu/bNpHaU0zX5yv3rmI+ENUBnow6LjnlRIm5aSwYMZIr8sREQmLqAz017ZXs7WykS/Mn6h55yLiG1EZ6EvfLCUnNYGrdVWoiPhI1AX6tv2NvLG9hlvPGadHy4mIr0Rdoi19s5SE2CHcNG+c16WIiIRVVAX6gaZ2nl63l4+flUtGcrzX5YiIhFVUBfrjq/bQ0RXk8+fle12KiEjYRU2gB4KOx1eVce7ETApGpHpdjohI2EVNoL++vZq9da3cNG+s16WIiAyIqAn0x1buITM5nsun60IiEfGnqAj0yvo2Xt5axfWFuZqqKCK+FRXp9sSaMgJBx41zNNwiIv7l+0APBh1/WF3GRyZlkZ+V7HU5IiIDxveB/k7pAfbWtfLJwlyvSxERGVC+D/Q/rdtLcnyMToaKiO/5OtDbOgM8v6mShaeNYmh8jNfliIgMKF8H+stbq2hs7+JjZ47xuhQRkQHn60B/Zt1eclITOGeiHgAtIv7n20A/2NzBq8VVXDNrNDF6iIWIRAHfBvqKzfvoDDiu03CLiEQJ3wb6XzdXMj4rmRmj07wuRUTklPBloNe1dPCPHQdYeNpIzDTcIiLRwZeB/vctVXQFHQtnaO65iEQPXwb6X4sqGZWeyMzcdK9LERE5ZXwX6M3tXby+rZoFMzTcIiLRxXeB/sb2atq7gizQcIuIRBnfBforW6tJS4xlTv5wr0sRETmlfBXozjle21bN+QXZxMb4qmkiIsfkq9TbWtlIZUMbF07O9roUEZFTzleB/mpxNQAXTlGgi0j08VmgVzFtVBoj0hK9LkVE5JTzTaA3tnWydvdB5qt3LiJRyjeBvqq0lq6g4/yCLK9LERHxRL8C3cwWmlmxmZWY2V197P+MmW0Mvd42s1nhL/XoVpbWEh8zhLPGarqiiESnYwa6mcUAvwGuAKYDN5rZ9F6HlQIXOudmAj8C7g93oceysrSWWXnpJMbpUXMiEp3600OfC5Q453Y65zqAZcC1PQ9wzr3tnDsYWn0HyA1vmUfX1N7F5r31zBuvJxOJSPTqT6CPAcp6rJeHth3JPwHP97XDzBab2RozW1NdXd3/Ko9h7e6DBIKOeRMywvaeIiKRpj+B3tcdrlyfB5pdRHegf6uv/c65+51zhc65wuzs8M1GWbnzALFDjNnjNH4uItErth/HlAN5PdZzgYreB5nZTOAB4Arn3IHwlNc/K0trOT03naT4/jRHRMSf+tNDXw0UmNl4M4sHbgCW9zzAzMYCTwO3OOe2hb/MI2vrDLCxvE7j5yIS9Y7ZpXXOdZnZEuAFIAZY6pwrMrM7QvvvA74PZAL3hO5B3uWcKxy4st9XVNFAZ8Bx1thhp+LjREQGrX6NUTjnVgArem27r8fyImBReEvrnw1ldQCckTfMi48XERk0Iv5K0Q3ldYxKTyRH928RkSgX+YFeVses3GFelyEi4rmIDvS6lg52HWhhZp4eBi0iEtGBvmlvPYB66CIiRHigb93XCMC0UWkeVyIi4r2IDvQtlQ3kpCaQkRzvdSkiIp6L6EAvrmxkqnrnIiJABAd6VyDI9v1NTBuZ6nUpIiKDQsQGemlNMx2BIFMU6CIiQAQH+tbK7hOiU0dqyEVEBCI40Lfvb2SIwcScZK9LEREZFCI20HcdaCF3eBIJsXrknIgIRHSgNzMuM8nrMkREBo2IDHTnHKU1zeRnarhFROSQiAz0upZOGtu61EMXEekhIgN914FmAPXQRUR6iMhA332gBYD8LPXQRUQOichA33WgGTPIHa5AFxE5JCIDffeBFkanDyUxTlMWRUQOichAL6ttIXf4UK/LEBEZVCIy0Csb2hiVrmeIioj0FHGB7pyjqqGdEQp0EZEPiLhAr23uoCMQZGSaAl1EpKeIC/TKhjYABbqISC+RF+j13YGuIRcRkQ+KuEBPHxrHghkjNMtFRKSXWK8LOF6F+RkU5md4XYaIyKATcT10ERHpmwJdRMQnFOgiIj6hQBcR8QkFuoiITyjQRUR8QoEuIuITCnQREZ8w55w3H2xWDew+wV/PAmrCWE6kULujRzS2GaKz3cfb5nHOuey+dngW6CfDzNY45wq9ruNUU7ujRzS2GaKz3eFss4ZcRER8QoEuIuITkRro93tdgEfU7ugRjW2G6Gx32NockWPoIiLyYZHaQxcRkV4U6CIiPhFxgW5mC82s2MxKzOwur+sJJzNbamZVZra5x7YMM3vJzLaHfg7vse/boe+h2MwWeFP1yTGzPDN7xcy2mFmRmf2f0HbfttvMEs1slZltCLX5h6Htvm1zT2YWY2brzOy50Lrv221mu8xsk5mtN7M1oW3hb7dzLmJeQAywA5gAxAMbgOle1xXG9l0AnAVs7rHtZ8BdoeW7gJ+GlqeH2p8AjA99LzFet+EE2jwKOCu0nApsC7XNt+0GDEgJLccBK4Gz/dzmXu3/GvAY8Fxo3fftBnYBWb22hb3dkdZDnwuUOOd2Ouc6gGXAtR7XFDbOudeB2l6brwV+H1r+PXBdj+3LnHPtzrlSoITu7yeiOOf2OefeDS03AluAMfi43a5bU2g1LvRy+LjNh5hZLnAV8ECPzb5v9xGEvd2RFuhjgLIe6+WhbX42wjm3D7rDD8gJbffdd2Fm+cCZdPdYfd3u0LDDeqAKeMk55/s2h/wK+Bcg2GNbNLTbAS+a2VozWxzaFvZ2R9pDoq2PbdE679JX34WZpQBPAV9xzjWY9dW87kP72BZx7XbOBYAzzGwY8IyZnXaUw33RZjO7Gqhyzq01s/n9+ZU+tkVcu0POc85VmFkO8JKZbT3KsSfc7kjroZcDeT3Wc4EKj2o5Vfab2SiA0M+q0HbffBdmFkd3mD/qnHs6tNn37QZwztUBrwIL8X+bzwOuMbNddA+XXmxmj+D/duOcqwj9rAKeoXsIJeztjrRAXw0UmNl4M4sHbgCWe1zTQFsOfDa0/Fngzz2232BmCWY2HigAVnlQ30mx7q74g8AW59wveuzybbvNLDvUM8fMhgKXAlvxcZsBnHPfds7lOufy6f67+7Jz7mZ83m4zSzaz1EPLwOXAZgai3V6f/T2Bs8VX0j0TYgfwXa/rCXPbHgf2AZ10/yv9T0Am8Hdge+hnRo/jvxv6HoqBK7yu/wTb/BG6/zu5EVgfel3p53YDM4F1oTZvBr4f2u7bNvfxHczn/Vkuvm433bPyNoReRYdyayDarUv/RUR8ItKGXERE5AgU6CIiPqFAFxHxCQW6iIhPKNBFRHxCgS4i4hMKdBERn/j/PMyNWsqWzZcAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "beta = 0.99\n", "r_inf = 2/(1-beta) - 1\n", "rs = np.array([r_inf - 2*s*beta**s/(1-beta**s) for s in range(5,500)])\n", "v = np.sqrt(((rs-4) * (rs-2) * r_inf)/((r_inf-4)*(r_inf-2)*rs))\n", "plt.plot(v);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = tst_param([1,2,3], [0.1,0.2,0.3])\n", "opt = RAdam(params, lr=0.1)\n", "#The r factor is lower than 5 during the first 5 steps so updates use the average of gradients (all the same)\n", "r_inf = 2/(1-0.99) - 1\n", "for i in range(5): \n", " r = r_inf - 2*(i+1)*0.99**(i+1)/(1-0.99**(i+1))\n", " assert r <= 5\n", " opt.step()\n", "p = tensor([0.95, 1.9, 2.85])\n", "test_close(params[0], p)\n", "\n", "#The r factor is greater than 5 for the sixth step so we update with RAdam\n", "r = r_inf - 2*6*0.99**6/(1-0.99**6)\n", "assert r > 5\n", "opt.step()\n", "v = math.sqrt(((r-4) * (r-2) * r_inf)/((r_inf-4)*(r_inf-2)*r))\n", "step = -0.1*0.1*v/(math.sqrt(0.1**2) + 1e-8)\n", "test_close(params[0], p+step)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### QHAdam" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "QHAdam (for Quasi-Hyperbolic Adam) was introduced by Ma & Yarats in [Quasi-Hyperbolic Momentum and Adam for Deep Learning](https://arxiv.org/pdf/1810.06801.pdf) as a *\"computationally cheap, intuitive to interpret, and simple to implement\"* optimizer. Additional code can be found in their [qhoptim repo](https://github.com/facebookresearch/qhoptim). QHAdam is based on QH-Momentum, which introduces the immediate discount factor `nu`, encapsulating plain SGD (`nu = 0`) and momentum (`nu = 1`). QH-Momentum is defined below, where g_t+1 is the update of the moment. An interpretation of QHM is as a nu-weighted average of the momentum update step and the plain SGD update step.\n", "\n", "> θ_t+1 ← θ_t − lr * [(1 − nu) · ∇L_t(θ_t) + nu · g_t+1]\n", "\n", "QHAdam takes the concept behind QHM above and applies it to Adam, replacing both of Adam’s moment estimators with quasi-hyperbolic terms. \n", "\n", "The paper's suggested default parameters are `mom = 0.999`, `sqr_mom = 0.999`, `nu_1 = 0.7` and `and nu_2 = 1.0`. When training is not stable, it is possible that setting `nu_2 < 1` can improve stability by imposing a tighter step size bound. Note that QHAdam recovers Adam when `nu_1 = nu_2 = 1.0`. QHAdam recovers RMSProp (Hinton et al., 2012) when `nu_1 = 0` and `nu_2 = 1`, and NAdam (Dozat, 2016) when `nu_1 = mom` and `nu_2 = 1`.\n", "\n", "Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def qhadam_step(p, lr, mom, sqr_mom, sqr_avg, nu_1, nu_2, step, grad_avg, eps, **kwargs):\n", " debias1 = debias(mom, 1-mom, step)\n", " debias2 = debias(sqr_mom, 1-sqr_mom, step)\n", " p.data.addcdiv_(((1-nu_1) * p.grad.data) + (nu_1 * (grad_avg / debias1)),\n", " (((1 - nu_2) * (p.grad.data)**2) + (nu_2 * (sqr_avg / debias2))).sqrt() + eps,\n", " value = -lr)\n", " return p\n", "\n", "qhadam_step._defaults = dict(eps=1e-8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def QHAdam(params, lr, mom=0.999, sqr_mom=0.999, nu_1=0.7, nu_2 = 1.0, eps=1e-8, wd=0., decouple_wd=True):\n", " \"An `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `nus`, eps` and `params`\"\n", " cbs = [weight_decay] if decouple_wd else [l2_reg]\n", " cbs += [partial(average_grad, dampening=True), partial(average_sqr_grad, dampening=True), step_stat, qhadam_step]\n", " return Optimizer(params, cbs, lr=lr, nu_1=nu_1, nu_2=nu_2 ,\n", " mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = tst_param([1,2,3], [0.1,0.2,0.3])\n", "opt = QHAdam(params, lr=0.1)\n", "opt.step()\n", "step = -0.1 * (((1-0.7) * 0.1) + (0.7 * 0.1)) / (\n", " math.sqrt(((1-1.0) * 0.1**2) + (1.0 * 0.1**2)) + 1e-8) \n", "test_close(params[0], tensor([1+step, 2+step, 3+step]))\n", "opt.step()\n", "test_close(params[0], tensor([1+2*step, 2+2*step, 3+2*step]), eps=1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### LARS/LARC" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def larc_layer_lr(p, lr, trust_coeff, wd, eps, clip=True, **kwargs):\n", " \"Computes the local lr before weight decay is applied\"\n", " p_norm,g_norm = torch.norm(p.data),torch.norm(p.grad.data)\n", " local_lr = lr*trust_coeff * (p_norm) / (g_norm + p_norm * wd + eps)\n", " return {'local_lr': min(lr, local_lr) if clip else local_lr}\n", "\n", "larc_layer_lr.defaults = dict(trust_coeff=0.02, wd=0., eps=1e-8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def larc_step(p, local_lr, grad_avg=None, **kwargs):\n", " \"Step for LARC `local_lr` on `p`\"\n", " p.data.add_(p.grad.data if grad_avg is None else grad_avg, alpha = -local_lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def Larc(params, lr, mom=0.9, clip=True, trust_coeff=0.02, eps=1e-8, wd=0., decouple_wd=True):\n", " \"A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`\"\n", " cbs = [weight_decay] if decouple_wd else [l2_reg]\n", " if mom!=0.: cbs.append(average_grad)\n", " cbs += [partial(larc_layer_lr, clip=clip), larc_step]\n", " return Optimizer(params, cbs, lr=lr, mom=mom, trust_coeff=trust_coeff, eps=eps, wd=wd)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The LARS optimizer was first introduced in [Large Batch Training of Convolutional Networks](https://arxiv.org/abs/1708.03888) then refined in its LARC variant (original LARS is with `clip=False`). A learning rate is computed for each individual layer with a certain `trust_coefficient`, then clipped to be always less than `lr`.\n", "\n", "Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]\n", "opt = Larc(params, lr=0.1)\n", "opt.step()\n", "#First param local lr is 0.02 < lr so it's not clipped\n", "test_close(opt.state[params[0]]['local_lr'], 0.02)\n", "#Second param local lr is 0.2 > lr so it's clipped\n", "test_eq(opt.state[params[1]]['local_lr'], 0.1)\n", "test_close(params[0], tensor([0.998,1.996,2.994]))\n", "test_close(params[1], tensor([0.999,1.998,2.997]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = [tst_param([1,2,3], [0.1,0.2,0.3]), tst_param([1,2,3], [0.01,0.02,0.03])]\n", "opt = Larc(params, lr=0.1, clip=False)\n", "opt.step()\n", "#No clipping\n", "test_close(opt.state[params[0]]['local_lr'], 0.02)\n", "test_close(opt.state[params[1]]['local_lr'], 0.2)\n", "test_close(params[0], tensor([0.998,1.996,2.994]))\n", "test_close(params[1], tensor([0.998,1.996,2.994]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### LAMB" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def lamb_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, **kwargs):\n", " \"Step for LAMB with `lr` on `p`\"\n", " debias1 = debias(mom, 1-mom, step)\n", " debias2 = debias(sqr_mom, 1-sqr_mom, step)\n", " r1 = p.data.pow(2).mean().sqrt()\n", " step = (grad_avg/debias1) / ((sqr_avg/debias2).sqrt()+eps)\n", " r2 = step.pow(2).mean().sqrt()\n", " q = 1 if r1 == 0 or r2 == 0 else min(r1/r2,10)\n", " p.data.add_(step, alpha = -lr * q)\n", "\n", "lamb_step._defaults = dict(eps=1e-6, wd=0.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def Lamb(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., decouple_wd=True):\n", " \"A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`\"\n", " cbs = [weight_decay] if decouple_wd else [l2_reg]\n", " cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, lamb_step]\n", " return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "LAMB was introduced in [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962). Intuitively, it's LARC applied to Adam. As in `Adam`, we renamed `beta1` and `beta2` in the paper to `mom` and `sqr_mom`. Note that our defaults also differ from the paper (0.99 for `sqr_mom` or `beta2`, 1e-5 for `eps`). Those values seem to be better from our experiments in a wide range of situations.\n", "\n", "Optional weight decay of `wd` is applied, as true weight decay (decay the weights directly) if `decouple_wd=True` else as L2 regularization (add the decay to the gradients)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = tst_param([1,2,3], [0.1,0.2,0.3])\n", "opt = Lamb(params, lr=0.1)\n", "opt.step()\n", "test_close(params[0], tensor([0.7840,1.7840,2.7840]), eps=1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Lookahead -" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lookahead was introduced by Zhang et al. in [Lookahead Optimizer: k steps forward, 1 step back](https://arxiv.org/abs/1907.08610). It can be run on top of any optimizer and consists in having the final weights of the model be a moving average. In practice, we update our model using the internal optimizer but keep a copy of old weights that and every `k` steps, we change the weights by a moving average of the *fast weights* (the ones updated by the inner optimizer) with the *slow weights* (the copy of old weights). Those *slow weights* act like a stability mechanism." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Lookahead(Optimizer, GetAttr):\n", " \"Wrap `opt` in a lookahead optimizer\"\n", " _default='opt'\n", " def __init__(self, opt, k=6, alpha=0.5):\n", " store_attr('opt,k,alpha')\n", " self._init_state()\n", "\n", " def step(self):\n", " if self.slow_weights is None: self._copy_weights()\n", " self.opt.step()\n", " self.count += 1\n", " if self.count%self.k != 0: return\n", " for slow_pg,fast_pg in zip(self.slow_weights,self.param_lists):\n", " for slow_p,fast_p in zip(slow_pg,fast_pg):\n", " slow_p.data.add_(fast_p.data-slow_p.data, alpha=self.alpha)\n", " fast_p.data.copy_(slow_p.data)\n", "\n", " def clear_state(self):\n", " self.opt.clear_state()\n", " self._init_state()\n", "\n", " def state_dict(self):\n", " state = self.opt.state_dict()\n", " state.update({'count': self.count, 'slow_weights': self.slow_weights})\n", " return state\n", "\n", " def load_state_dict(self, sd):\n", " self.count = sd.pop('count')\n", " self.slow_weights = sd.pop('slow_weights')\n", " self.opt.load_state_dict(sd)\n", "\n", " def _init_state(self): self.count,self.slow_weights = 0,None\n", " def _copy_weights(self): self.slow_weights = L(L(p.clone().detach() for p in pg) for pg in self.param_lists)\n", "\n", " @property\n", " def param_lists(self): return self.opt.param_lists\n", " @param_lists.setter\n", " def param_lists(self, v): self.opt.param_lists = v" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = tst_param([1,2,3], [0.1,0.2,0.3])\n", "p,g = params[0].data.clone(),tensor([0.1,0.2,0.3])\n", "opt = Lookahead(SGD(params, lr=0.1))\n", "for k in range(5): opt.step()\n", "#first 5 steps are normal SGD steps\n", "test_close(params[0], p - 0.5*g)\n", "#Since k=6, sixth step is a moving average of the 6 SGD steps with the initial weight\n", "opt.step()\n", "test_close(params[0], p * 0.5 + (p-0.6*g) * 0.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@delegates(RAdam)\n", "def ranger(p, lr, mom=0.95, wd=0.01, eps=1e-6, **kwargs):\n", " \"Convenience method for `Lookahead` with `RAdam`\"\n", " return Lookahead(RAdam(p, lr=lr, mom=mom, wd=wd, eps=eps, **kwargs))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## OptimWrapper -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def detuplify_pg(d):\n", " res = {}\n", " for k,v in d.items():\n", " if k == 'params': continue\n", " if is_listy(v): res.update(**{f'{k}__{i}': v_ for i,v_ in enumerate(v)})\n", " else: res[k] = v\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = {'lr': 1e-2, 'mom': 0.9, 'params':[0,1,2]}\n", "test_eq(detuplify_pg(tst), {'lr': 1e-2, 'mom': 0.9})\n", "tst = {'lr': 1e-2, 'betas': (0.9,0.999), 'params':[0,1,2]}\n", "test_eq(detuplify_pg(tst), {'lr': 1e-2, 'betas__0': 0.9, 'betas__1': 0.999})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def set_item_pg(pg, k, v):\n", " if '__' not in k: pg[k] = v\n", " else:\n", " name,idx = k.split('__')\n", " pg[name] = tuple(v if i==int(idx) else pg[name][i] for i in range_of(pg[name]))\n", " return pg" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = {'lr': 1e-2, 'mom': 0.9, 'params':[0,1,2]}\n", "test_eq(set_item_pg(tst, 'lr', 1e-3), {'lr': 1e-3, 'mom': 0.9, 'params':[0,1,2]})\n", "tst = {'lr': 1e-2, 'betas': (0.9,0.999), 'params':[0,1,2]}\n", "test_eq(set_item_pg(tst, 'betas__0', 0.95), {'lr': 1e-2, 'betas': (0.95,0.999), 'params':[0,1,2]})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "pytorch_hp_map = {'momentum': 'mom', 'weight_decay': 'wd', 'alpha': 'sqr_mom', 'betas__0': 'mom', 'betas__1': 'sqr_mom'}" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class OptimWrapper(_BaseOptimizer, GetAttr):\n", " _xtra=['zero_grad', 'step', 'state_dict', 'load_state_dict']\n", " _default='opt'\n", " def __init__(self, opt, hp_map=None):\n", " self.opt = opt\n", " if hp_map is None: hp_map = pytorch_hp_map\n", " self.fwd_map = {k: hp_map[k] if k in hp_map else k for k in detuplify_pg(opt.param_groups[0]).keys()}\n", " self.bwd_map = {v:k for k,v in self.fwd_map.items()}\n", " self.state = defaultdict(dict, {})\n", " self.frozen_idx = 0\n", "\n", " @property\n", " def hypers(self):\n", " return [{self.fwd_map[k]:v for k,v in detuplify_pg(pg).items() if k != 'params'} for pg in self.opt.param_groups]\n", "\n", " def _set_hyper(self, k, v):\n", " for pg,v_ in zip(self.opt.param_groups,v): pg = set_item_pg(pg, self.bwd_map[k], v_)\n", "\n", " def clear_state(self): self.opt.state = defaultdict(dict, {})\n", "\n", " @property\n", " def param_lists(self): return [pg['params'] for pg in self.opt.param_groups]\n", " @param_lists.setter\n", " def param_lists(self, v):\n", " for pg,v_ in zip(self.opt.param_groups,v): pg['params'] = v_" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sgd = SGD([tensor([1,2,3])], lr=1e-3, mom=0.9, wd=1e-2)\n", "tst_sgd = OptimWrapper(torch.optim.SGD([tensor([1,2,3])], lr=1e-3, momentum=0.9, weight_decay=1e-2))\n", "#Access to param_groups\n", "test_eq(tst_sgd.param_lists, sgd.param_lists)\n", "#Set param_groups\n", "tst_sgd.param_lists = [[tensor([4,5,6])]]\n", "test_eq(tst_sgd.opt.param_groups[0]['params'], [tensor(4,5,6)])\n", "#Access to hypers\n", "test_eq(tst_sgd.hypers, [{**sgd.hypers[0], 'dampening': 0., 'nesterov': False}])\n", "#Set hypers\n", "tst_sgd.set_hyper('mom', 0.95)\n", "test_eq(tst_sgd.opt.param_groups[0]['momentum'], 0.95)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_sgd = OptimWrapper(torch.optim.SGD([{'params': [tensor([1,2,3])], 'lr': 1e-3}, \n", " {'params': [tensor([4,5,6])], 'lr': 1e-2}], momentum=0.9, weight_decay=1e-2))\n", "sgd = SGD([[tensor([1,2,3])], [tensor([4,5,6])]], lr=[1e-3, 1e-2], mom=0.9, wd=1e-2)\n", "#Access to param_groups\n", "test_eq(tst_sgd.param_lists, sgd.param_lists)\n", "#Set param_groups\n", "tst_sgd.param_lists = [[tensor([4,5,6])], [tensor([1,2,3])]]\n", "test_eq(tst_sgd.opt.param_groups[0]['params'], [tensor(4,5,6)])\n", "test_eq(tst_sgd.opt.param_groups[1]['params'], [tensor(1,2,3)])\n", "#Access to hypers\n", "test_eq(tst_sgd.hypers, [{**sgd.hypers[i], 'dampening': 0., 'nesterov': False} for i in range(2)])\n", "#Set hypers\n", "tst_sgd.set_hyper('mom', 0.95)\n", "test_eq([pg['momentum'] for pg in tst_sgd.opt.param_groups], [0.95,0.95])\n", "tst_sgd.set_hyper('lr', [1e-4,1e-3])\n", "test_eq([pg['lr'] for pg in tst_sgd.opt.param_groups], [1e-4,1e-3])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#check it works with tuply hp names like in Adam\n", "tst_adam = OptimWrapper(torch.optim.Adam([tensor([1,2,3])], lr=1e-2, betas=(0.9, 0.99)))\n", "test_eq(tst_adam.hypers, [{'lr': 0.01, 'mom': 0.9, 'sqr_mom': 0.99, 'eps': 1e-08, 'wd': 0, 'amsgrad': False}])\n", "tst_adam.set_hyper('mom', 0.95)\n", "test_eq(tst_adam.opt.param_groups[0]['betas'], (0.95, 0.99))\n", "tst_adam.set_hyper('sqr_mom', 0.9)\n", "test_eq(tst_adam.opt.param_groups[0]['betas'], (0.95, 0.9))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def _mock_train(m, x, y, opt):\n", " m.train()\n", " for i in range(0, 100, 25):\n", " z = m(x[i:i+25])\n", " loss = F.mse_loss(z, y[i:i+25])\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m = nn.Linear(4,5)\n", "x = torch.randn(100, 3, 4)\n", "y = torch.randn(100, 3, 5)\n", "try:\n", " torch.save(m.state_dict(), 'tmp.pth')\n", " wgt,bias = m.weight.data.clone(),m.bias.data.clone()\n", "\n", " m.load_state_dict(torch.load('tmp.pth'))\n", " opt1 = OptimWrapper(torch.optim.AdamW(m.parameters(), betas=(0.9, 0.99), eps=1e-5, weight_decay=1e-2))\n", " _mock_train(m, x.clone(), y.clone(), opt1)\n", " wgt1,bias1 = m.weight.data.clone(),m.bias.data.clone()\n", "\n", " m.load_state_dict(torch.load('tmp.pth'))\n", " opt2 = Adam(m.parameters(), 1e-3, wd=1e-2)\n", " _mock_train(m, x.clone(), y.clone(), opt2)\n", " wgt2,bias2 = m.weight.data.clone(),m.bias.data.clone()\n", " \n", " test_close(wgt1,wgt2,eps=1e-3)\n", " test_close(bias1,bias2,eps=1e-3)\n", "finally: os.remove('tmp.pth')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m = nn.Linear(4,5)\n", "x = torch.randn(100, 3, 4)\n", "y = torch.randn(100, 3, 5)\n", "try:\n", " torch.save(m.state_dict(), 'tmp.pth')\n", " wgt,bias = m.weight.data.clone(),m.bias.data.clone()\n", "\n", " m.load_state_dict(torch.load('tmp.pth'))\n", " opt1 = OptimWrapper(torch.optim.Adam(m.parameters(), betas=(0.9, 0.99), eps=1e-5, weight_decay=1e-2))\n", " _mock_train(m, x.clone(), y.clone(), opt1)\n", " wgt1,bias1 = m.weight.data.clone(),m.bias.data.clone()\n", "\n", " m.load_state_dict(torch.load('tmp.pth'))\n", " opt2 = Adam(m.parameters(), 1e-3, wd=1e-2, decouple_wd=False)\n", " _mock_train(m, x.clone(), y.clone(), opt2)\n", " wgt2,bias2 = m.weight.data.clone(),m.bias.data.clone()\n", " \n", " test_close(wgt1,wgt2,eps=1e-3)\n", " test_close(bias1,bias2,eps=1e-3)\n", "finally: os.remove('tmp.pth')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_torch_core.ipynb.\n", "Converted 01_layers.ipynb.\n", "Converted 02_data.load.ipynb.\n", "Converted 03_data.core.ipynb.\n", "Converted 04_data.external.ipynb.\n", "Converted 05_data.transforms.ipynb.\n", "Converted 06_data.block.ipynb.\n", "Converted 07_vision.core.ipynb.\n", "Converted 08_vision.data.ipynb.\n", "Converted 09_vision.augment.ipynb.\n", "Converted 09b_vision.utils.ipynb.\n", "Converted 09c_vision.widgets.ipynb.\n", "Converted 10_tutorial.pets.ipynb.\n", "Converted 11_vision.models.xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_callback.core.ipynb.\n", "Converted 13a_learner.ipynb.\n", "Converted 13b_metrics.ipynb.\n", "Converted 14_callback.schedule.ipynb.\n", "Converted 14a_callback.data.ipynb.\n", "Converted 15_callback.hook.ipynb.\n", "Converted 15a_vision.models.unet.ipynb.\n", "Converted 16_callback.progress.ipynb.\n", "Converted 17_callback.tracker.ipynb.\n", "Converted 18_callback.fp16.ipynb.\n", "Converted 18a_callback.training.ipynb.\n", "Converted 18b_callback.preds.ipynb.\n", "Converted 19_callback.mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision.learner.ipynb.\n", "Converted 22_tutorial.imagenette.ipynb.\n", "Converted 23_tutorial.vision.ipynb.\n", "Converted 24_tutorial.siamese.ipynb.\n", "Converted 24_vision.gan.ipynb.\n", "Converted 30_text.core.ipynb.\n", "Converted 31_text.data.ipynb.\n", "Converted 32_text.models.awdlstm.ipynb.\n", "Converted 33_text.models.core.ipynb.\n", "Converted 34_callback.rnn.ipynb.\n", "Converted 35_tutorial.wikitext.ipynb.\n", "Converted 36_text.models.qrnn.ipynb.\n", "Converted 37_text.learner.ipynb.\n", "Converted 38_tutorial.text.ipynb.\n", "Converted 39_tutorial.transformers.ipynb.\n", "Converted 40_tabular.core.ipynb.\n", "Converted 41_tabular.data.ipynb.\n", "Converted 42_tabular.model.ipynb.\n", "Converted 43_tabular.learner.ipynb.\n", "Converted 44_tutorial.tabular.ipynb.\n", "Converted 45_collab.ipynb.\n", "Converted 46_tutorial.collab.ipynb.\n", "Converted 50_tutorial.datablock.ipynb.\n", "Converted 60_medical.imaging.ipynb.\n", "Converted 61_tutorial.medical_imaging.ipynb.\n", "Converted 65_medical.text.ipynb.\n", "Converted 70_callback.wandb.ipynb.\n", "Converted 71_callback.tensorboard.ipynb.\n", "Converted 72_callback.neptune.ipynb.\n", "Converted 73_callback.captum.ipynb.\n", "Converted 74_callback.cutmix.ipynb.\n", "Converted 97_test_utils.ipynb.\n", "Converted 99_pytorch_doc.ipynb.\n", "Converted dev-setup.ipynb.\n", "Converted index.ipynb.\n", "Converted quick_start.ipynb.\n", "Converted tutorial.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import *\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }