{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "from fastai2.basics import *\n", "from fastai2.vision.core import *\n", "from fastai2.vision.data import *\n", "from fastai2.vision.augment import *\n", "from fastai2.vision import models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#default_exp vision.learner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "from nbdev.showdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Learner for the vision applications\n", "\n", "> All the functions necessary to build `Learner` suitable for transfer learning in computer vision" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The most important functions of this module are `cnn_learner` and `unet_learner`. They will help you define a `Learner` using a pretrained model. See the [vision tutorial](http://dev.fast.ai/tutorial.vision) for examples of use." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cut a pretrained model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _is_pool_type(l): return re.search(r'Pool[123]d$', l.__class__.__name__)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "m = nn.Sequential(nn.AdaptiveAvgPool2d(5), nn.Linear(2,3), nn.Conv2d(2,3,1), nn.MaxPool3d(5))\n", "test_eq([bool(_is_pool_type(m_)) for m_ in m.children()], [True,False,False,True])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, the fastai library cuts a pretrained model at the pooling layer. This function helps detecting it. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def has_pool_type(m):\n", " \"Return `True` if `m` is a pooling layer or has one in its children\"\n", " if _is_pool_type(m): return True\n", " for l in m.children():\n", " if has_pool_type(l): return True\n", " return False" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "m = nn.Sequential(nn.AdaptiveAvgPool2d(5), nn.Linear(2,3), nn.Conv2d(2,3,1), nn.MaxPool3d(5))\n", "assert has_pool_type(m)\n", "test_eq([has_pool_type(m_) for m_ in m.children()], [True,False,False,True])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def _get_first_layer(m):\n", " \"Access first layer of a model\"\n", " c,p,n = m,None,None # child, parent, name\n", " for n in next(m.named_parameters())[0].split('.')[:-1]:\n", " p,c=c,getattr(c,n)\n", " return c,p,n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _load_pretrained_weights(new_layer, previous_layer):\n", " \"Load pretrained weights based on number of input channels\"\n", " n_in = getattr(new_layer, 'in_channels')\n", " if n_in==1:\n", " # we take the sum\n", " new_layer.weight.data = previous_layer.weight.data.sum(dim=1, keepdim=True)\n", " elif n_in==2:\n", " # we take first 2 channels + 50%\n", " new_layer.weight.data = previous_layer.weight.data[:,:2] * 1.5\n", " else:\n", " # keep 3 channels weights and set others to null\n", " new_layer.weight.data[:,:3] = previous_layer.weight.data\n", " new_layer.weight.data[:,3:].zero_()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _update_first_layer(model, n_in, pretrained):\n", " \"Change first layer based on number of input channels\"\n", " if n_in == 3: return\n", " first_layer, parent, name = _get_first_layer(model)\n", " assert isinstance(first_layer, nn.Conv2d), f'Change of input channels only supported with Conv2d, found {first_layer.__class__.__name__}'\n", " assert getattr(first_layer, 'in_channels') == 3, f'Unexpected number of input channels, found {getattr(first_layer, \"in_channels\")} while expecting 3'\n", " params = {attr:getattr(first_layer, attr) for attr in 'out_channels kernel_size stride padding dilation groups padding_mode'.split()}\n", " params['bias'] = getattr(first_layer, 'bias') is not None\n", " params['in_channels'] = n_in\n", " new_layer = nn.Conv2d(**params)\n", " if pretrained:\n", " _load_pretrained_weights(new_layer, first_layer)\n", " setattr(parent, name, new_layer)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def create_body(arch, n_in=3, pretrained=True, cut=None):\n", " \"Cut off the body of a typically pretrained `arch` as determined by `cut`\"\n", " model = arch(pretrained=pretrained)\n", " _update_first_layer(model, n_in, pretrained)\n", " #cut = ifnone(cut, cnn_config(arch)['cut'])\n", " if cut is None:\n", " ll = list(enumerate(model.children()))\n", " cut = next(i for i,o in reversed(ll) if has_pool_type(o))\n", " if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])\n", " elif callable(cut): return cut(model)\n", " else: raise NamedError(\"cut must be either integer or a function\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`cut` can either be an integer, in which case we cut the model at the coresponding layer, or a function, in which case, this function returns `cut(model)`. It defaults to the first layer that contains some pooling otherwise." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = lambda pretrained : nn.Sequential(nn.Conv2d(3,5,3), nn.BatchNorm2d(5), nn.AvgPool2d(1), nn.Linear(3,4))\n", "m = create_body(tst)\n", "test_eq(len(m), 2)\n", "\n", "m = create_body(tst, cut=3)\n", "test_eq(len(m), 3)\n", "\n", "m = create_body(tst, cut=noop)\n", "test_eq(len(m), 4)\n", "\n", "for n in range(1,5): \n", " m = create_body(tst, n_in=n)\n", " test_eq(_get_first_layer(m)[0].in_channels, n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Head and model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def create_head(nf, n_out, lin_ftrs=None, ps=0.5, concat_pool=True, bn_final=False, lin_first=False, y_range=None):\n", " \"Model head that takes `nf` features, runs through `lin_ftrs`, and out `n_out` classes.\"\n", " lin_ftrs = [nf, 512, n_out] if lin_ftrs is None else [nf] + lin_ftrs + [n_out]\n", " ps = L(ps)\n", " if len(ps) == 1: ps = [ps[0]/2] * (len(lin_ftrs)-2) + ps\n", " actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]\n", " pool = AdaptiveConcatPool2d() if concat_pool else nn.AdaptiveAvgPool2d(1)\n", " layers = [pool, Flatten()]\n", " if lin_first: layers.append(nn.Dropout(ps.pop(0)))\n", " for ni,no,p,actn in zip(lin_ftrs[:-1], lin_ftrs[1:], ps, actns):\n", " layers += LinBnDrop(ni, no, bn=True, p=p, act=actn, lin_first=lin_first)\n", " if lin_first: layers.append(nn.Linear(lin_ftrs[-2], n_out))\n", " if bn_final: layers.append(nn.BatchNorm1d(lin_ftrs[-1], momentum=0.01))\n", " if y_range is not None: layers.append(SigmoidRange(*y_range))\n", " return nn.Sequential(*layers)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The head begins with fastai's `AdaptiveConcatPool2d` if `concat_pool=True` otherwise, it uses traditional average pooling. Then it uses a `Flatten` layer before going on blocks of `BatchNorm`, `Dropout` and `Linear` layers (if `lin_first=True`, those are `Linear`, `BatchNorm`, `Dropout`).\n", "\n", "Those blocks start at `nf`, then every element of `lin_ftrs` (defaults to `[512]`) and end at `n_out`. `ps` is a list of probabiliies used for the dropouts (if you only pass 1, it will use half the value then that value as many times as necessary).\n", "\n", "If `bn_final=True`, a final `BatchNorm` layer is added. If `y_range` is passed, the function adds a `SigmoidRange` to that range." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Sequential(\n", " (0): AdaptiveConcatPool2d(\n", " (ap): AdaptiveAvgPool2d(output_size=1)\n", " (mp): AdaptiveMaxPool2d(output_size=1)\n", " )\n", " (1): Flatten(full=False)\n", " (2): BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (3): Dropout(p=0.25, inplace=False)\n", " (4): Linear(in_features=5, out_features=512, bias=False)\n", " (5): ReLU(inplace=True)\n", " (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (7): Dropout(p=0.5, inplace=False)\n", " (8): Linear(in_features=512, out_features=10, bias=False)\n", ")" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tst = create_head(5, 10)\n", "tst" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "mods = list(tst.children())\n", "test_eq(len(mods), 9)\n", "assert isinstance(mods[2], nn.BatchNorm1d)\n", "assert isinstance(mods[-1], nn.Linear)\n", "\n", "tst = create_head(5, 10, lin_first=True)\n", "mods = list(tst.children())\n", "test_eq(len(mods), 8)\n", "assert isinstance(mods[2], nn.Dropout)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from fastai2.callback.hook import num_features_model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "@delegates(create_head)\n", "def create_cnn_model(arch, n_out, cut=None, pretrained=True, n_in=3, init=nn.init.kaiming_normal_, custom_head=None,\n", " concat_pool=True, **kwargs):\n", " \"Create custom convnet architecture using `arch`, `n_in` and `n_out`\"\n", " body = create_body(arch, n_in, pretrained, cut)\n", " if custom_head is None:\n", " nf = num_features_model(nn.Sequential(*body.children())) * (2 if concat_pool else 1)\n", " head = create_head(nf, n_out, concat_pool=concat_pool, **kwargs)\n", " else: head = custom_head\n", " model = nn.Sequential(body, head)\n", " if init is not None: apply_init(model[1], init)\n", " return model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/markdown": [ "
create_cnn_model[source]create_cnn_model(**`arch`**, **`n_out`**, **`cut`**=*`None`*, **`pretrained`**=*`True`*, **`n_in`**=*`3`*, **`init`**=*`'kaiming_normal_'`*, **`custom_head`**=*`None`*, **`concat_pool`**=*`True`*, **`lin_ftrs`**=*`None`*, **`ps`**=*`0.5`*, **`bn_final`**=*`False`*, **`lin_first`**=*`False`*, **`y_range`**=*`None`*)\n",
"\n",
"Create custom convnet architecture using `arch`, `n_in` and `n_out`"
],
"text/plain": [
"