{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp learner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.test import *\n", "from local.data.all import *\n", "from local.optimizer import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.notebook.showdoc import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "_all_ = ['CancelFitException', 'CancelEpochException', 'CancelTrainException', 'CancelValidException', 'CancelBatchException']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Learner\n", "\n", "> Basic class for handling the training loop" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll use the following for testing purposes (a basic linear regression problem):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import TensorDataset\n", "\n", "def synth_dbunch(a=2, b=3, bs=16, n_train=10, n_valid=2, cuda=False):\n", " def get_data(n):\n", " x = torch.randn(int(bs*n))\n", " return TensorDataset(x, a*x + b + 0.1*torch.randn(int(bs*n)))\n", " train_ds = get_data(n_train)\n", " valid_ds = get_data(n_valid)\n", " tfms = Cuda() if cuda else None\n", " train_dl = TfmdDL(train_ds, bs=bs, shuffle=True, after_batch=tfms, num_workers=0)\n", " valid_dl = TfmdDL(valid_ds, bs=bs, after_batch=tfms, num_workers=0)\n", " return DataBunch(train_dl, valid_dl)\n", "\n", "class RegModel(Module):\n", " def __init__(self): self.a,self.b = nn.Parameter(torch.randn(1)),nn.Parameter(torch.randn(1))\n", " def forward(self, x): return x*self.a + self.b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Callback - " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Callback(GetAttr):\n", " \"Basic class handling tweaks of the training loop by changing a `Learner` in various events\"\n", " _default,learn = 'learn',None\n", " def __repr__(self): return type(self).__name__\n", "\n", " def __call__(self, event_name):\n", " \"Call `self.{event_name}` if it's defined\"\n", " getattr(self, event_name, noop)()\n", "\n", " @property\n", " def name(self):\n", " \"Name of the `Callback`, camel-cased and with '*Callback*' removed\"\n", " return class2attr(self, 'Callback')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The training loop is defined in `Learner` a bit below and consists in a minimal set of instructions: looping through the data we:\n", "- compute the output of the model from the input\n", "- calculate a loss between this output and the desired target\n", "- compute the gradients of this loss with respect to all the model parameters\n", "- update the parameters accordingly\n", "- zero all the gradients\n", "\n", "Any tweak of this training loop is defined in a `Callback` to avoid over-complicating the code of the training loop, and to make it easy to mix and match different techniques (since they'll be defined in different callbacks). A callback can implement actions on the following events:\n", "\n", "- `begin_fit`: called before doing anything, ideal for initial setup.\n", "- `begin_epoch`: called at the beginning of each epoch, useful for any behavior you need to reset at each epoch.\n", "- `begin_train`: called at the beginning of the training part of an epoch.\n", "- `begin_batch`: called at the beginning of each batch, just after drawing said batch. It can be used to do any setup necessary for the batch (like hyper-parameter scheduling) or to change the input/target before it goes in the model (change of the input with techniques like mixup for instance).\n", "- `after_pred`: called after computing the output of the model on the batch. It can be used to change that output before it's fed to the loss.\n", "- `after_loss`: called after the loss has been computed, but before the backward pass. It can be used to add any penalty to the loss (AR or TAR in RNN training for instance).\n", "- `after_backward`: called after the backward pass, but before the update of the parameters. It can be used to do any change to the gradients before said update (gradient clipping for instance).\n", "- `after_step`: called after the step and before the gradients are zeroed.\n", "- `after_batch`: called at the end of a batch, for any clean-up before the next one.\n", "- `after_train`: called at the end of the training phase of an epoch.\n", "- `begin_validate`: called at the beginning of the validation phase of an epoch, useful for any setup needed specifically for validation.\n", "- `after_validate`: called at the end of the validation part of an epoch.\n", "- `after_epoch`: called at the end of an epoch, for any clean-up before the next one.\n", "- `after_fit`: called at the end of training, for final clean-up." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "
Callback.__call__
[source]Callback.__call__
(**`event_name`**)\n",
"\n",
"Call `self.{event_name}` if it's defined"
],
"text/plain": [
"GetAttr.__getattr__
[source]GetAttr.__getattr__
(**`k`**)\n",
"\n"
],
"text/plain": [
"Callback.name
[source]class
TrainEvalCallback
[source]TrainEvalCallback
() :: [`Callback`](/learner.html#Callback)\n",
"\n",
"[`Callback`](/learner.html#Callback) that tracks the number of iterations done and properly sets training/eval mode"
],
"text/plain": [
"TrainEvalCallback.begin_fit
[source]TrainEvalCallback.begin_fit
()\n",
"\n",
"Set the iter and epoch counters to 0, put the model and the right device"
],
"text/plain": [
"TrainEvalCallback.after_batch
[source]TrainEvalCallback.after_batch
()\n",
"\n",
"Update the iter counter (in training mode)"
],
"text/plain": [
"TrainEvalCallback.begin_train
[source]TrainEvalCallback.begin_train
()\n",
"\n",
"Set the model in training mode"
],
"text/plain": [
"TrainEvalCallback.begin_validate
[source]TrainEvalCallback.begin_validate
()\n",
"\n",
"Set the model in validation mode"
],
"text/plain": [
"class
GatherPredsCallback
[source]GatherPredsCallback
(**`with_input`**=*`False`*, **`with_loss`**=*`False`*) :: [`Callback`](/learner.html#Callback)\n",
"\n",
"[`Callback`](/learner.html#Callback) that saves the predictions and targets, optionally `with_loss`"
],
"text/plain": [
"GatherPredsCallback.begin_validate
[source]GatherPredsCallback.begin_validate
()\n",
"\n",
"Initialize containers"
],
"text/plain": [
"GatherPredsCallback.after_batch
[source]GatherPredsCallback.after_batch
()\n",
"\n",
"Save predictions, targets and potentially losses"
],
"text/plain": [
"class
CancelBatchException
[source]CancelBatchException
(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n",
"\n",
"Interrupts training and go to `after_fit`"
],
"text/plain": [
"class
CancelTrainException
[source]CancelTrainException
(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n",
"\n",
"Skip the rest of the validation part of the epoch and go to `after_validate`"
],
"text/plain": [
"class
CancelValidException
[source]CancelValidException
(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n",
"\n",
"Skip the rest of this epoch and go to `after_epoch`"
],
"text/plain": [
"class
CancelEpochException
[source]CancelEpochException
(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n",
"\n",
"Skip the rest of the training part of the epoch and go to `after_train`"
],
"text/plain": [
"class
CancelFitException
[source]CancelFitException
(**\\*`args`**, **\\*\\*`kwargs`**) :: `Exception`\n",
"\n",
"Skip the rest of this batch and go to `after_batch`"
],
"text/plain": [
"class
event
[source]event
(**\\*`args`**, **\\*\\*`kwargs`**)\n",
"\n",
"All possible events as attributes to get tab-completion and typo-proofing"
],
"text/plain": [
"Learner.fit
[source]Learner.fit
(**`n_epoch`**, **`lr`**=*`None`*, **`wd`**=*`0.01`*, **`cbs`**=*`None`*, **`reset_opt`**=*`False`*)\n",
"\n",
"Fit `self.model` for `n_epoch` using `cbs`. Optionally `reset_opt`."
],
"text/plain": [
"Learner.one_batch
[source]Learner.one_batch
(**`i`**, **`b`**)\n",
"\n",
"Train or evaluate `self.model` on batch `(xb,yb)`"
],
"text/plain": [
"Learner.all_batches
[source]Learner.all_batches
()\n",
"\n",
"Train or evaluate `self.model` on all batches of `self.dl`"
],
"text/plain": [
"Learner.save
[source]Learner.save
(**`file`**, **`with_opt`**=*`True`*)\n",
"\n",
"Save model and optimizer state (if `with_opt`) to `self.path/self.model_dir/file`"
],
"text/plain": [
"Learner.load
[source]Learner.load
(**`file`**, **`with_opt`**=*`None`*, **`device`**=*`None`*, **`strict`**=*`True`*)\n",
"\n",
"Load model and optimizer state (if `with_opt`) from `self.path/self.model_dir/file` using [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)"
],
"text/plain": [
"Learner.__call__
[source]Learner.__call__
(**`event_name`**)\n",
"\n",
"Call self as a function."
],
"text/plain": [
"Learner.add_cb
[source]Learner.add_cb
(**`cb`**)\n",
"\n",
"Add `cb` to the list of [`Callback`](/learner.html#Callback) and register `self` as their learner"
],
"text/plain": [
"Learner.add_cbs
[source]Learner.add_cbs
(**`cbs`**)\n",
"\n",
"Add `cbs` to the list of [`Callback`](/learner.html#Callback) and register `self` as their learner"
],
"text/plain": [
"Learner.remove_cb
[source]Learner.remove_cb
(**`cb`**)\n",
"\n",
"Add `cb` from the list of [`Callback`](/learner.html#Callback) and deregister `self` as their learner"
],
"text/plain": [
"Learner.remove_cbs
[source]Learner.remove_cbs
(**`cbs`**)\n",
"\n",
"Remove `cbs` from the list of [`Callback`](/learner.html#Callback) and deregister `self` as their learner"
],
"text/plain": [
"class
Metric
[source]\n",
"\n",
"> Metric
()\n",
"\n",
"Blueprint for defining a metric"
],
"text/plain": [
"Metric.reset
[source]Metric.reset
()\n",
"\n",
"Reset inner state to prepare for new computation"
],
"text/plain": [
"Metric.accumulate
[source]Metric.accumulate
(**`learn`**)\n",
"\n",
"Use `learn` to update the state with new results"
],
"text/plain": [
"Metric.value
[source]Metric.name
[source]class
AvgMetric
[source]AvgMetric
(**`func`**) :: [`Metric`](/learner.html#Metric)\n",
"\n",
"Average the values of `func` taking into account potential different batch sizes"
],
"text/plain": [
"class
AvgLoss
[source]AvgLoss
() :: [`Metric`](/learner.html#Metric)\n",
"\n",
"Average the losses taking into account potential different batch sizes"
],
"text/plain": [
"class
AvgSmoothLoss
[source]AvgSmoothLoss
(**`beta`**=*`0.98`*) :: [`Metric`](/learner.html#Metric)\n",
"\n",
"Smooth average of the losses (exponentially weighted with `beta`)"
],
"text/plain": [
"Recorder.begin_fit
[source]Recorder.begin_fit
()\n",
"\n",
"Prepare state for training"
],
"text/plain": [
"Recorder.begin_epoch
[source]Recorder.begin_epoch
()\n",
"\n",
"Set timer if `self.add_time=True`"
],
"text/plain": [
"Recorder.begin_validate
[source]Recorder.begin_validate
()\n",
"\n",
"Reset loss and metrics state"
],
"text/plain": [
"Recorder.after_batch
[source]Recorder.after_batch
()\n",
"\n",
"Update all metrics and records lr and smooth loss in training"
],
"text/plain": [
"Recorder.after_epoch
[source]Recorder.after_epoch
()\n",
"\n",
"Store and log the loss/metric values"
],
"text/plain": [
"Recorder.plot_loss
[source]Recorder.plot_loss
(**`skip_start`**=*`5`*, **`with_valid`**=*`True`*)\n",
"\n",
"Plot the losses from `skip_start` and onward"
],
"text/plain": [
"