{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# default_exp layers\n", "# default_cls_lvl 3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "from local.core.all import *\n", "from local.torch_imports import *\n", "from local.torch_core import *\n", "from local.test import *\n", "from torch.nn.utils import weight_norm, spectral_norm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from local.notebook.showdoc import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Layers\n", "> Custom fastai layers and basic functions to grab them." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic manipulations and resize" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class Identity(Module):\n", " \"Do nothing at all\"\n", " def forward(self,x): return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Lambda(Module):\n", " \"An easy way to create a pytorch layer for a simple `func`\"\n", " def __init__(self, func): self.func=func\n", "\n", " def forward(self, x): return self.func(x)\n", " def __repr__(self): return f'{self.__class__.__name__}({self.func})'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "> Warning: In the tests below, we use lambda functions for convenience, but you shouldn't do this when building a real modules as it would make models that won't pickle (so you won't be able to save/export them)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = Lambda(lambda x:x+2)\n", "x = torch.randn(10,20)\n", "test_eq(tst(x), x+2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class PartialLambda(Lambda):\n", " \"Layer that applies `partial(func, **kwargs)`\"\n", " def __init__(self, func, **kwargs):\n", " super().__init__(partial(func, **kwargs))\n", " self.repr = f'{func.__name__}, {kwargs}'\n", "\n", " def forward(self, x): return self.func(x)\n", " def __repr__(self): return f'{self.__class__.__name__}({self.repr})'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def test_func(a,b=2): return a+b\n", "tst = PartialLambda(test_func, b=5)\n", "test_eq(tst(x), x+5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class View(Module):\n", " \"Reshape `x` to `size`\"\n", " def __init__(self, *size): self.size = size\n", " def forward(self, x): return x.view(self.size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = View(10,5,4)\n", "test_eq(tst(x).shape, [10,5,4])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class ResizeBatch(Module):\n", " \"Reshape `x` to `size`, keeping batch dim the same size\"\n", " def __init__(self, *size): self.size = size\n", " def forward(self, x): return x.view((x.size(0),) + self.size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = ResizeBatch(5,4)\n", "test_eq(tst(x).shape, [10,5,4])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Flatten(Module):\n", " \"Flatten `x` to a single dimension, often used at the end of a model. `full` for rank-1 tensor\"\n", " def __init__(self, full=False): self.full = full\n", " def forward(self, x): return x.view(-1) if self.full else x.view(x.size(0), -1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = Flatten()\n", "x = torch.randn(10,5,4)\n", "test_eq(tst(x).shape, [10,20])\n", "tst = Flatten(full=True)\n", "test_eq(tst(x).shape, [200])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Debugger(nn.Module):\n", " \"A module to debug inside a model.\"\n", " def forward(self,x):\n", " set_trace()\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def sigmoid_range(x, low, high):\n", " \"Sigmoid function with range `(low, high)`\"\n", " return torch.sigmoid(x) * (high - low) + low" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "test = tensor([-10.,0.,10.])\n", "assert torch.allclose(sigmoid_range(test, -1, 2), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)\n", "assert torch.allclose(sigmoid_range(test, -5, -1), tensor([-5.,-3.,-1.]), atol=1e-4, rtol=1e-4)\n", "assert torch.allclose(sigmoid_range(test, 2, 4), tensor([2., 3., 4.]), atol=1e-4, rtol=1e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class SigmoidRange(Module):\n", " \"Sigmoid module with range `(low, high)`\"\n", " def __init__(self, low, high): self.low,self.high = low,high\n", " def forward(self, x): return sigmoid_range(x, self.low, self.high)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = SigmoidRange(-1, 2)\n", "assert torch.allclose(tst(test), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pooling layers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class AdaptiveConcatPool2d(nn.Module):\n", " \"Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d`\"\n", " def __init__(self, size=None):\n", " super().__init__()\n", " self.size = size or 1\n", " self.ap = nn.AdaptiveAvgPool2d(self.size)\n", " self.mp = nn.AdaptiveMaxPool2d(self.size)\n", " def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If the input is `bs x nf x h x h`, the output will be `bs x 2*nf x 1 x 1` if no size is passed or `bs x 2*nf x size x size`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = AdaptiveConcatPool2d()\n", "x = torch.randn(10,5,4,4)\n", "test_eq(tst(x).shape, [10,10,1,1])\n", "max1 = torch.max(x, dim=2, keepdim=True)[0]\n", "maxp = torch.max(max1, dim=3, keepdim=True)[0]\n", "test_eq(tst(x)[:,:5], maxp)\n", "test_eq(tst(x)[:,5:], x.mean(dim=[2,3], keepdim=True))\n", "tst = AdaptiveConcatPool2d(2)\n", "test_eq(tst(x).shape, [10,10,2,2])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "mk_class('PoolType', **{o:o for o in 'Avg Max Cat'.split()})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "_all_ = ['PoolType']" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def pool_layer(pool_type):\n", " return nn.AdaptiveAvgPool2d if pool_type=='Avg' else nn.AdaptiveMaxPool2d if pool_type=='Max' else AdaptiveConcatPool2d" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class PoolFlatten(nn.Sequential):\n", " \"Combine `nn.AdaptiveAvgPool2d` and `Flatten`.\"\n", " def __init__(self, pool_type=PoolType.Avg): super().__init__(pool_layer(pool_type)(1), Flatten())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = PoolFlatten()\n", "test_eq(tst(x).shape, [10,5])\n", "test_eq(tst(x), x.mean(dim=[2,3]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## BatchNorm layers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "NormType = Enum('NormType', 'Batch BatchZero Weight Spectral')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def BatchNorm(nf, norm_type=NormType.Batch, ndim=2, **kwargs):\n", " \"BatchNorm layer with `nf` features and `ndim` initialized depending on `norm_type`.\"\n", " assert 1 <= ndim <= 3\n", " bn = getattr(nn, f\"BatchNorm{ndim}d\")(nf, **kwargs)\n", " bn.bias.data.fill_(1e-3)\n", " bn.weight.data.fill_(0. if norm_type==NormType.BatchZero else 1.)\n", " return bn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`kwargs` are passed to `nn.BatchNorm` and can be `eps`, `momentum`, `affine` and `track_running_stats`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = BatchNorm(15)\n", "assert isinstance(tst, nn.BatchNorm2d)\n", "test_eq(tst.weight, torch.ones(15))\n", "tst = BatchNorm(15, norm_type=NormType.BatchZero)\n", "test_eq(tst.weight, torch.zeros(15))\n", "tst = BatchNorm(15, ndim=1)\n", "assert isinstance(tst, nn.BatchNorm1d)\n", "tst = BatchNorm(15, ndim=3)\n", "assert isinstance(tst, nn.BatchNorm3d)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class BatchNorm1dFlat(nn.BatchNorm1d):\n", " \"`nn.BatchNorm1d`, but first flattens leading dimensions\"\n", " def forward(self, x):\n", " if x.dim()==2: return super().forward(x)\n", " *f,l = x.shape\n", " x = x.contiguous().view(-1,l)\n", " return super().forward(x).view(*f,l)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = BatchNorm1dFlat(15)\n", "x = torch.randn(32, 64, 15)\n", "y = tst(x)\n", "mean = x.mean(dim=[0,1])\n", "test_close(tst.running_mean, 0*0.9 + mean*0.1)\n", "var = (x-mean).pow(2).mean(dim=[0,1])\n", "test_close(tst.running_var, 1*0.9 + var*0.1, eps=1e-4)\n", "test_close(y, (x-mean)/torch.sqrt(var+1e-5) * tst.weight + tst.bias, eps=1e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class LinBnDrop(nn.Sequential):\n", " \"Module grouping `BatchNorm1d`, `Dropout` and `Linear` layers\"\n", " def __init__(self, n_in, n_out, bn=True, p=0., act=None, lin_first=False):\n", " layers = [BatchNorm(n_out if lin_first else n_in, ndim=1)] if bn else []\n", " if p != 0: layers.append(nn.Dropout(p))\n", " lin = [nn.Linear(n_in, n_out, bias=not bn)]\n", " if act is not None: lin.append(act)\n", " layers = lin+layers if lin_first else layers+lin\n", " super().__init__(*layers)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `BatchNorm` layer is skipped if `bn=False`, as is the dropout if `p=0.`. Optionally, you can add an activation for after the linear laeyr with `act`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = LinBnDrop(10, 20)\n", "mods = list(tst.children())\n", "test_eq(len(mods), 2)\n", "assert isinstance(mods[0], nn.BatchNorm1d)\n", "assert isinstance(mods[1], nn.Linear)\n", "\n", "tst = LinBnDrop(10, 20, p=0.1)\n", "mods = list(tst.children())\n", "test_eq(len(mods), 3)\n", "assert isinstance(mods[0], nn.BatchNorm1d)\n", "assert isinstance(mods[1], nn.Dropout)\n", "assert isinstance(mods[2], nn.Linear)\n", "\n", "tst = LinBnDrop(10, 20, act=nn.ReLU(), lin_first=True)\n", "mods = list(tst.children())\n", "test_eq(len(mods), 3)\n", "assert isinstance(mods[0], nn.Linear)\n", "assert isinstance(mods[1], nn.ReLU)\n", "assert isinstance(mods[2], nn.BatchNorm1d)\n", "\n", "tst = LinBnDrop(10, 20, bn=False)\n", "mods = list(tst.children())\n", "test_eq(len(mods), 1)\n", "assert isinstance(mods[0], nn.Linear)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convolutions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def init_default(m, func=nn.init.kaiming_normal_):\n", " \"Initialize `m` weights with `func` and set `bias` to 0.\"\n", " if func and hasattr(m, 'weight'): func(m.weight)\n", " with torch.no_grad():\n", " if getattr(m, 'bias', None) is not None: m.bias.fill_(0.)\n", " return m" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _conv_func(ndim=2, transpose=False):\n", " \"Return the proper conv `ndim` function, potentially `transposed`.\"\n", " assert 1 <= ndim <=3\n", " return getattr(nn, f'Conv{\"Transpose\" if transpose else \"\"}{ndim}d')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "test_eq(_conv_func(ndim=1),torch.nn.modules.conv.Conv1d)\n", "test_eq(_conv_func(ndim=2),torch.nn.modules.conv.Conv2d)\n", "test_eq(_conv_func(ndim=3),torch.nn.modules.conv.Conv3d)\n", "test_eq(_conv_func(ndim=1, transpose=True),torch.nn.modules.conv.ConvTranspose1d)\n", "test_eq(_conv_func(ndim=2, transpose=True),torch.nn.modules.conv.ConvTranspose2d)\n", "test_eq(_conv_func(ndim=3, transpose=True),torch.nn.modules.conv.ConvTranspose3d)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "defaults.activation=nn.ReLU" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class ConvLayer(nn.Sequential):\n", " \"Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and `norm_type` layers.\"\n", " def __init__(self, ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=NormType.Batch, bn_1st=True,\n", " act_cls=defaults.activation, transpose=False, init=nn.init.kaiming_normal_, xtra=None, **kwargs):\n", " if padding is None: padding = ((ks-1)//2 if not transpose else 0)\n", " bn = norm_type in (NormType.Batch, NormType.BatchZero)\n", " if bias is None: bias = not bn\n", " conv_func = _conv_func(ndim, transpose=transpose)\n", " conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding, **kwargs), init)\n", " if norm_type==NormType.Weight: conv = weight_norm(conv)\n", " elif norm_type==NormType.Spectral: conv = spectral_norm(conv)\n", " layers = [conv]\n", " act_bn = []\n", " if act_cls is not None: act_bn.append(act_cls())\n", " if bn: act_bn.append(BatchNorm(nf, norm_type=norm_type, ndim=ndim))\n", " if bn_1st: act_bn.reverse()\n", " layers += act_bn\n", " if xtra: layers.append(xtra)\n", " super().__init__(*layers)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The convolution uses `ks` (kernel size) `stride`, `padding` and `bias`. `padding` will default to the appropriate value (`(ks-1)//2` if it's not a transposed conv) and `bias` will default to `True` the `norm_type` is `Spectral` or `Weight`, `False` if it's `Batch` or `BatchZero`. Note that if you don't want any normalization, you should pass `norm_type=None`.\n", "\n", "This defines a conv layer with `ndim` (1,2 or 3) that will be a ConvTranspose if `transpose=True`. `act_cls` is the class of the activation function to use (instantiated inside). Pass `act=None` if you don't want an activation function. If you quickly want to change your default activation, you can change the value of `defaults.activation`.\n", "\n", "`init` is used to initialize the weights (the bias are initiliazed to 0) and `xtra` is an optional layer to add at the end." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = ConvLayer(16, 32)\n", "mods = list(tst.children())\n", "test_eq(len(mods), 3)\n", "test_eq(mods[1].weight, torch.ones(32))\n", "test_eq(mods[0].padding, (1,1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = torch.randn(64, 16, 8, 8)#.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# tst = tst.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Padding is selected to make the shape the same if stride=1\n", "test_eq(tst(x).shape, [64,32,8,8])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Padding is selected to make the shape half if stride=2\n", "tst = ConvLayer(16, 32, stride=2)\n", "test_eq(tst(x).shape, [64,32,4,4])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#But you can always pass your own padding if you want\n", "tst = ConvLayer(16, 32, padding=0)\n", "test_eq(tst(x).shape, [64,32,6,6])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#No bias by default for Batch NormType\n", "assert mods[0].bias is None\n", "#But can be overriden with `bias=True`\n", "tst = ConvLayer(16, 32, bias=True)\n", "test_eq(list(tst.children())[0].bias, torch.zeros(32))\n", "#For no norm, or spectral/weight, bias is True by default\n", "for t in [None, NormType.Spectral, NormType.Weight]:\n", " tst = ConvLayer(16, 32, norm_type=t)\n", " test_eq(list(tst.children())[0].bias, torch.zeros(32))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Various n_dim/tranpose\n", "tst = ConvLayer(16, 32, ndim=3)\n", "assert isinstance(list(tst.children())[0], nn.Conv3d)\n", "tst = ConvLayer(16, 32, ndim=1, transpose=True)\n", "assert isinstance(list(tst.children())[0], nn.ConvTranspose1d)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#No activation/leaky\n", "tst = ConvLayer(16, 32, ndim=3, act_cls=None)\n", "mods = list(tst.children())\n", "test_eq(len(mods), 2)\n", "tst = ConvLayer(16, 32, ndim=3, act_cls=partial(nn.LeakyReLU, negative_slope=0.1))\n", "mods = list(tst.children())\n", "test_eq(len(mods), 3)\n", "assert isinstance(mods[2], nn.LeakyReLU)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.nn.modules.pooling.MaxPool2d" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nn.MaxPool2d" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def AdaptiveAvgPool(sz=1, ndim=2):\n", " \"nn.AdaptiveAvgPool layer for `ndim`\"\n", " assert 1 <= ndim <= 3\n", " return getattr(nn, f\"AdaptiveAvgPool{ndim}d\")(sz)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def MaxPool(ks=2, stride=None, padding=0, ndim=2):\n", " \"nn.MaxPool layer for `ndim`\"\n", " assert 1 <= ndim <= 3\n", " return getattr(nn, f\"MaxPool{ndim}d\")(ks, stride=stride, padding=padding)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False):\n", " \"nn.AvgPool layer for `ndim`\"\n", " assert 1 <= ndim <= 3\n", " return getattr(nn, f\"AvgPool{ndim}d\")(ks, stride=stride, padding=padding, ceil_mode=ceil_mode)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## fastai loss functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following class if the base class to warp a loss function it provides several added functionality:\n", "- it flattens the tensors before trying to take the losses since it's more convenient (with a potential tranpose to put `axis` at the end)\n", "- it has a potential `activation` method that tells the library if there is an activation fused in the loss (useful for inference and methods such as `Learner.get_preds` or `Learner.predict`)\n", "- it has a potential `decodes` method that is used on predictions in inference (for instance, an argmax in classification)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([[0.4297, 0.7398, 0.6388, 0.5635, 0.3743],\n", " [0.5245, 1.0219, 0.9097, 0.9432, 0.8671],\n", " [0.5316, 0.2025, 0.7467, 2.2852, 1.8450],\n", " [0.5773, 0.8334, 1.0466, 0.9615, 0.1678]])" ] }, "execution_count": null, "metadata": {}, "output_type": "execute_result" } ], "source": [ "F.binary_cross_entropy_with_logits(torch.randn(4,5), torch.randint(0, 2, (4,5)).float(), reduction='none')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "@funcs_kwargs\n", "class BaseLoss():\n", " \"Same as `loss_cls`, but flattens input and target.\"\n", " activation=decodes=noops\n", " _methods = \"activation decodes\".split()\n", " def __init__(self, loss_cls, *args, axis=-1, flatten=True, floatify=False, is_2d=True, **kwargs):\n", " store_attr(self, \"axis,flatten,floatify,is_2d\")\n", " self.func = loss_cls(*args,**kwargs)\n", " functools.update_wrapper(self, self.func)\n", "\n", " def __repr__(self): return f\"FlattenedLoss of {self.func}\"\n", " @property\n", " def reduction(self): return self.func.reduction\n", " @reduction.setter\n", " def reduction(self, v): self.func.reduction = v\n", "\n", " def __call__(self, inp, targ, **kwargs):\n", " inp = inp .transpose(self.axis,-1).contiguous()\n", " targ = targ.transpose(self.axis,-1).contiguous()\n", " if self.floatify and targ.dtype!=torch.float16: targ = targ.float()\n", " if targ.dtype in [torch.int8, torch.int16, torch.int32]: targ = targ.long() \n", " if self.flatten: inp = inp.view(-1,inp.shape[-1]) if self.is_2d else inp.view(-1)\n", " return self.func.__call__(inp, targ.view(-1) if self.flatten else targ, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `args` and `kwargs` will be passed to `loss_cls` during the initialization to instantiate a loss function. `axis` is put at the end for losses like softmax that are often performed on the last axis. If `floatify=True` the targs will be converted to float (usefull for losses that only accept float targets like `BCEWithLogitsLoss`) and `is_2d` determines if we flatten while keeping the first dimension (batch size) or completely flatten the input. We want the first for losses like Cross Entropy, and the second for pretty much anything else." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "@delegates(keep=True)\n", "class CrossEntropyLossFlat(BaseLoss):\n", " \"Same as `nn.CrossEntropyLoss`, but flattens input and target.\"\n", " y_int = True\n", " def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)\n", " def decodes(self, x): return x.argmax(dim=self.axis)\n", " def activation(self, x): return F.softmax(x, dim=self.axis)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = CrossEntropyLossFlat()\n", "output = torch.randn(32, 5, 10)\n", "target = torch.randint(0, 10, (32,5))\n", "#nn.CrossEntropy would fail with those two tensors, but not our flattened version.\n", "_ = tst(output, target)\n", "test_fail(lambda x: nn.CrossEntropyLoss()(output,target))\n", "\n", "#Associated activation is softmax\n", "test_eq(tst.activation(output), F.softmax(output, dim=-1))\n", "#This loss function has a decodes which is argmax\n", "test_eq(tst.decodes(output), output.argmax(dim=-1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#In a segmentation task, we want to take the softmax over the channel dimension\n", "tst = CrossEntropyLossFlat(axis=1)\n", "output = torch.randn(32, 5, 128, 128)\n", "target = torch.randint(0, 5, (32, 128, 128))\n", "_ = tst(output, target)\n", "\n", "test_eq(tst.activation(output), F.softmax(output, dim=1))\n", "test_eq(tst.decodes(output), output.argmax(dim=1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "@delegates(keep=True)\n", "class BCEWithLogitsLossFlat(BaseLoss):\n", " \"Same as `nn.CrossEntropyLoss`, but flattens input and target.\"\n", " def __init__(self, *args, axis=-1, floatify=True, thresh=0.5, **kwargs):\n", " super().__init__(nn.BCEWithLogitsLoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)\n", " self.thresh = thresh\n", " \n", " def decodes(self, x): return x>self.thresh\n", " def activation(self, x): return torch.sigmoid(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = BCEWithLogitsLossFlat()\n", "output = torch.randn(32, 5, 10)\n", "target = torch.randn(32, 5, 10)\n", "#nn.BCEWithLogitsLoss would fail with those two tensors, but not our flattened version.\n", "_ = tst(output, target)\n", "test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))\n", "output = torch.randn(32, 5)\n", "target = torch.randint(0,2,(32, 5))\n", "#nn.BCEWithLogitsLoss would fail with int targets but not our flattened version.\n", "_ = tst(output, target)\n", "test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))\n", "\n", "#Associated activation is sigmoid\n", "test_eq(tst.activation(output), torch.sigmoid(output))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def BCELossFlat(*args, axis=-1, floatify=True, **kwargs):\n", " \"Same as `nn.BCELoss`, but flattens input and target.\"\n", " return BaseLoss(nn.BCELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = BCELossFlat()\n", "output = torch.sigmoid(torch.randn(32, 5, 10))\n", "target = torch.randint(0,2,(32, 5, 10))\n", "_ = tst(output, target)\n", "test_fail(lambda x: nn.BCELoss()(output,target))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def MSELossFlat(*args, axis=-1, floatify=True, **kwargs):\n", " \"Same as `nn.MSELoss`, but flattens input and target.\"\n", " return BaseLoss(nn.MSELoss, *args, axis=axis, floatify=floatify, is_2d=False, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = MSELossFlat()\n", "output = torch.sigmoid(torch.randn(32, 5, 10))\n", "target = torch.randint(0,2,(32, 5, 10))\n", "_ = tst(output, target)\n", "test_fail(lambda x: nn.MSELoss()(output,target))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#hide\n", "#cuda\n", "#Test losses work in half precision\n", "output = torch.sigmoid(torch.randn(32, 5, 10)).half().cuda()\n", "target = torch.randint(0,2,(32, 5, 10)).half().cuda()\n", "for tst in [BCELossFlat(), MSELossFlat()]: _ = tst(output, target)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class LabelSmoothingCrossEntropy(Module):\n", " y_int = True\n", " def __init__(self, eps:float=0.1, reduction='mean'): self.eps,self.reduction = eps,reduction\n", "\n", " def forward(self, output, target):\n", " c = output.size()[-1]\n", " log_preds = F.log_softmax(output, dim=-1)\n", " if self.reduction=='sum': loss = -log_preds.sum()\n", " else:\n", " loss = -log_preds.sum(dim=-1)\n", " if self.reduction=='mean': loss = loss.mean()\n", " return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction)\n", "\n", " def activation(self, out): return F.softmax(out, dim=-1)\n", " def decodes(self, out): return out.argmax(dim=-1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "On top of the formula we define:\n", "- a `reduction` attribute, that will be used when we call `Learner.get_preds`\n", "- an `activation` function that represents the activation fused in the loss (since we use cross entropy behind the scenes). It will be applied to the output of the model when calling `Learner.get_preds` or `Learner.predict`\n", "- a `decodes` function that converts the output of the model to a format similar to the target (here indices). This is used in `Learner.predict` and `Learner.show_results` to decode the predictions " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Embeddings" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def trunc_normal_(x, mean=0., std=1.):\n", " \"Truncated normal initialization (approximation)\"\n", " # From https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/12\n", " return x.normal_().fmod_(2).mul_(std).add_(mean)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class Embedding(nn.Embedding):\n", " \"Embedding layer with truncated normal initialization\"\n", " def __init__(self, ni, nf):\n", " super().__init__(ni, nf)\n", " trunc_normal_(self.weight.data, std=0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Truncated normal initialization bounds the distribution to avoid large value. For a given standard deviation `std`, the bounds are roughly `-std`, `std`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = Embedding(10, 30)\n", "assert tst.weight.min() > -0.02\n", "assert tst.weight.max() < 0.02\n", "test_close(tst.weight.mean(), 0, 1e-2)\n", "test_close(tst.weight.std(), 0.01, 0.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Self attention" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class SelfAttention(nn.Module):\n", " \"Self attention layer for `n_channels`.\"\n", " def __init__(self, n_channels):\n", " super().__init__()\n", " self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels)]\n", " self.gamma = nn.Parameter(tensor([0.]))\n", "\n", " def _conv(self,n_in,n_out):\n", " return ConvLayer(n_in, n_out, ks=1, ndim=1, norm_type=NormType.Spectral, act_cls=None, bias=False)\n", "\n", " def forward(self, x):\n", " #Notation from the paper.\n", " size = x.size()\n", " x = x.view(*size[:2],-1)\n", " f,g,h = self.query(x),self.key(x),self.value(x)\n", " beta = F.softmax(torch.bmm(f.transpose(1,2), g), dim=1)\n", " o = self.gamma * torch.bmm(h, beta) + x\n", " return o.view(*size).contiguous()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Self-attention layer as introduced in [Self-Attention Generative Adversarial Networks](https://arxiv.org/abs/1805.08318).\n", "\n", "Initially, no change is done to the input. This is controlled by a trainable parameter named `gamma` as we return `x + gamma * out`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = SelfAttention(16)\n", "x = torch.randn(32, 16, 8, 8)\n", "test_eq(tst(x),x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then during training `gamma` will probably change since it's a trainable parameter. Let's see what's hapenning when it gets a nonzero value." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst.gamma.data.fill_(1.)\n", "y = tst(x)\n", "test_eq(y.shape, [32,16,8,8])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The attention mechanism requires three matrix multiplications (here represented by 1x1 convs). The multiplications are done on the channel level (the second dimension in our tensor) and we flatten the feature map (which is 8x8 here). As in the paper, we note `f`, `g` and `h` the results of those multiplications." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "q,k,v = tst.query[0].weight.data,tst.key[0].weight.data,tst.value[0].weight.data\n", "test_eq([q.shape, k.shape, v.shape], [[2, 16, 1], [2, 16, 1], [16, 16, 1]])\n", "f,g,h = map(lambda m: x.view(32, 16, 64).transpose(1,2) @ m.squeeze().t(), [q,k,v])\n", "test_eq([f.shape, g.shape, h.shape], [[32,64,2], [32,64,2], [32,64,16]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The key part of the attention layer is to compute attention weights for each of our location in the feature map (here 8x8 = 64). Those are positive numbers that sum to 1 and tell the model to pay attention to this or that part of the picture. We make the product of `f` and the transpose of `g` (to get something of size bs by 64 by 64) then apply a softmax on the first dimension (to get the positive numbers that sum up to 1). The result can then be multiplied with `h` transposed to get an output of size bs by channels by 64, which we can then be viewed as an output the same size as the original input. \n", "\n", "The final result is then `x + gamma * out` as we saw before." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "beta = F.softmax(torch.bmm(f, g.transpose(1,2)), dim=1)\n", "test_eq(beta.shape, [32, 64, 64])\n", "out = torch.bmm(h.transpose(1,2), beta)\n", "test_eq(out.shape, [32, 16, 64])\n", "test_close(y, x + out.view(32, 16, 8, 8), eps=1e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class PooledSelfAttention2d(nn.Module):\n", " \"Pooled self attention layer for 2d.\"\n", " def __init__(self, n_channels):\n", " super().__init__()\n", " self.n_channels = n_channels\n", " self.query,self.key,self.value = [self._conv(n_channels, c) for c in (n_channels//8,n_channels//8,n_channels//2)]\n", " self.out = self._conv(n_channels//2, n_channels)\n", " self.gamma = nn.Parameter(tensor([0.]))\n", "\n", " def _conv(self,n_in,n_out):\n", " return ConvLayer(n_in, n_out, ks=1, norm_type=NormType.Spectral, act_cls=None, bias=False)\n", "\n", " def forward(self, x):\n", " n_ftrs = x.shape[2]*x.shape[3]\n", " f = self.query(x).view(-1, self.n_channels//8, n_ftrs)\n", " g = F.max_pool2d(self.key(x), [2,2]).view(-1, self.n_channels//8, n_ftrs//4)\n", " h = F.max_pool2d(self.value(x), [2,2]).view(-1, self.n_channels//2, n_ftrs//4)\n", " beta = F.softmax(torch.bmm(f.transpose(1, 2), g), -1)\n", " o = self.out(torch.bmm(h, beta.transpose(1,2)).view(-1, self.n_channels//2, x.shape[2], x.shape[3]))\n", " return self.gamma * o + x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Self-attention layer used in the [Big GAN paper](https://arxiv.org/abs/1809.11096).\n", "\n", "It uses the same attention as in `SelfAttention` but adds a max pooling of stride 2 before computing the matrices `g` and `h`: the attention is ported on one of the 2x2 max-pooled window, not the whole feature map. There is also a final matrix product added at the end to the output, before retuning `gamma * out + x`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "def _conv1d_spect(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):\n", " \"Create and initialize a `nn.Conv1d` layer with spectral normalization.\"\n", " conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)\n", " nn.init.kaiming_normal_(conv.weight)\n", " if bias: conv.bias.data.zero_()\n", " return spectral_norm(conv)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export\n", "class SimpleSelfAttention(Module):\n", " def __init__(self, n_in:int, ks=1, sym=False):\n", " self.sym,self.n_in = sym,n_in\n", " self.conv = _conv1d_spect(n_in, n_in, ks, padding=ks//2, bias=False)\n", " self.gamma = nn.Parameter(tensor([0.]))\n", "\n", " def forward(self,x):\n", " if self.sym:\n", " c = self.conv.weight.view(self.n_in,self.n_in)\n", " c = (c + c.t())/2\n", " self.conv.weight = c.view(self.n_in,self.n_in,1)\n", "\n", " size = x.size()\n", " x = x.view(*size[:2],-1)\n", "\n", " convx = self.conv(x)\n", " xxT = torch.bmm(x,x.permute(0,2,1).contiguous())\n", " o = torch.bmm(xxT, convx)\n", " o = self.gamma * o + x\n", " return o.view(*size).contiguous()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## PixelShuffle" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "PixelShuffle introduced in [this article](https://arxiv.org/pdf/1609.05158.pdf) to avoid checkerboard artifacts when upsampling images. If we want an output with `ch_out` filters, we use a convolution with `ch_out * (r**2)` filters, where `r` is the upsampling factor. Then we reorganize those filters like in the picture below:\n", "\n", "\"Pixelshuffle\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "def icnr_init(x, scale=2, init=nn.init.kaiming_normal_):\n", " \"ICNR init of `x`, with `scale` and `init` function\"\n", " ni,nf,h,w = x.shape\n", " ni2 = int(ni/(scale**2))\n", " k = init(x.new_zeros([ni2,nf,h,w])).transpose(0, 1)\n", " k = k.contiguous().view(ni2, nf, -1)\n", " k = k.repeat(1, 1, scale**2)\n", " return k.contiguous().view([nf,ni,h,w]).transpose(0, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ICNR init was introduced in [this article](https://arxiv.org/abs/1707.02937). It suggests to initialize the convolution that will be used in PixelShuffle so that each of the `r**2` channels get the same weight (so that in the picture above, the 9 colors in a 3 by 3 window are initially the same).\n", "\n", "> Note: This is done on the first dimension because PyTorch stores the weights of a convolutional layer in this format: `ch_out x ch_in x ks x ks`. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = torch.randn(16*4, 32, 1, 1)\n", "tst = icnr_init(tst)\n", "for i in range(0,16*4,4):\n", " test_eq(tst[i],tst[i+1])\n", " test_eq(tst[i],tst[i+2])\n", " test_eq(tst[i],tst[i+3])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class PixelShuffle_ICNR(nn.Sequential):\n", " \"Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`.\"\n", " def __init__(self, ni, nf=None, scale=2, blur=False, norm_type=NormType.Weight, act_cls=defaults.activation):\n", " super().__init__()\n", " nf = ifnone(nf, ni)\n", " layers = [ConvLayer(ni, nf*(scale**2), ks=1, norm_type=norm_type, act_cls=act_cls),\n", " nn.PixelShuffle(scale)]\n", " layers[0][0].weight.data.copy_(icnr_init(layers[0][0].weight.data))\n", " if blur: layers += [nn.ReplicationPad2d((1,0,1,0)), nn.AvgPool2d(2, stride=1)]\n", " super().__init__(*layers)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The convolutional layer is initialized with `icnr_init` and passed `act_cls` and `norm_type` (the default of weight normalization seemed to be what's best for super-resolution problems, in our experiments). \n", "\n", "The `blur` option comes from [Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts](https://arxiv.org/abs/1806.02658) where the authors add a little bit of blur to completely get rid of checkerboard artifacts." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "psfl = PixelShuffle_ICNR(16, norm_type=None) #Deactivate weight norm as it changes the weight\n", "x = torch.randn(64, 16, 8, 8)\n", "y = psfl(x)\n", "test_eq(y.shape, [64, 16, 16, 16])\n", "#ICNR init makes every 2x2 window (stride 2) have the same elements\n", "for i in range(0,16,2):\n", " for j in range(0,16,2):\n", " test_eq(y[:,:,i,j],y[:,:,i+1,j])\n", " test_eq(y[:,:,i,j],y[:,:,i ,j+1])\n", " test_eq(y[:,:,i,j],y[:,:,i+1,j+1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sequential extensions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class SequentialEx(Module):\n", " \"Like `nn.Sequential`, but with ModuleList semantics, and can access module input\"\n", " def __init__(self, *layers): self.layers = nn.ModuleList(layers)\n", "\n", " def forward(self, x):\n", " res = x\n", " for l in self.layers:\n", " res.orig = x\n", " nres = l(res)\n", " # We have to remove res.orig to avoid hanging refs and therefore memory leaks\n", " res.orig = None\n", " res = nres\n", " return res\n", "\n", " def __getitem__(self,i): return self.layers[i]\n", " def append(self,l): return self.layers.append(l)\n", " def extend(self,l): return self.layers.extend(l)\n", " def insert(self,i,l): return self.layers.insert(i,l)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is useful to write layers that require to remember the input (like a resnet block) in a sequential way." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class MergeLayer(Module):\n", " \"Merge a shortcut with the result of the module by adding them or concatenating them if `dense=True`.\"\n", " def __init__(self, dense:bool=False): self.dense=dense\n", " def forward(self, x): return torch.cat([x,x.orig], dim=1) if self.dense else (x+x.orig)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res_block = SequentialEx(ConvLayer(16, 16), ConvLayer(16,16))\n", "res_block.append(MergeLayer()) # just to test append - normally it would be in init params\n", "x = torch.randn(32, 16, 8, 8)\n", "y = res_block(x)\n", "test_eq(y.shape, [32, 16, 8, 8])\n", "test_eq(y, x + res_block[1](res_block[0](x)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Concat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Equivalent to keras.layers.Concatenate, it will concat the outputs of a ModuleList over a given dimesion (default the filter dimesion)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#export \n", "class Cat(nn.ModuleList):\n", " \"Concatenate layers outputs over a given dim\"\n", " def __init__(self, layers, dim=1):\n", " self.dim=dim\n", " super().__init__(layers)\n", " def forward(self, x): return torch.cat([l(x) for l in self], dim=self.dim)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "layers = [ConvLayer(2,4), ConvLayer(2,4), ConvLayer(2,4)] \n", "x = torch.rand(1,2,8,8) \n", "cat = Cat(layers) \n", "test_eq(cat(x).shape, [1,12,8,8]) \n", "test_eq(cat(x), torch.cat([l(x) for l in layers], dim=1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Ready-to-go models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# export\n", "class SimpleCNN(nn.Sequential):\n", " \"Create a simple CNN with `filters`.\"\n", " def __init__(self, filters, kernel_szs=None, strides=None, bn=True):\n", " nl = len(filters)-1\n", " kernel_szs = ifnone(kernel_szs, [3]*nl)\n", " strides = ifnone(strides , [2]*nl)\n", " layers = [ConvLayer(filters[i], filters[i+1], kernel_szs[i], stride=strides[i],\n", " norm_type=(NormType.Batch if bn and i