{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic training functionality" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.basic_train import *\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *\n", "from fastai.distributed import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`basic_train`](/basic_train.html#basic_train) wraps together the data (in a [`DataBunch`](/basic_data.html#DataBunch) object) with a PyTorch model to define a [`Learner`](/basic_train.html#Learner) object. Here the basic training loop is defined for the [`fit`](/basic_train.html#fit) method. The [`Learner`](/basic_train.html#Learner) object is the entry point of most of the [`Callback`](/callback.html#Callback) objects that will customize this training loop in different ways. Some of the most commonly used customizations are available through the [`train`](/train.html#train) module, notably:\n", "\n", " - [`Learner.lr_find`](/train.html#lr_find) will launch an LR range test that will help you select a good learning rate.\n", " - [`Learner.fit_one_cycle`](/train.html#fit_one_cycle) will launch a training using the 1cycle policy to help you train your model faster.\n", " - [`Learner.to_fp16`](/train.html#to_fp16) will convert your model to half precision and help you launch a training in mixed precision." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class Learner[source][test]

\n", "\n", "> Learner(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`opt_func`**:`Callable`=***`'Adam'`***, **`loss_func`**:`Callable`=***`None`***, **`metrics`**:`Collection`\\[`Callable`\\]=***`None`***, **`true_wd`**:`bool`=***`True`***, **`bn_wd`**:`bool`=***`True`***, **`wd`**:`Floats`=***`0.01`***, **`train_bn`**:`bool`=***`True`***, **`path`**:`str`=***`None`***, **`model_dir`**:`PathOrStr`=***`'models'`***, **`callback_fns`**:`Collection`\\[`Callable`\\]=***`None`***, **`callbacks`**:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=***``***, **`layer_groups`**:`ModuleList`=***`None`***, **`add_time`**:`bool`=***`True`***, **`silent`**:`bool`=***`None`***, **`cb_fns_registered`**:`bool`=***`False`***)\n", "\n", "
×

Tests found for Learner:

Some other tests where Learner is used:

  • pytest -sv tests/test_basic_train.py::test_memory [source]

To run tests please refer to this guide.

\n", "\n", "Trainer for `model` using `data` to minimize `loss_func` with optimizer `opt_func`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner, title_level=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The main purpose of [`Learner`](/basic_train.html#Learner) is to train `model` using [`Learner.fit`](/basic_train.html#Learner.fit). After every epoch, all *metrics* will be printed and also made available to callbacks.\n", "\n", "The default weight decay will be `wd`, which will be handled using the method from [Fixing Weight Decay Regularization in Adam](https://arxiv.org/abs/1711.05101) if `true_wd` is set (otherwise it's L2 regularization). If `true_wd` is set it will affect all optimizers, not only Adam. If `bn_wd` is `False`, then weight decay will be removed from batchnorm layers, as recommended in [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/abs/1706.02677). If `train_bn`, batchnorm layer learnable params are trained even for frozen layer groups.\n", "\n", "To use [discriminative layer training](#Discriminative-layer-training), pass a list of [`nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) as `layer_groups`; each [`nn.Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) will be used to customize the optimization of the corresponding layer group.\n", "\n", "If `path` is provided, all the model files created will be saved in `path`/`model_dir`; if not, then they will be saved in `data.path`/`model_dir`.\n", "\n", "You can pass a list of [`callback`](/callback.html#callback)s that you have already created, or (more commonly) simply pass a list of callback functions to `callback_fns` and each function will be called (passing `self`) on object initialization, with the results stored as callback objects. For a walk-through, see the [training overview](/training.html) page. You may also want to use an [application](applications.html) specific model. For example, if you are dealing with a vision dataset, here the MNIST, you might want to use the [`cnn_learner`](/vision.learner.html#cnn_learner) method:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": false }, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "learn = cnn_learner(data, models.resnet18, metrics=accuracy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model fitting methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

lr_find[source][test]

\n", "\n", "> lr_find(**`learn`**:[`Learner`](/basic_train.html#Learner), **`start_lr`**:`Floats`=***`1e-07`***, **`end_lr`**:`Floats`=***`10`***, **`num_it`**:`int`=***`100`***, **`stop_div`**:`bool`=***`True`***, **`wd`**:`float`=***`None`***)\n", "\n", "
×

Tests found for lr_find:

  • pytest -sv tests/test_train.py::test_lr_find [source]
  • pytest -sv tests/test_vision_train.py::test_lrfind [source]

To run tests please refer to this guide.

\n", "\n", "Explore lr from `start_lr` to `end_lr` over `num_it` iterations in `learn`. If `stop_div`, stops when loss diverges. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.lr_find)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Runs the learning rate finder defined in [`LRFinder`](/callbacks.lr_finder.html#LRFinder), as discussed in [Cyclical Learning Rates for Training Neural Networks](https://arxiv.org/abs/1506.01186). " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] } ], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

fit[source][test]

\n", "\n", "> fit(**`epochs`**:`int`, **`lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`wd`**:`Floats`=***`None`***, **`callbacks`**:`Collection`\\[[`Callback`](/callback.html#Callback)\\]=***`None`***)\n", "\n", "
×

Tests found for fit:

  • pytest -sv tests/test_train.py::test_fit [source]

Some other tests where fit is used:

  • pytest -sv tests/test_basic_train.py::test_destroy [source]

To run tests please refer to this guide.

\n", "\n", "Fit the model on this learner with `lr` learning rate, `wd` weight decay for `epochs` with `callbacks`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.fit)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Uses [discriminative layer training](#Discriminative-layer-training) if multiple learning rates or weight decay values are passed. To control training behaviour, use the [`callback`](/callback.html#callback) system or one or more of the pre-defined [`callbacks`](/callbacks.html#callbacks)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
10.1353430.0831900.97203100:05
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

fit_one_cycle[source][test]

\n", "\n", "> fit_one_cycle(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cyc_len`**:`int`, **`max_lr`**:`Union`\\[`float`, `Collection`\\[`float`\\], `slice`\\]=***`slice(None, 0.003, None)`***, **`moms`**:`Point`=***`(0.95, 0.85)`***, **`div_factor`**:`float`=***`25.0`***, **`pct_start`**:`float`=***`0.3`***, **`final_div`**:`float`=***`None`***, **`wd`**:`float`=***`None`***, **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`tot_epochs`**:`int`=***`None`***, **`start_epoch`**:`int`=***`None`***)\n", "\n", "
×

Tests found for fit_one_cycle:

  • pytest -sv tests/test_train.py::test_fit_one_cycle [source]

Some other tests where fit_one_cycle is used:

  • pytest -sv tests/test_tabular_train.py::test_empty_cont [source]
  • pytest -sv tests/test_text_train.py::test_qrnn_works_if_split_fn_provided [source]
  • pytest -sv tests/test_text_train.py::test_qrnn_works_with_no_split [source]

To run tests please refer to this guide.

\n", "\n", "Fit a model following the 1cycle policy. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.fit_one_cycle)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use cycle length `cyc_len`, a per cycle maximal learning rate `max_lr`, momentum `moms`, division factor `div_factor`, weight decay `wd`, and optional callbacks [`callbacks`](/callbacks.html#callbacks). Uses the [`OneCycleScheduler`](/callbacks.one_cycle.html#OneCycleScheduler) callback. Please refer to [What is 1-cycle](/callbacks.one_cycle.html#What-is-1cycle?) for a conceptual background of 1-cycle training policy and more technical details on what do the method's arguments do." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
10.0758380.0618690.97988200:05
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### See results" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

predict[source][test]

\n", "\n", "> predict(**`item`**:[`ItemBase`](/core.html#ItemBase), **`return_x`**:`bool`=***`False`***, **`batch_first`**:`bool`=***`True`***, **`with_dropout`**:`bool`=***`False`***, **\\*\\*`kwargs`**)\n", "\n", "
×

Tests found for predict:

  • pytest -sv tests/test_vision_train.py::test_models_meta [source]
  • pytest -sv tests/test_vision_train.py::test_preds [source]

To run tests please refer to this guide.

\n", "\n", "Return predicted class, label and probabilities for `item`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.predict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`predict` can be used to get a single prediction from the trained learner on one specific piece of data you are interested in." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Image (3, 28, 28), )" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.data.train_ds[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each element of the dataset is a tuple, where the first element is the data itself, while the second element is the target label. So to get the data, we need to index one more time." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = learn.data.train_ds[0][0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAcABwDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD+f+vtr4ff8G5v/BaH4o/D/Q/ih4L/AGH9Tn0bxHDHNpM114t0W1mkikyUkeCe9SaFGA3BpEUbSrZwwJ+Ja/fD/g13+Dtzq/w6l/b6/ae/4Kna9o+gaP4jTSNL+GUnxZezs3a32LGNVWeYfuzJND5VsuFdSu8usnl0AfjV+2R+wp+1l/wT9+J8Pwc/a/8Ag1feDPEFzp6X1nbXF5bXcN1buSBJFcWsssEoBBVtjkqwKtg8V5JX6a/8HYmq/tMt/wAFVb7w5+0L8RtB1nT7bwta3XgDTPDaSxwaPo000yxwTRykkXbNC0krbmDl1ZdqFY0/MqgAr9b/APgkv+xl8Hf+Cb/7Jeuf8Fwv+CjHw1g1KbRV8v4BfC3Xo2t7nV9XWVEj1FoZk5VZGTy5MOI1WSbaWWGvyQr2f9rD/goV+2T+3DpPhHw/+1H8cb7xTp3gPR10zwlpp0+0s7bT7dVVRiK0iiR5CqqpmcNIyqoLEAAAGX+2p+2F8Zf29f2mPFP7Vfx71G1uPEvim8WS4jsbcRW9rDGixQW0S8kRxRIkalizEICzMSSfLKKKAP/Z\n", "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAYAAAByDd+UAAAABHNCSVQICAgIfAhkiAAAAbBJREFUSInt1rGKwkAQBuBfuUJMRK0FQbYxoGBpLQimDAg+goVPIFrY29iYykoQfAEraxFDhBRW2lmJVqaV7Fxl0BPNrp42dwtTJOzOl1lmk4QAED44wp/E/sE/AubzebRaLXDOwTkHEWGxWMCyLP8e5xyNRkMYpUfRbrfJ87zAOB6PVKvVHuYCQF9BT2PbNvr9vn9dKpWgadrNPFVVwRh7vcKfwRijYrFI1Wr1qsL5fE6pVEokhxwIgDRNo9VqdQUWCgXR9eJQLBYj0zRpu9360G63o2w2S5FI5PfBwWBw0yyu61K32xXdTnGw0+nQ6XS626Wz2Uwoj/DBVxQF4fD96fF4XChP4LE4j2aziWQy6Sfu9XrI5XIwTRMAYFmWaCr5LgVAiqKQbdvv7dJzZDKZK8xxHEqn0+8BE4kELZdLH1uv18QYk8khB14eDcdxZDE50DAMOhwO5HkebTYbmW2UB3VdJ9d1/eoMw3iq2YRAVVVpNBr52GQyoWg0+h5Q13Uaj8c+NhwOqVKpPIsFg5cf4P1+T+Vy+RVM/NUGAPV6HdPpVGbJzQidy/zU+Phf2zcGKKJjocPzywAAAABJRU5ErkJggg==\n", "text/plain": [ "Image (3, 28, 28)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(, tensor(0), tensor([0.5748, 0.4252]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pred = learn.predict(data)\n", "pred" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first two elements of the tuple are, respectively, the predicted class and label. Label here is essentially an internal representation of each class, since class name is a string and cannot be used in computation. To check what each label corresponds to, run:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['3', '7']" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.data.classes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So category 0 is 3 while category 1 is 7." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "probs = pred[2]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The last element in the tuple is the predicted probabilities. For a categorization dataset, the number of probabilities returned is the same as the number of classes; `probs[i]` is the probability that the `item` belongs to `learn.data.classes[i]`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAcABwDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD+f+vSvC/7Gf7YHjjwBF8V/BX7KXxK1jwtO4SDxLpfgXULjT5GIJAW4jhMZOFY4Dfwn0rz7RdQt9J1m01S70e21GK2uY5ZdPvTIIblVYExSeU6PsYDadjK2CcMDgj9d9T+GP8Awd8+HfHui+J/DV58Wre21sQPoFr4N8VabN4asraSONYUS3tZ3sLS2SN0A3qkaBS2flZgAfkHdWtzY3Mlle28kM0MhSaGVCrIwOCpB5BB4INR19xf8HB/7Q/gv9ob/goEkvhrxF4f8S6x4L+HWg+FfHnj7wysQtPF/iK0tydQ1NTCqo376U24KjaVtV2/Livh2gAr7N/4IZftP/tK/Cz/AIKW/BX4bfC/x1rdx4f8Z/ELTPDXi3wi081zp2o6LqFzFbX6TWmTG6rbs8m4r+7MSyfwV8ZV7r+zr/wUA+KP7Knhu8T4HfDbwJo3jG40S60iz+KMegyHxDptncpLHOtrL532eGV4ppIjciA3ARsCUYGADiv2svBHgX4Z/tT/ABL+HHwvvRc+GfD/AMQNZ03w7crdLOJbGC+mit38xflkzGiHcOGzkda8/pWZnYu7EknJJPWkoA//2Q==\n", "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAYAAAByDd+UAAAABHNCSVQICAgIfAhkiAAAAdJJREFUSIntVbGq4kAUvW9ZUZCQIqjBzlohhWAsLEQLwdZUNn6AkEbQD7AJthZWFnYWWgj+gKRIpRZiK1hpIYoY1BSZ86q12SUm8T2L5V04zcyZOffMvZf5ICLQG+PXO8V+BH8EfcVvN6RCoUC1Wo1yuRwlEgmaTqd0Op1ot9vRZDIhwzA8icIJqqrCsizYtg3GGGzbfoAxBsuysNlsIMuy4z1/8EEOg8/zPK3XaxqPx9Tv9x/roVCIKpUKxeNxKpVKJAgCzedzymQyrznUNA35fN4xY0mSHo7dOHz6pM8gSRIYY9B13RX/5S4tFosEgDqdjuszjhlxHAdRFBEMBv/a43ke+/0e2+0WHMe5cvh0LGazGUmSRIvFgo7HI41GI7rf70REpGkaRSIRqtfrdLlcvsZht9sFY+yfOBwOUBTFa92dCYFAALFYDNlsFo1GA+12G6ZpwrZt9Ho9P43mvTNlWcb1esX5fIYgCN8vSERoNptgjKFarb5HkIjAGMNyufR05uU5BOCJ////h0Q+69dqtQDgPU0TjUZxu91gGAbC4fDXCyaTSaiqClEUUS6Xoes6TNNEOp328zruiIqiYDAYYLVaYTgcIpVK+SqF44//HfEJXkMk1eKRk3QAAAAASUVORK5CYII=\n", "text/plain": [ "Image (3, 28, 28)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.data.valid_ds[0][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You could always check yourself if the probabilities given make sense." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

get_preds[source][test]

\n", "\n", "> get_preds(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`activ`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=***`None`***, **`with_loss`**:`bool`=***`False`***, **`n_batch`**:`Optional`\\[`int`\\]=***`None`***, **`pbar`**:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=***`None`***) → `List`\\[`Tensor`\\]\n", "\n", "
×

Tests found for get_preds:

  • pytest -sv tests/test_basic_train.py::test_get_preds [source]

To run tests please refer to this guide.

\n", "\n", "Return predictions and targets on `ds_type` dataset. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.get_preds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It will run inference using the learner on all the data in the `ds_type` dataset and return the predictions; if `n_batch` is not specified, it will run the predictions on the default batch size. If `with_loss`, it will also return the loss on each prediction." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is how you check the default batch size." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "64" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.data.batch_size" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([[9.9925e-01, 7.4895e-04],\n", " [9.8333e-01, 1.6672e-02],\n", " [9.9996e-01, 3.8919e-05],\n", " ...,\n", " [1.6180e-04, 9.9984e-01],\n", " [2.5164e-02, 9.7484e-01],\n", " [1.8179e-02, 9.8182e-01]]), tensor([0, 0, 0, ..., 1, 1, 1])]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = learn.get_preds()\n", "preds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first element of the tuple is a tensor that contains all the predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[9.9925e-01, 7.4895e-04],\n", " [9.8333e-01, 1.6672e-02],\n", " [9.9996e-01, 3.8919e-05],\n", " ...,\n", " [1.6180e-04, 9.9984e-01],\n", " [2.5164e-02, 9.7484e-01],\n", " [1.8179e-02, 9.8182e-01]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "While the second element of the tuple is a tensor that contains all the target labels." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([0, 0, 0, ..., 1, 1, 1])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds[1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor(0)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds[1][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For more details about what each number mean, refer to the documentation of [`predict`](/basic_train.html#predict).\n", "\n", "Since [`get_preds`](/basic_train.html#get_preds) gets predictions on all the data in the `ds_type` dataset, here the number of predictions will be equal to the number of data in the validation dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2038" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(learn.data.valid_ds)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2038, 2038)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(preds[0]), len(preds[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To get predictions on the entire training dataset, simply set the `ds_type` argument accordingly." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([[9.9801e-01, 1.9876e-03],\n", " [1.7900e-06, 1.0000e+00],\n", " [1.3191e-03, 9.9868e-01],\n", " ...,\n", " [9.9991e-01, 8.6866e-05],\n", " [1.6420e-04, 9.9984e-01],\n", " [2.2937e-03, 9.9771e-01]]), tensor([0, 1, 1, ..., 0, 1, 1])]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.get_preds(ds_type=DatasetType.Train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To also get prediction loss along with the predictions and the targets, set `with_loss=True` in the arguments." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[tensor([[9.9925e-01, 7.4895e-04],\n", " [9.8333e-01, 1.6672e-02],\n", " [9.9996e-01, 3.8919e-05],\n", " ...,\n", " [1.6180e-04, 9.9984e-01],\n", " [2.5164e-02, 9.7484e-01],\n", " [1.8179e-02, 9.8182e-01]]),\n", " tensor([0, 0, 0, ..., 1, 1, 1]),\n", " tensor([7.4911e-04, 1.6813e-02, 3.8624e-05, ..., 1.6165e-04, 2.5486e-02,\n", " 1.8347e-02])]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.get_preds(with_loss=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the third tensor in the output tuple contains the losses." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

validate[source][test]

\n", "\n", "> validate(**`dl`**=***`None`***, **`callbacks`**=***`None`***, **`metrics`**=***`None`***)\n", "\n", "
×

Tests found for validate:

  • pytest -sv tests/test_collab_train.py::test_val_loss [source]
  • pytest -sv tests/test_text_train.py::test_val_loss [source]

To run tests please refer to this guide.

\n", "\n", "Validate on `dl` with potential `callbacks` and `metrics`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.validate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Return the calculated loss and the metrics of the current model on the given data loader `dl`. The default data loader `dl` is the validation dataloader." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can check the default metrics of the learner using:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'[]'" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "str(learn.metrics)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0.061868817, tensor(0.9799)]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.validate()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0.061868817, tensor(0.9799)]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.validate(learn.data.valid_dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[0.036164965, tensor(0.9887)]" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.validate(learn.data.train_dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

show_results[source][test]

\n", "\n", "> show_results(**`ds_type`**=***``***, **`rows`**:`int`=***`5`***, **\\*\\*`kwargs`**)\n", "\n", "
×

No tests found for show_results. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Show `rows` result of predictions on `ds_type` dataset. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.show_results)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the text number on the top is the ground truth, or the target label, the one in the middle is the prediction, while the image number on the bottom is the image data itself." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.show_results()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.show_results(ds_type=DatasetType.Train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

pred_batch[source][test]

\n", "\n", "> pred_batch(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`batch`**:`Tuple`=***`None`***, **`reconstruct`**:`bool`=***`False`***, **`with_dropout`**:`bool`=***`False`***) → `List`\\[`Tensor`\\]\n", "\n", "
×

No tests found for pred_batch. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Return output of the model on one batch from `ds_type` dataset. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.pred_batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the number of predictions given equals to the batch size." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "64" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.data.batch_size" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "64" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds = learn.pred_batch()\n", "len(preds)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since the total number of predictions is too large, we will only look at a part of them." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[9.9925e-01, 7.4895e-04],\n", " [9.8333e-01, 1.6672e-02],\n", " [9.9996e-01, 3.8919e-05],\n", " [9.9998e-01, 1.7812e-05],\n", " [9.9993e-01, 6.8040e-05],\n", " [9.9533e-01, 4.6744e-03],\n", " [9.9838e-01, 1.6157e-03],\n", " [1.0000e+00, 1.4298e-06],\n", " [9.9942e-01, 5.8188e-04],\n", " [9.9999e-01, 1.2754e-05]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preds[:10]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/jpeg": "/9j/4AAQSkZJRgABAQEAZABkAAD/2wBDAAIBAQEBAQIBAQECAgICAgQDAgICAgUEBAMEBgUGBgYFBgYGBwkIBgcJBwYGCAsICQoKCgoKBggLDAsKDAkKCgr/2wBDAQICAgICAgUDAwUKBwYHCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgr/wAARCAAcABwDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD+f+vtr4ff8G5v/BaH4o/D/Q/ih4L/AGH9Tn0bxHDHNpM114t0W1mkikyUkeCe9SaFGA3BpEUbSrZwwJ+Ja/fD/g13+Dtzq/w6l/b6/ae/4Kna9o+gaP4jTSNL+GUnxZezs3a32LGNVWeYfuzJND5VsuFdSu8usnl0AfjV+2R+wp+1l/wT9+J8Pwc/a/8Ag1feDPEFzp6X1nbXF5bXcN1buSBJFcWsssEoBBVtjkqwKtg8V5JX6a/8HYmq/tMt/wAFVb7w5+0L8RtB1nT7bwta3XgDTPDaSxwaPo000yxwTRykkXbNC0krbmDl1ZdqFY0/MqgAr9b/APgkv+xl8Hf+Cb/7Jeuf8Fwv+CjHw1g1KbRV8v4BfC3Xo2t7nV9XWVEj1FoZk5VZGTy5MOI1WSbaWWGvyQr2f9rD/goV+2T+3DpPhHw/+1H8cb7xTp3gPR10zwlpp0+0s7bT7dVVRiK0iiR5CqqpmcNIyqoLEAAAGX+2p+2F8Zf29f2mPFP7Vfx71G1uPEvim8WS4jsbcRW9rDGixQW0S8kRxRIkalizEICzMSSfLKKKAP/Z\n", "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAYAAAByDd+UAAAABHNCSVQICAgIfAhkiAAAAbBJREFUSInt1rGKwkAQBuBfuUJMRK0FQbYxoGBpLQimDAg+goVPIFrY29iYykoQfAEraxFDhBRW2lmJVqaV7Fxl0BPNrp42dwtTJOzOl1lmk4QAED44wp/E/sE/AubzebRaLXDOwTkHEWGxWMCyLP8e5xyNRkMYpUfRbrfJ87zAOB6PVKvVHuYCQF9BT2PbNvr9vn9dKpWgadrNPFVVwRh7vcKfwRijYrFI1Wr1qsL5fE6pVEokhxwIgDRNo9VqdQUWCgXR9eJQLBYj0zRpu9360G63o2w2S5FI5PfBwWBw0yyu61K32xXdTnGw0+nQ6XS626Wz2Uwoj/DBVxQF4fD96fF4XChP4LE4j2aziWQy6Sfu9XrI5XIwTRMAYFmWaCr5LgVAiqKQbdvv7dJzZDKZK8xxHEqn0+8BE4kELZdLH1uv18QYk8khB14eDcdxZDE50DAMOhwO5HkebTYbmW2UB3VdJ9d1/eoMw3iq2YRAVVVpNBr52GQyoWg0+h5Q13Uaj8c+NhwOqVKpPIsFg5cf4P1+T+Vy+RVM/NUGAPV6HdPpVGbJzQidy/zU+Phf2zcGKKJjocPzywAAAABJRU5ErkJggg==\n", "text/plain": [ "Image (3, 28, 28)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "item = learn.data.train_ds[0][0]\n", "item" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[[[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]],\n", " \n", " [[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]],\n", " \n", " [[0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.],\n", " [0., 0., 0., ..., 0., 0., 0.]]]], device='cuda:0'),\n", " tensor([0], device='cuda:0'))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "batch = learn.data.one_item(item)\n", "batch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.5748, 0.4252]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.pred_batch(batch=batch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

interpret[source][test]

\n", "\n", "> interpret(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`tta`**=***`False`***)\n", "\n", "
×

Tests found for _learner_interpret:

  • pytest -sv tests/test_vision_train.py::test_interp_shortcut [source]

To run tests please refer to this guide.

\n", "\n", "Create a [`ClassificationInterpretation`](/train.html#ClassificationInterpretation) object from `learner` on `ds_type` with `tta`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.interpret, full_name='interpret')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
Note: This function only works in the vision application.
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "jekyll_note('This function only works in the vision application.')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For more details, refer to [ClassificationInterpretation](/vision.learner.html#ClassificationInterpretation)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model summary" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

model_summary[source][test]

\n", "\n", "> model_summary(**`m`**:[`Learner`](/basic_train.html#Learner), **`n`**:`int`=***`70`***)\n", "\n", "
×

Tests found for model_summary:

  • pytest -sv tests/test_basic_train.py::test_export_load_learner [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_collab [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_tabular [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_text [source]
  • pytest -sv tests/test_callbacks_hooks.py::test_model_summary_vision [source]

To run tests please refer to this guide.

\n", "\n", "Print a summary of `m` using a output text width of `n` chars " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.summary)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test time augmentation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

TTA[source][test]

\n", "\n", "> TTA(**`learn`**:[`Learner`](/basic_train.html#Learner), **`beta`**:`float`=***`0.4`***, **`scale`**:`float`=***`1.35`***, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`activ`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=***`None`***, **`with_loss`**:`bool`=***`False`***) → `Tensors`\n", "\n", "
×

No tests found for _TTA. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Applies TTA to predict on `ds_type` dataset. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.TTA, full_name = 'TTA')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Applies Test Time Augmentation to `learn` on the dataset `ds_type`. We take the average of our regular predictions (with a weight `beta`) with the average of predictions obtained through augmented versions of the training set (with a weight `1-beta`). The transforms decided for the training set are applied with a few changes `scale` controls the scale for zoom (which isn't random), the cropping isn't random but we make sure to get the four corners of the image. Flipping isn't random but applied once on each of those corner images (so that makes 8 augmented versions total)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Gradient clipping" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

clip_grad[source][test]

\n", "\n", "> clip_grad(**`learn`**:[`Learner`](/basic_train.html#Learner), **`clip`**:`float`=***`0.1`***) → [`Learner`](/basic_train.html#Learner)\n", "\n", "
×

No tests found for clip_grad. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Add gradient clipping of `clip` during training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.clip_grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Mixed precision training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

to_fp16[source][test]

\n", "\n", "> to_fp16(**`learn`**:[`Learner`](/basic_train.html#Learner), **`loss_scale`**:`float`=***`None`***, **`max_noskip`**:`int`=***`1000`***, **`dynamic`**:`bool`=***`True`***, **`clip`**:`float`=***`None`***, **`flat_master`**:`bool`=***`False`***, **`max_scale`**:`float`=***`16777216`***) → [`Learner`](/basic_train.html#Learner)\n", "\n", "
×

No tests found for to_fp16. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Put `learn` in FP16 precision mode. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.to_fp16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Uses the [`MixedPrecision`](/callbacks.fp16.html#MixedPrecision) callback to train in mixed precision (i.e. forward and backward passes using fp16, with weight updates using fp32), using all [NVIDIA recommendations](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html) for ensuring speed and accuracy." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

to_fp32[source][test]

\n", "\n", "> to_fp32(**`learn`**:[`Learner`](/basic_train.html#Learner))\n", "\n", "
×

No tests found for to_fp32. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Put `learn` back to FP32 precision mode. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.to_fp32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Distributed training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you want to use ditributed training or [`torch.nn.DataParallel`](https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel) these will directly wrap the model for you." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

to_distributed[source][test]

\n", "\n", "> to_distributed(**`learn`**:[`Learner`](/basic_train.html#Learner), **`cuda_id`**:`int`, **`cache_dir`**:`PathOrStr`=***`'tmp'`***)\n", "\n", "
×

No tests found for _learner_distributed. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Put `learn` on distributed training with `cuda_id`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.to_distributed, full_name='to_distributed')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

to_parallel[source][test]

\n", "\n", "> to_parallel(**`learn`**:[`Learner`](/basic_train.html#Learner))\n", "\n", "
×

No tests found for _learner_parallel. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Use nn.DataParallel when training and remove when done " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.to_parallel, full_name='to_parallel')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Discriminative layer training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When fitting a model you can pass a list of learning rates (and/or weight decay amounts), which will apply a different rate to each *layer group* (i.e. the parameters of each module in `self.layer_groups`). See the [Universal Language Model Fine-tuning for Text Classification](https://arxiv.org/abs/1801.06146) paper for details and experimental results in NLP (we also frequently use them successfully in computer vision, but have not published a paper on this topic yet). When working with a [`Learner`](/basic_train.html#Learner) on which you've called `split`, you can set hyperparameters in four ways:\n", "\n", "1. `param = [val1, val2 ..., valn]` (n = number of layer groups)\n", "2. `param = val`\n", "3. `param = slice(start,end)`\n", "4. `param = slice(end)`\n", "\n", "If we chose to set it in way 1, we must specify a number of values exactly equal to the number of layer groups. If we chose to set it in way 2, the chosen value will be repeated for all layer groups. See [`Learner.lr_range`](/basic_train.html#Learner.lr_range) for an explanation of the `slice` syntax).\n", "\n", "Here's an example of how to use discriminative learning rates (note that you don't actually need to manually call [`Learner.split`](/basic_train.html#Learner.split) in this case, since fastai uses this exact function as the default split for `resnet18`; this is just to show how to customize it):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# creates 3 layer groups\n", "learn.split(lambda m: (m[0][6], m[1]))\n", "# only randomly initialized head now trainable\n", "learn.freeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
10.0596130.0546040.98184500:05
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
10.0263790.0087630.99803700:07
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# all layers now trainable\n", "learn.unfreeze()\n", "# optionally, separate LR and WD for each group\n", "learn.fit_one_cycle(1, max_lr=(1e-4, 1e-3, 1e-2), wd=(1e-4,1e-4,1e-1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

lr_range[source][test]

\n", "\n", "> lr_range(**`lr`**:`Union`\\[`float`, `slice`\\]) → `ndarray`\n", "\n", "
×

No tests found for lr_range. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Build differential learning rates from `lr`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.lr_range)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Rather than manually setting an LR for every group, it's often easier to use [`Learner.lr_range`](/basic_train.html#Learner.lr_range). This is a convenience method that returns one learning rate for each layer group. If you pass `slice(start,end)` then the first group's learning rate is `start`, the last is `end`, and the remaining are evenly geometrically spaced.\n", "\n", "If you pass just `slice(end)` then the last group's learning rate is `end`, and all the other groups are `end/10`. For instance (for our learner that has 3 layer groups):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(array([1.e-05, 1.e-04, 1.e-03]), array([0.0001, 0.0001, 0.001 ]))" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.lr_range(slice(1e-5,1e-3)), learn.lr_range(slice(1e-3))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

unfreeze[source][test]

\n", "\n", "> unfreeze()\n", "\n", "
×

Tests found for unfreeze:

  • pytest -sv tests/test_basic_train.py::test_unfreeze [source]

To run tests please refer to this guide.

\n", "\n", "Unfreeze entire model. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.unfreeze)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sets every layer group to *trainable* (i.e. `requires_grad=True`)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

freeze[source][test]

\n", "\n", "> freeze()\n", "\n", "
×

Tests found for freeze:

  • pytest -sv tests/test_basic_train.py::test_freeze [source]

To run tests please refer to this guide.

\n", "\n", "Freeze up to last layer group. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.freeze)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sets every layer group except the last to *untrainable* (i.e. `requires_grad=False`).\n", "\n", "What does '**the last layer group**' mean? \n", "\n", "In the case of transfer learning, such as `learn = cnn_learner(data, models.resnet18, metrics=error_rate)`, `learn.model`will print out two large groups of layers: (0) Sequential and (1) Sequental in the following structure. We can consider the last conv layer as the break line between the two groups.\n", "```\n", "Sequential(\n", " (0): Sequential(\n", " (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU(inplace)\n", " ...\n", " \n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " )\n", " (1): Sequential(\n", " (0): AdaptiveConcatPool2d(\n", " (ap): AdaptiveAvgPool2d(output_size=1)\n", " (mp): AdaptiveMaxPool2d(output_size=1)\n", " )\n", " (1): Flatten()\n", " (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (3): Dropout(p=0.25)\n", " (4): Linear(in_features=1024, out_features=512, bias=True)\n", " (5): ReLU(inplace)\n", " (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (7): Dropout(p=0.5)\n", " (8): Linear(in_features=512, out_features=12, bias=True)\n", " )\n", ")\n", "```\n", "\n", "`learn.freeze` freezes the first group and keeps the second or last group free to train, including multiple layers inside (this is why calling it 'group'), as you can see in `learn.summary()` output. How to read the table below, please see [model summary docs](/callbacks.hooks.html#model_summary).\n", "\n", "```\n", "======================================================================\n", "Layer (type) Output Shape Param # Trainable \n", "======================================================================\n", "...\n", "...\n", "...\n", "______________________________________________________________________\n", "Conv2d [1, 512, 4, 4] 2,359,296 False \n", "______________________________________________________________________\n", "BatchNorm2d [1, 512, 4, 4] 1,024 True \n", "______________________________________________________________________\n", "AdaptiveAvgPool2d [1, 512, 1, 1] 0 False \n", "______________________________________________________________________\n", "AdaptiveMaxPool2d [1, 512, 1, 1] 0 False \n", "______________________________________________________________________\n", "Flatten [1, 1024] 0 False \n", "______________________________________________________________________\n", "BatchNorm1d [1, 1024] 2,048 True \n", "______________________________________________________________________\n", "Dropout [1, 1024] 0 False \n", "______________________________________________________________________\n", "Linear [1, 512] 524,800 True \n", "______________________________________________________________________\n", "ReLU [1, 512] 0 False \n", "______________________________________________________________________\n", "BatchNorm1d [1, 512] 1,024 True \n", "______________________________________________________________________\n", "Dropout [1, 512] 0 False \n", "______________________________________________________________________\n", "Linear [1, 12] 6,156 True \n", "______________________________________________________________________\n", "\n", "Total params: 11,710,540\n", "Total trainable params: 543,628\n", "Total non-trainable params: 11,166,912\n", "```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

freeze_to[source][test]

\n", "\n", "> freeze_to(**`n`**:`int`)\n", "\n", "
×

Tests found for freeze_to:

  • pytest -sv tests/test_basic_train.py::test_freeze_to [source]

To run tests please refer to this guide.

\n", "\n", "Freeze layers up to layer group `n`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.freeze_to)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From above we know what is layer group, but **what exactly does `freeze_to` do behind the scenes**?\n", "\n", "The `freeze_to` source code can be understood as the following pseudo-code:\n", "```python\n", "def freeze_to(self, n:int)->None:\n", " for g in self.layer_groups[:n]: freeze \n", " for g in self.layer_groups[n:]: unfreeze\n", "``` \n", "In other words, for example, `freeze_to(1)` is to freeze layer group 0 and unfreeze the rest layer groups, and `freeze_to(3)` is to freeze layer groups 0, 1, and 2 but unfreeze the rest layer groups (if there are more layer groups left). \n", "\n", "Both `freeze` and `unfreeze` [sources](https://github.com/fastai/fastai/blob/master/fastai/basic_train.py#L216) are defined using `freeze_to`: \n", "- When we say `freeze`, we mean that in the specified layer groups the [`requires_grad`](/torch_core.html#requires_grad) of all layers with weights (except BatchNorm layers) are set `False`, so the layer weights won't be updated during training. \n", "- when we say `unfreeze`, we mean that in the specified layer groups the [`requires_grad`](/torch_core.html#requires_grad) of all layers with weights (except BatchNorm layers) are set `True`, so the layer weights will be updated during training." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

split[source][test]

\n", "\n", "> split(**`split_on`**:`SplitFuncOrIdxList`)\n", "\n", "
×

No tests found for split. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Split the model at `split_on`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.split)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A convenience method that sets `layer_groups` based on the result of [`split_model`](/torch_core.html#split_model). If `split_on` is a function, it calls that function and passes the result to [`split_model`](/torch_core.html#split_model) (see above for example)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Saving and loading models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Simply call [`Learner.save`](/basic_train.html#Learner.save) and [`Learner.load`](/basic_train.html#Learner.load) to save and load models. Only the parameters are saved, not the actual architecture (so you'll need to create your model in the same way before loading weights back in). Models are saved to the `path`/`model_dir` directory." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

save[source][test]

\n", "\n", "> save(**`file`**:`PathLikeOrBinaryStream`=***`None`***, **`return_path`**:`bool`=***`False`***, **`with_opt`**:`bool`=***`True`***)\n", "\n", "
×

Tests found for save:

  • pytest -sv tests/test_basic_train.py::test_memory [source]
  • pytest -sv tests/test_basic_train.py::test_save_load [source]

Some other tests where save is used:

  • pytest -sv tests/test_basic_train.py::test_model_load_mem_leak [source]

To run tests please refer to this guide.

\n", "\n", "Save model and optimizer state (if `with_opt`) with `file` to `self.model_dir`. `file` can be file-like (file or buffer) " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.save)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If argument `name` is a pathlib object that's an absolute path, it'll override the default base directory (`learn.path`), otherwise the model will be saved in a file relative to `learn.path`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save(\"trained_model\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('/home/ubuntu/.fastai/data/mnist_sample/models/trained_model.pth')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.save(\"trained_model\", return_path=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

load[source][test]

\n", "\n", "> load(**`file`**:`PathLikeOrBinaryStream`=***`None`***, **`device`**:[`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch-device)=***`None`***, **`strict`**:`bool`=***`True`***, **`with_opt`**:`bool`=***`None`***, **`purge`**:`bool`=***`False`***, **`remove_module`**:`bool`=***`False`***)\n", "\n", "
×

Tests found for load:

  • pytest -sv tests/test_basic_train.py::test_memory [source]
  • pytest -sv tests/test_basic_train.py::test_save_load [source]

Some other tests where load is used:

  • pytest -sv tests/test_basic_train.py::test_model_load_mem_leak [source]

To run tests please refer to this guide.

\n", "\n", "Load model and optimizer state (if `with_opt`) `file` from `self.model_dir` using `device`. `file` can be file-like (file or buffer) " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.load)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This method only works after `save` (don't confuse with `export`/[`load_learner`](/basic_train.html#load_learner) pair).\n", "\n", "If the `purge` argument is `True` (default) `load` internally calls `purge` with `clear_opt=False` to presever `learn.opt`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = learn.load(\"trained_model\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Deploying your model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When you are ready to put your model in production, export the minimal state of your [`Learner`](/basic_train.html#Learner) with:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

export[source][test]

\n", "\n", "> export(**`file`**:`PathLikeOrBinaryStream`=***`'export.pkl'`***, **`destroy`**=***`False`***)\n", "\n", "
×

Tests found for export:

  • pytest -sv tests/test_basic_train.py::test_export_load_learner [source]

To run tests please refer to this guide.

\n", "\n", "Export the state of the [`Learner`](/basic_train.html#Learner) in `self.path/file`. `file` can be file-like (file or buffer) " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.export)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If argument `fname` is a pathlib object that's an absolute path, it'll override the default base directory (`learn.path`), otherwise the model will be saved in a file relative to `learn.path`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Passing `destroy=True` will destroy the [`Learner`](/basic_train.html#Learner), freeing most of its memory consumption. For specifics see [`Learner.destroy`](/basic_train.html#Learner.destroy).\n", "\n", "This method only works with the [`Learner`](/basic_train.html#Learner) whose [`data`](/vision.data.html#vision.data) was created through the [data block API](/data_block.html).\n", "\n", "Otherwise, you will have to create a [`Learner`](/basic_train.html#Learner) yourself at inference and load the model with [`Learner.load`](/basic_train.html#Learner.load)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.export()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.export('trained_model.pkl')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "PosixPath('/home/ubuntu/.fastai/data/mnist_sample')" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = learn.path\n", "path" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

load_learner[source][test]

\n", "\n", "> load_learner(**`path`**:`PathOrStr`, **`file`**:`PathLikeOrBinaryStream`=***`'export.pkl'`***, **`test`**:[`ItemList`](/data_block.html#ItemList)=***`None`***, **\\*\\*`db_kwargs`**)\n", "\n", "
×

Tests found for load_learner:

  • pytest -sv tests/test_basic_train.py::test_export_load_learner [source]

To run tests please refer to this guide.

\n", "\n", "Load a [`Learner`](/basic_train.html#Learner) object saved with `export_state` in `path/file` with empty data, optionally add `test` and load on `cpu`. `file` can be file-like (file or buffer) " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(load_learner)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function only works after `export` (don't confuse with `save`/`load` pair).\n", "\n", "The `db_kwargs` will be passed to the call to `databunch` so you can specify a `bs` for the test set, or `num_workers`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = load_learner(path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = load_learner(path, 'trained_model.pkl')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "WARNING: If you used any customized classes when creating your learner, you must first define these classes first before executing [`load_learner`](/basic_train.html#load_learner).\n", "\n", "You can find more information and multiple examples in [this tutorial](/tutorial.inference.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Freeing memory\n", "\n", "If you want to be able to do more without needing to restart your notebook, the following methods are designed to free memory when it's no longer needed. \n", "\n", "Refer to [this tutorial](/tutorial.resources.html) to learn how and when to use these methods." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

purge[source][test]

\n", "\n", "> purge(**`clear_opt`**:`bool`=***`True`***)\n", "\n", "
×

Tests found for purge:

  • pytest -sv tests/test_basic_train.py::test_memory [source]
  • pytest -sv tests/test_basic_train.py::test_purge [source]
  • pytest -sv tests/test_basic_train.py::test_save_load [source]

To run tests please refer to this guide.

\n", "\n", "Purge the [`Learner`](/basic_train.html#Learner) of all cached attributes to release some GPU memory. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.purge)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If `learn.path` is read-only, you can set `model_dir` attribute in Learner to a full `libpath` path that is writable (by setting `learn.model_dir` or passing `model_dir` argument in the [`Learner`](/basic_train.html#Learner) constructor)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

destroy[source][test]

\n", "\n", "> destroy()\n", "\n", "
×

Tests found for destroy:

  • pytest -sv tests/test_basic_train.py::test_destroy [source]
  • pytest -sv tests/test_basic_train.py::test_memory [source]

To run tests please refer to this guide.

\n", "\n", "Free the Learner internals, leaving just an empty shell that consumes no memory " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.destroy)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you need to free the memory consumed by the [`Learner`](/basic_train.html#Learner) object, call this method. \n", "\n", "It can also be automatically invoked through [`Learner.export`](/basic_train.html#Learner.export) via its `destroy=True` argument." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Other methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

init[source][test]

\n", "\n", "> init(**`init`**)\n", "\n", "
×

No tests found for init. To contribute a test please refer to this guide and this discussion.

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.init)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initializes all weights (except batchnorm) using function `init`, which will often be from PyTorch's [`nn.init`](https://pytorch.org/docs/stable/nn.html#torch-nn-init) module." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

mixup[source][test]

\n", "\n", "> mixup(**`learn`**:[`Learner`](/basic_train.html#Learner), **`alpha`**:`float`=***`0.4`***, **`stack_x`**:`bool`=***`False`***, **`stack_y`**:`bool`=***`True`***) → [`Learner`](/basic_train.html#Learner)\n", "\n", "
×

No tests found for mixup. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Add mixup https://arxiv.org/abs/1710.09412 to `learn`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.mixup)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Uses [`MixUpCallback`](/callbacks.mixup.html#MixUpCallback)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

backward[source][test]

\n", "\n", "> backward(**`item`**)\n", "\n", "
×

No tests found for backward. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Pass `item` through the model and computes the gradient. Useful if `backward_hooks` are attached. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.backward)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

create_opt[source][test]

\n", "\n", "> create_opt(**`lr`**:`Floats`, **`wd`**:`Floats`=***`0.0`***)\n", "\n", "
×

No tests found for create_opt. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Create optimizer with `lr` learning rate and `wd` weight decay. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.create_opt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You generally won't need to call this yourself - it's used to create the [`optim`](https://pytorch.org/docs/stable/optim.html#module-torch.optim) optimizer before fitting the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

dl[source][test]

\n", "\n", "> dl(**`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***)\n", "\n", "
×

No tests found for dl. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Return DataLoader for DatasetType `ds_type`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.dl)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceDataLoader(dl=, device=device(type='cuda'), tfms=[], collate_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.dl()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceDataLoader(dl=, device=device(type='cuda'), tfms=[], collate_fn=)" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "learn.dl(DatasetType.Train)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class Recorder[source][test]

\n", "\n", "> Recorder(**`learn`**:[`Learner`](/basic_train.html#Learner), **`add_time`**:`bool`=***`True`***, **`silent`**:`bool`=***`False`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n", "\n", "
×

Tests found for Recorder:

  • pytest -sv tests/test_vision_train.py::test_1cycle_lrs [source]
  • pytest -sv tests/test_vision_train.py::test_1cycle_moms [source]

To run tests please refer to this guide.

\n", "\n", "A [`LearnerCallback`](/basic_train.html#LearnerCallback) that records epoch, loss, opt and metric data during training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder, title_level=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A [`Learner`](/basic_train.html#Learner) creates a [`Recorder`](/basic_train.html#Recorder) object automatically - you do not need to explicitly pass it to `callback_fns` - because other callbacks rely on it being available. It stores the smoothed loss, hyperparameter values, and metrics for each batch, and provides plotting methods for each. Note that [`Learner`](/basic_train.html#Learner) automatically sets an attribute with the snake-cased name of each callback, so you can access this through `Learner.recorder`, as shown below." ] }, { "cell_type": "markdown", "metadata": { "hide_input": true }, "source": [ "### Plotting methods" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot[source][test]

\n", "\n", "> plot(**`skip_start`**:`int`=***`10`***, **`skip_end`**:`int`=***`5`***, **`suggestion`**:`bool`=***`False`***, **`return_fig`**:`bool`=***`None`***, **\\*\\*`kwargs`**) → `Optional`\\[`Figure`\\]\n", "\n", "
×

No tests found for plot. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Plot learning rate and losses, trimmed between `skip_start` and `skip_end`. Optionally plot and return min gradient " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.plot)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is mainly used with the learning rate finder, since it shows a scatterplot of loss vs learning rate." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.MNIST_SAMPLE)\n", "data = ImageDataBunch.from_folder(path)\n", "learn = cnn_learner(data, models.resnet18, metrics=accuracy)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.lr_find()\n", "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_losses[source][test]

\n", "\n", "> plot_losses(**`skip_start`**:`int`=***`0`***, **`skip_end`**:`int`=***`0`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n", "\n", "
×

No tests found for plot_losses. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Plot training and validation losses. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.plot_losses)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that validation losses are only calculated once per epoch, whereas training losses are calculated after every batch." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
epochtrain_lossvalid_lossaccuracytime
10.2285240.1222850.95878300:05
20.1188380.0752220.97105000:05
30.0667150.0549200.98135400:05
40.0481550.0486120.98331700:05
50.0375350.0460140.98233600:05
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.fit_one_cycle(5)\n", "learn.recorder.plot_losses()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_lr[source][test]

\n", "\n", "> plot_lr(**`show_moms`**=***`False`***, **`skip_start`**:`int`=***`0`***, **`skip_end`**:`int`=***`0`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n", "\n", "
×

No tests found for plot_lr. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Plot learning rate, `show_moms` to include momentum. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.plot_lr)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_lr(show_moms=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

plot_metrics[source][test]

\n", "\n", "> plot_metrics(**`skip_start`**:`int`=***`0`***, **`skip_end`**:`int`=***`0`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n", "\n", "
×

No tests found for plot_metrics. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Plot metrics collected during training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.plot_metrics)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that metrics are only collected at the end of each epoch, so you'll need to train at least two epochs to have anything to show here." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "learn.recorder.plot_metrics()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Callback methods" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You don't call these yourself - they're called by fastai's [`Callback`](/callback.html#Callback) system automatically to enable the class's functionality. Refer to [`Callback`](/callback.html#Callback) for more details." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_backward_begin[source][test]

\n", "\n", "> on_backward_begin(**`smooth_loss`**:`Tensor`, **\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_backward_begin. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Record the loss before any other callback has a chance to modify it. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.on_backward_begin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_batch_begin[source][test]

\n", "\n", "> on_batch_begin(**`train`**, **\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_batch_begin. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Record learning rate and momentum at beginning of batch. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.on_batch_begin)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_epoch_end[source][test]

\n", "\n", "> on_epoch_end(**`epoch`**:`int`, **`num_batch`**:`int`, **`smooth_loss`**:`Tensor`, **`last_metrics`**=***`typing.Collection[typing.Union[torch.Tensor, numbers.Number]]`***, **\\*\\*`kwargs`**:`Any`) → `bool`\n", "\n", "
×

No tests found for on_epoch_end. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Save epoch info: num_batch, smooth_loss, metrics. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.on_epoch_end)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_train_begin[source][test]

\n", "\n", "> on_train_begin(**`pbar`**:`PBar`, **`metrics_names`**:`StrList`, **\\*\\*`kwargs`**:`Any`)\n", "\n", "
×

No tests found for on_train_begin. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Initialize recording status at beginning of training. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.on_train_begin)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Inner functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following functions are used along the way by the [`Recorder`](/basic_train.html#Recorder) or can be called by other callbacks." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

add_metric_names[source][test]

\n", "\n", "> add_metric_names(**`names`**)\n", "\n", "
×

No tests found for add_metric_names. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Add `names` to the inner metric names. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.add_metric_names)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

format_stats[source][test]

\n", "\n", "> format_stats(**`stats`**:`MetricsList`)\n", "\n", "
×

No tests found for format_stats. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Format stats before printing. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Recorder.format_stats)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Module functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generally you'll want to use a [`Learner`](/basic_train.html#Learner) to train your model, since they provide a lot of functionality and make things easier. However, for ultimate flexibility, you can call the same underlying functions that [`Learner`](/basic_train.html#Learner) calls behind the scenes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

fit[source][test]

\n", "\n", "> fit(**`epochs`**:`int`, **`learn`**:[`BasicLearner`](/basic_train.html#BasicLearner), **`callbacks`**:`Optional`\\[`Collection`\\[[`Callback`](/callback.html#Callback)\\]\\]=***`None`***, **`metrics`**:`OptMetrics`=***`None`***)\n", "\n", "
×

Tests found for fit:

Some other tests where fit is used:

  • pytest -sv tests/test_basic_train.py::test_destroy [source]

To run tests please refer to this guide.

\n", "\n", "Fit the `model` on `data` and learn using `loss_func` and `opt`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(fit)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that you have to create the [`Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer) yourself if you call this function, whereas [`Learn.fit`](/basic_train.html#fit) creates it for you automatically." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

train_epoch[source][test]

\n", "\n", "> train_epoch(**`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`dl`**:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), **`opt`**:[`Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), **`loss_func`**:`LossFunction`)\n", "\n", "
×

No tests found for train_epoch. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Simple training of `model` for 1 epoch of `dl` using optim `opt` and loss function `loss_func`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(train_epoch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You won't generally need to call this yourself - it's what [`fit`](/basic_train.html#fit) calls for each epoch." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

validate[source][test]

\n", "\n", "> validate(**`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`dl`**:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), **`loss_func`**:`OptLossFunc`=***`None`***, **`cb_handler`**:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=***`None`***, **`pbar`**:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=***`None`***, **`average`**=***`True`***, **`n_batch`**:`Optional`\\[`int`\\]=***`None`***) → `Iterator`\\[`Tuple`\\[`IntOrTensor`, `Ellipsis`\\]\\]\n", "\n", "
×

Tests found for validate:

  • pytest -sv tests/test_tabular_train.py::test_accuracy [source]

To run tests please refer to this guide.

\n", "\n", "Calculate `loss_func` of `model` on `dl` in evaluation mode. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(validate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is what [`fit`](/basic_train.html#fit) calls after each epoch. You can call it if you want to run inference on a [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) manually." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

get_preds[source][test]

\n", "\n", "> get_preds(**`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`dl`**:[`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader), **`pbar`**:`Union`\\[`MasterBar`, `ProgressBar`, `NoneType`\\]=***`None`***, **`cb_handler`**:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=***`None`***, **`activ`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=***`None`***, **`loss_func`**:`OptLossFunc`=***`None`***, **`n_batch`**:`Optional`\\[`int`\\]=***`None`***) → `List`\\[`Tensor`\\]\n", "\n", "
×

Tests found for get_preds:

Some other tests where get_preds is used:

  • pytest -sv tests/test_basic_train.py::test_get_preds [source]

To run tests please refer to this guide.

\n", "\n", "Tuple of predictions and targets, and optional losses (if `loss_func`) using `dl`, max batches `n_batch`. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(get_preds)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

loss_batch[source][test]

\n", "\n", "> loss_batch(**`model`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`xb`**:`Tensor`, **`yb`**:`Tensor`, **`loss_func`**:`OptLossFunc`=***`None`***, **`opt`**:`OptOptimizer`=***`None`***, **`cb_handler`**:`Optional`\\[[`CallbackHandler`](/callback.html#CallbackHandler)\\]=***`None`***) → `Tuple`\\[`Union`\\[`Tensor`, `int`, `float`, `str`\\]\\]\n", "\n", "
×

No tests found for loss_batch. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Calculate loss and metrics for a batch, call out to callbacks as necessary. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(loss_batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You won't generally need to call this yourself - it's what [`fit`](/basic_train.html#fit) and [`validate`](/basic_train.html#validate) call for each batch. It only does a backward pass if you set `opt`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Other classes" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class LearnerCallback[source][test]

\n", "\n", "> LearnerCallback(**`learn`**) :: [`Callback`](/callback.html#Callback)\n", "\n", "
×

No tests found for LearnerCallback. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Base class for creating callbacks for a [`Learner`](/basic_train.html#Learner). " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(LearnerCallback, title_level=3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

class RecordOnCPU[source][test]

\n", "\n", "> RecordOnCPU() :: [`Callback`](/callback.html#Callback)\n", "\n", "
×

No tests found for RecordOnCPU. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Store the `input` and `target` going through the model on the CPU. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(RecordOnCPU, title_level=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Open This Notebook\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Undocumented Methods - Methods moved below this line will intentionally be hidden" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_tta_only[source][test]

\n", "\n", "> _tta_only(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`activ`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=***`None`***, **`scale`**:`float`=***`1.35`***) → `Iterator`\\[`List`\\[`Tensor`\\]\\]\n", "\n", "
×

No tests found for _tta_only. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Computes the outputs for several augmented inputs for TTA " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.tta_only)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

_TTA[source][test]

\n", "\n", "> _TTA(**`learn`**:[`Learner`](/basic_train.html#Learner), **`beta`**:`float`=***`0.4`***, **`scale`**:`float`=***`1.35`***, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***``***, **`activ`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=***`None`***, **`with_loss`**:`bool`=***`False`***) → `Tensors`\n", "\n", "
×

No tests found for _TTA. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Applies TTA to predict on `ds_type` dataset. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(Learner.TTA)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "

on_batch_begin[source][test]

\n", "\n", "> on_batch_begin(**`last_input`**, **`last_target`**, **\\*\\*`kwargs`**)\n", "\n", "
×

No tests found for on_batch_begin. To contribute a test please refer to this guide and this discussion.

\n", "\n", "Set HP before the output and loss are computed. " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "show_doc(RecordOnCPU.on_batch_begin)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## New Methods - Please document or move to the undocumented section" ] } ], "metadata": { "jekyll": { "keywords": "fastai", "summary": "Learner class and training loop", "title": "basic_train" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }