{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Basic training functionality"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [],
"source": [
"from fastai.basic_train import *\n",
"from fastai.gen_doc.nbdoc import *\n",
"from fastai.vision import *\n",
"from fastai.distributed import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[`basic_train`](/basic_train.html#basic_train) wraps together the data (in a [`DataBunch`](/basic_data.html#DataBunch) object) with a PyTorch model to define a [`Learner`](/basic_train.html#Learner) object. Here the basic training loop is defined for the [`fit`](/basic_train.html#fit) method. The [`Learner`](/basic_train.html#Learner) object is the entry point of most of the [`Callback`](/callback.html#Callback) objects that will customize this training loop in different ways. Some of the most commonly used customizations are available through the [`train`](/train.html#train) module, notably:\n",
"\n",
" - [`Learner.lr_find`](/train.html#lr_find) will launch an LR range test that will help you select a good learning rate.\n",
" - [`Learner.fit_one_cycle`](/train.html#fit_one_cycle) will launch a training using the 1cycle policy to help you train your model faster.\n",
" - [`Learner.to_fp16`](/train.html#to_fp16) will convert your model to half precision and help you launch a training in mixed precision."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"
\n",
"\n",
"> Learner(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`opt_func`**:`Callable`=***`'Adam'`***, **`loss_func`**:`Callable`=***`None`***, **`metrics`**:`Collection`\\[`Callable`\\]=***`None`***, **`true_wd`**:`bool`=***`True`***, **`bn_wd`**:`bool`=***`True`***, **`wd`**:`Floats`=***`0.01`***, **`train_bn`**:`bool`=***`True`***, **`path`**:`str`=***`None`***, **`model_dir`**:`str`=***`'models'`***, **`callback_fns`**:`Collection`\\[`Callable`\\]=***`None`***, **`callbacks`**:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=***``***, **`layer_groups`**:`ModuleList`=***`None`***)\n",
"\n",
"Trainer for `model` using `data` to minimize `loss_func` with optimizer `opt_func`. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(Learner, title_level=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The main purpose of [`Learner`](/basic_train.html#Learner) is to train `model` using [`Learner.fit`](/basic_train.html#Learner.fit). After every epoch, all *metrics* will be printed and also made available to callbacks.\n",
"\n",
"The default weight decay will be `wd`, which will be handled using the method from [Fixing Weight Decay Regularization in Adam](https://arxiv.org/abs/1711.05101) if `true_wd` is set (otherwise it's L2 regularization). If `bn_wd` is `False`, then weight decay will be removed from batchnorm layers, as recommended in [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/abs/1706.02677). If `train_bn`, batchnorm layer learnable params are trained even for frozen layer groups.\n",
"\n",
"To use [discriminative layer training](#Discriminative-layer-training), pass a list of [`nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) as `layer_groups`; each [`nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) will be used to customize the optimization of the corresponding layer group.\n",
"\n",
"If `path` is provided, all the model files created will be saved in `path`/`model_dir`; if not, then they will be saved in `data.path`/`model_dir`.\n",
"\n",
"You can pass a list of [`callback`](/callback.html#callback)s that you have already created, or (more commonly) simply pass a list of callback functions to `callback_fns` and each function will be called (passing `self`) on object initialization, with the results stored as callback objects. For a walk-through, see the [training overview](/training.html) page. You may also want to use an [application](applications.html) specific model. For example, if you are dealing with a vision dataset, here the MNIST, you might want to use the [`create_cnn`](/vision.learner.html#create_cnn) method:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": false
},
"outputs": [],
"source": [
"path = untar_data(URLs.MNIST_SAMPLE)\n",
"data = ImageDataBunch.from_folder(path)\n",
"learn = create_cnn(data, models.resnet18, metrics=accuracy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Model fitting methods"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"
\n",
"\n",
"> fit(**`epochs`**:`int`, **`lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`wd`**:`Floats`=***`None`***, **`callbacks`**:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=***`None`***)\n",
"\n",
"Fit the model on this learner with `lr` learning rate, `wd` weight decay for `epochs` with `callbacks`. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(Learner.fit)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Uses [discriminative layer training](#Discriminative-layer-training) if multiple learning rates or weight decay values are passed. To control training behaviour, use the [`callback`](/callback.html#callback) system or one or more of the pre-defined [`callbacks`](/callbacks.html#callbacks)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:04
\n",
"\n",
"> predict(**`item`**:[`ItemBase`](/core.html#ItemBase), **\\*\\*`kwargs`**)\n",
"\n",
"Return predicted class, label and probabilities for `item`. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(Learner.predict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`predict` can be used to get a single prediction from the trained learner on one specific piece of data you are interested in."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Image (3, 28, 28), Category 3)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.data.train_ds[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each element of the dataset is a tuple, where the first element is the data itself, while the second element is the target label. So to get the data, we need to index one more time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = learn.data.train_ds[0][0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAcABwDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD+f+vc/wDgnb+xNcft+ftLWPwEk+NfhT4eaaNNuNU17xf4xv1gtNPsLcKZnBYqHkww2oWQHnLKATXhlfoV+z5/wSO/Ya8EfCzwl+0p/wAFKv8Agq58NfB/hvxHZWup23w9+F923iPxPPayosnkzJAjCxl2kgny51jJAY7soADQ+IP/AAQ5/Zii+FPx0+K/7PP/AAVn8DfE1Pgx4Wm8Qy6b4X8FX0sd1aCZ4oopr/zPskE8jKqokckxYlmA2ruP5y1+oP8AwWN/aB+Hvh/9i/4cfBL/AIJXPpmhfseeJNYv1kbSre+t9Y8R+KbIWzXg15rxVlnZFmtpYQpaIo6DgwJHD+X1AH2L/wAE3/2W/wDgkp+0B8OtZ1D9vn/gopr3wc8WWmtSR6Votn4Gnv7W604QwFbg3McbgSNK8yeUQCFhDc7uPr7w98Bv+DQj9mayfWPiR+1v8WPjtqCWCv8A2Lpmm39nBPJkAiPyLW02NlSQslzgK/JJwa/HyigD6/8A+CnP/BUXRP20NB8Lfs1fs2/s9aH8JPgH8Nbm4f4e+ArCCOe8WWYnzr27u2BkeaUksyByoJ+ZpWHmn5AoooA//9k=\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAYAAAByDd+UAAAABHNCSVQICAgIfAhkiAAAAlhJREFUSIntVj1rKlEQPUZDUoiF0SR2FomWIRGRRLCwDhK2EdE/YJE/INhaWiVNSCWI1TZRi4BVTBMswoJFwEZsJGyz2gjKvec1j+Vp1u/3AoE3MLB775w5d2bP7qwNAPGNtvOdZD+P8ODgABcXF2thHJuSxWIxFAoFhEIh3N3dQdd17O3tIZVKAQBUVUU+n7fEcl2/vb3lYDCgEMJ0KeXU/fPzsyV2owrPzs7gdDoBAJ1OB09PT9B1Haqq4vLyEqVSCZ+fn5bYjQhfX18xGo1QqVTQarUwmUwAAPv7+8hkMgCAdrs9F792S63c5XKx2+1SSsn393ceHR3Ni92OyOFw0OPxsNvtUghBTdPo8/kWYTYn8/l8LBaLpmg0TePx8fEy3GZEmUyGhmGYqnx4eFhW2foqTafTUBQFV1dXODw8nNr7+PhAv99fKc9SMdzf31NKaTpJ9no9NptN1mo1qqpKkszlcqt0aHFAJBLheDymEILD4ZDNZpPX19f0er1TcalUisPhkCcnJ9s/w3g8zkQisTBZNBqlEIKJROLvi8bK397eKIRgNpul3W6fG7f1eAqHw1BVFeFwGABwfn6OnZ3FaVeuwu/3MxqN8vHxkS8vL+aXRUrJwWDAbDa7NIft9wUA4ObmBoFA4MuJFEWBy+WC1+uF2+2GzWYDacJgGAaSySQajcZKXTHZZ0eM1diZXSuXy3S73St3aerFr1arCAaDOD09tTxZp9OBEAK6rqNer6PRaEDTNEgpV6oMAKZaCgC7u7umAGbtz1G0qX0h/Nf2s/7a/hNa2S/tek2pzxDJXQAAAABJRU5ErkJggg==\n",
"text/plain": [
"Image (3, 28, 28)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(Category 3, tensor(0), tensor([9.9979e-01, 2.0649e-04]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pred = learn.predict(data)\n",
"pred"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The first two elements of the tuple are, respectively, the predicted class and label. Label here is essentially an internal representation of each class, since class name is a string and cannot be used in computation. To check what each label corresponds to, run:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['3', '7']"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.data.classes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So category 0 is 3 while category 1 is 7."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"probs = pred[2]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The last element in the tuple is the predicted probabilities. For a categorization dataset, the number of probabilities returned is the same as the number of classes; `probs[i]` is the probability that the `item` belongs to `learn.data.classes[i]`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAcABwDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD+f+vSvC/7Gf7YHjjwBF8V/BX7KXxK1jwtO4SDxLpfgXULjT5GIJAW4jhMZOFY4Dfwn0rz7RdQt9J1m01S70e21GK2uY5ZdPvTIIblVYExSeU6PsYDadjK2CcMDgj9d9T+GP8Awd8+HfHui+J/DV58Wre21sQPoFr4N8VabN4asraSONYUS3tZ3sLS2SN0A3qkaBS2flZgAfkHdWtzY3Mlle28kM0MhSaGVCrIwOCpB5BB4INR19xf8HB/7Q/gv9ob/goEkvhrxF4f8S6x4L+HWg+FfHnj7wysQtPF/iK0tydQ1NTCqo376U24KjaVtV2/Livh2gAr7N/4IZftP/tK/Cz/AIKW/BX4bfC/x1rdx4f8Z/ELTPDXi3wi081zp2o6LqFzFbX6TWmTG6rbs8m4r+7MSyfwV8ZV7r+zr/wUA+KP7Knhu8T4HfDbwJo3jG40S60iz+KMegyHxDptncpLHOtrL532eGV4ppIjciA3ARsCUYGADiv2svBHgX4Z/tT/ABL+HHwvvRc+GfD/AMQNZ03w7crdLOJbGC+mit38xflkzGiHcOGzkda8/pWZnYu7EknJJPWkoA//2Q==\n",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAYAAAByDd+UAAAABHNCSVQICAgIfAhkiAAAAdJJREFUSIntVbGq4kAUvW9ZUZCQIqjBzlohhWAsLEQLwdZUNn6AkEbQD7AJthZWFnYWWgj+gKRIpRZiK1hpIYoY1BSZ86q12SUm8T2L5V04zcyZOffMvZf5ICLQG+PXO8V+BH8EfcVvN6RCoUC1Wo1yuRwlEgmaTqd0Op1ot9vRZDIhwzA8icIJqqrCsizYtg3GGGzbfoAxBsuysNlsIMuy4z1/8EEOg8/zPK3XaxqPx9Tv9x/roVCIKpUKxeNxKpVKJAgCzedzymQyrznUNA35fN4xY0mSHo7dOHz6pM8gSRIYY9B13RX/5S4tFosEgDqdjuszjhlxHAdRFBEMBv/a43ke+/0e2+0WHMe5cvh0LGazGUmSRIvFgo7HI41GI7rf70REpGkaRSIRqtfrdLlcvsZht9sFY+yfOBwOUBTFa92dCYFAALFYDNlsFo1GA+12G6ZpwrZt9Ho9P43mvTNlWcb1esX5fIYgCN8vSERoNptgjKFarb5HkIjAGMNyufR05uU5BOCJ////h0Q+69dqtQDgPU0TjUZxu91gGAbC4fDXCyaTSaiqClEUUS6Xoes6TNNEOp328zruiIqiYDAYYLVaYTgcIpVK+SqF44//HfEJXkMk1eKRk3QAAAAASUVORK5CYII=\n",
"text/plain": [
"Image (3, 28, 28)"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.data.valid_ds[0][0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You could always check yourself if the probabilities given make sense."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"
\n",
"\n",
"> show_results(**`ds_type`**=***``***, **`rows`**:`int`=***`5`***, **\\*\\*`kwargs`**)\n",
"\n",
"Show `rows` result of predictions on `ds_type` dataset. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(Learner.show_results)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that the text number on the top is the ground truth, or the target label, the one in the middle is the prediction, while the image number on the bottom is the image data itself."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"
\n",
"\n",
"> TTA(**`learn`**:[`Learner`](/basic_train.html#Learner), **`beta`**:`float`=***`0.4`***, **`scale`**:`float`=***`1.35`***, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`with_loss`**:`bool`=***`False`***) → `Tensors`\n",
"\n",
"Applies TTA to predict on `ds_type` dataset. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(Learner.TTA, full_name = 'TTA')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Applies Test Time Augmentation to `learn` on the dataset `ds_type`. We take the average of our regular predictions (with a weight `beta`) with the average of predictions obtained through augmented versions of the training set (with a weight `1-beta`). The transforms decided for the training set are applied with a few changes `scale` controls the scale for zoom (which isn't random), the cropping isn't random but we make sure to get the four corners of the image. Flipping isn't random but applied once on each of those corner images (so that makes 8 augmented versions total)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Gradient clipping"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"
\n",
"\n",
"> distributed(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cuda_id`**:`int`, **`cache_dir`**:`PathOrStr`=***`'tmp'`***)\n",
"\n",
"Put `learn` on distributed training with `cuda_id`. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(Learner.distributed, full_name='distributed')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Discriminative layer training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When fitting a model you can pass a list of learning rates (and/or weight decay amounts), which will apply a different rate to each *layer group* (i.e. the parameters of each module in `self.layer_groups`). See the [Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146) paper for details and experimental results in NLP (we also frequently use them successfully in computer vision, but have not published a paper on this topic yet). When working with a [`Learner`](/basic_train.html#Learner) on which you've called `split`, you can set hyperparameters in four ways:\n",
"\n",
"1. `param = [val1, val2 ..., valn]` (n = number of layer groups)\n",
"2. `param = val`\n",
"3. `param = slice(start,end)`\n",
"4. `param = slice(end)`\n",
"\n",
"If we chose to set it in way 1, we must specify a number of values exactly equal to the number of layer groups. If we chose to set it in way 2, the chosen value will be repeated for all layer groups. See [`Learner.lr_range`](/basic_train.html#Learner.lr_range) for an explanation of the `slice` syntax).\n",
"\n",
"Here's an example of how to use discriminative learning rates (note that you don't actually need to manually call [`Learner.split`](/basic_train.html#Learner.split) in this case, since fastai uses this exact function as the default split for `resnet18`; this is just to show how to customize it):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# creates 3 layer groups\n",
"learn.split(lambda m: (m[0][6], m[1]))\n",
"# only randomly initialized head now trainable\n",
"learn.freeze()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:04
\n",
"\n",
"> lr_range(**`lr`**:`Union`\\[`float`, `slice`\\]) → `ndarray`\n",
"\n",
"Build differential learning rates from `lr`. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(Learner.lr_range)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Rather than manually setting an LR for every group, it's often easier to use [`Learner.lr_range`](/basic_train.html#Learner.lr_range). This is a convenience method that returns one learning rate for each layer group. If you pass `slice(start,end)` then the first group's learning rate is `start`, the last is `end`, and the remaining are evenly geometrically spaced.\n",
"\n",
"If you pass just `slice(end)` then the last group's learning rate is `end`, and all the other groups are `end/10`. For instance (for our learner that has 3 layer groups):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([1.e-05, 1.e-04, 1.e-03]), array([0.0001, 0.0001, 0.001 ]))"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.lr_range(slice(1e-5,1e-3)), learn.lr_range(slice(1e-3))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"
\n",
"\n",
"> split(**`split_on`**:`SplitFuncOrIdxList`)\n",
"\n",
"Split the model at `split_on`. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(Learner.split)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A convenience method that sets `layer_groups` based on the result of [`split_model`](/torch_core.html#split_model). If `split_on` is a function, it calls that function and passes the result to [`split_model`](/torch_core.html#split_model) (see above for example)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Saving and loading models"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Simply call [`Learner.save`](/basic_train.html#Learner.save) and [`Learner.load`](/basic_train.html#Learner.load) to save and load models. Only the parameters are saved, not the actual architecture (so you'll need to create your model in the same way before loading weights back in). Models are saved to the `path`/`model_dir` directory."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": false
},
"outputs": [
{
"data": {
"text/markdown": [
"
\n",
"\n",
"> Recorder(**`learn`**:[`Learner`](/basic_train.html#Learner)) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"A [`LearnerCallback`](/basic_train.html#LearnerCallback) that records epoch, loss, opt and metric data during training. "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"show_doc(Recorder, title_level=2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A [`Learner`](/basic_train.html#Learner) creates a [`Recorder`](/basic_train.html#Recorder) object automatically - you do not need to explicitly pass it to `callback_fns` - because other callbacks rely on it being available. It stores the smoothed loss, hyperparameter values, and metrics for each batch, and provides plotting methods for each. Note that [`Learner`](/basic_train.html#Learner) automatically sets an attribute with the snake-cased name of each callback, so you can access this through `Learner.recorder`, as shown below."
]
},
{
"cell_type": "markdown",
"metadata": {
"hide_input": true
},
"source": [
"### Plotting methods"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"hide_input": true
},
"outputs": [
{
"data": {
"text/markdown": [
"