{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp learner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.test import *\n", "from local.data.all import *\n", "from local.optimizer import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.notebook.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_all_ = ['CancelFitException', 'CancelEpochException', 'CancelTrainException', 'CancelValidException', 'CancelBatchException']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Learner\n", "\n", "> Basic class for handling the training loop" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll use the following for testing purposes (a basic linear regression problem):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import TensorDataset\n", "\n", "def synth_dbunch(a=2, b=3, bs=16, n_train=10, n_valid=2, cuda=False):\n", " def get_data(n):\n", " x = torch.randn(int(bs*n))\n", " return TensorDataset(x, a*x + b + 0.1*torch.randn(int(bs*n)))\n", " train_ds = get_data(n_train)\n", " valid_ds = get_data(n_valid)\n", " tfms = Cuda() if cuda else None\n", " train_dl = TfmdDL(train_ds, bs=bs, shuffle=True, after_batch=tfms, num_workers=0)\n", " valid_dl = TfmdDL(valid_ds, bs=bs, after_batch=tfms, num_workers=0)\n", " return DataBunch(train_dl, valid_dl)\n", "\n", "class RegModel(Module):\n", " def __init__(self): self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))\n", " def forward(self, x): return x*self.a + self.b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Callback - " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Callback(GetAttr):\n", " \"Basic class handling tweaks of the training loop by changing a `Learner` in various events\"\n", " _default,learn = 'learn',None\n", " def __repr__(self): return type(self).__name__\n", "\n", " def __call__(self, event_name):\n", " \"Call `self.{event_name}` if it's defined\"\n", " getattr(self, event_name, noop)()\n", "\n", " @property\n", " def name(self):\n", " \"Name of the `Callback`, camel-cased and with '*Callback*' removed\"\n", " return class2attr(self, 'Callback')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The training loop is defined in `Learner` a bit below and consists in a minimal set of instructions: looping through the data we:\n", "- compute the output of the model from the input\n", "- calculate a loss between this output and the desired target\n", "- compute the gradients of this loss with respect to all the model parameters\n", "- update the parameters accordingly\n", "- zero all the gradients\n", "\n", "Any tweak of this training loop is defined in a `Callback` to avoid over-complicating the code of the training loop, and to make it easy to mix and match different techniques (since they'll be defined in different callbacks). A callback can implement actions on the following events:\n", "\n", "- `begin_fit`: called before doing anything, ideal for initial setup.\n", "- `begin_epoch`: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.\n", "- `begin_train`: called at the beginning of the training part of an epoch.\n", "- `begin_batch`: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).\n", "- `after_pred`: called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss.\n", "- `after_loss`: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).\n", "- `after_backward`: called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).\n", "- `after_step`: called after the step and before the gradients are zeroed.\n", "- `after_batch`: called at the end of a batch, for any clean-up before the next one.\n", "- `after_train`: called at the end of the training phase of an epoch.\n", "- `begin_validate`: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.\n", "- `after_validate`: called at the end of the validation part of an epoch.\n", "- `after_epoch`: called at the end of an epoch, for any clean-up before the next one.\n", "- `after_fit`: called at the end of training, for final clean-up." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Callback.__call__[source]

\n", "\n", "> Callback.__call__(**`event_name`**)\n", "\n", "Call `self.{event_name}` if it's defined" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Callback.__call__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst_cb = Callback()\n", "tst_cb.call_me = lambda: print(\"maybe\")\n", "test_stdout(lambda: tst_cb(\"call_me\"), \"maybe\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

GetAttr.__getattr__[source]

\n", "\n", "> GetAttr.__getattr__(**`k`**)\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Callback.__getattr__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a shortcut to avoid having to write `self.learn.bla` for any `bla` attribute we seek, and just write `self.bla`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mk_class('TstLearner', 'a')\n", "\n", "class TstCallback(Callback):\n", " def batch_begin(self): print(self.a)\n", "\n", "learn,cb = TstLearner(1),TstCallback()\n", "cb.learn = learn\n", "test_stdout(lambda: cb('batch_begin'), \"1\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that it only works to get the value of the attribute, if you want to change it, you have to manually access it with `self.learn.bla`. In the example below, `self.a += 1` creates an `a` attribute of 2 in the callback instead of setting the `a` of the learner to 2:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TstCallback(Callback):\n", " def batch_begin(self): self.a += 1\n", "\n", "learn,cb = TstLearner(1),TstCallback()\n", "cb.learn = learn\n", "cb('batch_begin')\n", "test_eq(cb.a, 2)\n", "test_eq(cb.learn.a, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A proper version needs to write `self.learn.a = self.a + 1`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TstCallback(Callback):\n", " def batch_begin(self): self.learn.a = self.a + 1\n", "\n", "learn,cb = TstLearner(1),TstCallback()\n", "cb.learn = learn\n", "cb('batch_begin')\n", "test_eq(cb.learn.a, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Callback.name[source]

\n", "\n", "Name of the [`Callback`](/learner.html#Callback), camel-cased and with '*Callback*' removed" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Callback.name, name='Callback.name')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(TstCallback().name, 'tst')\n", "class ComplicatedNameCallback(Callback): pass\n", "test_eq(ComplicatedNameCallback().name, 'complicated_name')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### TrainEvalCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TrainEvalCallback(Callback):\n", " \"`Callback` that tracks the number of iterations done and properly sets training/eval mode\"\n", " def begin_fit(self):\n", " \"Set the iter and epoch counters to 0, put the model and the right device\"\n", " self.learn.train_iter,self.learn.pct_train = 0,0.\n", " self.model.to(self.dbunch.device)\n", "\n", " def after_batch(self):\n", " \"Update the iter counter (in training mode)\"\n", " if not self.training: return\n", " self.learn.pct_train += 1./(self.n_iter*self.n_epoch)\n", " self.learn.train_iter += 1\n", "\n", " def begin_train(self):\n", " \"Set the model in training mode\"\n", " self.learn.pct_train=self.epoch/self.n_epoch\n", " self.model.train()\n", " self.learn.training=True\n", "\n", " def begin_validate(self):\n", " \"Set the model in validation mode\"\n", " self.model.eval()\n", " self.learn.training=False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class TrainEvalCallback[source]

\n", "\n", "> TrainEvalCallback() :: [`Callback`](/learner.html#Callback)\n", "\n", "[`Callback`](/learner.html#Callback) that tracks the number of iterations done and properly sets training/eval mode" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainEvalCallback, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This `Callback` is automatically added in every `Learner` at initialization." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#test of the TrainEvalCallback below in Learner.fit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [ { "data": { "text/markdown": [ "

TrainEvalCallback.begin_fit[source]

\n", "\n", "> TrainEvalCallback.begin_fit()\n", "\n", "Set the iter and epoch counters to 0, put the model and the right device" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainEvalCallback.begin_fit)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [ { "data": { "text/markdown": [ "

TrainEvalCallback.after_batch[source]

\n", "\n", "> TrainEvalCallback.after_batch()\n", "\n", "Update the iter counter (in training mode)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainEvalCallback.after_batch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [ { "data": { "text/markdown": [ "

TrainEvalCallback.begin_train[source]

\n", "\n", "> TrainEvalCallback.begin_train()\n", "\n", "Set the model in training mode" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainEvalCallback.begin_train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [ { "data": { "text/markdown": [ "

TrainEvalCallback.begin_validate[source]

\n", "\n", "> TrainEvalCallback.begin_validate()\n", "\n", "Set the model in validation mode" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainEvalCallback.begin_validate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GatherPredsCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class GatherPredsCallback(Callback):\n", " \"`Callback` that saves the predictions and targets, optionally `with_loss`\"\n", " def __init__(self, with_input=False, with_loss=False): store_attr(self, \"with_input,with_loss\")\n", "\n", " def begin_batch(self):\n", " if self.with_input: self.inputs.append((to_detach(self.xb)))\n", "\n", " def begin_validate(self):\n", " \"Initialize containers\"\n", " self.preds,self.targets = [],[]\n", " if self.with_input: self.inputs=[]\n", " if self.with_loss: self.losses = []\n", "\n", " def after_batch(self):\n", " \"Save predictions, targets and potentially losses\"\n", " self.preds.append(to_detach(self.pred))\n", " self.targets.append(to_detach(self.yb))\n", " if self.with_loss: \n", " bs = find_bs(self.yb)\n", " loss = self.loss if self.loss.numel() == bs else self.loss.view(bs,-1).mean(1)\n", " self.losses.append(to_detach(loss))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class GatherPredsCallback[source]

\n", "\n", "> GatherPredsCallback(**`with_input`**=*`False`*, **`with_loss`**=*`False`*) :: [`Callback`](/learner.html#Callback)\n", "\n", "[`Callback`](/learner.html#Callback) that saves the predictions and targets, optionally `with_loss`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GatherPredsCallback, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

GatherPredsCallback.begin_validate[source]

\n", "\n", "> GatherPredsCallback.begin_validate()\n", "\n", "Initialize containers" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GatherPredsCallback.begin_validate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

GatherPredsCallback.after_batch[source]

\n", "\n", "> GatherPredsCallback.after_batch()\n", "\n", "Save predictions, targets and potentially losses" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GatherPredsCallback.after_batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Callbacks control flow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It happens that we may want to skip some of the steps of the training loop: in gradient accumulation, we don't aways want to do the step/zeroing of the grads for instance. During an LR finder test, we don't want to do the validation phase of an epoch. Or if we're training with a strategy of early stopping, we want to be able to completely interrupt the training loop.\n", "\n", "This is made possible by raising specific exceptions the training loop will look for (and properly catch)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_ex_docs = dict(\n", " CancelFitException=\"Skip the rest of this batch and go to `after_batch`\",\n", " CancelEpochException=\"Skip the rest of the training part of the epoch and go to `after_train`\",\n", " CancelTrainException=\"Skip the rest of the validation part of the epoch and go to `after_validate`\",\n", " CancelValidException=\"Skip the rest of this epoch and go to `after_epoch`\",\n", " CancelBatchException=\"Interrupts training and go to `after_fit`\")\n", "\n", "for c,d in _ex_docs.items(): mk_class(c,sup=Exception,doc=d)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class CancelBatchException[source]

\n", "\n", "> CancelBatchException(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n", "\n", "Interrupts training and go to `after_fit`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(CancelBatchException, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class CancelTrainException[source]

\n", "\n", "> CancelTrainException(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n", "\n", "Skip the rest of the validation part of the epoch and go to `after_validate`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(CancelTrainException, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class CancelValidException[source]

\n", "\n", "> CancelValidException(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n", "\n", "Skip the rest of this epoch and go to `after_epoch`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(CancelValidException, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class CancelEpochException[source]

\n", "\n", "> CancelEpochException(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n", "\n", "Skip the rest of the training part of the epoch and go to `after_train`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(CancelEpochException, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class CancelFitException[source]

\n", "\n", "> CancelFitException(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n", "\n", "Skip the rest of this batch and go to `after_batch`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(CancelFitException, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can detect one of those exceptions occurred and add code that executes right after with the following events:\n", "- `after_cancel_batch`: reached imediately after a `CancelBatchException` before proceeding to `after_batch`\n", "- `after_cancel_train`: reached imediately after a `CancelTrainException` before proceeding to `after_epoch`\n", "- `after_cancel_valid`: reached imediately after a `CancelValidException` before proceeding to `after_epoch`\n", "- `after_cancel_epoch`: reached imediately after a `CancelEpochException` before proceeding to `after_epoch`\n", "- `after_cancel_fit`: reached imediately after a `CancelFitException` before proceeding to `after_fit`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "_events = L.split('begin_fit begin_epoch begin_train begin_batch after_pred after_loss \\\n", " after_backward after_step after_cancel_batch after_batch after_cancel_train \\\n", " after_train begin_validate after_cancel_validate after_validate after_cancel_epoch \\\n", " after_epoch after_cancel_fit after_fit')\n", "\n", "mk_class('event', **_events.map_dict(),\n", " doc=\"All possible events as attributes to get tab-completion and typo-proofing\")\n", "\n", "_before_epoch = [event.begin_fit, event.begin_epoch]\n", "_after_epoch = [event.after_epoch, event.after_fit]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "_all_ = ['event']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class event[source]

\n", "\n", "> event(**\\*`args`**, **\\*\\*`kwargs`**)\n", "\n", "All possible events as attributes to get tab-completion and typo-proofing" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(event, name='event', title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(event.after_backward, 'after_backward')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's the full list: *begin_fit begin_epoch begin_train begin_batch after_pred after_loss after_backward after_step after_cancel_batch after_batch after_cancel_train after_train begin_validate after_cancel_validate after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit*." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Full test of the control flow below, after the Learner class" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learner -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "defaults.lr = slice(3e-3)\n", "defaults.wd = 1e-2\n", "defaults.callbacks = [TrainEvalCallback]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def replacing_yield(o, attr, val):\n", " \"Context manager to temporarily replace an attribute\"\n", " old = getattr(o,attr)\n", " try: yield setattr(o,attr,val)\n", " finally: setattr(o,attr,old)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def mk_metric(m):\n", " \"Convert `m` to an `AvgMetric`, unless it's already a `Metric`\"\n", " return m if isinstance(m, Metric) else AvgMetric(m)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def save_model(file, model, opt, with_opt=True):\n", " \"Save `model` to `file` along with `opt` (if available, and if `with_opt`)\"\n", " if opt is None: with_opt=False\n", " state = get_model(model).state_dict()\n", " if with_opt: state = {'model': state, 'opt':opt.state_dict()}\n", " torch.save(state, file)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def load_model(file, model, opt, with_opt=None, device=None, strict=True):\n", " \"Load `model` from `file` along with `opt` (if available, and if `with_opt`)\"\n", " if isinstance(device, int): device = torch.device('cuda', device)\n", " elif device is None: device = 'cpu'\n", " state = torch.load(file, map_location=device)\n", " hasopt = set(state)=={'model', 'opt'}\n", " model_state = state['model'] if hasopt else state\n", " get_model(model).load_state_dict(model_state, strict=strict)\n", " if hasopt and ifnone(with_opt,True):\n", " try: opt.load_state_dict(state['opt'])\n", " except:\n", " if with_opt: warn(\"Could not load the optimizer state.\")\n", " elif with_opt: warn(\"Saved filed doesn't contain an optimizer state.\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _try_concat(o):\n", " try:\n", " return torch.cat(o)\n", " except:\n", " return sum([L(o_[i,:] for i in range_of(o_)) for o_ in o], L())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Learner():\n", " def __init__(self, dbunch, model, loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params, cbs=None,\n", " cb_funcs=None, metrics=None, path=None, model_dir='models', wd_bn_bias=False, train_bn=True):\n", " store_attr(self, \"dbunch,model,opt_func,lr,splitter,model_dir,wd_bn_bias,train_bn,metrics\")\n", " self.training,self.logger,self.opt,self.cbs = False,print,None,L()\n", " #TODO: infer loss_func from data\n", " if loss_func is None:\n", " loss_func = getattr(dbunch.train_ds, 'loss_func', None)\n", " assert loss_func is not None, \"Could not infer loss function from the data, please pass a loss function.\"\n", " self.loss_func = loss_func\n", " self.path = path if path is not None else getattr(dbunch, 'path', Path('.'))\n", " self.add_cbs(cbf() for cbf in L(defaults.callbacks)+L(cb_funcs))\n", " self.add_cbs(cbs)\n", " self.model.to(self.dbunch.device)\n", "\n", " @property\n", " def metrics(self): return self._metrics\n", " @metrics.setter\n", " def metrics(self,v): self._metrics = L(v).map(mk_metric)\n", " \n", " def add_cbs(self, cbs): L(cbs).map(self.add_cb)\n", " def remove_cbs(self, cbs): L(cbs).map(self.remove_cb)\n", " def add_cb(self, cb):\n", " old = getattr(self, cb.name, None)\n", " assert not old or isinstance(old, type(cb)), f\"self.{cb.name} already registered\"\n", " cb.learn = self\n", " setattr(self, cb.name, cb)\n", " self.cbs.append(cb)\n", " return self\n", "\n", " def remove_cb(self, cb):\n", " cb.learn = None\n", " if hasattr(self, cb.name): delattr(self, cb.name)\n", " if cb in self.cbs: self.cbs.remove(cb)\n", "\n", " @contextmanager\n", " def added_cbs(self, cbs):\n", " self.add_cbs(cbs)\n", " yield\n", " self.remove_cbs(cbs)\n", " \n", " def ordered_cbs(self, cb_func:str): return [cb for cb in sort_by_run(self.cbs) if hasattr(cb, cb_func)]\n", "\n", " def __call__(self, event_name): L(event_name).map(self._call_one)\n", " def _call_one(self, event_name):\n", " assert hasattr(event, event_name)\n", " [cb(event_name) for cb in sort_by_run(self.cbs)]\n", "\n", " def _bn_bias_state(self, with_bias): return bn_bias_params(self.model, with_bias).map(self.opt.state)\n", " def create_opt(self):\n", " self.opt = self.opt_func(self.splitter(self.model), lr=self.lr)\n", " if not self.wd_bn_bias:\n", " for p in self._bn_bias_state(False): p['do_wd'] = False\n", " if self.train_bn:\n", " for p in self._bn_bias_state(True ): p['force_train'] = True\n", "\n", " def _split(self, b):\n", " i = getattr(self.dbunch, 'n_inp', 1 if len(b)==1 else len(b)-1)\n", " self.xb,self.yb = b[:i],b[i:]\n", "\n", " def all_batches(self):\n", " self.n_iter = len(self.dl)\n", " for o in enumerate(self.dl): self.one_batch(*o)\n", "\n", " def one_batch(self, i, b):\n", " self.iter = i\n", " try:\n", " self._split(b); self('begin_batch')\n", " self.pred = self.model(*self.xb); self('after_pred')\n", " if len(self.yb) == 0: return\n", " self.loss = self.loss_func(self.pred, *self.yb); self('after_loss')\n", " if not self.training: return\n", " self.loss.backward(); self('after_backward')\n", " self.opt.step(); self('after_step')\n", " self.opt.zero_grad()\n", " except CancelBatchException: self('after_cancel_batch')\n", " finally: self('after_batch')\n", "\n", " def _do_begin_fit(self, n_epoch):\n", " self.n_epoch,self.loss = n_epoch,tensor(0.); self('begin_fit')\n", "\n", " def _do_epoch_train(self):\n", " try:\n", " self.dl = self.dbunch.train_dl; self('begin_train')\n", " self.all_batches()\n", " except CancelTrainException: self('after_cancel_train')\n", " finally: self('after_train')\n", "\n", " def _do_epoch_validate(self, ds_idx=1, dl=None):\n", " if dl is None: dl = self.dbunch.dls[ds_idx]\n", " try:\n", " self.dl = dl; self('begin_validate')\n", " with torch.no_grad(): self.all_batches()\n", " except CancelValidException: self('after_cancel_validate')\n", " finally: self('after_validate')\n", "\n", " def fit(self, n_epoch, lr=None, wd=defaults.wd, cbs=None, reset_opt=False):\n", " with self.added_cbs(cbs):\n", " if reset_opt or not self.opt: self.create_opt()\n", " self.opt.set_hypers(wd=wd, lr=self.lr if lr is None else lr)\n", "\n", " try:\n", " self._do_begin_fit(n_epoch)\n", " for epoch in range(n_epoch):\n", " try:\n", " self.epoch=epoch; self('begin_epoch')\n", " self._do_epoch_train()\n", " self._do_epoch_validate()\n", " except CancelEpochException: self('after_cancel_epoch')\n", " finally: self('after_epoch')\n", "\n", " except CancelFitException: self('after_cancel_fit')\n", " finally: self('after_fit')\n", "\n", " def validate(self, ds_idx=1, dl=None, cbs=None):\n", " self.epoch,self.n_epoch,self.loss = 0,1,tensor(0.)\n", " if dl is None: dl = self.dbunch.dls[ds_idx]\n", " with self.added_cbs(cbs), self.no_logging():\n", " self(_before_epoch)\n", " self._do_epoch_validate(ds_idx, dl)\n", " self(_after_epoch)\n", " return self.recorder.values[-1]\n", "\n", " def get_preds(self, ds_idx=1, dl=None, with_input=False, with_loss=False, with_decoded=False, act=None):\n", " self.epoch,self.n_epoch,self.loss = 0,1,tensor(0.)\n", " cb = GatherPredsCallback(with_input=with_input, with_loss=with_loss)\n", " with self.no_logging(), self.added_cbs(cb), self.loss_not_reduced():\n", " self(_before_epoch)\n", " self._do_epoch_validate(ds_idx, dl)\n", " self(_after_epoch)\n", " if act is None: act = getattr(self.loss_func, 'activation', noop)\n", " preds = act(torch.cat(cb.preds))\n", " res = (preds, detuplify(tuple(torch.cat(o) for o in zip(*cb.targets))))\n", " if with_decoded: res = res + (getattr(self.loss_func, 'decodes', noop)(preds),)\n", " if with_input: res = (tuple(_try_concat(o) for o in zip(*cb.inputs)),) + res\n", " if with_loss: res = res + (torch.cat(cb.losses),)\n", " return res\n", "\n", " def predict(self, item, rm_type_tfms=0):\n", " dl = test_dl(self.dbunch, [item], rm_type_tfms=rm_type_tfms)\n", " inp,preds,_ = self.get_preds(dl=dl, with_input=True)\n", " dec_preds = getattr(self.loss_func, 'decodes', noop)(preds)\n", " i = getattr(self.dbunch, 'n_inp', -1)\n", " full_dec = self.dbunch.decode_batch((*inp,dec_preds))[0][i:]\n", " return detuplify(full_dec),dec_preds[0],preds[0]\n", "\n", " def show_results(self, ds_idx=0, dl=None, max_n=10, **kwargs):\n", " if dl is None: dl = self.dbunch.dls[ds_idx]\n", " b = dl.one_batch()\n", " _,_,preds = self.get_preds(dl=[b], with_decoded=True)\n", " self.dbunch.show_results(b, preds, max_n=max_n, **kwargs)\n", " \n", " def show_training_loop(self):\n", " loop = ['Start Fit', 'begin_fit', 'Start Epoch Loop', 'begin_epoch', 'Start Train', 'begin_train', \n", " 'Start Batch Loop', 'begin_batch', 'after_pred', 'after_loss', 'after_backward', \n", " 'after_step', 'after_cancel_batch', 'after_batch','End Batch Loop','End Train', \n", " 'after_cancel_train', 'after_train', 'Start Valid', 'begin_validate','Start Batch Loop', \n", " '**CBs same as train batch**', 'End Batch Loop', 'End Valid', 'after_cancel_validate', \n", " 'after_validate', 'End Epoch Loop', 'after_cancel_epoch', 'after_epoch', 'End Fit', \n", " 'after_cancel_fit', 'after_fit']\n", " indent = 0\n", " for s in loop:\n", " if s.startswith('Start'): print(f'{\" \"*indent}{s}'); indent += 2\n", " elif s.startswith('End'): indent -= 2; print(f'{\" \"*indent}{s}')\n", " else: print(f'{\" \"*indent} - {s:15}:', self.ordered_cbs(s))\n", "\n", " @contextmanager\n", " def no_logging(self): return replacing_yield(self, 'logger', noop)\n", "\n", " @contextmanager\n", " def loss_not_reduced(self):\n", " if hasattr(self.loss_func, 'reduction'): return replacing_yield(self.loss_func, 'reduction', 'none')\n", " else: return replacing_yield(self, 'loss_func', partial(self.loss_func, reduction='none'))\n", "\n", " def save(self, file, with_opt=True):\n", " if rank_distrib(): return # don't save if slave proc\n", " file = join_path_file(file, self.path/self.model_dir, ext='.pth')\n", " save_model(file, self.model, getattr(self,'opt',None), with_opt)\n", "\n", " def load(self, file, with_opt=None, device=None, strict=True):\n", " if device is None: device = self.dbunch.device\n", " if self.opt is None: self.create_opt()\n", " file = join_path_file(file, self.path/self.model_dir, ext='.pth')\n", " load_model(file, self.model, self.opt, with_opt=with_opt, device=device, strict=strict)\n", " return self\n", "\n", "Learner.x,Learner.y = add_props(lambda i,x: detuplify((x.xb,x.yb)[i]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "add_docs(Learner, \"Group together a `model`, some `dbunch` and a `loss_func` to handle training\",\n", " add_cbs=\"Add `cbs` to the list of `Callback` and register `self` as their learner\",\n", " add_cb=\"Add `cb` to the list of `Callback` and register `self` as their learner\",\n", " remove_cbs=\"Remove `cbs` from the list of `Callback` and deregister `self` as their learner\",\n", " remove_cb=\"Add `cb` from the list of `Callback` and deregister `self` as their learner\",\n", " added_cbs=\"Context manage that temporarily adds `cbs`\",\n", " ordered_cbs=\"Return a list of `Callback` for one step `cb_func` in the training loop\",\n", " create_opt=\"Create an optimizer with `lr`\",\n", " one_batch=\"Train or evaluate `self.model` on batch `(xb,yb)`\",\n", " all_batches=\"Train or evaluate `self.model` on all batches of `self.dl`\",\n", " fit=\"Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`.\",\n", " validate=\"Validate on `dl` with potential new `cbs`.\",\n", " get_preds=\"Get the predictions and targets on the `ds_idx`-th dbunchset or `dl`, optionally `with_input` and `with_loss`\",\n", " predict=\"Return the prediction on `item`, fully decoded, loss function decoded and probabilities\",\n", " show_results=\"Show some predictions on `ds_idx`-th dbunchset or `dl`\",\n", " show_training_loop=\"Show each step in the training loop\",\n", " no_logging=\"Context manager to temporarily remove `logger`\",\n", " loss_not_reduced=\"A context manager to evaluate `loss_func` with reduction set to none.\",\n", " save=\"Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`\",\n", " load=\"Load model and optimizer state (if `with_opt`) from `self.path/self.model_dir/file` using `device`\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`opt_func` will be used to create an optimizer when `Learner.fit` is called, with `lr` as a learning rate. `splitter` is a function taht takes `self.model` and returns a list of parameter groups (or just one parameter group if there are no different parameter groups). The default is `trainable_params`, which returns all trainable parameters of the model.\n", "\n", "`cbs` is one or a list of `Callback`s to pass to the `Learner`, and `cb_funcs` is one or a list of functions returning a `Callback` that will be called at init. Each `Callback` is registered as an attribute of `Learner` (with camel case). At creation, all the callbacks in `defaults.callbacks` (`TrainEvalCallback` and `Recorder`) are associated to the `Learner`.\n", "\n", "`metrics` is an optional list of metrics, that can be either functions or `Metric`s (see below)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Training loop" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Test init with callbacks\n", "def synth_learner(n_train=10, n_valid=2, cuda=False, lr=defaults.lr, **kwargs):\n", " data = synth_dbunch(n_train=n_train,n_valid=n_valid, cuda=cuda)\n", " return Learner(data, RegModel(), loss_func=MSELossFlat(), lr=lr, **kwargs)\n", "\n", "tst_learn = synth_learner()\n", "test_eq(len(tst_learn.cbs), 1)\n", "assert isinstance(tst_learn.cbs[0], TrainEvalCallback)\n", "assert hasattr(tst_learn, ('train_eval'))\n", "\n", "tst_learn = synth_learner(cbs=TstCallback())\n", "test_eq(len(tst_learn.cbs), 2)\n", "assert isinstance(tst_learn.cbs[1], TstCallback)\n", "assert hasattr(tst_learn, ('tst'))\n", "\n", "tst_learn = synth_learner(cb_funcs=TstCallback)\n", "test_eq(len(tst_learn.cbs), 2)\n", "assert isinstance(tst_learn.cbs[1], TstCallback)\n", "assert hasattr(tst_learn, ('tst'))\n", "\n", "#A name that becomes an existing attribute of the Learner will throw an exception (here add_cb)\n", "class AddCbCallback(Callback): pass\n", "test_fail(lambda: synth_learner(cbs=AddCbCallback()))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.fit[source]

\n", "\n", "> Learner.fit(**`n_epoch`**, **`lr`**=*`None`*, **`wd`**=*`0.01`*, **`cbs`**=*`None`*, **`reset_opt`**=*`False`*)\n", "\n", "Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.fit)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Training a few epochs should make the model better\n", "learn = synth_learner(cb_funcs=TstCallback, lr=1e-2)\n", "xb,yb = learn.dbunch.one_batch()\n", "init_loss = learn.loss_func(learn.model(xb), yb)\n", "learn.fit(2)\n", "assert learn.loss < init_loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Test of TrainEvalCallback\n", "class TestTrainEvalCallback(Callback):\n", " run_after=TrainEvalCallback\n", " def begin_fit(self): \n", " test_eq([self.pct_train,self.train_iter], [0., 0])\n", " self.old_pct_train,self.old_train_iter = self.pct_train,self.train_iter\n", " \n", " def begin_batch(self): test_eq(next(self.model.parameters()).device, find_device(self.xb))\n", " \n", " def after_batch(self):\n", " if self.training:\n", " test_eq(self.pct_train , self.old_pct_train+1/(self.n_iter*self.n_epoch))\n", " test_eq(self.train_iter, self.old_train_iter+1)\n", " self.old_pct_train,self.old_train_iter = self.pct_train,self.train_iter\n", " \n", " def begin_train(self):\n", " assert self.training and self.model.training\n", " test_eq(self.pct_train, self.epoch/self.n_epoch)\n", " self.old_pct_train = self.pct_train\n", " \n", " def begin_validate(self):\n", " assert not self.training and not self.model.training\n", " \n", "learn = synth_learner(cb_funcs=TestTrainEvalCallback)\n", "learn.fit(1)\n", "#Check order is properly taken into account\n", "learn.cbs = L(reversed(learn.cbs))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#cuda\n", "#Check model is put on the GPU if needed\n", "learn = synth_learner(cb_funcs=TestTrainEvalCallback, cuda=True)\n", "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Check wd is not applied on bn/bias when option wd_bn_bias=False\n", "class _TstModel(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))\n", " self.tst = nn.Sequential(nn.Linear(4,5), nn.BatchNorm1d(3))\n", " self.tst[0].bias.data,self.tst[1].bias.data = torch.randn(5),torch.randn(3) \n", " def forward(self, x): return x * self.a + self.b\n", " \n", "class _PutGrad(Callback):\n", " def after_backward(self):\n", " for p in self.learn.model.tst.parameters():\n", " p.grad = torch.ones_like(p.data)\n", " \n", "learn = synth_learner(n_train=5, opt_func = partial(SGD, wd=1, decouple_wd=True), cb_funcs=_PutGrad)\n", "learn.model = _TstModel()\n", "init = [p.clone() for p in learn.model.tst.parameters()]\n", "learn.fit(1, lr=1e-2)\n", "end = list(learn.model.tst.parameters())\n", "assert not torch.allclose(end[0]-init[0], -0.05 * torch.ones_like(end[0]))\n", "for i in [1,2,3]: test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.one_batch[source]

\n", "\n", "> Learner.one_batch(**`i`**, **`b`**)\n", "\n", "Train or evaluate `self.model` on batch `(xb,yb)`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.one_batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is an internal method called by `Learner.fit`. If passed, `i` is the index of this iteration in the epoch. In training method, this does a full training step on the batch (compute predictions, loss, gradients, update the model parameters and zero the gradients). In validation mode, it stops at the loss computation." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class VerboseCallback(Callback):\n", " \"Callback that prints the name of each event called\"\n", " def __call__(self, event_name):\n", " print(event_name)\n", " super().__call__(event_name)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "class TestOneBatch(VerboseCallback):\n", " def __init__(self, xb, yb, i):\n", " self.save_xb,self.save_yb,self.i = xb,yb,i\n", " self.old_pred,self.old_loss = None,tensor(0.)\n", " \n", " def begin_batch(self):\n", " self.old_a,self.old_b = self.model.a.data.clone(),self.model.b.data.clone()\n", " test_eq(self.iter, self.i)\n", " test_eq(self.save_xb, *self.xb)\n", " test_eq(self.save_yb, *self.yb)\n", " if hasattr(self.learn, 'pred'): test_eq(self.pred, self.old_pred)\n", " \n", " def after_pred(self):\n", " self.old_pred = self.pred\n", " test_eq(self.pred, self.model.a.data * self.x + self.model.b.data)\n", " test_eq(self.loss, self.old_loss)\n", " \n", " def after_loss(self):\n", " self.old_loss = self.loss\n", " test_eq(self.loss, self.loss_func(self.old_pred, self.save_yb))\n", " for p in self.model.parameters(): \n", " if not hasattr(p, 'grad') or p.grad is not None: test_eq(p.grad, tensor([0.]))\n", " \n", " def after_backward(self):\n", " self.grad_a = (2 * self.x * (self.pred.data - self.y)).mean()\n", " self.grad_b = 2 * (self.pred.data - self.y).mean()\n", " test_close(self.model.a.grad.data, self.grad_a)\n", " test_close(self.model.b.grad.data, self.grad_b)\n", " test_eq(self.model.a.data, self.old_a)\n", " test_eq(self.model.b.data, self.old_b)\n", " \n", " def after_step(self):\n", " test_close(self.model.a.data, self.old_a - self.lr * self.grad_a)\n", " test_close(self.model.b.data, self.old_b - self.lr * self.grad_b)\n", " self.old_a,self.old_b = self.model.a.data.clone(),self.model.b.data.clone()\n", " test_close(self.model.a.grad.data, self.grad_a)\n", " test_close(self.model.b.grad.data, self.grad_b)\n", " \n", " def after_batch(self):\n", " for p in self.model.parameters(): test_eq(p.grad, tensor([0.]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "learn = synth_learner()\n", "b = learn.dbunch.one_batch()\n", "learn = synth_learner(cbs=TestOneBatch(*b, 42), lr=1e-2)\n", "#Remove train/eval\n", "learn.cbs = learn.cbs[1:]\n", "#Setup\n", "learn.loss,learn.training = tensor(0.),True\n", "learn.opt = SGD(learn.model.parameters(), lr=learn.lr)\n", "learn.model.train()\n", "batch_events = ['begin_batch', 'after_pred', 'after_loss', 'after_backward', 'after_step', 'after_batch']\n", "test_stdout(lambda: learn.one_batch(42, b), '\\n'.join(batch_events))\n", "test_stdout(lambda: learn.one_batch(42, b), '\\n'.join(batch_events)) #Check it works for a second batch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.all_batches[source]

\n", "\n", "> Learner.all_batches()\n", "\n", "Train or evaluate `self.model` on all batches of `self.dl`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.all_batches)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "learn = synth_learner(n_train=5, cbs=VerboseCallback())\n", "learn.opt = SGD(learn.model.parameters(), lr=learn.lr)\n", "with redirect_stdout(io.StringIO()): \n", " learn._do_begin_fit(1)\n", " learn.epoch,learn.dl = 0,learn.dbunch.train_dl\n", " learn('begin_epoch')\n", " learn('begin_train')\n", "test_stdout(learn.all_batches, '\\n'.join(batch_events * 5))\n", "test_eq(learn.train_iter, 5)\n", "\n", "valid_events = ['begin_batch', 'after_pred', 'after_loss', 'after_batch']\n", "with redirect_stdout(io.StringIO()): \n", " learn.dl = learn.dbunch.valid_dl\n", " learn('begin_validate')\n", "test_stdout(learn.all_batches, '\\n'.join(valid_events * 2))\n", "test_eq(learn.train_iter, 5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "learn = synth_learner(n_train=5, cbs=VerboseCallback())\n", "test_stdout(lambda: learn._do_begin_fit(42), 'begin_fit')\n", "test_eq(learn.n_epoch, 42)\n", "test_eq(learn.loss, tensor(0.))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "learn.opt = SGD(learn.model.parameters(), lr=learn.lr)\n", "learn.epoch = 0\n", "test_stdout(lambda: learn._do_epoch_train(), '\\n'.join(['begin_train'] + batch_events * 5 + ['after_train']))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "test_stdout(learn._do_epoch_validate, '\\n'.join(['begin_validate'] + valid_events * 2+ ['after_validate']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Serializing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.save[source]

\n", "\n", "> Learner.save(**`file`**, **`with_opt`**=*`True`*)\n", "\n", "Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.save)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`file` can be a `Path`, a `string` or a buffer." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.load[source]

\n", "\n", "> Learner.load(**`file`**, **`with_opt`**=*`None`*, **`device`**=*`None`*, **`strict`**=*`True`*)\n", "\n", "Load model and optimizer state (if `with_opt`) from `self.path/self.model_dir/file` using [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.load)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`file` can be a `Path`, a `string` or a buffer. Use `device` to load the model/optimizer state on a device different from the one it was saved." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = synth_learner(cb_funcs=TstCallback, opt_func=partial(SGD, mom=0.9))\n", "xb,yb = learn.dbunch.one_batch()\n", "init_loss = learn.loss_func(learn.model(xb), yb)\n", "learn.fit(1)\n", "learn.save('tmp')\n", "assert (Path.cwd()/'models/tmp.pth').exists()\n", "\n", "learn1 = synth_learner(cb_funcs=TstCallback, opt_func=partial(SGD, mom=0.9))\n", "learn1 = learn1.load('tmp')\n", "test_eq(learn.model.a, learn1.model.a)\n", "test_eq(learn.model.b, learn1.model.b)\n", "test_eq(learn.opt.state_dict(), learn1.opt.state_dict())\n", "\n", "learn.save('tmp1', with_opt=False)\n", "learn1 = synth_learner(cb_funcs=TstCallback, opt_func=partial(SGD, mom=0.9))\n", "learn1 = learn1.load('tmp1')\n", "test_eq(learn.model.a, learn1.model.a)\n", "test_eq(learn.model.b, learn1.model.b)\n", "test_ne(learn.opt.state_dict(), learn1.opt.state_dict())\n", "\n", "shutil.rmtree('models')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback handling" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.__call__[source]

\n", "\n", "> Learner.__call__(**`event_name`**)\n", "\n", "Call self as a function." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.__call__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.add_cb[source]

\n", "\n", "> Learner.add_cb(**`cb`**)\n", "\n", "Add `cb` to the list of [`Callback`](/learner.html#Callback) and register `self` as their learner" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.add_cb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = synth_learner()\n", "learn.add_cb(TestTrainEvalCallback())\n", "test_eq(len(learn.cbs), 2)\n", "assert isinstance(learn.cbs[1], TestTrainEvalCallback)\n", "test_eq(learn.train_eval.learn, learn)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.add_cbs[source]

\n", "\n", "> Learner.add_cbs(**`cbs`**)\n", "\n", "Add `cbs` to the list of [`Callback`](/learner.html#Callback) and register `self` as their learner" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.add_cbs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.add_cbs([TestTrainEvalCallback(), TestTrainEvalCallback()])\n", "test_eq(len(learn.cbs), 4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.remove_cb[source]

\n", "\n", "> Learner.remove_cb(**`cb`**)\n", "\n", "Add `cb` from the list of [`Callback`](/learner.html#Callback) and deregister `self` as their learner" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.remove_cb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cb = learn.cbs[1]\n", "learn.remove_cb(learn.cbs[1])\n", "test_eq(len(learn.cbs), 3)\n", "assert cb.learn is None\n", "assert not getattr(learn,'test_train_eval',None)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.remove_cbs[source]

\n", "\n", "> Learner.remove_cbs(**`cbs`**)\n", "\n", "Remove `cbs` from the list of [`Callback`](/learner.html#Callback) and deregister `self` as their learner" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.remove_cbs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cb = learn.cbs[1]\n", "learn.remove_cbs(learn.cbs[1:])\n", "test_eq(len(learn.cbs), 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When writing a callback, the following attributes of `Learner` are available:\n", "- `model`: the model used for training/validation\n", "- `data`: the underlying `DataBunch`\n", "- `loss_func`: the loss function used\n", "- `opt`: the optimizer used to udpate the model parameters\n", "- `opt_func`: the function used to create the optimizer\n", "- `cbs`: the list containing all `Callback`s\n", "- `dl`: current `DataLoader` used for iteration\n", "- `x`/`xb`: last input drawn from `self.dl` (potentially modified by callbacks). `xb` is always a tuple (potentially with one element) and `x` is detuplified. You can only assign to `xb`.\n", "- `y`/`yb`: last target drawn from `self.dl` (potentially modified by callbacks). `yb` is always a tuple (potentially with one element) and `y` is detuplified. You can only assign to `yb`.\n", "- `pred`: last predictions from `self.model` (potentially modified by callbacks)\n", "- `loss`: last computed loss (potentially modified by callbacks)\n", "- `n_epoch`: the number of epochs in this training\n", "- `n_iter`: the number of iterations in the current `self.dl`\n", "- `epoch`: the current epoch index (from 0 to `n_epoch-1`)\n", "- `iter`: the current iteration index in `self.dl` (from 0 to `n_iter-1`)\n", "\n", "The following attributes are added by `TrainEvalCallback` and should be available unless you went out of your way to remove that callback:\n", "\n", "- `train_iter`: the number of training iterations done since the beginning of this training\n", "- `pct_train`: from 0. to 1., the percentage of training iterations completed\n", "- `training`: flag to indicate if we're in training mode or not\n", "\n", "The following attribute is added by `Recorder` and should be available unless you went out of your way to remove that callback:\n", "\n", "- `smooth_loss`: an exponentially-averaged version of the training loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Control flow testing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "batch_events = ['begin_batch', 'after_pred', 'after_loss', 'after_backward', 'after_step', 'after_batch']\n", "batchv_events = ['begin_batch', 'after_pred', 'after_loss', 'after_batch']\n", "train_events = ['begin_train'] + batch_events + ['after_train']\n", "valid_events = ['begin_validate'] + batchv_events + ['after_validate']\n", "epoch_events = ['begin_epoch'] + train_events + valid_events + ['after_epoch']\n", "cycle_events = ['begin_fit'] + epoch_events + ['after_fit']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "learn = synth_learner(n_train=1, n_valid=1)\n", "test_stdout(lambda: learn.fit(1, cbs=VerboseCallback()), '\\n'.join(cycle_events))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "class TestCancelCallback(VerboseCallback):\n", " def __init__(self, cancel_at=event.begin_batch, exception=CancelBatchException, train=None):\n", " def _interrupt(): \n", " if train is None or train == self.training: raise exception()\n", " setattr(self, cancel_at, _interrupt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#test cancel batch\n", "for i,e in enumerate(batch_events[:-1]):\n", " be = batch_events[:i+1] + ['after_cancel_batch', 'after_batch']\n", " bev = be if i <3 else batchv_events\n", " cycle = cycle_events[:3] + be + ['after_train', 'begin_validate'] + bev + cycle_events[-3:]\n", " test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(cancel_at=e)), '\\n'.join(cycle))\n", "\n", "#CancelBatchException not caught if thrown in any other event\n", "for e in cycle_events:\n", " if e not in batch_events[:-1]:\n", " with redirect_stdout(io.StringIO()):\n", " cb = TestCancelCallback(cancel_at=e)\n", " test_fail(lambda: learn.fit(1, cbs=cb))\n", " learn.remove_cb(cb) #Have to remove it manually" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#test cancel train\n", "for i,e in enumerate(['begin_train'] + batch_events):\n", " be = batch_events[:i] + (['after_batch'] if i >=1 and i < len(batch_events) else []) \n", " be += ['after_cancel_train', 'after_train']\n", " cycle = cycle_events[:3] + be + ['begin_validate'] + batchv_events + cycle_events[-3:]\n", " test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelTrainException, True)), '\\n'.join(cycle))\n", "\n", "#CancelTrainException not caught if thrown in any other event\n", "for e in cycle_events:\n", " if e not in ['begin_train'] + batch_events[:-1]:\n", " with redirect_stdout(io.StringIO()):\n", " cb = TestCancelCallback(e, CancelTrainException)\n", " test_fail(lambda: learn.fit(1, cbs=cb))\n", " learn.remove_cb(cb) #Have to remove it manually " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#test cancel valid\n", "for i,e in enumerate(['begin_validate'] + batchv_events):\n", " bev = batchv_events[:i] + (['after_batch'] if i >=1 and i < len(batchv_events) else []) + ['after_cancel_validate']\n", " cycle = cycle_events[:3] + batch_events + ['after_train', 'begin_validate'] + bev + cycle_events[-3:]\n", " test_stdout(lambda: learn.fit(1, cbs=TestCancelCallback(e, CancelValidException, False)), '\\n'.join(cycle))\n", " \n", "#CancelValidException not caught if thrown in any other event\n", "for e in cycle_events:\n", " if e not in ['begin_validate'] + batch_events[:3]:\n", " with redirect_stdout(io.StringIO()):\n", " cb = TestCancelCallback(e, CancelValidException)\n", " test_fail(lambda: learn.fit(1, cbs=cb))\n", " learn.remove_cb(cb) #Have to remove it manually " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#test cancel epoch\n", "#In train\n", "for i,e in enumerate(['begin_train'] + batch_events):\n", " be = batch_events[:i] + (['after_batch'] if i >=1 and i=1 and i=1 and i=1 and iclass Metric[source]\n", "\n", "> Metric()\n", "\n", "Blueprint for defining a metric" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Metric, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Metrics can be simple averages (like accuracy) but sometimes their computation is a little bit more complex and can't be averaged over batches (like precision or recall), which is why we need a special class for them. For simple functions that can be computed as averages over batches, we can use the class `AvgMetric`, otherwise you'll need to implement the following methods.\n", "\n", "> Note: If your `Metric` has state depending on tensors, don't forget to store it on the CPU to avoid any potential memory leaks." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Metric.reset[source]

\n", "\n", "> Metric.reset()\n", "\n", "Reset inner state to prepare for new computation" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Metric.reset)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Metric.accumulate[source]

\n", "\n", "> Metric.accumulate(**`learn`**)\n", "\n", "Use `learn` to update the state with new results" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Metric.accumulate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Metric.value[source]

\n", "\n", "The value of the metric" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Metric.value, name='Metric.value')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Metric.name[source]

\n", "\n", "Name of the [`Metric`](/learner.html#Metric), camel-cased and with Metric removed" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Metric.name, name='Metric.name')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _maybe_reduce(val):\n", " if num_distrib()>1:\n", " val = val.clone()\n", " torch.distributed.all_reduce(val, op=torch.distributed.ReduceOp.SUM)\n", " val /= num_distrib()\n", " return val" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class AvgMetric(Metric):\n", " \"Average the values of `func` taking into account potential different batch sizes\"\n", " def __init__(self, func): self.func = func\n", " def reset(self): self.total,self.count = 0.,0\n", " def accumulate(self, learn):\n", " bs = find_bs(learn.yb)\n", " self.total += to_detach(_maybe_reduce(self.func(learn.pred, *learn.yb)))*bs\n", " self.count += bs\n", " @property\n", " def value(self): return self.total/self.count if self.count != 0 else None\n", " @property\n", " def name(self): return self.func.__name__" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class AvgMetric[source]

\n", "\n", "> AvgMetric(**`func`**) :: [`Metric`](/learner.html#Metric)\n", "\n", "Average the values of `func` taking into account potential different batch sizes" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(AvgMetric, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = synth_learner()\n", "tst = AvgMetric(lambda x,y: (x-y).abs().mean())\n", "t,u = torch.randn(100),torch.randn(100)\n", "tst.reset()\n", "for i in range(0,100,25): \n", " learn.pred,learn.yb = t[i:i+25],(u[i:i+25],)\n", " tst.accumulate(learn)\n", "test_close(tst.value, (t-u).abs().mean())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#With varying batch size\n", "tst.reset()\n", "splits = [0, 30, 50, 60, 100]\n", "for i in range(len(splits )-1): \n", " learn.pred,learn.yb = t[splits[i]:splits[i+1]],(u[splits[i]:splits[i+1]],)\n", " tst.accumulate(learn)\n", "test_close(tst.value, (t-u).abs().mean())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class AvgLoss(Metric):\n", " \"Average the losses taking into account potential different batch sizes\"\n", " def reset(self): self.total,self.count = 0.,0\n", " def accumulate(self, learn):\n", " bs = find_bs(learn.yb)\n", " self.total += to_detach(_maybe_reduce(learn.loss.mean()))*bs\n", " self.count += bs\n", " @property\n", " def value(self): return self.total/self.count if self.count != 0 else None\n", " @property\n", " def name(self): return \"loss\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class AvgLoss[source]

\n", "\n", "> AvgLoss() :: [`Metric`](/learner.html#Metric)\n", "\n", "Average the losses taking into account potential different batch sizes" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(AvgLoss, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = AvgLoss()\n", "t = torch.randn(100)\n", "tst.reset()\n", "for i in range(0,100,25): \n", " learn.yb,learn.loss = t[i:i+25],t[i:i+25].mean()\n", " tst.accumulate(learn)\n", "test_close(tst.value, t.mean())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#With varying batch size\n", "tst.reset()\n", "splits = [0, 30, 50, 60, 100]\n", "for i in range(len(splits )-1): \n", " learn.yb,learn.loss = t[splits[i]:splits[i+1]],t[splits[i]:splits[i+1]].mean()\n", " tst.accumulate(learn)\n", "test_close(tst.value, t.mean())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class AvgSmoothLoss(Metric):\n", " \"Smooth average of the losses (exponentially weighted with `beta`)\"\n", " def __init__(self, beta=0.98): self.beta = beta\n", " def reset(self): self.count,self.val = 0,tensor(0.)\n", " def accumulate(self, learn):\n", " self.count += 1\n", " self.val = torch.lerp(to_detach(learn.loss.mean()), self.val, self.beta)\n", " @property\n", " def value(self): return self.val/(1-self.beta**self.count)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class AvgSmoothLoss[source]

\n", "\n", "> AvgSmoothLoss(**`beta`**=*`0.98`*) :: [`Metric`](/learner.html#Metric)\n", "\n", "Smooth average of the losses (exponentially weighted with `beta`)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(AvgSmoothLoss, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = AvgSmoothLoss()\n", "t = torch.randn(100)\n", "tst.reset()\n", "val = tensor(0.)\n", "for i in range(4): \n", " learn.loss = t[i*25:(i+1)*25].mean()\n", " tst.accumulate(learn)\n", " val = val*0.98 + t[i*25:(i+1)*25].mean()*(1-0.98)\n", " test_close(val/(1-0.98**(i+1)), tst.value)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Recorder --" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastprogress.fastprogress import format_time\n", "\n", "def _maybe_item(t):\n", " t = t.value\n", " return t.item() if isinstance(t, Tensor) and t.numel()==1 else t" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Recorder(Callback):\n", " \"Callback that registers statistics (lr, loss and metrics) during training\"\n", " run_after = TrainEvalCallback\n", "\n", " def __init__(self, add_time=True, train_metrics=False, beta=0.98):\n", " self.add_time,self.train_metrics = add_time,train_metrics\n", " self.loss,self.smooth_loss = AvgLoss(),AvgSmoothLoss(beta=beta)\n", "\n", " def begin_fit(self):\n", " \"Prepare state for training\"\n", " self.lrs,self.iters,self.losses,self.values = [],[],[],[]\n", " names = self._valid_mets.attrgot('name')\n", " if self.train_metrics: names = names.map('train_{}') + names.map('valid_{}')\n", " else: names = L('train_loss', 'valid_loss') + names[1:]\n", " if self.add_time: names.append('time')\n", " self.metric_names = 'epoch'+names\n", " self.smooth_loss.reset()\n", "\n", " def after_batch(self):\n", " \"Update all metrics and records lr and smooth loss in training\"\n", " if len(self.yb) == 0: return\n", " mets = self._train_mets if self.training else self._valid_mets\n", " for met in mets: met.accumulate(self.learn)\n", " if not self.training: return\n", " self.lrs.append(self.opt.hypers[-1]['lr'])\n", " self.losses.append(self.smooth_loss.value)\n", " self.learn.smooth_loss = self.smooth_loss.value\n", "\n", " def begin_epoch(self):\n", " \"Set timer if `self.add_time=True`\"\n", " self.cancel_train,self.cancel_valid = False,False\n", " if self.add_time: self.start_epoch = time.time()\n", " self.log = L(getattr(self, 'epoch', 0))\n", "\n", " def begin_train (self): self._train_mets[1:].map(Self.reset())\n", " def begin_validate(self): self._valid_mets.map(Self.reset())\n", " def after_train (self): self.log += self._train_mets.map(_maybe_item)\n", " def after_validate(self): self.log += self._valid_mets.map(_maybe_item)\n", " def after_cancel_train(self): self.cancel_train = True\n", " def after_cancel_validate(self): self.cancel_valid = True\n", "\n", " def after_epoch(self):\n", " \"Store and log the loss/metric values\"\n", " self.values.append(self.log[1:].copy())\n", " if self.add_time: self.log.append(format_time(time.time() - self.start_epoch))\n", " self.logger(self.log)\n", " self.iters.append(self.smooth_loss.count)\n", "\n", " @property\n", " def _train_mets(self):\n", " if getattr(self, 'cancel_train', False): return L()\n", " return L(self.smooth_loss) + (self.metrics if self.train_metrics else L())\n", "\n", " @property\n", " def _valid_mets(self):\n", " if getattr(self, 'cancel_valid', False): return L()\n", " return L(self.loss) + self.metrics\n", "\n", " def plot_loss(self, skip_start=5, with_valid=True): \n", " plt.plot(self.losses[skip_start:], label='train')\n", " if with_valid:\n", " plt.plot(self.iters, L(self.values).itemgot(1), label='valid')\n", " plt.legend()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "add_docs(Recorder,\n", " begin_train = \"Reset loss and metrics state\",\n", " after_train = \"Log loss and metric values on the training set (if `self.training_metrics=True`)\",\n", " begin_validate = \"Reset loss and metrics state\",\n", " after_validate = \"Log loss and metric values on the validation set\",\n", " after_cancel_train = \"Ignore training metrics for this epoch\",\n", " after_cancel_validate = \"Ignore validation metrics for this epoch\",\n", " plot_loss = \"Plot the losses from `skip_start` and onward\")\n", "\n", "defaults.callbacks = [TrainEvalCallback, Recorder]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, metrics are computed on the validation set only, although that can be changed with `training_metrics=True`. `beta` is the weight used to compute the exponentially weighted average of the losses (which gives the `smooth_loss` attribute to `Learner`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Test printed output\n", "def tst_metric(out, targ): return F.mse_loss(out, targ)\n", "learn = synth_learner(n_train=5, metrics=tst_metric)\n", "pat = r\"[tensor\\(\\d.\\d*\\), tensor\\(\\d.\\d*\\), tensor\\(\\d.\\d*\\), 'dd:dd']\"\n", "test_stdout(lambda: learn.fit(1), pat, regex=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "class TestRecorderCallback(Callback):\n", " run_after=Recorder\n", " \n", " def begin_fit(self): \n", " self.train_metrics,self.add_time = self.recorder.train_metrics,self.recorder.add_time\n", " self.beta = self.recorder.smooth_loss.beta\n", " for m in self.metrics: assert isinstance(m, Metric)\n", " test_eq(self.recorder.smooth_loss.val, 0.)\n", " #To test what the recorder logs, we use a custom logger function.\n", " self.learn.logger = self.test_log\n", " self.old_smooth,self.count = tensor(0.),0\n", " \n", " def after_batch(self):\n", " if self.training:\n", " self.count += 1\n", " test_eq(len(self.recorder.lrs), self.count)\n", " test_eq(self.recorder.lrs[-1], self.opt.hypers[-1]['lr'])\n", " test_eq(len(self.recorder.losses), self.count)\n", " smooth = (1 - self.beta**(self.count-1)) * self.old_smooth * self.beta + self.loss * (1-self.beta)\n", " smooth /= 1 - self.beta**self.count\n", " test_close(self.recorder.losses[-1], smooth, eps=1e-4)\n", " test_close(self.smooth_loss, smooth, eps=1e-4)\n", " self.old_smooth = self.smooth_loss\n", " self.bs += find_bs(self.yb)\n", " if not self.training: test_eq(self.recorder.loss.count, self.bs)\n", " if self.train_metrics or not self.training: \n", " for m in self.metrics: test_eq(m.count, self.bs)\n", " self.losses.append(self.loss.detach().cpu())\n", " \n", " def begin_epoch(self): \n", " if self.add_time: self.start_epoch = time.time()\n", " self.log = [self.epoch]\n", " \n", " def begin_train(self):\n", " self.bs = 0\n", " self.losses = []\n", " for m in self.recorder._train_mets: test_eq(m.count, self.bs)\n", " \n", " def after_train(self):\n", " mean = tensor(self.losses).mean()\n", " self.log += [self.smooth_loss, mean] if self.train_metrics else [self.smooth_loss]\n", " test_eq(self.log, self.recorder.log)\n", " self.losses = []\n", " \n", " def begin_validate(self):\n", " self.bs = 0\n", " self.losses = []\n", " for m in [self.recorder.loss] + self.metrics: test_eq(m.count, self.bs)\n", " \n", " def test_log(self, log):\n", " res = tensor(self.losses).mean()\n", " self.log += [res, res]\n", " if self.add_time: self.log.append(format_time(time.time() - self.start_epoch))\n", " test_eq(log, self.log)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "learn = synth_learner(n_train=5, metrics = tst_metric, cb_funcs = TestRecorderCallback)\n", "learn.fit(1)\n", "test_eq(learn.recorder.metric_names, ['epoch', 'train_loss', 'valid_loss', 'tst_metric', 'time'])\n", "\n", "learn = synth_learner(n_train=5, metrics = tst_metric, cb_funcs = TestRecorderCallback)\n", "learn.recorder.train_metrics=True\n", "learn.fit(1)\n", "test_eq(learn.recorder.metric_names, \n", " ['epoch', 'train_loss', 'train_tst_metric', 'valid_loss', 'valid_tst_metric', 'time'])\n", "\n", "learn = synth_learner(n_train=5, metrics = tst_metric, cb_funcs = TestRecorderCallback)\n", "learn.recorder.add_time=False\n", "learn.fit(1)\n", "test_eq(learn.recorder.metric_names, ['epoch', 'train_loss', 'valid_loss', 'tst_metric'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(#5) [0,10.249631881713867,9.148826599121094,9.148827075958252,00:00]\n" ] } ], "source": [ "#hide\n", "#Test numpy metric\n", "def tst_metric_np(out, targ): return F.mse_loss(out, targ).numpy()\n", "learn = synth_learner(n_train=5, metrics=tst_metric_np)\n", "learn.fit(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback internals" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Recorder.begin_fit[source]

\n", "\n", "> Recorder.begin_fit()\n", "\n", "Prepare state for training" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.begin_fit)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Recorder.begin_epoch[source]

\n", "\n", "> Recorder.begin_epoch()\n", "\n", "Set timer if `self.add_time=True`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.begin_epoch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Recorder.begin_validate[source]

\n", "\n", "> Recorder.begin_validate()\n", "\n", "Reset loss and metrics state" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.begin_validate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Recorder.after_batch[source]

\n", "\n", "> Recorder.after_batch()\n", "\n", "Update all metrics and records lr and smooth loss in training" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.after_batch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Recorder.after_epoch[source]

\n", "\n", "> Recorder.after_epoch()\n", "\n", "Store and log the loss/metric values" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.after_epoch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plotting tools" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Recorder.plot_loss[source]

\n", "\n", "> Recorder.plot_loss(**`skip_start`**=*`5`*, **`with_valid`**=*`True`*)\n", "\n", "Plot the losses from `skip_start` and onward" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.plot_loss)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de3yU5Z338c8vJ5JAQiCEY8IZEgqEABGhCKIoB22tXanFavusbsuitlVbFdrdbevTXeuhB7EuWqq126dUbVHXtgoqVqRqUYNAOIQzSCIC4QzhHH7PH0kphAATmMmdmfm+X6+8ZOa+ZuY7L+u3F3eu+7rN3RERkeiXEHQAEREJDxW6iEiMUKGLiMQIFbqISIxQoYuIxAgVuohIjDhnoZtZvpktPulnr5ndWWfMjWZWWvvzrpkNiFxkERGpjzVkHbqZJQIfAxe7+0cnPf9poMzdd5nZeOAH7n5x2NOKiMgZJTVw/Ghg3cllDuDu7570cAGQe6HBRESkYRpa6BOBZ84x5l+A2fUdMLNJwCSA5s2bDy4oKGjgx4uIxLeFCxdud/ec+o6FfMrFzFKAzUBfd996hjGXAdOBS9x9x9ner7i42EtKSkL6bBERqWFmC929uL5jDZmhjwc+PEuZFwJPAuPPVeYiIhJ+DVm2eANnON1iZp2BF4Avu/vqcAQTEZGGCWmGbmbpwJXAv5703GQAd38C+B6QDUw3M4BjZ/orgYiIREZIhe7uB6gp7JOfe+KkP38V+Gp4o4mInOro0aNUVFRw6NChoKNEXGpqKrm5uSQnJ4f8moauchERCUxFRQUZGRl07dqV2rMBMcnd2bFjBxUVFXTr1i3k1+nSfxGJGocOHSI7OzumyxzAzMjOzm7w30RU6CISVWK9zP/ufL5n1BX6ph0HuO9PyzlafTzoKCIiTUrUFfqabft4+p2NPPtBedBRRCTO7N69m+nTpzf4dVdddRW7d++OQKJTRV2hX17QliHdWjNt7hqqDh8LOo6IxJEzFXp1dfVZX/fKK6+QlZUVqVgnRF2hmxnfGV/A9v2H+eVf1wcdR0TiyNSpU1m3bh1FRUVcdNFFXHbZZXzpS1+if//+AFx77bUMHjyYvn37MmPGjBOv69q1K9u3b2fjxo306dOHr33ta/Tt25cxY8Zw8ODBsOWLymWLAzu3Yny/9syYv54bL+5CTkazoCOJSCO770/LWbF5b1jf81MdM/n+Z/ue8fgDDzzAsmXLWLx4MfPmzePqq69m2bJlJ5YW/upXv6J169YcPHiQiy66iOuuu47s7FMu4WHNmjU888wz/PKXv+T666/n+eef56abbgpL/qibof/dPWPzOXzsOI++sSboKCISp4YMGXLKOvFHH32UAQMGMHToUMrLy1mz5vR+6tatG0VFRQAMHjyYjRs3hi1PVM7QAbrntOCGIXk88/4mbrmkG93aNA86kog0orPNpBtL8+b/6J158+Yxd+5c/va3v5Gens6oUaPqXUferNk/zigkJiaG9ZRL1M7QAe4Y3ZuUpAQefnVl0FFEJA5kZGSwb9++eo/t2bOHVq1akZ6ezsqVK1mwYEEjp4viGTpATkYzvjaiO9PeWMOiTbsY2LlV0JFEJIZlZ2czfPhw+vXrR1paGu3atTtxbNy4cTzxxBMUFhaSn5/P0KFDGz1fg+4pGk7husHF/sPHGPXwm3TPacFzk4bGzVVkIvGorKyMPn36BB2j0dT3fc92g4uoPuUC0KJZEneM7sX7G3by5qptQccREQlM1Bc6wMQhnenWpjkPzl5F9fFg/sYhIhK0mCj05MQE7hmbz6qt+3j+w4qg44iIBCImCh1gfL/2FOVl8bPXV3Po6NkvwxURiUXnLHQzyzezxSf97DWzO+uMMTN71MzWmlmpmQ2KXOQz5uQ74wv4ZM8hnn5nY2N/vIhI4M5Z6O6+yt2L3L0IGAwcAF6sM2w80Kv2ZxLweLiDhuLi7tmMLmjL9Hlr2VV1JIgIIiKBaegpl9HAOnf/qM7znwN+4zUWAFlm1iEsCRtoyvgCqg4f47/fXBvEx4uInNCiRQsANm/ezIQJE+odM2rUKMKxhBsaXugTgWfqeb4TcPIG5RW1zzW63u0ymDA4l9/87SPKdx4IIoKIyCk6duzIrFmzIv45IRe6maUA1wB/qO9wPc+dtn7QzCaZWYmZlVRWVoaesoHuurI3ZvDT11dH7DNEJP5MmTLllP3Qf/CDH3DfffcxevRoBg0aRP/+/XnppZdOe93GjRvp168fAAcPHmTixIkUFhbyxS9+MbDtc8cDH7r71nqOVQB5Jz3OBTbXHeTuM4AZUHOlaAM+u0E6tEzjlku68cRb6/jqiG707dgyUh8lIkGZPRW2LA3ve7bvD+MfOOPhiRMncuedd3LbbbcB8Pvf/545c+Zw1113kZmZyfbt2xk6dCjXXHPNGa9af/zxx0lPT6e0tJTS0lIGDQrfGpKGnHK5gfpPtwD8EfhK7WqXocAed//kgtNdgMmX9qBlWjIPzNbGXSISHgMHDmTbtm1s3ryZJUuW0KpVKzp06MB3v/tdCgsLueKKK/j444/ZurW+eW+N+fPnn9j/vLCwkMLCwrDlC2mGbmbpwJXAv5703GQAd38CeAW4ClhLzSqYm8OW8Dy1TEvm65f15D9fLuPtNdu5pFeboCOJSDidZSYdSRMmTGDWrFls2bKFiRMnMnPmTCorK1m4cCHJycl07dq13m1zTxapPadCmqG7+wF3z3b3PSc990RtmVO7uuV2d+/h7v3dPTy/sr1AXx7WhU5ZafxodhnHtSWAiITBxIkTefbZZ5k1axYTJkxgz549tG3bluTkZN58800++qjuIsBTjRw5kpkzZwKwbNkySktLw5YtZq4UrU+zpETuHtub5Zv38qfS007pi4g0WN++fdm3bx+dOnWiQ4cO3HjjjZSUlFBcXMzMmTMpKCg46+tvvfVW9u/fT2FhIQ899BBDhgwJW7ao3z73XI4fdz7z87fZe+gob3z7UpolJUb8M0UkMrR9boxvn3suCQnG1PEFVOw6yG8XbAo6johIxMR8oQOM7J3DJT3b8Nhf1rD30NGg44iIRERcFDrA1PEF7DpwlCfmrQs6iohcgKBOEze28/mecVPo/Tq15HNFHfnVOxvYsufsS4pEpGlKTU1lx44dMV/q7s6OHTtITU1t0Oui+ibRDXX3mHxmL93Cz15fzYMTwreYX0QaR25uLhUVFURy65CmIjU1ldzc3Aa9Jq4KPa91OjcN7cKv393AV0d0o1e7jKAjiUgDJCcn061bt6BjNFlxc8rl775+eU+apyTx4JxVQUcREQmruCv01s1TmDyqB3PLtvLBxp1BxxERCZu4K3SAW4Z3o31mKve/Uhbzv1wRkfgRl4WelpLIXVf2YtGm3by6fEvQcUREwiIuCx3gukG59GrbgofmrOJo9fGg44iIXLC4LfSkxASmjCtg/fYqnvug/NwvEBFp4uK20AFG92nLkK6teWTuGqoOHws6jojIBYnrQjczpl5VwPb9h3nyrxuCjiMickHiutABBnVuxfh+7Zkxfx3b9x8OOo6IyHmL+0IHuGdsPoeOHefRN9YEHUVE5LyFVOhmlmVms8xspZmVmdmwOsdbmtmfzGyJmS03s8DvKdoQ3XNacMOQPH733iY2bK8KOo6IyHkJdYY+DZjj7gXAAKCszvHbgRXuPgAYBfzEzFLClrIRfHN0L1KSEvjxq9oSQESi0zkL3cwygZHAUwDufsTdd9cZ5kCG1dzKugWwE4iqZSNtM1L56ojuvLz0ExaX1/16IiJNXygz9O5AJfC0mS0ysyfNrHmdMY8BfYDNwFLgDnc/7WodM5tkZiVmVtIUt7+cNLI7bVqk8CNtCSAiUSiUQk8CBgGPu/tAoAqYWmfMWGAx0BEoAh6rndmfwt1nuHuxuxfn5ORcWPIIaNEsiW+O7sV7G3by5qptQccREWmQUAq9Aqhw9/dqH8+ipuBPdjPwgtdYC2wACsIXs/HcMKQzXbPTeXD2KqqPa5YuItHjnIXu7luAcjPLr31qNLCizrBNtc9jZu2AfGB9GHM2muTEBO4ZW8Cqrft4/sOKoOOIiIQs1FUu3wBmmlkpNadU7jezyWY2ufb4D4FPm9lS4A1girtvD3/cxnFV//YMyMviZ6+v5tDR6qDjiIiEJKRb0Ln7YqC4ztNPnHR8MzAmjLkCZWZ8Z3wBE2cs4Ol3NnLrqB5BRxIROSddKXoGQ7tnc3lBW6bPW8uuqiNBxxEROScV+llMGVdA1eFjTJ+3NugoIiLnpEI/i/z2GVw3KJf/efcjKnYdCDqOiMhZqdDP4VtjemMGP31tddBRRETOSoV+Dh1apnHz8G68uPhjVmzeG3QcEZEzUqGH4NZRPWiZlswDc1YGHUVE5IxU6CFomZbM1y/ryfzVlbyzNmqX14tIjFOhh+jLw7rQKSuNH80u47i2BBCRJkiFHqJmSYncPbY3yz7ey59KNwcdR0TkNCr0BvjcgE58qkMmP35tFYePaUsAEWlaVOgNkJBgTB1fQPnOg8xcsCnoOCIip1ChN9CIXm0Y3jObn/9lDXsPHQ06jojICSr0BjIzpo7rw64DR/nFW+uCjiMicoIK/Tz0z23JNQM68tTbG9iy51DQcUREABX6ebtnbD7Vx51H5mpLABFpGlTo5ymvdTo3De3C70vKWbN1X9BxRERU6BfiG5f3onlKEg/OWRV0FBGR0ArdzLLMbJaZrTSzMjMbVs+YUWa22MyWm9lb4Y/a9LRunsLkUT2YW7aVDzbuDDqOiMS5UGfo04A57l4ADADKTj5oZlnAdOAad+8LfCGsKZuwW4Z3o11mM+5/pQx3bQkgIsE5Z6GbWSYwEngKwN2PuPvuOsO+BLzg7ptqx2wLd9CmKi0lkbuu6M2iTbt5dfmWoOOISBwLZYbeHagEnjazRWb2pJk1rzOmN9DKzOaZ2UIz+0p9b2Rmk8ysxMxKKisrLzB60zFhcC4927bgoTmrOFp9POg4IhKnQin0JGAQ8Li7DwSqgKn1jBkMXA2MBf7DzHrXfSN3n+Huxe5enJOTc2HJm5CkxASmjCtg/fYqnvugPOg4IhKnQin0CqDC3d+rfTyLmoKvO2aOu1e5+3ZgPjXn2uPGFX3aclHXVjwydw1Vh48FHUdE4tA5C93dtwDlZpZf+9RoYEWdYS8BI8wsyczSgYup84vTWGdmTB3fh+37D/PU2xuCjiMicSjUVS7fAGaaWSlQBNxvZpPNbDKAu5cBc4BS4H3gSXdfFonATdngLq0Y17c9v3hrHdv3Hw46jojEGQtqqV1xcbGXlJQE8tmRtK5yP2N+Np+bLu7MfZ/rF3QcEYkxZrbQ3YvrO6YrRcOsR04LJl6Ux8z3NrFxe1XQcUQkjqjQI+COK3qRkpTAw69pSwARaTwq9Ahom5HKV0d05+XST1hSXvcaLBGRyFChR8ikkd3Jbp7Cj2ZrSwARaRwq9Ahp0SyJb47uxYL1O5m3KnauihWRpkuFHkE3DOlMl+x0Hpi9kurjmqWLSGSp0CMoJSmBe8bms2rrPl74sCLoOCIS41ToEXZ1/w4MyG3JT19fzaGj1UHHEZEYpkKPsL9vCfDJnkP8+t2NQccRkRimQm8Ew3pkc1l+DtPfXMvuA0eCjiMiMUqF3kimjC9g3+Fj/Peba4OOIiIxSoXeSAraZ3LdoFz+592PqNh1IOg4IhKDVOiN6FtX9sYMfvra6qCjiEgMUqE3oo5Zafzz8K68uPhjVmzeG3QcEYkxKvRGdtulPclMTeaBOSuDjiIiMUaF3shapifz9ct6Mn91Je+s3R50HBGJISr0AHx5WBc6ZaXxwOyVHNeWACISJiEVupllmdksM1tpZmVmNuwM4y4ys2ozmxDemLElNTmRb4/pzdKP9/DnpZ8EHUdEYkSoM/RpwBx3LwAGUM8NoM0sEXgQeDV88WLXtUWd6NMhkx+/uoojx44HHUdEYsA5C93MMoGRwFMA7n7E3eu7a8M3gOeBbWFNGKMSEoyp4wvYtPMAM9/7KOg4IhIDQpmhdwcqgafNbJGZPWlmzU8eYGadgM8DT0QgY8wa2asNw3tm8/O/rGXfoaNBxxGRKBdKoScBg4DH3X0gUAVMrTPmEWCKu591O0Ezm2RmJWZWUlmpmz6YGVPH9WFn1RF+8db6oOOISJQLpdArgAp3f6/28SxqCv5kxcCzZrYRmABMN7Nr676Ru89w92J3L87JybmA2LGjf25LrhnQkSffXs/WvYeCjiMiUeyche7uW4ByM8uvfWo0sKLOmG7u3tXdu1JT+Le5+/+GO2ysuntMPtXHnUfmaksAETl/oa5y+QYw08xKgSLgfjObbGaTIxctfnTOTufGi7vw3AflrN22L+g4IhKlLKg70hcXF3tJSUkgn90U7dh/mEsfnsewHtn88ivFQccRkSbKzBa6e70loStFm4jsFs2YfGl3Xl+xlZKNO4OOIyJRSIXehNxySTfaZjTj/lfKCOpvTiISvVToTUh6ShJ3XdmbDzft5tXlW4OOIyJRRoXexHxhcC49cprz0KsrOVatLQFEJHQq9CYmKTGBKeMKWF9ZxXMl5UHHEZEookJvgq78VDuKu7TikblrOHDkWNBxRCRKqNCbIDPjO1cVULnvME/+dUPQcUQkSqjQm6jBXVoztm87fvHWOrbvPxx0HBGJAir0JuzecQUcOnacn7+xJugoIhIFVOhNWI+cFnzxojxmvreJjdurgo4jIk2cCr2Ju3N0L5ITE3j4tVVBRxGRJk6F3sS1zUzlayO783LpJ7y5UjeDEpEzU6FHgdtG9aCgfQb3zFpC5T79glRE6qdCjwKpyYlMmziQvYeOce+sJdrnRUTqpUKPEvntM/i3q/rw5qpK/t8C3VRaRE6nQo8iXxnWhcvyc/ivl8tYvVU3whCRU6nQo4iZ8dCEAbRolsQ3n1nEoaNnvSe3iMQZFXqUycloxsNfKGTlln08/KqWMorIP4RU6GaWZWazzGylmZWZ2bA6x280s9Lan3fNbEBk4grA5QXt+MqwLjz19gbmr64MOo6INBGhztCnAXPcvQAYAJTVOb4BuNTdC4EfAjPCF1Hq892r+tCrbQu+/Ycl7Kw6EnQcEWkCzlnoZpYJjASeAnD3I+6+++Qx7v6uu++qfbgAyA13UDnV35cy7jlwlHtnlWopo4iENEPvDlQCT5vZIjN70syan2X8vwCz6ztgZpPMrMTMSiordargQn2qYyb3jstnbtlWfvf+pqDjiEjAQin0JGAQ8Li7DwSqgKn1DTSzy6gp9Cn1HXf3Ge5e7O7FOTk55xlZTnbL8G6M6NWGH/55BWu37Q86jogEKJRCrwAq3P292sezqCn4U5hZIfAk8Dl33xG+iHI2CQnGj78wgLTkRO54dhFHjuk+pCLx6pyF7u5bgHIzy699ajSw4uQxZtYZeAH4sruvDntKOat2mak8eF0hyzfv5SfalVEkbiWFOO4bwEwzSwHWAzeb2WQAd38C+B6QDUw3M4Bj7l4cgbxyBmP6tudLF3fmF/PXM7J3DsN7tgk6kog0MgtqdURxcbGXlJQE8tmx6sCRY3zm529z4HA1s+8YQavmKUFHEpEwM7OFZ5ow60rRGJKeksSjEweyo+ow331xqZYyisQZFXqM6depJXePyWf2si38oaQi6Dgi0ohU6DHoayO68+ke2fzgT8vZoHuRisQNFXoMSkgwfnL9AJITE7jj2UUcrdZSRpF4oEKPUR1apvHAP/WntGIPj8zVSlKReKBCj2Hj+3fg+uJcps9bx3vrda2XSKxToce473+2L11ap3PXc4vZc+Bo0HFEJIJU6DGuebMkpk0cyLZ9h/nu/2opo0gsU6HHgQF5Wdx1ZW9eLv2EFz78OOg4IhIhKvQ4MfnSHgzp2prvvbSMj3ZoKaNILFKhx4nEBONnE4tISDDufG4xx7SUUSTmqNDjSKesNO7/fH8WbdrNo39ZG3QcEQkzFXqc+eyAjvzToE489pc1lGzcGXQcEQkjFXocuu+avnRqlcadzy1m7yEtZRSJFSr0OJSRmswjXxzIJ3sO8f2XlgcdR0TCRIUepwZ3acU3L+/Fi4s+5qXFWsooEgtU6HHs9st6MLhLK/79xWWU7zwQdBwRuUAhFbqZZZnZLDNbaWZlZjasznEzs0fNbK2ZlZrZaTeRlqYnKTGBR75YhAN3aSmjSNQLdYY+DZjj7gXAAKCszvHxQK/an0nA42FLKBGV1zqdH17bl5KPdvH4vHVBxxGRC3DOQjezTGAk8BSAux9x9911hn0O+I3XWABkmVmHsKeViLi2qBPXDOjII2+sYdGmXUHHEZHzFMoMvTtQCTxtZovM7Ekza15nTCeg/KTHFbXPncLMJplZiZmVVFZWnndoCS8z44fX9qN9Zip3PLuY/YePBR1JRM5DKIWeBAwCHnf3gUAVMLXOGKvndadt6+fuM9y92N2Lc3JyGhxWIqdlWjI/+2IRFbsO8IM/aimjSDQKpdArgAp3f6/28SxqCr7umLyTHucCmy88njSmId1ac/tlPZm1sII/l+pfn0i0OWehu/sWoNzM8mufGg2sqDPsj8BXale7DAX2uPsn4Y0qjeGbo3tRlJfFd19YyubdB4OOIyINEOoql28AM82sFCgC7jezyWY2ufb4K8B6YC3wS+C2sCeVRpGcmMC0iUVUH3fuem4x1cd1QwyRaGFB3cGmuLjYS0pKAvlsObc/lJRzz6xS7h2Xz22jegYdR0RqmdlCdy+u75iuFJV6TRicy9X9O/DT11ZTWlF3laqINEUqdKmXmfFfn+9HTkYz7nh2MQeOaCmjSFOnQpczykpP4afXF7FxRxU//HPd34OLSFOjQpezGtYjm8mX9uCZ98uZs2xL0HFE5CxU6HJOd13Rm/6dWjL1hVK27DkUdBwROQMVupxTSlICj0ws4vDR43z7D4s5rqWMIk2SCl1C0iOnBd/77Kd4Z+0Onnp7Q9BxRKQeKnQJ2cSL8hjbtx0PvbqSZR/vCTqOiNShQpeQmRkP/FMhrZuncMezizh4pDroSCJyEhW6NEir5in85AtFrKus4r9e0VJGkaZEhS4NdkmvNnxtRDd+u2ATc1dsDTqOiNRSoct5uXtsPp/qkMm9z5eybZ+WMoo0BSp0OS/NkhJ59IYiqg4f4+4/lGopo0gToEKX89azbQb//plPMX91Jb9+d2PQcUTingpdLshNF3dmdEFbHpi9krJP9gYdRySuqdDlgpgZD04oJDMtmTufXcyho1rKKBIUFbpcsDYtmvHjLxSyaus+Hpi9Mug4InErpEI3s41mttTMFpvZabcZMrOWZvYnM1tiZsvN7ObwR5WmbFR+W24e3pVfv7uRN1duCzqOSFxqyAz9MncvOsOtj24HVrj7AGAU8BMzSwlHQIkeU8YVkN8ug3tmLWH7/sNBxxGJO+E65eJAhpkZ0ALYCegWN3EmNTmRaTcUsffQMe6dVUpQ96sViVehFroDr5nZQjObVM/xx4A+wGZgKXCHux+vO8jMJplZiZmVVFZWnndoaboK2mfynfEF/GXlNn674KOg44jElVALfbi7DwLGA7eb2cg6x8cCi4GOQBHwmJll1n0Td5/h7sXuXpyTk3MhuaUJ++dPd+XS3jn858tlrN66L+g4InEjpEJ39821/9wGvAgMqTPkZuAFr7EW2AAUhDOoRA8z4+EvFNKiWRLffGYRh49pKaNIYzhnoZtZczPL+PufgTHAsjrDNgGja8e0A/KB9eGNKtGkbUYqD00oZOWWfTw8Z1XQcUTiQigz9HbA22a2BHgfeNnd55jZZDObXDvmh8CnzWwp8AYwxd23RyayRIvRfdrx5aFdePLtDfx1jX5nIhJpFtRKhOLiYi8pOW1Ju8SYg0eq+exjb7P34FHm3DmS1s21mlXkQpjZwjMsH9eVohJZaSmJTJtYxO4DR5nyvJYyikSSCl0irm/Hltw7Lp/XV2zlmffLg44jErNU6NIobhnejUt6tuH//nk5a7ftDzqOSExSoUujSEgwfnL9AFKTE7nzuUUcOXbadWcCuvG2XBAVujSadpmpPHhdIcs+3stPXtdSxoNHqnl/w05mzF/HbTMXMuxHb3DLrz8IOpZEsaSgA0h8Gdu3PTcM6cyM+eu5tFcOn+7ZJuhIjeL4cWf99ioWbdrF4vLdLC7fzcot+6iuvXVfXus0iru2Zmj31gEnlWimQpdG9x+f6cN763fwrd8vYc6dI8hKj72ljDurjrC4fBeLN+1mUW2B7ztUs19dRrMkBuRlceulPSjKy6KocxZtWjQLOLHEAhW6NLr0lCSmTRzI56e/w3deWMr0GwdRs1FndDp8rJoVm/eemHkv2rSbTTsPAJBgkN8+k88O6EhRXhYD87LokdOChITo/b7SdKnQJRD9c1vy7TH5PDhnJX9YWMH1xXlBRwqJu1O+8yCLynexaFNNga/YvJcj1TW/5G2fmUpRXhZfurgzA/Oy6J/bkvQU/WcmjUP/S5PATBrZnbdWb+MHf1zORV1b061N86AjnWbPwaOUVuw+Ud6Ly3ezs+oIAGnJifTPbcnNw7ueOHXSoWVawIklnunSfwnU5t0HGT/tr3TNTmfWrZ8mOTG4hVfHqo+zcsu+k06d7GJdZdWJ4z3btmBgbXEX5WWR3y6DpADzSnw626X/mqFLoDpmpXH/5/tz++8+ZNrcNdw9Nr/RPvuTPQf/MfPetJvSj3dz6GjNqZPs5ikU5WXx+YGdKMprRWFeSzJTkxstm8j5UKFL4K4u7MC8Vbn897y1jOjVhou7Z4f9M6oOH2Ppx3tqC7xm6eDWvTX3PU1JTKBvp0xuGNKZorwsBnVuRW6rtKj+Ra3EJxW6NAnfv6Yv72/cybd+v4RX7hhBy7Tznw0fP+6srdx/Ysngok27WL11H7VLvumSnc6w7tm1571b0adDBs2SEsP0TUSCo0KXJqFFs5qljNc9/i7/9uJSfn7DwJBnyJX7Dtee966ZeS8p38P+wzVrvjNTa9Z8j+nbnoF5WQzIy9IWvhKzVOjSZBTlZXHXFb348WurubygLf80KPe0MYeOVrN8895Trris2HUQgKQEo6BDBreymhUAAASmSURBVNcO7EhRXisGds6iW3ZzrfmWuKFClybl1lE9eWt1Jd97aTnFXVpT7X7KFZdln+zlaHXNuZOOLVMZ2LkV/2dYV4o6Z9GvY0vSUnTqROJXSMsWzWwjsA+oBo7Vt2TGzEYBjwDJwHZ3v/Rs76lli3ImFbsOMP6Rv3LgaPWJvU7SUxIpzG3JwM6tTlxx2TYzNeCkIo0vXMsWLzvTfULNLAuYDoxz901m1vY8cooAkNsqncduHMSry7fQv1NLBnbOolfbDBJ16kTkrMJ1yuVLwAvuvgnA3beF6X0lTl3aO4dLe+cEHUMkqoR6mZsDr5nZQjObVM/x3kArM5tXO+Yr4YsoIiKhCHWGPtzdN9eeSnndzFa6+/w67zMYGA2kAX8zswXuvvrkN6n9P4NJAJ07d77w9CIickJIM3R331z7z23Ai8CQOkMqgDnuXlV7nn0+MKCe95nh7sXuXpyTo79Oi4iE0zkL3cyam1nG3/8MjAGW1Rn2EjDCzJLMLB24GCgLd1gRETmzUE65tANerL1qLwn4nbvPMbPJAO7+hLuXmdkcoBQ4Djzp7nVLX0REIkjb54qIRJGzrUPXZs4iIjFChS4iEiMCO+ViZpXAR+f58jZAvVetxjB95/ig7xwfLuQ7d3H3epcJBlboF8LMSs50DilW6TvHB33n+BCp76xTLiIiMUKFLiISI6K10GcEHSAA+s7xQd85PkTkO0flOXQRETldtM7QRUSkDhW6iEiMiLpCN7NxZrbKzNaa2dSg80Samf3KzLaZWdzsjWNmeWb2ppmVmdlyM7sj6EyRZmapZva+mS2p/c73BZ2pMZhZopktMrM/B52lMZjZRjNbamaLzSzse59E1Tl0M0sEVgNXUrNl7wfADe6+ItBgEWRmI4H9wG/cvV/QeRqDmXUAOrj7h7U7fS4Ero3xf88GNHf3/WaWDLwN3OHuCwKOFlFm9i2gGMh0988EnSfSau/PXHym23leqGiboQ8B1rr7enc/AjwLfC7gTBFVeyORnUHnaEzu/om7f1j7533UbMXcKdhUkeU19tc+TK79iZ7Z1nkws1zgauDJoLPEimgr9E5A+UmPK4jx/9DjnZl1BQYC7wWbJPJqTz8sBrYBr7t7rH/nR4B7qdlyO16c63aeFyTaCr2+277H9CwmnplZC+B54E533xt0nkhz92p3LwJygSFmFrOn2MzsM8A2d18YdJZGNtzdBwHjgdtrT6mGTbQVegWQd9LjXGBzQFkkgmrPIz8PzHT3F4LO05jcfTcwDxgXcJRIGg5cU3tO+VngcjP7bbCRIi+E23lekGgr9A+AXmbWzcxSgInAHwPOJGFW+wvCp4Ayd/9p0Hkag5nlmFlW7Z/TgCuAlcGmihx3/46757p7V2r+O/6Lu98UcKyICvF2nhckqgrd3Y8BXwdepeYXZb939+XBpoosM3sG+BuQb2YVZvYvQWdqBMOBL1Mza1tc+3NV0KEirAPwppmVUjNxed3d42IpXxxpB7xtZkuA94GX3X1OOD8gqpYtiojImUXVDF1ERM5MhS4iEiNU6CIiMUKFLiISI1ToIiIxQoUuIhIjVOgiIjHi/wMu6L2AfbaH9AAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "#hide\n", "learn.recorder.plot_loss(skip_start=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.no_logging[source]

\n", "\n", "> Learner.no_logging()\n", "\n", "Context manager to temporarily remove `logger`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.no_logging)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = synth_learner(n_train=5, metrics=tst_metric)\n", "with learn.no_logging():\n", " test_stdout(lambda: learn.fit(1), '')\n", "test_eq(learn.logger, print)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.validate[source]

\n", "\n", "> Learner.validate(**`ds_idx`**=*`1`*, **`dl`**=*`None`*, **`cbs`**=*`None`*)\n", "\n", "Validate on `dl` with potential new `cbs`." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.validate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Test result\n", "learn = synth_learner(n_train=5, metrics=tst_metric)\n", "res = learn.validate()\n", "test_eq(res[0], res[1])\n", "x,y = learn.dbunch.valid_ds.tensors\n", "test_close(res[0], F.mse_loss(learn.model(x), y))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Test other dl\n", "res = learn.validate(dl=learn.dbunch.train_dl)\n", "test_eq(res[0], res[1])\n", "x,y = learn.dbunch.train_ds.tensors\n", "test_close(res[0], F.mse_loss(learn.model(x), y))\n", "\n", "#Test additional callback is executed.\n", "cycle = cycle_events[:2] + ['begin_validate'] + batchv_events * 2 + cycle_events[-3:]\n", "test_stdout(lambda: learn.validate(cbs=VerboseCallback()), '\\n'.join(cycle))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.loss_not_reduced[source]

\n", "\n", "> Learner.loss_not_reduced()\n", "\n", "A context manager to evaluate `loss_func` with reduction set to none." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.loss_not_reduced)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "test_eq(learn.loss_func.reduction, 'mean')\n", "with learn.loss_not_reduced():\n", " test_eq(learn.loss_func.reduction, 'none')\n", " x,y = learn.dbunch.one_batch()\n", " p = learn.model(x)\n", " losses = learn.loss_func(p, y)\n", " test_eq(losses.shape, y.shape)\n", " test_eq(losses, F.mse_loss(p,y, reduction='none'))\n", "test_eq(learn.loss_func.reduction, 'mean')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.get_preds[source]

\n", "\n", "> Learner.get_preds(**`ds_idx`**=*`1`*, **`dl`**=*`None`*, **`with_input`**=*`False`*, **`with_loss`**=*`False`*, **`with_decoded`**=*`False`*, **`act`**=*`None`*)\n", "\n", "Get the predictions and targets on the `ds_idx`-th dbunchset or `dl`, optionally `with_input` and `with_loss`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.get_preds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Depending on the `loss_func` attribute of `Learner`, an activation function will be picked automatically so that the predictions make sense. For instance if the loss is a case of cross-entropy, a softmax will be applied, or if the loss is binary cross entropy with logits, a sigmoid will be applied. If you want to make sure a certain activation function is applied, you can pass it with `act`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Note: If you want to use the option `with_loss=True` on a custom loss function, make sure you have implemented a `reduction` attribute that supports 'none' " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Test result\n", "learn = synth_learner(n_train=5, metrics=tst_metric)\n", "preds,targs = learn.get_preds()\n", "x,y = learn.dbunch.valid_ds.tensors\n", "test_eq(targs, y)\n", "test_close(preds, learn.model(x))\n", "\n", "preds,targs = learn.get_preds(act = torch.sigmoid)\n", "test_eq(targs, y)\n", "test_close(preds, torch.sigmoid(learn.model(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Test get_preds work with ds not evenly dividble by bs\n", "learn = synth_learner(n_train=2.5, metrics=tst_metric)\n", "preds,targs = learn.get_preds(ds_idx=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Test other dataset\n", "x = torch.randn(16*5)\n", "y = 2*x + 3 + 0.1*torch.randn(16*5)\n", "dl = TfmdDL(TensorDataset(x, y), bs=16)\n", "preds,targs = learn.get_preds(dl=dl)\n", "test_eq(targs, y)\n", "test_close(preds, learn.model(x))\n", "\n", "#Test with loss\n", "preds,targs,losses = learn.get_preds(dl=dl, with_loss=True)\n", "test_eq(targs, y)\n", "test_close(preds, learn.model(x))\n", "test_close(losses, F.mse_loss(preds, targs, reduction='none'))\n", "\n", "#Test with inputs\n", "inps,preds,targs = learn.get_preds(dl=dl, with_input=True)\n", "test_eq(*inps,x)\n", "test_eq(targs, y)\n", "test_close(preds, learn.model(x))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Test with no target\n", "learn = synth_learner(n_train=5)\n", "x = torch.randn(16*5)\n", "dl = TfmdDL(TensorDataset(x), bs=16)\n", "preds,targs = learn.get_preds(dl=dl)\n", "assert targs is None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Test with targets that are tuples\n", "def _fake_loss(x,y,z,reduction=None): return F.mse_loss(x,y)\n", "\n", "learn = synth_learner(n_train=5)\n", "x = torch.randn(16*5)\n", "y = 2*x + 3 + 0.1*torch.randn(16*5)\n", "learn.dbunch.n_inp=1\n", "learn.loss_func = _fake_loss\n", "dl = TfmdDL(TensorDataset(x, y, y), bs=16)\n", "preds,targs = learn.get_preds(dl=dl)\n", "test_eq(targs, [y,y])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Test with inputs that are tuples\n", "class _TupleModel(Module):\n", " def __init__(self, model): self.model=model\n", " def forward(self, x1, x2): return self.model(x1)\n", "\n", "learn = synth_learner(n_train=5)\n", "#learn.dbunch.n_inp=2\n", "x = torch.randn(16*5)\n", "y = 2*x + 3 + 0.1*torch.randn(16*5)\n", "learn.model = _TupleModel(learn.model)\n", "learn.dbunch = DataBunch(TfmdDL(TensorDataset(x, x, y), bs=16),TfmdDL(TensorDataset(x, x, y), bs=16))\n", "inps,preds,targs = learn.get_preds(ds_idx=0, with_input=True)\n", "test_eq(inps, [x,x])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#Test auto activation function is picked\n", "learn = synth_learner(n_train=5)\n", "learn.loss_func = BCEWithLogitsLossFlat()\n", "x = torch.randn(16*5)\n", "y = 2*x + 3 + 0.1*torch.randn(16*5)\n", "dl = TfmdDL(TensorDataset(x, y), bs=16)\n", "preds,targs = learn.get_preds(dl=dl)\n", "test_close(preds, torch.sigmoid(learn.model(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Learner.predict[source]

\n", "\n", "> Learner.predict(**`item`**, **`rm_type_tfms`**=*`0`*)\n", "\n", "Return the prediction on `item`, fully decoded, loss function decoded and probabilities" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It returns a tuple of three elements with, in reverse order,\n", "- the prediction from the model, potentially passed through the activation of the loss function (if it has one)\n", "- the decoded prediction, using the poential `decodes` method from it\n", "- the fully decoded prediction, using the transforms used to buil the `DataSource`/`DataBunch`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class _FakeLossFunc(Module):\n", " reduction = 'none'\n", " def forward(self, x, y): return F.mse_loss(x,y)\n", " def activation(self, x): return x+1\n", " def decodes(self, x): return 2*x\n", "\n", "class _Add1(Transform):\n", " def encodes(self, x): return x+1\n", " def decodes(self, x): return x-1\n", " \n", "learn = synth_learner(n_train=5)\n", "dl = TfmdDL(DataSource(torch.arange(50), tfms = [L(), [_Add1()]]))\n", "learn.dbunch = DataBunch(dl, dl)\n", "learn.loss_func = _FakeLossFunc()\n", "\n", "inp = tensor([2.])\n", "out = learn.model(inp).detach()+1 #applying model + activation\n", "dec = 2*out #decodes from loss function\n", "full_dec = dec-1 #decodes from _Add1\n", "test_eq(learn.predict(tensor([2.])), [full_dec, dec, out])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transfer learning" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def freeze_to(self:Learner, n):\n", " if self.opt is None: self.create_opt()\n", " self.opt.freeze_to(n)\n", "\n", "@patch\n", "def freeze(self:Learner): self.freeze_to(-1)\n", "\n", "@patch\n", "def unfreeze(self:Learner): self.freeze_to(0)\n", "\n", "add_docs(Learner,\n", " freeze_to=\"Freeze parameter groups up to `n`\",\n", " freeze=\"Freeze up to last parameter group\",\n", " unfreeze=\"Unfreeze the entire model\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(#4) [0,10.893106460571289,7.8781023025512695,00:00]\n" ] } ], "source": [ "#hide\n", "class _TstModel(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))\n", " self.tst = nn.Sequential(nn.Linear(4,5), nn.BatchNorm1d(3))\n", " self.tst[0].bias.data,self.tst[1].bias.data = torch.randn(5),torch.randn(3) \n", " def forward(self, x): return x * self.a + self.b\n", " \n", "class _PutGrad(Callback):\n", " def after_backward(self):\n", " for p in self.learn.model.tst.parameters():\n", " if p.requires_grad: p.grad = torch.ones_like(p.data)\n", "\n", "def _splitter(m): return [list(m.tst[0].parameters()), list(m.tst[1].parameters()), [m.a,m.b]]\n", " \n", "learn = synth_learner(n_train=5, opt_func = partial(SGD), cb_funcs=_PutGrad, splitter=_splitter, lr=1e-2)\n", "learn.model = _TstModel()\n", "learn.freeze()\n", "init = [p.clone() for p in learn.model.tst.parameters()]\n", "learn.fit(1)\n", "end = list(learn.model.tst.parameters())\n", "#linear was not trained\n", "for i in [0,1]: test_close(end[i],init[i])\n", "#bn was trained even frozen since `train_bn=True` by default\n", "for i in [2,3]: test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(#4) [0,30.35057258605957,27.175193786621094,00:00]\n", "(#4) [0,23.77756690979004,21.27766227722168,00:00]\n", "(#4) [0,18.555871963500977,16.66706085205078,00:00]\n" ] } ], "source": [ "#hide\n", "learn = synth_learner(n_train=5, opt_func = partial(SGD), cb_funcs=_PutGrad, splitter=_splitter, train_bn=False, lr=1e-2)\n", "learn.model = _TstModel()\n", "learn.freeze()\n", "init = [p.clone() for p in learn.model.tst.parameters()]\n", "learn.fit(1)\n", "end = list(learn.model.tst.parameters())\n", "#linear and bn were not trained\n", "for i in range(4): test_close(end[i],init[i])\n", "\n", "learn.freeze_to(-2)\n", "init = [p.clone() for p in learn.model.tst.parameters()]\n", "learn.fit(1)\n", "end = list(learn.model.tst.parameters())\n", "#linear was not trained\n", "for i in [0,1]: test_close(end[i],init[i])\n", "#bn was trained \n", "for i in [2,3]: test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i]))\n", " \n", "learn.unfreeze()\n", "init = [p.clone() for p in learn.model.tst.parameters()]\n", "learn.fit(1)\n", "end = list(learn.model.tst.parameters())\n", "#linear and bn were trained\n", "for i in range(4): test_close(end[i]-init[i], -0.05 * torch.ones_like(end[i]), 1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exporting a `Learner`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@patch\n", "def export(self:Learner, fname='export.pkl'):\n", " \"Export the content of `self` without the items and the optimizer state for inference\"\n", " if rank_distrib(): return # don't export if slave proc\n", " old_dbunch = self.dbunch\n", " self.dbunch = dbunch.new_empty()\n", " state = self.opt.state_dict()\n", " self.opt = None\n", " with warnings.catch_warnings():\n", " #To avoid the warning that come from PyTorch about model not being checked\n", " warnings.simplefilter(\"ignore\")\n", " torch.save(self, open(self.path/fname, 'wb'))\n", " self.create_opt()\n", " self.opt.load_state_dict(state)\n", " self.dbunch = old_dbunch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_test.ipynb.\n", "Converted 01_core.ipynb.\n", "Converted 01a_utils.ipynb.\n", "Converted 01b_dispatch.ipynb.\n", "Converted 01c_transform.ipynb.\n", "Converted 02_script.ipynb.\n", "Converted 03_torch_core.ipynb.\n", "Converted 03a_layers.ipynb.\n", "Converted 04_dataloader.ipynb.\n", "Converted 05_data_core.ipynb.\n", "Converted 06_data_transforms.ipynb.\n", "Converted 07_data_block.ipynb.\n", "Converted 08_vision_core.ipynb.\n", "Converted 09_vision_augment.ipynb.\n", "Converted 09a_vision_data.ipynb.\n", "Converted 10_pets_tutorial.ipynb.\n", "Converted 11_vision_models_xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_learner.ipynb.\n", "Converted 13a_metrics.ipynb.\n", "Converted 14_callback_schedule.ipynb.\n", "Converted 14a_callback_data.ipynb.\n", "Converted 15_callback_hook.ipynb.\n", "Converted 15a_vision_models_unet.ipynb.\n", "Converted 16_callback_progress.ipynb.\n", "Converted 17_callback_tracker.ipynb.\n", "Converted 18_callback_fp16.ipynb.\n", "Converted 19_callback_mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision_learner.ipynb.\n", "Converted 22_tutorial_imagenette.ipynb.\n", "Converted 23_tutorial_transfer_learning.ipynb.\n", "Converted 30_text_core.ipynb.\n", "Converted 31_text_data.ipynb.\n", "Converted 32_text_models_awdlstm.ipynb.\n", "Converted 33_text_models_core.ipynb.\n", "Converted 34_callback_rnn.ipynb.\n", "Converted 35_tutorial_wikitext.ipynb.\n", "Converted 36_text_models_qrnn.ipynb.\n", "Converted 37_text_learner.ipynb.\n", "Converted 38_tutorial_ulmfit.ipynb.\n", "Converted 40_tabular_core.ipynb.\n", "Converted 41_tabular_model.ipynb.\n", "Converted 42_tabular_rapids.ipynb.\n", "Converted 50_data_block_examples.ipynb.\n", "Converted 60_medical_imaging.ipynb.\n", "Converted 65_medical_text.ipynb.\n", "Converted 70_callback_wandb.ipynb.\n", "Converted 90_notebook_core.ipynb.\n", "Converted 91_notebook_export.ipynb.\n", "Converted 92_notebook_showdoc.ipynb.\n", "Converted 93_notebook_export2html.ipynb.\n", "Converted 94_notebook_test.ipynb.\n", "Converted 95_index.ipynb.\n", "Converted 96_data_external.ipynb.\n", "Converted 97_utils_test.ipynb.\n", "Converted notebook2jekyll.ipynb.\n" ] } ], "source": [ "#hide\n", "from local.notebook.export import notebook2script\n", "notebook2script(all_fs=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }