{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"#skip\n",
"! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# default_exp optimizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"from fastai.torch_basics import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from nbdev.showdoc import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Optimizers\n",
"\n",
"> Define the general fastai optimizer and the variants"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## `_BaseOptimizer` -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class _BaseOptimizer():\n",
" \"Common functionality between `Optimizer` and `OptimWrapper`\"\n",
" def all_params(self, n=slice(None), with_grad=False):\n",
" res = L((p,pg,self.state[p],hyper) for pg,hyper in zip(self.param_lists[n],self.hypers[n]) for p in pg)\n",
" return L(o for o in res if hasattr(o[0], 'grad') and o[0].grad is not None) if with_grad else res\n",
"\n",
" def _set_require_grad(self, rg, p,pg,state,h): p.requires_grad_(rg or state.get('force_train', False))\n",
" def freeze_to(self, n):\n",
" self.frozen_idx = n if n >= 0 else len(self.param_lists) + n\n",
" if self.frozen_idx >= len(self.param_lists):\n",
" warn(f\"Freezing {self.frozen_idx} groups; model has {len(self.param_lists)}; whole model is frozen.\")\n",
" for o in self.all_params(slice(n, None)): self._set_require_grad(True, *o)\n",
" for o in self.all_params(slice(None, n)): self._set_require_grad(False, *o)\n",
"\n",
" def freeze(self):\n",
" assert(len(self.param_lists)>1)\n",
" self.freeze_to(-1)\n",
"\n",
" def set_freeze(self, n, rg, ignore_force_train=False):\n",
" for p in self.param_lists[n]: p.requires_grad_(rg or (state.get('force_train', False) and not ignore_force_train))\n",
"\n",
" def unfreeze(self): self.freeze_to(0)\n",
" def set_hypers(self, **kwargs): L(kwargs.items()).starmap(self.set_hyper)\n",
" def _set_hyper(self, k, v):\n",
" for v_,h in zip(v, self.hypers): h[k] = v_\n",
"\n",
" def set_hyper(self, k, v):\n",
" if isinstance(v, slice):\n",
" if v.start: v = even_mults(v.start, v.stop, len(self.param_lists))\n",
" else: v = [v.stop/10]*(len(self.param_lists)-1) + [v.stop]\n",
" v = L(v, use_list=None)\n",
" if len(v)==1: v = v*len(self.param_lists)\n",
" assert len(v) == len(self.hypers), f\"Trying to set {len(v)} values for {k} but there are {len(self.param_lists)} parameter groups.\"\n",
" self._set_hyper(k, v)\n",
"\n",
" @property\n",
" def param_groups(self): return [{**{'params': pg}, **hp} for pg,hp in zip(self.param_lists, self.hypers)]\n",
" @param_groups.setter\n",
" def param_groups(self, v):\n",
" for pg,v_ in zip(self.param_lists,v): pg = v_['params']\n",
" for hyper,v_ in zip(self.hypers,v):\n",
" for k,t in v_.items():\n",
" if k != 'params': hyper[k] = t"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"add_docs(_BaseOptimizer, \n",
" all_params=\"List of param_groups, parameters, and hypers\",\n",
" freeze_to=\"Freeze parameter groups up to `n`\",\n",
" freeze=\"Freeze up to last parameter group\",\n",
" set_freeze=\"Set `rg` for parameter group `n` only\",\n",
" unfreeze=\"Unfreeze the entire model\",\n",
" set_hypers=\"`set_hyper` for all `kwargs`\",\n",
" set_hyper=\"Set the value(s) in `v` for hyper-parameter `k`\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"def _update(state, new=None):\n",
" if new is None: return state\n",
" if isinstance(new, dict): state.update(new)\n",
" return state"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## `Optimizer` -"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# export\n",
"class Optimizer(_BaseOptimizer):\n",
" \"Base optimizer class for the fastai library, updating `params` with `cbs`\"\n",
" _keep_on_clear = ['force_train', 'do_wd']\n",
" def __init__(self, params, cbs, train_bn=True, **defaults):\n",
" params = L(params)\n",
" self.cbs,self.state,self.train_bn = L(cbs),defaultdict(dict),train_bn\n",
" defaults = merge(*self.cbs.attrgot('defaults'), defaults)\n",
" self.param_lists = L(L(p) for p in params) if isinstance(params[0], (L,list)) else L([params])\n",
" self.hypers = L({} for _ in range_of(self.param_lists))\n",
" self.set_hypers(**defaults)\n",
" self.frozen_idx = 0\n",
"\n",
" def zero_grad(self):\n",
" for p,*_ in self.all_params(with_grad=True):\n",
" p.grad.detach_()\n",
" p.grad.zero_()\n",
"\n",
" def step(self):\n",
" for p,pg,state,hyper in self.all_params(with_grad=True):\n",
" for cb in self.cbs: state = _update(state, cb(p, **{**state, **hyper}))\n",
" self.state[p] = state\n",
"\n",
" def clear_state(self):\n",
" for p,pg,state,hyper in self.all_params():\n",
" self.state[p] = {k: state[k] for k in self._keep_on_clear if k in state}\n",
"\n",
" def state_dict(self):\n",
" state = [self.state[p] for p,*_ in self.all_params()]\n",
" return {'state': state, 'hypers': self.hypers}\n",
"\n",
" def load_state_dict(self, sd):\n",
" assert len(sd[\"hypers\"]) == len(self.param_lists)\n",
" assert len(sd[\"state\"]) == sum([len(pg) for pg in self.param_lists])\n",
" self.hypers = sd['hypers']\n",
" self.state = {p: s for p,s in zip(self.all_params().itemgot(0), sd['state'])}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"add_docs(Optimizer, \n",
" zero_grad=\"Standard PyTorch API: Zero all the grad attributes of the parameters\",\n",
" step=\"Standard PyTorch API: Update the stats and execute the steppers in on all parameters that have a grad\",\n",
" state_dict=\"Return the state of the optimizer in a dictionary\",\n",
" load_state_dict=\"Load the content of `sd`\",\n",
" clear_state=\"Reset the state of the optimizer\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Initializing an Optimizer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`params` will be used to create the `param_groups` of the optimizer. If it's a collection (or a generator) of parameters, it will be a `L` containing one `L` with all the parameters. To define multiple parameter groups `params` should be passed as a collection (or a generator) of `L`s.\n",
"\n",
"> Note: In PyTorch, model.parameters()
returns a generator with all the parameters, that you can directly pass to Optimizer
."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"opt = Optimizer([1,2,3], noop)\n",
"test_eq(opt.param_lists, [[1,2,3]])\n",
"opt = Optimizer(range(3), noop)\n",
"test_eq(opt.param_lists, [[0,1,2]])\n",
"opt = Optimizer([[1,2],[3]], noop)\n",
"test_eq(opt.param_lists, [[1,2],[3]])\n",
"opt = Optimizer(([o,o+1] for o in range(0,4,2)), noop)\n",
"test_eq(opt.param_lists, [[0,1],[2,3]])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`cbs` is a list of functions that will be composed when applying the step. For instance, you can compose a function making the SGD step, with another one applying weight decay. Additionally, each `cb` can have a `defaults` attribute that contains hyper-parameters and their default value. Those are all gathered at initialization, and new values can be passed to override those defaults with the `defaults` kwargs. The steppers will be called by `Optimizer.step` (which is the standard PyTorch name), and gradients can be cleared with `Optimizer.zero_grad` (also a standard PyTorch name).\n",
"\n",
"Once the defaults have all been pulled off, they are copied as many times as there are `param_groups` and stored in `hypers`. To apply different hyper-parameters to different groups (differential learning rates, or no weight decay for certain layers for instance), you will need to adjust those values after the init. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tst_arg(p, lr=0, **kwargs): return p\n",
"tst_arg.defaults = dict(lr=1e-2)\n",
"\n",
"def tst_arg2(p, lr2=0, **kwargs): return p\n",
"tst_arg2.defaults = dict(lr2=1e-3)\n",
"\n",
"def tst_arg3(p, mom=0, **kwargs): return p\n",
"tst_arg3.defaults = dict(mom=0.9)\n",
"\n",
"def tst_arg4(p, **kwargs): return p\n",
"\n",
"opt = Optimizer([1,2,3], [tst_arg,tst_arg2, tst_arg3])\n",
"test_eq(opt.hypers, [{'lr2': 1e-3, 'mom': 0.9, 'lr': 1e-2}])\n",
"opt = Optimizer([1,2,3], tst_arg, lr=0.1)\n",
"test_eq(opt.hypers, [{'lr': 0.1}])\n",
"opt = Optimizer([[1,2],[3]], tst_arg)\n",
"test_eq(opt.hypers, [{'lr': 1e-2}, {'lr': 1e-2}])\n",
"opt = Optimizer([[1,2],[3]], tst_arg, lr=0.1)\n",
"test_eq(opt.hypers, [{'lr': 0.1}, {'lr': 0.1}])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For each hyper-parameter, you can pass a slice or a collection to set them, if there are multiple parameter groups. A slice will be converted to a log-uniform collection from its beginning to its end, or if it only has an end `e`, to a collection of as many values as there are parameter groups that are `...,e/10,e/10,e`.\n",
"\n",
"Setting an hyper-parameter with a collection that has a different number of elements than the optimizer has parameter groups will raise an error."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"opt = Optimizer([[1,2],[3]], tst_arg, lr=[0.1,0.2])\n",
"test_eq(opt.hypers, [{'lr': 0.1}, {'lr': 0.2}])\n",
"opt = Optimizer([[1,2],[3],[4]], tst_arg, lr=slice(1e-2))\n",
"test_eq(opt.hypers, [{'lr': 1e-3}, {'lr': 1e-3}, {'lr': 1e-2}])\n",
"opt = Optimizer([[1,2],[3],[4]], tst_arg, lr=slice(1e-4,1e-2))\n",
"test_eq(opt.hypers, [{'lr': 1e-4}, {'lr': 1e-3}, {'lr': 1e-2}])\n",
"test_eq(opt.param_groups, [{'params': [1,2], 'lr': 1e-4}, {'params': [3], 'lr': 1e-3}, {'params': [4], 'lr': 1e-2}])\n",
"test_fail(lambda: Optimizer([[1,2],[3],[4]], tst_arg, lr=np.array([0.1,0.2])))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Basic steppers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To be able to give examples of optimizer steps, we will need some steppers, like the following:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"def sgd_step(p, lr, **kwargs):\n",
" p.data.add_(p.grad.data, alpha=-lr)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tst_param(val, grad=None):\n",
" \"Create a tensor with `val` and a gradient of `grad` for testing\"\n",
" res = tensor([val]).float()\n",
" res.grad = tensor([val/10 if grad is None else grad]).float()\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"p = tst_param(1., 0.1)\n",
"sgd_step(p, 1.)\n",
"test_eq(p, tensor([0.9]))\n",
"test_eq(p.grad, tensor([0.1]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"def weight_decay(p, lr, wd, do_wd=True, **kwargs):\n",
" \"Weight decay as decaying `p` with `lr*wd`\"\n",
" if do_wd and wd!=0: p.data.mul_(1 - lr*wd)\n",
"\n",
"weight_decay.defaults = dict(wd=0.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"p = tst_param(1., 0.1)\n",
"weight_decay(p, 1., 0.1)\n",
"test_eq(p, tensor([0.9]))\n",
"test_eq(p.grad, tensor([0.1]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"def l2_reg(p, lr, wd, do_wd=True, **kwargs):\n",
" \"L2 regularization as adding `wd*p` to `p.grad`\"\n",
" if do_wd and wd!=0: p.grad.data.add_(p.data, alpha=wd)\n",
"\n",
"l2_reg.defaults = dict(wd=0.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"p = tst_param(1., 0.1)\n",
"l2_reg(p, 1., 0.1)\n",
"test_eq(p, tensor([1.]))\n",
"test_eq(p.grad, tensor([0.2]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Warning: Weight decay and L2 regularization is the same thing for basic SGD, but for more complex optimizers, they are very different."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Making the step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"
Optimizer.step
[source]Optimizer.step
()\n",
"\n",
"Standard PyTorch API: Update the stats and execute the steppers in on all parameters that have a grad"
],
"text/plain": [
"Optimizer.zero_grad
[source]Optimizer.zero_grad
()\n",
"\n",
"Standard PyTorch API: Zero all the grad attributes of the parameters"
],
"text/plain": [
"Optimizer.freeze
[source]Optimizer.freeze
()\n",
"\n",
"Freeze up to last parameter group"
],
"text/plain": [
"Optimizer.freeze_to
[source]Optimizer.freeze_to
(**`n`**)\n",
"\n",
"Freeze parameter groups up to `n`"
],
"text/plain": [
"Optimizer.unfreeze
[source]Optimizer.unfreeze
()\n",
"\n",
"Unfreeze the entire model"
],
"text/plain": [
"Optimizer.state_dict
[source]Optimizer.state_dict
()\n",
"\n",
"Return the state of the optimizer in a dictionary"
],
"text/plain": [
"Optimizer.load_state_dict
[source]Optimizer.load_state_dict
(**`sd`**)\n",
"\n",
"Load the content of `sd`"
],
"text/plain": [
"Optimizer.clear_state
[source]Optimizer.clear_state
()\n",
"\n",
"Reset the state of the optimizer"
],
"text/plain": [
"