from fastai2.vision.all import *
from nbdev.showdoc import *
import glob
import albumentations
from torchvision import models
from albumentations.pytorch.transforms import ToTensorV2
set_seed(2) Simply replacing `BN` with `GN` lead to sub-optimal results." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from fastai2.vision.all import *\n", "from nbdev.showdoc import *\n", "import glob\n", "import albumentations\n", "from torchvision import models\n", "from albumentations.pytorch.transforms import ToTensorV2\n", "set_s`eed(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# `Resnet` Implementation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We copy the implementation of `Weight Standardization` from the official repository [here](https://github.com/joe-siyuan-qiao/WeightStandardization) and also copy the implementation of `ResNet` from TorchVision." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.hub import load_state_dict_from_url" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model_urls = {\n", " 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n", " 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n", " 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n", " 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n", " 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',\n", " 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',\n", " 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',\n", " 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',\n", " 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We replace the `Convolution` layer inside `ResNet` with the standardized version as in the `Standardized Weights` research paper. Everything else remains the same." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class Conv2d_WS(nn.Conv2d):\n", " def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n", " padding=0, dilation=1, groups=1, bias=True):\n", " super(Conv2d_WS, self).__init__(in_channels, out_channels, kernel_size, stride,\n", " padding, dilation, groups, bias)\n", "\n", " def forward(self, x):\n", " weight = self.weight\n", " weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,\n", " keepdim=True).mean(dim=3, keepdim=True)\n", " weight = weight - weight_mean\n", " std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5\n", " weight = weight / std.expand_as(weight)\n", " return F.conv2d(x, weight, self.bias, self.stride,\n", " self.padding, self.dilation, self.groups)\n", "\n", "\n", "def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n", " \"\"\"3x3 convolution with padding\"\"\"\n", " return Conv2d_WS(in_planes, out_planes, kernel_size=3, stride=stride,\n", " padding=dilation, groups=groups, bias=False, dilation=dilation)\n", "\n", "\n", "def conv1x1(in_planes, out_planes, stride=1):\n", " \"\"\"1x1 convolution\"\"\"\n", " return Conv2d_WS(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class BasicBlock(nn.Module):\n", " expansion = 1\n", "\n", " def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n", " base_width=64, dilation=1, norm_layer=None):\n", " super(BasicBlock, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = nn.BatchNorm2d\n", " if groups != 1 or base_width != 64:\n", " raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n", " if dilation > 1:\n", " raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n", " # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n", " self.conv1 = conv3x3(inplanes, planes, stride)\n", " self.bn1 = norm_layer(planes)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.conv2 = conv3x3(planes, planes)\n", " self.bn2 = norm_layer(planes)\n", " self.downsample = downsample\n", " self.stride = stride\n", "\n", " def forward(self, x):\n", " identity = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", "\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out += identity\n", " out = self.relu(out)\n", "\n", " return out" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class Bottleneck(nn.Module):\n", " expansion = 4\n", " def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n", " base_width=64, dilation=1, norm_layer=None):\n", " super(Bottleneck, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = nn.BatchNorm2d\n", " width = int(planes * (base_width / 64.)) * groups\n", " # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n", " self.conv1 = conv1x1(inplanes, width)\n", " self.bn1 = norm_layer(width)\n", " self.conv2 = conv3x3(width, width, stride, groups, dilation)\n", " self.bn2 = norm_layer(width)\n", " self.conv3 = conv1x1(width, planes * self.expansion)\n", " self.bn3 = norm_layer(planes * self.expansion)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.downsample = downsample\n", " self.stride = stride\n", "\n", " def forward(self, x):\n", " identity = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv3(out)\n", " out = self.bn3(out)\n", "\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out += identity\n", " out = self.relu(out)\n", "\n", " return out" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class ResNet(nn.Module):\n", "\n", " def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,\n", " groups=1, width_per_group=64, replace_stride_with_dilation=None,\n", " norm_layer=None):\n", " super(ResNet, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = nn.BatchNorm2d\n", " self._norm_layer = norm_layer\n", "\n", " self.inplanes = 64\n", " self.dilation = 1\n", " if replace_stride_with_dilation is None:\n", " replace_stride_with_dilation = [False, False, False]\n", " if len(replace_stride_with_dilation) != 3:\n", " raise ValueError(\"replace_stride_with_dilation should be None \"\n", " \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n", " self.groups = groups\n", " self.base_width = width_per_group\n", " self.conv1 = Conv2d_WS(3, self.inplanes, kernel_size=7, stride=2, padding=3,\n", " bias=False)\n", " self.bn1 = norm_layer(self.inplanes)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n", " self.layer1 = self._make_layer(block, 64, layers[0])\n", " self.layer2 = self._make_layer(block, 128, layers[1], stride=2,\n", " dilate=replace_stride_with_dilation[0])\n", " self.layer3 = self._make_layer(block, 256, layers[2], stride=2,\n", " dilate=replace_stride_with_dilation[1])\n", " self.layer4 = self._make_layer(block, 512, layers[3], stride=2,\n", " dilate=replace_stride_with_dilation[2])\n", " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n", " self.fc = nn.Linear(512 * block.expansion, num_classes)\n", "\n", " for m in self.modules():\n", " if isinstance(m, nn.Conv2d):\n", " nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n", " elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n", " nn.init.constant_(m.weight, 1)\n", " nn.init.constant_(m.bias, 0)\n", "\n", " # Zero-initialize the last BN in each residual branch,\n", " # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n", " # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n", " if zero_init_residual:\n", " for m in self.modules():\n", " if isinstance(m, Bottleneck):\n", " nn.init.constant_(m.bn3.weight, 0)\n", " elif isinstance(m, BasicBlock):\n", " nn.init.constant_(m.bn2.weight, 0)\n", "\n", " def _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n", " norm_layer = self._norm_layer\n", " downsample = None\n", " previous_dilation = self.dilation\n", " if dilate:\n", " self.dilation *= stride\n", " stride = 1\n", " if stride != 1 or self.inplanes != planes * block.expansion:\n", " downsample = nn.Sequential(\n", " conv1x1(self.inplanes, planes * block.expansion, stride),\n", " norm_layer(planes * block.expansion),\n", " )\n", "\n", " layers = []\n", " layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n", " self.base_width, previous_dilation, norm_layer))\n", " self.inplanes = planes * block.expansion\n", " for _ in range(1, blocks):\n", " layers.append(block(self.inplanes, planes, groups=self.groups,\n", " base_width=self.base_width, dilation=self.dilation,\n", " norm_layer=norm_layer))\n", "\n", " return nn.Sequential(*layers)\n", "\n", " def _forward_impl(self, x):\n", " # See note [TorchScript super()]\n", " x = self.conv1(x)\n", " x = self.bn1(x)\n", " x = self.relu(x)\n", " x = self.maxpool(x)\n", "\n", " x = self.layer1(x)\n", " x = self.layer2(x)\n", " x = self.layer3(x)\n", " x = self.layer4(x)\n", "\n", " x = self.avgpool(x)\n", " x = torch.flatten(x, 1)\n", " x = self.fc(x)\n", "\n", " return x\n", "\n", " def forward(self, x):\n", " return self._forward_impl(x)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def _resnet(arch, block, layers, pretrained, progress, **kwargs):\n", " model = ResNet(block, layers, **kwargs)\n", " if pretrained:\n", " state_dict = load_state_dict_from_url(model_urls[arch],\n", " progress=progress)\n", " model.load_state_dict(state_dict)\n", " return model\n", "\n", "\n", "def resnet18(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNet-18 model from\n", " `\"Deep Residual Learning for Image Recognition\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,\n", " **kwargs)\n", "\n", "\n", "def resnet34(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNet-34 model from\n", " `\"Deep Residual Learning for Image Recognition\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,\n", " **kwargs)\n", "\n", "\n", "def resnet50(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNet-50 model from\n", " `\"Deep Residual Learning for Image Recognition\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,\n", " **kwargs)\n", "\n", "\n", "def resnet101(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNet-101 model from\n", " `\"Deep Residual Learning for Image Recognition\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,\n", " **kwargs)\n", "\n", "\n", "def resnet152(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNet-152 model from\n", " `\"Deep Residual Learning for Image Recognition\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,\n", " **kwargs)\n", "\n", "\n", "def resnext50_32x4d(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNeXt-50 32x4d model from\n", " `\"Aggregated Residual Transformation for Deep Neural Networks\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " kwargs['groups'] = 32\n", " kwargs['width_per_group'] = 4\n", " return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],\n", " pretrained, progress, **kwargs)\n", "\n", "\n", "def resnext101_32x8d(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNeXt-101 32x8d model from\n", " `\"Aggregated Residual Transformation for Deep Neural Networks\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " kwargs['groups'] = 32\n", " kwargs['width_per_group'] = 8\n", " return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],\n", " pretrained, progress, **kwargs)\n", "\n", "\n", "def wide_resnet50_2(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"Wide ResNet-50-2 model from\n", " `\"Wide Residual Networks\" `_\n", "\n", " The model is the same as ResNet except for the bottleneck number of channels\n", " which is twice larger in every block. The number of channels in outer 1x1\n", " convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n", " channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " kwargs['width_per_group'] = 64 * 2\n", " return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],\n", " pretrained, progress, **kwargs)\n", "\n", "\n", "def wide_resnet101_2(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"Wide ResNet-101-2 model from\n", " `\"Wide Residual Networks\" `_\n", "\n", " The model is the same as ResNet except for the bottleneck number of channels\n", " which is twice larger in every block. The number of channels in outer 1x1\n", " convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n", " channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " kwargs['width_per_group'] = 64 * 2\n", " return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],\n", " pretrained, progress, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# `Pets` Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we use the wonderful [fastai library](https://github.com/fastai/fastai2) to get the `Pets` dataset." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "bs = 4" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "text/plain": [ "(5903, 1475)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trn_img_paths = image_paths[:-nval]\n", "val_img_paths = image_paths[-nval:]\n", "assert len(trn_img_paths) + len(val_img_paths) == len(image_paths)\n", "len(trn_img_paths), len(val_img_paths)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "trn_dataset = PetsDataset(trn_img_paths, transforms=tfms)\n", "val_dataset = PetsDataset(val_img_paths, transforms=tfms)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# use same `label2int` dictionary as in `dataset` for consistency across train and val\n", "trn_dataset.setup(label2int=dataset.label2int)\n", "val_dataset.setup(label2int=dataset.label2int)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=bs, num_workers=4, shuffle=True)\n", "val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs, num_workers=4, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([4, 3, 224, 224]), torch.Size([4, 3, 224, 224]))" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# make sure eveyrthing works so far\n", "next(iter(trn_loader))[0].shape, next(iter(val_loader))[0].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `Model`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we define the resnet34 from the `torchvision` repo with `pretrained=False` as we do not have pretrained weights for the `GroupNorm` layer." ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "ResNet(\n", " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): BasicBlock(\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (4): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (5): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer4): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (fc): Linear(in_features=512, out_features=37, bias=True)\n", ")" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Vanilla resnet with `BatchNorm`\n", "resnet34_bn = models.resnet34(num_classes=len(trn_dataset.label2int), pretrained=False)\n", "resnet34_bn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we define `GroupNorm_32` class with default 32 groups as in the `Group Normalization` research paper [here](https://arxiv.org/abs/1803.08494)." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "class GroupNorm_32(torch.nn.GroupNorm):\n", " def __init__(self, num_channels, num_groups=32, **kwargs):\n", " super().__init__(num_groups, num_channels, **kwargs)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "ResNet(\n", " (conv1): Conv2d_WS(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d_WS(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " (downsample): Sequential(\n", " (0): Conv2d_WS(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " )\n", " (3): BasicBlock(\n", " (conv1): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d_WS(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (downsample): Sequential(\n", " (0): Conv2d_WS(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " )\n", " (3): BasicBlock(\n", " (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " )\n", " (4): BasicBlock(\n", " (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " )\n", " (5): BasicBlock(\n", " (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " )\n", " )\n", " (layer4): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d_WS(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " (downsample): Sequential(\n", " (0): Conv2d_WS(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " )\n", " )\n", " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (fc): Linear(in_features=512, out_features=37, bias=True)\n", ")" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# resnet34 with `GroupNorm` and `Standardized Weights`\n", "# `conv2d` replaced with `Conv2d_WS` and `BatchNorm` replaced with `GroupNorm`\n", "resnet34_gn = resnet34(norm_layer=GroupNorm_32, num_classes=len(trn_dataset.label2int))\n", "resnet34_gn" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 37])" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# make sure we are able to make forward pass\n", "resnet34_gn(next(iter(trn_loader))[0]).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training using `PytorchLightning`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we use [PytorchLightning](https://github.com/PyTorchLightning/pytorch-lightning) for training the model. " ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "from pytorch_lightning import LightningModule, Trainer" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "class Model(LightningModule):\n", " def __init__(self, base):\n", " super().__init__()\n", " self.base = base\n", "\n", " def forward(self, x):\n", " return self.base(x)\n", "\n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=1e-3)\n", "\n", " def step(self, batch):\n", " x, y = batch\n", " y_hat = self(x)\n", " loss = nn.CrossEntropyLoss()(y_hat, y)\n", " return loss, y, y_hat\n", "\n", " def training_step(self, batch, batch_nb):\n", " loss, _, _ = self.step(batch)\n", " return {'loss': loss}\n", "\n", " def validation_step(self, batch, batch_nb):\n", " loss, y, y_hat = self.step(batch)\n", " return {'loss': loss, 'y': y.detach(), 'y_hat': y_hat.detach()}\n", "\n", " def validation_epoch_end(self, outputs):\n", " avg_loss = torch.stack([x['loss'] for x in outputs]).mean()\n", " acc = self.get_accuracy(outputs)\n", " print(f\"Epoch:{self.current_epoch} | Loss:{avg_loss} | Accuracy:{acc}\")\n", " return {'loss': avg_loss}\n", " \n", " def get_accuracy(self, outputs):\n", " from sklearn.metrics import accuracy_score\n", " y = torch.cat([x['y'] for x in outputs])\n", " y_hat = torch.cat([x['y_hat'] for x in outputs])\n", " preds = y_hat.argmax(1)\n", " return accuracy_score(y.cpu().numpy(), preds.cpu().numpy())" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "# define PL versions \n", "model_bn = Model(resnet34_bn)\n", "model_gn = Model(resnet34_gn)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "CUDA_VISIBLE_DEVICES: [0]\n" ] } ], "source": [ "debug = False\n", "gpus "markdown", "metadata": {}, "source": [ "- Model with `GroupNorm` + `Standardised Weights` was able to achieve similar performance as `BatchNorm`. Thus, `GroupNorm` can be considered as an alternative to `BatchNorm`. \n", "\n", "- `GroupNorm` does not necessarily achieve better performance than `BatchNorm` with lower batch size as reported in the paper for `Pets` dataset. \n", "\n", "- The research paper uses `Imagenet` dataset whereas this experiment was run using the `Pets` dataset due to lack of compute required to train on `Imagenet`\n", "\n", "- For more details, refer to my blogpost\n", "\n", "- For `bs=1` `GroupNorm` performs significantly better than `BatchNorm`" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "165px" }, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }