{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Computer Vision Learner" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[`vision.learner`](/vision.learner.html#vision.learner) is the module that defines the [`cnn_learner`](/vision.learner.html#cnn_learner) method, to easily get a model suitable for transfer learning." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [], "source": [ "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import *\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Transfer learning" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transfer learning is a technique where you use a model trained on a very large dataset (usually [ImageNet](http://image-net.org/) in computer vision) and then adapt it to your own dataset. The idea is that it has learned to recognize many features on all of this data, and that you will benefit from this knowledge, especially if your dataset is small, compared to starting from a randomly initialized model. It has been proved in [this article](https://arxiv.org/abs/1805.08974) on a wide range of tasks that transfer learning nearly always give better results.\n", "\n", "In practice, you need to change the last part of your model to be adapted to your own number of classes. Most convolutional models end with a few linear layers (a part we will call the head). The last convolutional layer will have analyzed features in the image that went through the model, and the job of the head is to convert those in predictions for each of our classes. In transfer learning we will keep all the convolutional layers (called the body or the backbone of the model) with their weights pretrained on ImageNet but will define a new head initialized randomly.\n", "\n", "Then we will train the model we obtain in two phases: first we freeze the body weights and only train the head (to convert those analyzed features into predictions for our own data), then we unfreeze the layers of the backbone (gradually if necessary) and fine-tune the whole model (possibly using differential learning rates).\n", "\n", "The [`cnn_learner`](/vision.learner.html#cnn_learner) factory method helps you to automatically get a pretrained model from a given architecture with a custom head that is suitable for your data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
cnn_learner
[source][test]cnn_learner
(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`base_arch`**:`Callable`, **`cut`**:`Union`\\[`int`, `Callable`\\]=***`None`***, **`pretrained`**:`bool`=***`True`***, **`lin_ftrs`**:`Optional`\\[`Collection`\\[`int`\\]\\]=***`None`***, **`ps`**:`Floats`=***`0.5`***, **`custom_head`**:`Optional`\\[[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)\\]=***`None`***, **`split_on`**:`Union`\\[`Callable`, `Collection`\\[`ModuleList`\\], `NoneType`\\]=***`None`***, **`bn_final`**:`bool`=***`False`***, **`init`**=***`'kaiming_normal_'`***, **`concat_pool`**:`bool`=***`True`***, **\\*\\*`kwargs`**:`Any`) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for cnn_learner
. To contribute a test please refer to this guide and this discussion.
epoch | \n", "train_loss | \n", "valid_loss | \n", "accuracy | \n", "time | \n", "
---|---|---|---|---|
0 | \n", "0.132899 | \n", "0.069354 | \n", "0.978901 | \n", "00:06 | \n", "
unet_learner
[source][test]unet_learner
(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`arch`**:`Callable`, **`pretrained`**:`bool`=***`True`***, **`blur_final`**:`bool`=***`True`***, **`norm_type`**:`Optional`\\[[`NormType`](/layers.html#NormType)\\]=***`'NormType'`***, **`split_on`**:`Union`\\[`Callable`, `Collection`\\[`ModuleList`\\], `NoneType`\\]=***`None`***, **`blur`**:`bool`=***`False`***, **`self_attention`**:`bool`=***`False`***, **`y_range`**:`OptRange`=***`None`***, **`last_cross`**:`bool`=***`True`***, **`bottle`**:`bool`=***`False`***, **`cut`**:`Union`\\[`int`, `Callable`\\]=***`None`***, **\\*\\*`learn_kwargs`**:`Any`) → [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for unet_learner
. To contribute a test please refer to this guide and this discussion.
predict
[source][test]predict
(**`item`**:[`ItemBase`](/core.html#ItemBase), **`return_x`**:`bool`=***`False`***, **`batch_first`**:`bool`=***`True`***, **`with_dropout`**:`bool`=***`False`***, **\\*\\*`kwargs`**)\n",
"\n",
"create_body
[source][test]create_body
(**`arch`**:`Callable`, **`pretrained`**:`bool`=***`True`***, **`cut`**:`Union`\\[`int`, `Callable`, `NoneType`\\]=***`None`***)\n",
"\n",
"create_head
[source][test]create_head
(**`nf`**:`int`, **`nc`**:`int`, **`lin_ftrs`**:`Optional`\\[`Collection`\\[`int`\\]\\]=***`None`***, **`ps`**:`Floats`=***`0.5`***, **`concat_pool`**:`bool`=***`True`***, **`bn_final`**:`bool`=***`False`***)\n",
"\n",
"class
ClassificationInterpretation
[source][test]ClassificationInterpretation
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`preds`**:`Tensor`, **`y_true`**:`Tensor`, **`losses`**:`Tensor`, **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`Tests found for ClassificationInterpretation
:
pytest -sv tests/test_vision_train.py::test_ClassificationInterpretation
[source]Some other tests where ClassificationInterpretation
is used:
pytest -sv tests/test_tabular_train.py::test_confusion_tabular
[source]pytest -sv tests/test_vision_train.py::test_interp
[source]To run tests please refer to this guide.
from_learner
[source][test]from_learner
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`interpret
[source][test]interpret
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`ds_type`**:[`DatasetType`](/basic_data.html#DatasetType)=***`plot_top_losses
[source][test]plot_top_losses
(**`k`**, **`largest`**=***`True`***, **`figsize`**=***`(12, 12)`***, **`heatmap`**:`bool`=***`None`***, **`heatmap_thresh`**:`int`=***`16`***, **`return_fig`**:`bool`=***`None`***) → `Optional`\\[`Figure`\\]\n",
"\n",
"No tests found for _cl_int_plot_top_losses
. To contribute a test please refer to this guide and this discussion.