{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp callback.core" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.data.all import *\n", "from fastai2.optimizer import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from nbdev.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_all_ = ['CancelFitException', 'CancelEpochException', 'CancelTrainException', 'CancelValidException', 'CancelBatchException']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Callback\n", "\n", "> Basic callbacks for Learner" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Events" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Callbacks can occur at any of these times:: *begin_fit begin_epoch begin_train begin_batch after_pred after_loss after_backward after_step after_cancel_batch after_batch after_cancel_train after_train begin_validate after_cancel_validate after_validate after_cancel_epoch after_epoch after_cancel_fit after_fit*." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "_events = L.split('begin_fit begin_epoch begin_train begin_batch after_pred after_loss \\\n", " after_backward after_step after_cancel_batch after_batch after_cancel_train \\\n", " after_train begin_validate after_cancel_validate after_validate after_cancel_epoch \\\n", " after_epoch after_cancel_fit after_fit')\n", "\n", "mk_class('event', **_events.map_dict(),\n", " doc=\"All possible events as attributes to get tab-completion and typo-proofing\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "_all_ = ['event']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class event[source]

\n", "\n", "> event(**\\*`args`**, **\\*\\*`kwargs`**)\n", "\n", "All possible events as attributes to get tab-completion and typo-proofing" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(event, name='event', title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To ensure that you are refering to an event (that is, the name of one of the times when callbacks are called) that exists, and to get tab completion of event names, use `event`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(event.after_backward, 'after_backward')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Callback - " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_inner_loop = \"begin_batch after_pred after_loss after_backward after_step after_cancel_batch after_batch\".split()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@funcs_kwargs(as_method=True)\n", "class Callback(GetAttr):\n", " \"Basic class handling tweaks of the training loop by changing a `Learner` in various events\"\n", " _default,learn,run,run_train,run_valid = 'learn',None,True,True,True\n", " _methods = _events\n", " \n", " def __init__(self, **kwargs): assert not kwargs, f'Passed unknown events: {kwargs}'\n", " def __repr__(self): return type(self).__name__\n", "\n", " def __call__(self, event_name):\n", " \"Call `self.{event_name}` if it's defined\"\n", " _run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or\n", " (self.run_valid and not getattr(self, 'training', False)))\n", " res = None\n", " if self.run and _run: res = getattr(self, event_name, noop)()\n", " if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit\n", " return res\n", "\n", " def __setattr__(self, name, value):\n", " if hasattr(self.learn,name):\n", " warn(f\"You are setting an attribute ({name}) that also exists in the learner, so you're not setting it in the learner but in the callback. Use `self.learn.{name}` otherwise.\")\n", " super().__setattr__(name, value)\n", "\n", " @property\n", " def name(self):\n", " \"Name of the `Callback`, camel-cased and with '*Callback*' removed\"\n", " return class2attr(self, 'Callback')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The training loop is defined in `Learner` a bit below and consists in a minimal set of instructions: looping through the data we:\n", "- compute the output of the model from the input\n", "- calculate a loss between this output and the desired target\n", "- compute the gradients of this loss with respect to all the model parameters\n", "- update the parameters accordingly\n", "- zero all the gradients\n", "\n", "Any tweak of this training loop is defined in a `Callback` to avoid over-complicating the code of the training loop, and to make it easy to mix and match different techniques (since they'll be defined in different callbacks). A callback can implement actions on the following events:\n", "\n", "- `begin_fit`: called before doing anything, ideal for initial setup.\n", "- `begin_epoch`: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.\n", "- `begin_train`: called at the beginning of the training part of an epoch.\n", "- `begin_batch`: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).\n", "- `after_pred`: called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss.\n", "- `after_loss`: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).\n", "- `after_backward`: called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).\n", "- `after_step`: called after the step and before the gradients are zeroed.\n", "- `after_batch`: called at the end of a batch, for any clean-up before the next one.\n", "- `after_train`: called at the end of the training phase of an epoch.\n", "- `begin_validate`: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.\n", "- `after_validate`: called at the end of the validation part of an epoch.\n", "- `after_epoch`: called at the end of an epoch, for any clean-up before the next one.\n", "- `after_fit`: called at the end of training, for final clean-up." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Callback.__call__[source]

\n", "\n", "> Callback.__call__(**`event_name`**)\n", "\n", "Call `self.{event_name}` if it's defined" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Callback.__call__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One way to define callbacks is through subclassing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class _T(Callback):\n", " def call_me(self): return \"maybe\"\n", "test_eq(_T()(\"call_me\"), \"maybe\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another way is by passing the callback function to the constructor:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def cb(self): return \"maybe\"\n", "_t = Callback(begin_fit=cb)\n", "test_eq(_t(event.begin_fit), \"maybe\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

GetAttr.__getattr__[source]

\n", "\n", "> GetAttr.__getattr__(**`k`**)\n", "\n" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Callback.__getattr__)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a shortcut to avoid having to write `self.learn.bla` for any `bla` attribute we seek, and just write `self.bla`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mk_class('TstLearner', 'a')\n", "\n", "class TstCallback(Callback):\n", " def batch_begin(self): print(self.a)\n", "\n", "learn,cb = TstLearner(1),TstCallback()\n", "cb.learn = learn\n", "test_stdout(lambda: cb('batch_begin'), \"1\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that it only works to get the value of the attribute, if you want to change it, you have to manually access it with `self.learn.bla`. In the example below, `self.a += 1` creates an `a` attribute of 2 in the callback instead of setting the `a` of the learner to 2. It also issues a warning that something is probably wrong:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.a" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jhoward/anaconda3/lib/python3.7/site-packages/ipykernel_launcher.py:22: UserWarning: You are setting an attribute (a) that also exists in the learner, so you're not setting it in the learner but in the callback. Use `self.learn.a` otherwise.\n" ] } ], "source": [ "class TstCallback(Callback):\n", " def batch_begin(self): self.a += 1\n", "\n", "learn,cb = TstLearner(1),TstCallback()\n", "cb.learn = learn\n", "cb('batch_begin')\n", "test_eq(cb.a, 2)\n", "test_eq(cb.learn.a, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A proper version needs to write `self.learn.a = self.a + 1`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TstCallback(Callback):\n", " def batch_begin(self): self.learn.a = self.a + 1\n", "\n", "learn,cb = TstLearner(1),TstCallback()\n", "cb.learn = learn\n", "cb('batch_begin')\n", "test_eq(cb.learn.a, 2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

Callback.name[source]

\n", "\n", "Name of the [`Callback`](callback.core#Callback), camel-cased and with '*Callback*' removed" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Callback.name, name='Callback.name')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test_eq(TstCallback().name, 'tst')\n", "class ComplicatedNameCallback(Callback): pass\n", "test_eq(ComplicatedNameCallback().name, 'complicated_name')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### TrainEvalCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class TrainEvalCallback(Callback):\n", " \"`Callback` that tracks the number of iterations done and properly sets training/eval mode\"\n", " run_valid = False\n", " def begin_fit(self):\n", " \"Set the iter and epoch counters to 0, put the model and the right device\"\n", " self.learn.train_iter,self.learn.pct_train = 0,0.\n", " if hasattr(self.dls, 'device'): self.model.to(self.dls.device)\n", " if hasattr(self.model, 'reset'): self.model.reset()\n", "\n", " def after_batch(self):\n", " \"Update the iter counter (in training mode)\"\n", " self.learn.pct_train += 1./(self.n_iter*self.n_epoch)\n", " self.learn.train_iter += 1\n", "\n", " def begin_train(self):\n", " \"Set the model in training mode\"\n", " self.learn.pct_train=self.epoch/self.n_epoch\n", " self.model.train()\n", " self.learn.training=True\n", "\n", " def begin_validate(self):\n", " \"Set the model in validation mode\"\n", " self.model.eval()\n", " self.learn.training=False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class TrainEvalCallback[source]

\n", "\n", "> TrainEvalCallback(**`begin_fit`**=*`None`*, **`begin_epoch`**=*`None`*, **`begin_train`**=*`None`*, **`begin_batch`**=*`None`*, **`after_pred`**=*`None`*, **`after_loss`**=*`None`*, **`after_backward`**=*`None`*, **`after_step`**=*`None`*, **`after_cancel_batch`**=*`None`*, **`after_batch`**=*`None`*, **`after_cancel_train`**=*`None`*, **`after_train`**=*`None`*, **`begin_validate`**=*`None`*, **`after_cancel_validate`**=*`None`*, **`after_validate`**=*`None`*, **`after_cancel_epoch`**=*`None`*, **`after_epoch`**=*`None`*, **`after_cancel_fit`**=*`None`*, **`after_fit`**=*`None`*) :: [`Callback`](callback.core#Callback)\n", "\n", "[`Callback`](callback.core#Callback) that tracks the number of iterations done and properly sets training/eval mode" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainEvalCallback, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This `Callback` is automatically added in every `Learner` at initialization." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#test of the TrainEvalCallback below in Learner.fit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [ { "data": { "text/markdown": [ "

TrainEvalCallback.begin_fit[source]

\n", "\n", "> TrainEvalCallback.begin_fit()\n", "\n", "Set the iter and epoch counters to 0, put the model and the right device" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainEvalCallback.begin_fit)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [ { "data": { "text/markdown": [ "

TrainEvalCallback.after_batch[source]

\n", "\n", "> TrainEvalCallback.after_batch()\n", "\n", "Update the iter counter (in training mode)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainEvalCallback.after_batch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [ { "data": { "text/markdown": [ "

TrainEvalCallback.begin_train[source]

\n", "\n", "> TrainEvalCallback.begin_train()\n", "\n", "Set the model in training mode" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainEvalCallback.begin_train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [ { "data": { "text/markdown": [ "

TrainEvalCallback.begin_validate[source]

\n", "\n", "> TrainEvalCallback.begin_validate()\n", "\n", "Set the model in validation mode" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(TrainEvalCallback.begin_validate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "if not hasattr(defaults, 'callbacks'): defaults.callbacks = [TrainEvalCallback]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GatherPredsCallback -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "#TODO: save_targs and save_preds only handle preds/targets that have one tensor, not tuples of tensors.\n", "class GatherPredsCallback(Callback):\n", " \"`Callback` that saves the predictions and targets, optionally `with_loss`\"\n", " def __init__(self, with_input=False, with_loss=False, save_preds=None, save_targs=None, concat_dim=0):\n", " store_attr(self, \"with_input,with_loss,save_preds,save_targs,concat_dim\")\n", "\n", " def begin_batch(self):\n", " if self.with_input: self.inputs.append((to_detach(self.xb)))\n", "\n", " def begin_validate(self):\n", " \"Initialize containers\"\n", " self.preds,self.targets = [],[]\n", " if self.with_input: self.inputs = []\n", " if self.with_loss: self.losses = []\n", "\n", " def after_batch(self):\n", " \"Save predictions, targets and potentially losses\"\n", " preds,targs = to_detach(self.pred),to_detach(self.yb)\n", " if self.save_preds is None: self.preds.append(preds)\n", " else: (self.save_preds/str(self.iter)).save_array(preds)\n", " if self.save_targs is None: self.targets.append(targs)\n", " else: (self.save_targs/str(self.iter)).save_array(targs[0])\n", " if self.with_loss:\n", " bs = find_bs(self.yb)\n", " loss = self.loss if self.loss.numel() == bs else self.loss.view(bs,-1).mean(1)\n", " self.losses.append(to_detach(loss))\n", "\n", " def after_validate(self):\n", " \"Concatenate all recorded tensors\"\n", " if self.with_input: self.inputs = detuplify(to_concat(self.inputs, dim=self.concat_dim))\n", " if not self.save_preds: self.preds = detuplify(to_concat(self.preds, dim=self.concat_dim))\n", " if not self.save_targs: self.targets = detuplify(to_concat(self.targets, dim=self.concat_dim))\n", " if self.with_loss: self.losses = to_concat(self.losses)\n", "\n", " def all_tensors(self):\n", " res = [None if self.save_preds else self.preds, None if self.save_targs else self.targets]\n", " if self.with_input: res = [self.inputs] + res\n", " if self.with_loss: res.append(self.losses)\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class GatherPredsCallback[source]

\n", "\n", "> GatherPredsCallback(**`with_input`**=*`False`*, **`with_loss`**=*`False`*, **`save_preds`**=*`None`*, **`save_targs`**=*`None`*, **`concat_dim`**=*`0`*) :: [`Callback`](callback.core#Callback)\n", "\n", "[`Callback`](callback.core#Callback) that saves the predictions and targets, optionally `with_loss`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GatherPredsCallback, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

GatherPredsCallback.begin_validate[source]

\n", "\n", "> GatherPredsCallback.begin_validate()\n", "\n", "Initialize containers" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GatherPredsCallback.begin_validate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

GatherPredsCallback.after_batch[source]

\n", "\n", "> GatherPredsCallback.after_batch()\n", "\n", "Save predictions, targets and potentially losses" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GatherPredsCallback.after_batch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

GatherPredsCallback.after_validate[source]

\n", "\n", "> GatherPredsCallback.after_validate()\n", "\n", "Concatenate all recorded tensors" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(GatherPredsCallback.after_validate)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class FetchPredsCallback(Callback):\n", " \"A callback to fetch predictions during the training loop\"\n", " remove_on_fetch = True\n", " def __init__(self, ds_idx=1, dl=None, with_input=False, with_decoded=False, cbs=None):\n", " self.cbs = L(cbs)\n", " store_attr(self, 'ds_idx,dl,with_input,with_decoded')\n", "\n", " def after_validate(self):\n", " to_rm = L(cb for cb in self.learn.cbs if getattr(cb, 'remove_on_fetch', False))\n", " with self.learn.removed_cbs(to_rm + self.cbs) as learn:\n", " self.preds = learn.get_preds(ds_idx=self.ds_idx, dl=self.dl,\n", " with_input=self.with_input, with_decoded=self.with_decoded, inner=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class FetchPredsCallback[source]

\n", "\n", "> FetchPredsCallback(**`ds_idx`**=*`1`*, **`dl`**=*`None`*, **`with_input`**=*`False`*, **`with_decoded`**=*`False`*, **`cbs`**=*`None`*) :: [`Callback`](callback.core#Callback)\n", "\n", "A callback to fetch predictions during the training loop" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(FetchPredsCallback, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When writing a callback, the following attributes of `Learner` are available:\n", "- `model`: the model used for training/validation\n", "- `data`: the underlying `DataLoaders`\n", "- `loss_func`: the loss function used\n", "- `opt`: the optimizer used to udpate the model parameters\n", "- `opt_func`: the function used to create the optimizer\n", "- `cbs`: the list containing all `Callback`s\n", "- `dl`: current `DataLoader` used for iteration\n", "- `x`/`xb`: last input drawn from `self.dl` (potentially modified by callbacks). `xb` is always a tuple (potentially with one element) and `x` is detuplified. You can only assign to `xb`.\n", "- `y`/`yb`: last target drawn from `self.dl` (potentially modified by callbacks). `yb` is always a tuple (potentially with one element) and `y` is detuplified. You can only assign to `yb`.\n", "- `pred`: last predictions from `self.model` (potentially modified by callbacks)\n", "- `loss`: last computed loss (potentially modified by callbacks)\n", "- `n_epoch`: the number of epochs in this training\n", "- `n_iter`: the number of iterations in the current `self.dl`\n", "- `epoch`: the current epoch index (from 0 to `n_epoch-1`)\n", "- `iter`: the current iteration index in `self.dl` (from 0 to `n_iter-1`)\n", "\n", "The following attributes are added by `TrainEvalCallback` and should be available unless you went out of your way to remove that callback:\n", "\n", "- `train_iter`: the number of training iterations done since the beginning of this training\n", "- `pct_train`: from 0. to 1., the percentage of training iterations completed\n", "- `training`: flag to indicate if we're in training mode or not\n", "\n", "The following attribute is added by `Recorder` and should be available unless you went out of your way to remove that callback:\n", "\n", "- `smooth_loss`: an exponentially-averaged version of the training loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Callbacks control flow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It happens that we may want to skip some of the steps of the training loop: in gradient accumulation, we don't aways want to do the step/zeroing of the grads for instance. During an LR finder test, we don't want to do the validation phase of an epoch. Or if we're training with a strategy of early stopping, we want to be able to completely interrupt the training loop.\n", "\n", "This is made possible by raising specific exceptions the training loop will look for (and properly catch)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_ex_docs = dict(\n", " CancelFitException=\"Skip the rest of this batch and go to `after_batch`\",\n", " CancelEpochException=\"Skip the rest of the training part of the epoch and go to `after_train`\",\n", " CancelTrainException=\"Skip the rest of the validation part of the epoch and go to `after_validate`\",\n", " CancelValidException=\"Skip the rest of this epoch and go to `after_epoch`\",\n", " CancelBatchException=\"Interrupts training and go to `after_fit`\")\n", "\n", "for c,d in _ex_docs.items(): mk_class(c,sup=Exception,doc=d)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class CancelBatchException[source]

\n", "\n", "> CancelBatchException(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n", "\n", "Interrupts training and go to `after_fit`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(CancelBatchException, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class CancelTrainException[source]

\n", "\n", "> CancelTrainException(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n", "\n", "Skip the rest of the validation part of the epoch and go to `after_validate`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(CancelTrainException, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class CancelValidException[source]

\n", "\n", "> CancelValidException(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n", "\n", "Skip the rest of this epoch and go to `after_epoch`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(CancelValidException, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class CancelEpochException[source]

\n", "\n", "> CancelEpochException(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n", "\n", "Skip the rest of the training part of the epoch and go to `after_train`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(CancelEpochException, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "

class CancelFitException[source]

\n", "\n", "> CancelFitException(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n", "\n", "Skip the rest of this batch and go to `after_batch`" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(CancelFitException, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can detect one of those exceptions occurred and add code that executes right after with the following events:\n", "- `after_cancel_batch`: reached imediately after a `CancelBatchException` before proceeding to `after_batch`\n", "- `after_cancel_train`: reached imediately after a `CancelTrainException` before proceeding to `after_epoch`\n", "- `after_cancel_valid`: reached imediately after a `CancelValidException` before proceeding to `after_epoch`\n", "- `after_cancel_epoch`: reached imediately after a `CancelEpochException` before proceeding to `after_epoch`\n", "- `after_cancel_fit`: reached imediately after a `CancelFitException` before proceeding to `after_fit`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export -" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Converted 00_torch_core.ipynb.\n", "Converted 01_layers.ipynb.\n", "Converted 02_data.load.ipynb.\n", "Converted 03_data.core.ipynb.\n", "Converted 04_data.external.ipynb.\n", "Converted 05_data.transforms.ipynb.\n", "Converted 06_data.block.ipynb.\n", "Converted 07_vision.core.ipynb.\n", "Converted 08_vision.data.ipynb.\n", "Converted 09_vision.augment.ipynb.\n", "Converted 09b_vision.utils.ipynb.\n", "Converted 09c_vision.widgets.ipynb.\n", "Converted 10_tutorial.pets.ipynb.\n", "Converted 11_vision.models.xresnet.ipynb.\n", "Converted 12_optimizer.ipynb.\n", "Converted 13_callback.core.ipynb.\n", "Converted 13a_learner.ipynb.\n", "Converted 13b_metrics.ipynb.\n", "Converted 14_callback.schedule.ipynb.\n", "Converted 14a_callback.data.ipynb.\n", "Converted 15_callback.hook.ipynb.\n", "Converted 15a_vision.models.unet.ipynb.\n", "Converted 16_callback.progress.ipynb.\n", "Converted 17_callback.tracker.ipynb.\n", "Converted 18_callback.fp16.ipynb.\n", "Converted 18a_callback.training.ipynb.\n", "Converted 19_callback.mixup.ipynb.\n", "Converted 20_interpret.ipynb.\n", "Converted 20a_distributed.ipynb.\n", "Converted 21_vision.learner.ipynb.\n", "Converted 22_tutorial.imagenette.ipynb.\n", "Converted 23_tutorial.vision.ipynb.\n", "Converted 24_tutorial.siamese.ipynb.\n", "Converted 24_vision.gan.ipynb.\n", "Converted 30_text.core.ipynb.\n", "Converted 31_text.data.ipynb.\n", "Converted 32_text.models.awdlstm.ipynb.\n", "Converted 33_text.models.core.ipynb.\n", "Converted 34_callback.rnn.ipynb.\n", "Converted 35_tutorial.wikitext.ipynb.\n", "Converted 36_text.models.qrnn.ipynb.\n", "Converted 37_text.learner.ipynb.\n", "Converted 38_tutorial.text.ipynb.\n", "Converted 39_tutorial.transformers.ipynb.\n", "Converted 40_tabular.core.ipynb.\n", "Converted 41_tabular.data.ipynb.\n", "Converted 42_tabular.model.ipynb.\n", "Converted 43_tabular.learner.ipynb.\n", "Converted 44_tutorial.tabular.ipynb.\n", "Converted 45_collab.ipynb.\n", "Converted 46_tutorial.collab.ipynb.\n", "Converted 50_tutorial.datablock.ipynb.\n", "Converted 60_medical.imaging.ipynb.\n", "Converted 61_tutorial.medical_imaging.ipynb.\n", "Converted 65_medical.text.ipynb.\n", "Converted 70_callback.wandb.ipynb.\n", "Converted 71_callback.tensorboard.ipynb.\n", "Converted 72_callback.neptune.ipynb.\n", "Converted 73_callback.captum.ipynb.\n", "Converted 74_callback.cutmix.ipynb.\n", "Converted 97_test_utils.ipynb.\n", "Converted 99_pytorch_doc.ipynb.\n", "Converted index.ipynb.\n", "Converted tutorial.ipynb.\n" ] } ], "source": [ "#hide\n", "from nbdev.export import notebook2script\n", "notebook2script()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "split_at_heading": true }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }