{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp learner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.data.all import *\n", "from fastai2.optimizer import *\n", "from fastai2.callback.core import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_all_ = ['CancelFitException', 'CancelEpochException', 'CancelTrainException', 'CancelValidException', 'CancelBatchException']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_loop = ['Start Fit', 'before_fit', 'Start Epoch Loop', 'before_epoch', 'Start Train', 'before_train',\n", " 'Start Batch Loop', 'before_batch', 'after_pred', 'after_loss', 'before_backward', 'after_backward',\n", " 'after_step', 'after_cancel_batch', 'after_batch','End Batch Loop','End Train',\n", " 'after_cancel_train', 'after_train', 'Start Valid', 'before_validate','Start Batch Loop',\n", " '**CBs same as train batch**', 'End Batch Loop', 'End Valid', 'after_cancel_validate',\n", " 'after_validate', 'End Epoch Loop', 'after_cancel_epoch', 'after_epoch', 'End Fit',\n", " 'after_cancel_fit', 'after_fit']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Learner\n", "\n", "> Basic class for handling the training loop" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You probably want to jump directly to the definition of `Learner`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Utils function" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#For tests\n", "from torch.utils.data import TensorDataset\n", "\n", "def synth_dbunch(a=2, b=3, bs=16, n_train=10, n_valid=2, cuda=False):\n", " \"A simple dataset where `x` is random and `y = a*x + b` plus some noise.\"\n", " def get_data(n):\n", " x = torch.randn(int(bs*n))\n", " return TensorDataset(x, a*x + b + 0.1*torch.randn(int(bs*n)))\n", " train_ds = get_data(n_train)\n", " valid_ds = get_data(n_valid)\n", " device = default_device() if cuda else None\n", " train_dl = TfmdDL(train_ds, bs=bs, shuffle=True, num_workers=0)\n", " valid_dl = TfmdDL(valid_ds, bs=bs, num_workers=0)\n", " return DataLoaders(train_dl, valid_dl, device=device)\n", "\n", "class RegModel(Module):\n", " \"A r\"\n", " def __init__(self): self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))\n", " def forward(self, x): return x*self.a + self.b" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "defaults.lr = 1e-3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def replacing_yield(o, attr, val):\n", " \"Context manager to temporarily replace an attribute\"\n", " old = getattr(o,attr)\n", " try: yield setattr(o,attr,val)\n", " finally: setattr(o,attr,old)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class _A:\n", " def __init__(self, a): self.a = a\n", " @contextmanager\n", " def a_changed(self, v): return replacing_yield(self, 'a', v)\n", "\n", "a = _A(42)\n", "with a.a_changed(32):\n", " test_eq(a.a, 32)\n", "test_eq(a.a, 42)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def mk_metric(m):\n", " \"Convert `m` to an `AvgMetric`, unless it's already a `Metric`\"\n", " return m if isinstance(m, Metric) else AvgMetric(m)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the class `Metric` below for more information." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def save_model(file, model, opt, with_opt=True, pickle_protocol=2):\n", " \"Save `model` to `file` along with `opt` (if available, and if `with_opt`)\"\n", " if rank_distrib(): return # don't save if child proc\n", " if opt is None: with_opt=False\n", " state = get_model(model).state_dict()\n", " if with_opt: state = {'model': state, 'opt':opt.state_dict()}\n", " torch.save(state, file, pickle_protocol=pickle_protocol)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`file` can be a `Path` object, a string or an opened file object. `pickle_protocol` is passed along to `torch.save`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def load_model(file, model, opt, with_opt=None, device=None, strict=True):\n", " \"Load `model` from `file` along with `opt` (if available, and if `with_opt`)\"\n", " distrib_barrier()\n", " if isinstance(device, int): device = torch.device('cuda', device)\n", " elif device is None: device = 'cpu'\n", " state = torch.load(file, map_location=device)\n", " hasopt = set(state)=={'model', 'opt'}\n", " model_state = state['model'] if hasopt else state\n", " get_model(model).load_state_dict(model_state, strict=strict)\n", " if hasopt and ifnone(with_opt,True):\n", " try: opt.load_state_dict(state['opt'])\n", " except:\n", " if with_opt: warn(\"Could not load the optimizer state.\")\n", " elif with_opt: warn(\"Saved filed doesn't contain an optimizer state.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`file` can be a `Path` object, a string or an opened file object. If a `device` is passed, the model is loaded on it, otherwise it's loaded on the CPU. \n", "\n", "If `strict` is `True`, the file must exactly contain weights for every parameter key in `model`, if `strict` is `False`, only the keys that are in the saved model are loaded in `model`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _try_concat(o):\n", " try: return torch.cat(o)\n", " except: return sum([L(o_[i,:] for i in range_of(o_)) for o_ in o], L())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_before_epoch = [event.before_fit, event.before_epoch]\n", "_after_epoch = [event.after_epoch, event.after_fit]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class _ConstantFunc():\n", " \"Returns a function that returns `o`\"\n", " def __init__(self, o): self.o = o\n", " def __call__(self, *args, **kwargs): return self.o" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learner -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "@log_args(but='dls,model,opt_func,cbs')\n", "class Learner():\n", " def __init__(self, dls, model, loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params, cbs=None,\n", " metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True,\n", " moms=(0.95,0.85,0.95)):\n", " store_attr(self, \"dls,model,opt_func,lr,splitter,model_dir,wd,wd_bn_bias,train_bn,metrics,moms\")\n", " self.training,self.create_mbar,self.logger,self.opt,self.cbs = False,True,print,None,L()\n", " if loss_func is None:\n", " loss_func = getattr(dls.train_ds, 'loss_func', None)\n", " assert loss_func is not None, \"Could not infer loss function from the data, please pass a loss function.\"\n", " self.loss_func = loss_func\n", " self.path = Path(path) if path is not None else getattr(dls, 'path', Path('.'))\n", " self.add_cbs([(cb() if isinstance(cb, type) else cb) for cb in L(defaults.callbacks)+L(cbs)])\n", " self.epoch,self.n_epoch,self.loss = 0,1,tensor(0.)\n", "\n", " @property\n", " def metrics(self): return self._metrics\n", " @metrics.setter\n", " def metrics(self,v): self._metrics = L(v).map(mk_metric)\n", "\n", " def _grab_cbs(self, cb_cls): return L(cb for cb in self.cbs if isinstance(cb, cb_cls))\n", " def add_cbs(self, cbs): L(cbs).map(self.add_cb)\n", " def remove_cbs(self, cbs): L(cbs).map(self.remove_cb)\n", " def add_cb(self, cb):\n", " old = getattr(self, cb.name, None)\n", " assert not old or isinstance(old, type(cb)), f\"self.{cb.name} already registered\"\n", " cb.learn = self\n", " setattr(self, cb.name, cb)\n", " self.cbs.append(cb)\n", " return self\n", "\n", " def remove_cb(self, cb):\n", " if isinstance(cb, type): self.remove_cbs(self._grab_cbs(cb))\n", " else:\n", " cb.learn = None\n", " if hasattr(self, cb.name): delattr(self, cb.name)\n", " if cb in self.cbs: self.cbs.remove(cb)\n", "\n", " @contextmanager\n", " def added_cbs(self, cbs):\n", " self.add_cbs(cbs)\n", " try: yield\n", " finally: self.remove_cbs(cbs)\n", "\n", " @contextmanager\n", " def removed_cbs(self, cbs):\n", " self.remove_cbs(cbs)\n", " try: yield self\n", " finally: self.add_cbs(cbs)\n", "\n", " def ordered_cbs(self, event): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, event)]\n", "\n", " def __call__(self, event_name): L(event_name).map(self._call_one)\n", "\n", " def _call_one(self, event_name):\n", " assert hasattr(event, event_name), event_name\n", " [cb(event_name) for cb in sort_by_run(self.cbs)]\n", "\n", " def _bn_bias_state(self, with_bias): return norm_bias_params(self.model, with_bias).map(self.opt.state)\n", " def create_opt(self):\n", " self.opt = self.opt_func(self.splitter(self.model), lr=self.lr)\n", " if not self.wd_bn_bias:\n", " for p in self._bn_bias_state(True ): p['do_wd'] = False\n", " if self.train_bn:\n", " for p in self._bn_bias_state(False): p['force_train'] = True\n", "\n", " def _split(self, b):\n", " i = getattr(self.dls, 'n_inp', 1 if len(b)==1 else len(b)-1)\n", " self.xb,self.yb = b[:i],b[i:]\n", "\n", " def _step(self): self.opt.step()\n", " def _backward(self): self.loss.backward()\n", "\n", " def _with_events(self, f, event_type, ex, final=noop):\n", " try: self(f'before_{event_type}') ;f()\n", " except ex: self(f'after_cancel_{event_type}')\n", " finally: self(f'after_{event_type}') ;final()\n", "\n", " def all_batches(self):\n", " self.n_iter = len(self.dl)\n", " for o in enumerate(self.dl): self.one_batch(*o)\n", "\n", " def _do_one_batch(self):\n", " self.pred = self.model(*self.xb); self('after_pred')\n", " if len(self.yb) == 0: return\n", " self.loss = self.loss_func(self.pred, *self.yb); self('after_loss')\n", " if not self.training: return\n", " self('before_backward')\n", " self._backward(); self('after_backward')\n", " self._step(); self('after_step')\n", " self.opt.zero_grad()\n", "\n", " def one_batch(self, i, b):\n", " self.iter = i\n", " self._split(b)\n", " self._with_events(self._do_one_batch, 'batch', CancelBatchException)\n", "\n", " def _do_epoch_train(self):\n", " self.dl = self.dls.train\n", " self._with_events(self.all_batches, 'train', CancelTrainException)\n", "\n", " def _do_epoch_validate(self, ds_idx=1, dl=None):\n", " if dl is None: dl = self.dls[ds_idx]\n", " self.dl = dl;\n", " with torch.no_grad(): self._with_events(self.all_batches, 'validate', CancelValidException)\n", "\n", " def _do_epoch(self):\n", " self._do_epoch_train()\n", " self._do_epoch_validate()\n", "\n", " def _do_fit(self):\n", " for epoch in range(self.n_epoch):\n", " self.epoch=epoch\n", " self._with_events(self._do_epoch, 'epoch', CancelEpochException)\n", "\n", " @log_args(but='cbs')\n", " def fit(self, n_epoch, lr=None, wd=None, cbs=None, reset_opt=False):\n", " with self.added_cbs(cbs):\n", " if reset_opt or not self.opt: self.create_opt()\n", " if wd is None: wd = self.wd\n", " if wd is not None: self.opt.set_hypers(wd=wd)\n", " self.opt.set_hypers(lr=self.lr if lr is None else lr)\n", " self.n_epoch,self.loss = n_epoch,tensor(0.)\n", " self._with_events(self._do_fit, 'fit', CancelFitException, self._end_cleanup)\n", "\n", " def _end_cleanup(self): self.dl,self.xb,self.yb,self.pred,self.loss = None,(None,),(None,),None,None\n", " def __enter__(self): self(_before_epoch); return self\n", " def __exit__(self, exc_type, exc_value, tb): self(_after_epoch)\n", "\n", " def validation_context(self, cbs=None, inner=False):\n", " cms = [self.no_logging(),self.no_mbar()]\n", " if cbs: cms.append(self.added_cbs(cbs))\n", " if not inner: cms.append(self)\n", " return ContextManagers(cms)\n", "\n", " def validate(self, ds_idx=1, dl=None, cbs=None):\n", " if dl is None: dl = self.dls[ds_idx]\n", " with self.validation_context(cbs=cbs): self._do_epoch_validate(ds_idx, dl)\n", " return getattr(self, 'final_record', None)\n", "\n", " @delegates(GatherPredsCallback.__init__)\n", " def get_preds(self, ds_idx=1, dl=None, with_input=False, with_decoded=False, with_loss=False, act=None,\n", " inner=False, reorder=True, cbs=None, **kwargs):\n", " if dl is None: dl = self.dls[ds_idx].new(shuffled=False, drop_last=False)\n", " if reorder and hasattr(dl, 'get_idxs'):\n", " idxs = dl.get_idxs()\n", " dl = dl.new(get_idxs = _ConstantFunc(idxs))\n", " cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss, **kwargs)\n", " ctx_mgrs = self.validation_context(cbs=L(cbs)+[cb], inner=inner)\n", " if with_loss: ctx_mgrs.append(self.loss_not_reduced())\n", " with ContextManagers(ctx_mgrs):\n", " self._do_epoch_validate(dl=dl)\n", " if act is None: act = getattr(self.loss_func, 'activation', noop)\n", " res = cb.all_tensors()\n", " pred_i = 1 if with_input else 0\n", " if res[pred_i] is not None:\n", " res[pred_i] = act(res[pred_i])\n", " if with_decoded: res.insert(pred_i+2, getattr(self.loss_func, 'decodes', noop)(res[pred_i]))\n", " if reorder and hasattr(dl, 'get_idxs'): res = nested_reorder(res, tensor(idxs).argsort())\n", " return tuple(res)\n", " self._end_cleanup()\n", "\n", " def predict(self, item, rm_type_tfms=None, with_input=False):\n", " dl = self.dls.test_dl([item], rm_type_tfms=rm_type_tfms, num_workers=0)\n", " inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)\n", " i = getattr(self.dls, 'n_inp', -1)\n", " inp = (inp,) if i==1 else tuplify(inp)\n", " dec = self.dls.decode_batch(inp + tuplify(dec_preds))[0]\n", " dec_inp,dec_targ = map(detuplify, [dec[:i],dec[i:]])\n", " res = dec_targ,dec_preds[0],preds[0]\n", " if with_input: res = (dec_inp,) + res\n", " return res\n", "\n", " def show_results(self, ds_idx=1, dl=None, max_n=9, shuffle=True, **kwargs):\n", " if dl is None: dl = self.dls[ds_idx].new(shuffle=shuffle)\n", " b = dl.one_batch()\n", " _,_,preds = self.get_preds(dl=[b], with_decoded=True)\n", " self.dls.show_results(b, preds, max_n=max_n, **kwargs)\n", "\n", " def show_training_loop(self):\n", " indent = 0\n", " for s in _loop:\n", " if s.startswith('Start'): print(f'{\" \"*indent}{s}'); indent += 2\n", " elif s.startswith('End'): indent -= 2; print(f'{\" \"*indent}{s}')\n", " else: print(f'{\" \"*indent} - {s:15}:', self.ordered_cbs(s))\n", "\n", " @contextmanager\n", " def no_logging(self): return replacing_yield(self, 'logger', noop)\n", " @contextmanager\n", " def no_mbar(self): return replacing_yield(self, 'create_mbar', False)\n", "\n", " @contextmanager\n", " def loss_not_reduced(self):\n", " if hasattr(self.loss_func, 'reduction'): return replacing_yield(self.loss_func, 'reduction', 'none')\n", " else: return replacing_yield(self, 'loss_func', partial(self.loss_func, reduction='none'))\n", "\n", " @delegates(save_model)\n", " def save(self, file, **kwargs):\n", " file = join_path_file(file, self.path/self.model_dir, ext='.pth')\n", " save_model(file, self.model, getattr(self,'opt',None), **kwargs)\n", "\n", " @delegates(load_model)\n", " def load(self, file, with_opt=None, device=None, **kwargs):\n", " if device is None and hasattr(self.dls, 'device'): device = self.dls.device\n", " if self.opt is None: self.create_opt()\n", " file = join_path_file(file, self.path/self.model_dir, ext='.pth')\n", " load_model(file, self.model, self.opt, device=device, **kwargs)\n", " return self\n", "\n", "Learner.x,Learner.y = add_props(lambda i,x: detuplify((x.xb,x.yb)[i]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "add_docs(Learner, \"Group together a `model`, some `dls` and a `loss_func` to handle training\",\n", " add_cbs=\"Add `cbs` to the list of `Callback` and register `self` as their learner\",\n", " add_cb=\"Add `cb` to the list of `Callback` and register `self` as their learner\",\n", " remove_cbs=\"Remove `cbs` from the list of `Callback` and deregister `self` as their learner\",\n", " remove_cb=\"Add `cb` from the list of `Callback` and deregister `self` as their learner\",\n", " added_cbs=\"Context manage that temporarily adds `cbs`\",\n", " removed_cbs=\"Context manage that temporarily removes `cbs`\",\n", " ordered_cbs=\"List of `Callback`s, in order, for an `event` in the training loop\",\n", " create_opt=\"Create an optimizer with default hyper-parameters\",\n", " one_batch=\"Train or evaluate `self.model` on batch `(xb,yb)`\",\n", " all_batches=\"Train or evaluate `self.model` on all the batches of `self.dl`\",\n", " fit=\"Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`.\",\n", " validate=\"Validate on `dl` with potential new `cbs`.\",\n", " get_preds=\"Get the predictions and targets on the `ds_idx`-th dbunchset or `dl`, optionally `with_input` and `with_loss`\",\n", " predict=\"Prediction on `item`, fully decoded, loss function decoded and probabilities\",\n", " validation_context=\"A `ContextManagers` suitable for validation, with optional `cbs`\",\n", " show_results=\"Show some predictions on `ds_idx`-th dataset or `dl`\",\n", " show_training_loop=\"Show each step in the training loop\",\n", " no_logging=\"Context manager to temporarily remove `logger`\",\n", " no_mbar=\"Context manager to temporarily prevent the master progress bar from being created\",\n", " loss_not_reduced=\"A context manager to evaluate `loss_func` with reduction set to none.\",\n", " save=\"Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`\",\n", " load=\"Load model and optimizer state (if `with_opt`) from `self.path/self.model_dir/file` using `device`\",\n", " __call__=\"Call `event_name` for all `Callback`s in `self.cbs`\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "
class
Learner
[source]Learner
(**`dls`**, **`model`**, **`loss_func`**=*`None`*, **`opt_func`**=*`'Adam'`*, **`lr`**=*`0.001`*, **`splitter`**=*`'trainable_params'`*, **`cbs`**=*`None`*, **`metrics`**=*`None`*, **`path`**=*`None`*, **`model_dir`**=*`'models'`*, **`wd`**=*`None`*, **`wd_bn_bias`**=*`False`*, **`train_bn`**=*`True`*, **`moms`**=*`(0.95, 0.85, 0.95)`*)\n",
"\n",
"Group together a `model`, some `dls` and a `loss_func` to handle training"
],
"text/plain": [
"Learner.fit
[source]Learner.fit
(**`n_epoch`**, **`lr`**=*`None`*, **`wd`**=*`None`*, **`cbs`**=*`None`*, **`reset_opt`**=*`False`*)\n",
"\n",
"Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`."
],
"text/plain": [
"Learner.one_batch
[source]Learner.one_batch
(**`i`**, **`b`**)\n",
"\n",
"Train or evaluate `self.model` on batch `(xb,yb)`"
],
"text/plain": [
"Learner.all_batches
[source]Learner.all_batches
()\n",
"\n",
"Train or evaluate `self.model` on all the batches of `self.dl`"
],
"text/plain": [
"Learner.create_opt
[source]Learner.create_opt
()\n",
"\n",
"Create an optimizer with default hyper-parameters"
],
"text/plain": [
"Learner.save
[source]Learner.save
(**`file`**, **`with_opt`**=*`True`*, **`pickle_protocol`**=*`2`*)\n",
"\n",
"Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`"
],
"text/plain": [
"Learner.load
[source]Learner.load
(**`file`**, **`with_opt`**=*`None`*, **`device`**=*`None`*, **`strict`**=*`True`*)\n",
"\n",
"Load model and optimizer state (if `with_opt`) from `self.path/self.model_dir/file` using `device`"
],
"text/plain": [
"Learner.__call__
[source]Learner.__call__
(**`event_name`**)\n",
"\n",
"Call `event_name` for all [`Callback`](/callback.core#Callback)s in `self.cbs`"
],
"text/plain": [
"Learner.add_cb
[source]Learner.add_cb
(**`cb`**)\n",
"\n",
"Add `cb` to the list of [`Callback`](/callback.core#Callback) and register `self` as their learner"
],
"text/plain": [
"Learner.add_cbs
[source]Learner.add_cbs
(**`cbs`**)\n",
"\n",
"Add `cbs` to the list of [`Callback`](/callback.core#Callback) and register `self` as their learner"
],
"text/plain": [
"Learner.added_cbs
[source]Learner.added_cbs
(**`cbs`**)\n",
"\n",
"Context manage that temporarily adds `cbs`"
],
"text/plain": [
"Learner.ordered_cbs
[source]Learner.ordered_cbs
(**`event`**)\n",
"\n",
"List of [`Callback`](/callback.core#Callback)s, in order, for an `event` in the training loop"
],
"text/plain": [
"Learner.remove_cb
[source]Learner.remove_cb
(**`cb`**)\n",
"\n",
"Add `cb` from the list of [`Callback`](/callback.core#Callback) and deregister `self` as their learner"
],
"text/plain": [
"Learner.remove_cbs
[source]Learner.remove_cbs
(**`cbs`**)\n",
"\n",
"Remove `cbs` from the list of [`Callback`](/callback.core#Callback) and deregister `self` as their learner"
],
"text/plain": [
"Learner.removed_cbs
[source]Learner.removed_cbs
(**`cbs`**)\n",
"\n",
"Context manage that temporarily removes `cbs`"
],
"text/plain": [
"Learner.show_training_loop
[source]Learner.show_training_loop
()\n",
"\n",
"Show each step in the training loop"
],
"text/plain": [
"class
Metric
[source]\n",
"\n",
"> Metric
()\n",
"\n",
"Blueprint for defining a metric"
],
"text/plain": [
"Metric
has state depending on tensors, don't forget to store it on the CPU to avoid any potential memory leaks."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/markdown": [
"Metric.reset
[source]Metric.reset
()\n",
"\n",
"Reset inner state to prepare for new computation"
],
"text/plain": [
"Metric.accumulate
[source]Metric.accumulate
(**`learn`**)\n",
"\n",
"Use `learn` to update the state with new results"
],
"text/plain": [
"Metric.value
[source]Metric.name
[source]class
AvgMetric
[source]AvgMetric
(**`func`**) :: [`Metric`](/learner#Metric)\n",
"\n",
"Average the values of `func` taking into account potential different batch sizes"
],
"text/plain": [
"class
AvgLoss
[source]AvgLoss
() :: [`Metric`](/learner#Metric)\n",
"\n",
"Average the losses taking into account potential different batch sizes"
],
"text/plain": [
"class
AvgSmoothLoss
[source]AvgSmoothLoss
(**`beta`**=*`0.98`*) :: [`Metric`](/learner#Metric)\n",
"\n",
"Smooth average of the losses (exponentially weighted with `beta`)"
],
"text/plain": [
"class
ValueMetric
[source]ValueMetric
(**`func`**, **`metric_name`**=*`None`*) :: [`Metric`](/learner#Metric)\n",
"\n",
"Use to include a pre-calculated metric value (for insance calculated in a [`Callback`](/callback.core#Callback)) and returned by `func`"
],
"text/plain": [
"Recorder.before_fit
[source]\n",
"\n",
"> Recorder.before_fit
()\n",
"\n",
"Prepare state for training"
],
"text/plain": [
"Recorder.before_epoch
[source]Recorder.before_epoch
()\n",
"\n",
"Set timer if `self.add_time=True`"
],
"text/plain": [
"Recorder.before_validate
[source]Recorder.before_validate
()\n",
"\n",
"Reset loss and metrics state"
],
"text/plain": [
"Recorder.after_batch
[source]Recorder.after_batch
()\n",
"\n",
"Update all metrics and records lr and smooth loss in training"
],
"text/plain": [
"Recorder.after_epoch
[source]Recorder.after_epoch
()\n",
"\n",
"Store and log the loss/metric values"
],
"text/plain": [
"Recorder.plot_loss
[source]Recorder.plot_loss
(**`skip_start`**=*`5`*, **`with_valid`**=*`True`*)\n",
"\n",
"Plot the losses from `skip_start` and onward"
],
"text/plain": [
"