{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# `GroupNorm` vs `BatchNorm` on `Pets` Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we implement `GroupNorm` with `Weight Standardization` and compare the results with `BatchNorm`. Simply replacing `BN` with `GN` lead to sub-optimal results." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from fastai2.vision.all import *\n", "from nbdev.showdoc import *\n", "import glob\n", "import albumentations\n", "from torchvision import models\n", "from albumentations.pytorch.transforms import ToTensorV2\n", "set_s`eed(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# `Resnet` Implementation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We copy the implementation of `Weight Standardization` from the official repository [here](https://github.com/joe-siyuan-qiao/WeightStandardization) and also copy the implementation of `ResNet` from TorchVision." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.hub import load_state_dict_from_url" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "model_urls = {\n", " 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n", " 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n", " 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n", " 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n", " 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',\n", " 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',\n", " 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',\n", " 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',\n", " 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We replace the `Convolution` layer inside `ResNet` with the standardized version as in the `Standardized Weights` research paper. Everything else remains the same." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class Conv2d_WS(nn.Conv2d):\n", " def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n", " padding=0, dilation=1, groups=1, bias=True):\n", " super(Conv2d_WS, self).__init__(in_channels, out_channels, kernel_size, stride,\n", " padding, dilation, groups, bias)\n", "\n", " def forward(self, x):\n", " weight = self.weight\n", " weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,\n", " keepdim=True).mean(dim=3, keepdim=True)\n", " weight = weight - weight_mean\n", " std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5\n", " weight = weight / std.expand_as(weight)\n", " return F.conv2d(x, weight, self.bias, self.stride,\n", " self.padding, self.dilation, self.groups)\n", "\n", "\n", "def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n", " \"\"\"3x3 convolution with padding\"\"\"\n", " return Conv2d_WS(in_planes, out_planes, kernel_size=3, stride=stride,\n", " padding=dilation, groups=groups, bias=False, dilation=dilation)\n", "\n", "\n", "def conv1x1(in_planes, out_planes, stride=1):\n", " \"\"\"1x1 convolution\"\"\"\n", " return Conv2d_WS(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class BasicBlock(nn.Module):\n", " expansion = 1\n", "\n", " def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n", " base_width=64, dilation=1, norm_layer=None):\n", " super(BasicBlock, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = nn.BatchNorm2d\n", " if groups != 1 or base_width != 64:\n", " raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n", " if dilation > 1:\n", " raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n", " # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n", " self.conv1 = conv3x3(inplanes, planes, stride)\n", " self.bn1 = norm_layer(planes)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.conv2 = conv3x3(planes, planes)\n", " self.bn2 = norm_layer(planes)\n", " self.downsample = downsample\n", " self.stride = stride\n", "\n", " def forward(self, x):\n", " identity = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", "\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out += identity\n", " out = self.relu(out)\n", "\n", " return out" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class Bottleneck(nn.Module):\n", " expansion = 4\n", " def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n", " base_width=64, dilation=1, norm_layer=None):\n", " super(Bottleneck, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = nn.BatchNorm2d\n", " width = int(planes * (base_width / 64.)) * groups\n", " # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n", " self.conv1 = conv1x1(inplanes, width)\n", " self.bn1 = norm_layer(width)\n", " self.conv2 = conv3x3(width, width, stride, groups, dilation)\n", " self.bn2 = norm_layer(width)\n", " self.conv3 = conv1x1(width, planes * self.expansion)\n", " self.bn3 = norm_layer(planes * self.expansion)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.downsample = downsample\n", " self.stride = stride\n", "\n", " def forward(self, x):\n", " identity = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv3(out)\n", " out = self.bn3(out)\n", "\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out += identity\n", " out = self.relu(out)\n", "\n", " return out" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class ResNet(nn.Module):\n", "\n", " def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,\n", " groups=1, width_per_group=64, replace_stride_with_dilation=None,\n", " norm_layer=None):\n", " super(ResNet, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = nn.BatchNorm2d\n", " self._norm_layer = norm_layer\n", "\n", " self.inplanes = 64\n", " self.dilation = 1\n", " if replace_stride_with_dilation is None:\n", " replace_stride_with_dilation = [False, False, False]\n", " if len(replace_stride_with_dilation) != 3:\n", " raise ValueError(\"replace_stride_with_dilation should be None \"\n", " \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n", " self.groups = groups\n", " self.base_width = width_per_group\n", " self.conv1 = Conv2d_WS(3, self.inplanes, kernel_size=7, stride=2, padding=3,\n", " bias=False)\n", " self.bn1 = norm_layer(self.inplanes)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n", " self.layer1 = self._make_layer(block, 64, layers[0])\n", " self.layer2 = self._make_layer(block, 128, layers[1], stride=2,\n", " dilate=replace_stride_with_dilation[0])\n", " self.layer3 = self._make_layer(block, 256, layers[2], stride=2,\n", " dilate=replace_stride_with_dilation[1])\n", " self.layer4 = self._make_layer(block, 512, layers[3], stride=2,\n", " dilate=replace_stride_with_dilation[2])\n", " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n", " self.fc = nn.Linear(512 * block.expansion, num_classes)\n", "\n", " for m in self.modules():\n", " if isinstance(m, nn.Conv2d):\n", " nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n", " elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n", " nn.init.constant_(m.weight, 1)\n", " nn.init.constant_(m.bias, 0)\n", "\n", " # Zero-initialize the last BN in each residual branch,\n", " # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n", " # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n", " if zero_init_residual:\n", " for m in self.modules():\n", " if isinstance(m, Bottleneck):\n", " nn.init.constant_(m.bn3.weight, 0)\n", " elif isinstance(m, BasicBlock):\n", " nn.init.constant_(m.bn2.weight, 0)\n", "\n", " def _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n", " norm_layer = self._norm_layer\n", " downsample = None\n", " previous_dilation = self.dilation\n", " if dilate:\n", " self.dilation *= stride\n", " stride = 1\n", " if stride != 1 or self.inplanes != planes * block.expansion:\n", " downsample = nn.Sequential(\n", " conv1x1(self.inplanes, planes * block.expansion, stride),\n", " norm_layer(planes * block.expansion),\n", " )\n", "\n", " layers = []\n", " layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n", " self.base_width, previous_dilation, norm_layer))\n", " self.inplanes = planes * block.expansion\n", " for _ in range(1, blocks):\n", " layers.append(block(self.inplanes, planes, groups=self.groups,\n", " base_width=self.base_width, dilation=self.dilation,\n", " norm_layer=norm_layer))\n", "\n", " return nn.Sequential(*layers)\n", "\n", " def _forward_impl(self, x):\n", " # See note [TorchScript super()]\n", " x = self.conv1(x)\n", " x = self.bn1(x)\n", " x = self.relu(x)\n", " x = self.maxpool(x)\n", "\n", " x = self.layer1(x)\n", " x = self.layer2(x)\n", " x = self.layer3(x)\n", " x = self.layer4(x)\n", "\n", " x = self.avgpool(x)\n", " x = torch.flatten(x, 1)\n", " x = self.fc(x)\n", "\n", " return x\n", "\n", " def forward(self, x):\n", " return self._forward_impl(x)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def _resnet(arch, block, layers, pretrained, progress, **kwargs):\n", " model = ResNet(block, layers, **kwargs)\n", " if pretrained:\n", " state_dict = load_state_dict_from_url(model_urls[arch],\n", " progress=progress)\n", " model.load_state_dict(state_dict)\n", " return model\n", "\n", "\n", "def resnet18(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNet-18 model from\n", " `\"Deep Residual Learning for Image Recognition\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,\n", " **kwargs)\n", "\n", "\n", "def resnet34(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNet-34 model from\n", " `\"Deep Residual Learning for Image Recognition\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,\n", " **kwargs)\n", "\n", "\n", "def resnet50(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNet-50 model from\n", " `\"Deep Residual Learning for Image Recognition\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,\n", " **kwargs)\n", "\n", "\n", "def resnet101(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNet-101 model from\n", " `\"Deep Residual Learning for Image Recognition\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,\n", " **kwargs)\n", "\n", "\n", "def resnet152(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNet-152 model from\n", " `\"Deep Residual Learning for Image Recognition\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,\n", " **kwargs)\n", "\n", "\n", "def resnext50_32x4d(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNeXt-50 32x4d model from\n", " `\"Aggregated Residual Transformation for Deep Neural Networks\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " kwargs['groups'] = 32\n", " kwargs['width_per_group'] = 4\n", " return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],\n", " pretrained, progress, **kwargs)\n", "\n", "\n", "def resnext101_32x8d(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"ResNeXt-101 32x8d model from\n", " `\"Aggregated Residual Transformation for Deep Neural Networks\" `_\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " kwargs['groups'] = 32\n", " kwargs['width_per_group'] = 8\n", " return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],\n", " pretrained, progress, **kwargs)\n", "\n", "\n", "def wide_resnet50_2(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"Wide ResNet-50-2 model from\n", " `\"Wide Residual Networks\" `_\n", "\n", " The model is the same as ResNet except for the bottleneck number of channels\n", " which is twice larger in every block. The number of channels in outer 1x1\n", " convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n", " channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " kwargs['width_per_group'] = 64 * 2\n", " return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],\n", " pretrained, progress, **kwargs)\n", "\n", "\n", "def wide_resnet101_2(pretrained=False, progress=True, **kwargs):\n", " r\"\"\"Wide ResNet-101-2 model from\n", " `\"Wide Residual Networks\" `_\n", "\n", " The model is the same as ResNet except for the bottleneck number of channels\n", " which is twice larger in every block. The number of channels in outer 1x1\n", " convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n", " channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n", "\n", " Args:\n", " pretrained (bool): If True, returns a model pre-trained on ImageNet\n", " progress (bool): If True, displays a progress bar of the download to stderr\n", " \"\"\"\n", " kwargs['width_per_group'] = 64 * 2\n", " return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],\n", " pretrained, progress, **kwargs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# `Pets` Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we use the wonderful [fastai library](https://github.com/fastai/fastai2) to get the `Pets` dataset." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "bs = 4" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Path('/home/ubuntu/.fastai/data/oxford-iiit-pet')" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "path = untar_data(URLs.PETS); path" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#7381) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "(path/'images').ls()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `Dataset`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The implementation of the `PetsDataset` has been heavily inspired and partially copied (regex part) from `fastai2` repo [here](https://github.com/fastai/fastai2/blob/master/nbs/course/lesson1-pets.ipynb)." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "class PetsDataset:\n", " def __init__(self, paths, transforms=None):\n", " self.image_paths = paths\n", " self.transforms = transforms\n", " \n", " def __len__(self): \n", " return len(self.image_paths)\n", " \n", " def setup(self, pat=r'(.+)_\\d+.jpg$', label2int=None):\n", " \"adds a label dictionary to `self`\"\n", " self.pat = re.compile(pat)\n", " if label2int is not None:\n", " self.label2int = label2int\n", " self.int2label = {v:i for i,v in self.label2int.items()}\n", " else:\n", " labels = [os.path.basename(self.pat.search(str(p)).group(1))\n", " for p in self.image_paths]\n", " self.labels = set(labels)\n", " self.label2int = {label:i for i,label in enumerate(self.labels)}\n", " self.int2label = {v:i for i,v in self.label2int.items()}\n", "\n", " \n", " def __getitem__(self, idx):\n", " img_path = self.image_paths[idx]\n", " img = Image.open(img_path)\n", " img = np.array(img)\n", " \n", " target = os.path.basename(self.pat.search(str(img_path)).group(1))\n", " target = self.label2int[target]\n", " \n", " if self.transforms:\n", " img = self.transforms(image=img)['image'] \n", " \n", " return img, torch.tensor(target, dtype=torch.long)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "(#7378) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image_paths = get_image_files(path/'images')\n", "image_paths" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# remove those images that are not 3 channel\n", "from tqdm.notebook import tqdm\n", "run_remove = False\n", "def remove(o):\n", " img = Image.open(o)\n", " img = np.array(img)\n", " if img.shape[2] != 3:\n", " os.remove(o)\n", "if run_remove:\n", " for o in tqdm(image_paths): remove(o)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(#7378) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image_paths = get_image_files(path/'images')\n", "image_paths" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# augmentations using `albumentations` library\n", "sz = 224\n", "tfms = albumentations.Compose([\n", " albumentations.Resize(sz, sz) if sz else albumentations.NoOp(),\n", " albumentations.OneOf(\n", " [albumentations.Cutout(random.randint(1,8), 16, 16),\n", " albumentations.CoarseDropout(random.randint(1,8), 16, 16)]\n", " ),\n", " albumentations.Normalize(always_apply=True),\n", " ToTensorV2()\n", "])" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "dataset = PetsDataset(image_paths, tfms)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# to setup the `label2int` dictionary\n", "dataset.setup()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(tensor([[[ 0.8618, 0.1597, 0.4166, ..., -0.6452, -0.3198, -0.2171],\n", " [ 1.1872, 0.3481, 0.4166, ..., -0.3027, 0.0912, 0.3138],\n", " [ 0.8104, 0.6049, 0.0227, ..., -0.3712, -0.1657, -0.1828],\n", " ...,\n", " [ 1.2385, 0.4851, 0.0227, ..., 0.8789, 1.2214, 0.8961],\n", " [ 0.7077, 0.9474, -0.6965, ..., 0.1254, 1.5297, 1.6667],\n", " [ 0.1083, -0.0801, 0.3652, ..., 0.2111, 0.5193, 0.6734]],\n", " \n", " [[ 0.9230, 0.4328, 0.4503, ..., -0.2850, -0.0224, -0.0399],\n", " [ 1.3256, 0.7304, 0.4678, ..., -0.0399, 0.1527, 0.3277],\n", " [ 0.8354, 0.8354, 0.3102, ..., -0.2500, -0.1975, -0.3200],\n", " ...,\n", " [ 1.3606, 1.3431, 0.6078, ..., 0.9755, 1.3957, 1.1331],\n", " [ 0.7654, 1.0455, -0.0574, ..., 0.7654, 1.6232, 1.7458],\n", " [ 0.4153, 0.5903, 0.9230, ..., 0.7654, 0.8529, 1.0980]],\n", " \n", " [[ 0.3393, -0.3578, -0.4275, ..., -0.7936, -0.4624, -0.3578],\n", " [ 0.6531, -0.2358, -0.4973, ..., -0.3753, -0.0615, 0.1128],\n", " [ 0.0431, 0.1128, -1.0201, ..., -0.4101, -0.2707, -0.3578],\n", " ...,\n", " [ 0.7228, 0.3219, -0.5321, ..., 0.4439, 1.0017, 0.7576],\n", " [ 0.2173, 0.4265, -1.1247, ..., -0.0790, 1.1411, 1.2457],\n", " [-0.4450, -0.2881, 0.1302, ..., 0.0082, 0.2696, 0.4439]]]),\n", " tensor(24))" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[0]" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([3, 224, 224])" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[0][0].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `DataLoaders`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We divide the `image_paths` into train and validation with 20% split." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1475" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nval = int(len(image_paths)*0.2)\n", "nval" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(5903, 1475)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trn_img_paths = image_paths[:-nval]\n", "val_img_paths = image_paths[-nval:]\n", "assert len(trn_img_paths) + len(val_img_paths) == len(image_paths)\n", "len(trn_img_paths), len(val_img_paths)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "trn_dataset = PetsDataset(trn_img_paths, transforms=tfms)\n", "val_dataset = PetsDataset(val_img_paths, transforms=tfms)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "# use same `label2int` dictionary as in `dataset` for consistency across train and val\n", "trn_dataset.setup(label2int=dataset.label2int)\n", "val_dataset.setup(label2int=dataset.label2int)" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=bs, num_workers=4, shuffle=True)\n", "val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs, num_workers=4, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(torch.Size([4, 3, 224, 224]), torch.Size([4, 3, 224, 224]))" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# make sure eveyrthing works so far\n", "next(iter(trn_loader))[0].shape, next(iter(val_loader))[0].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## `Model`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we define the resnet34 from the `torchvision` repo with `pretrained=False` as we do not have pretrained weights for the `GroupNorm` layer." ] }, { "cell_type": "code", "execution_count": 29, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "ResNet(\n", " (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): BasicBlock(\n", " (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (3): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (4): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (5): BasicBlock(\n", " (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (layer4): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (downsample): Sequential(\n", " (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " )\n", " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (fc): Linear(in_features=512, out_features=37, bias=True)\n", ")" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Vanilla resnet with `BatchNorm`\n", "resnet34_bn = models.resnet34(num_classes=len(trn_dataset.label2int), pretrained=False)\n", "resnet34_bn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we define `GroupNorm_32` class with default 32 groups as in the `Group Normalization` research paper [here](https://arxiv.org/abs/1803.08494)." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "class GroupNorm_32(torch.nn.GroupNorm):\n", " def __init__(self, num_channels, num_groups=32, **kwargs):\n", " super().__init__(num_groups, num_channels, **kwargs)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "ResNet(\n", " (conv1): Conv2d_WS(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", " (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", " (layer1): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True)\n", " )\n", " )\n", " (layer2): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d_WS(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " (downsample): Sequential(\n", " (0): Conv2d_WS(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " )\n", " (3): BasicBlock(\n", " (conv1): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 128, eps=1e-05, affine=True)\n", " )\n", " )\n", " (layer3): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d_WS(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (downsample): Sequential(\n", " (0): Conv2d_WS(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " )\n", " (3): BasicBlock(\n", " (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " )\n", " (4): BasicBlock(\n", " (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " )\n", " (5): BasicBlock(\n", " (conv1): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 256, eps=1e-05, affine=True)\n", " )\n", " )\n", " (layer4): Sequential(\n", " (0): BasicBlock(\n", " (conv1): Conv2d_WS(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " (downsample): Sequential(\n", " (0): Conv2d_WS(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", " (1): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " )\n", " )\n", " (1): BasicBlock(\n", " (conv1): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " )\n", " (2): BasicBlock(\n", " (conv1): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn1): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " (relu): ReLU(inplace=True)\n", " (conv2): Conv2d_WS(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", " (bn2): GroupNorm_32(32, 512, eps=1e-05, affine=True)\n", " )\n", " )\n", " (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n", " (fc): Linear(in_features=512, out_features=37, bias=True)\n", ")" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# resnet34 with `GroupNorm` and `Standardized Weights`\n", "# `conv2d` replaced with `Conv2d_WS` and `BatchNorm` replaced with `GroupNorm`\n", "resnet34_gn = resnet34(norm_layer=GroupNorm_32, num_classes=len(trn_dataset.label2int))\n", "resnet34_gn" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([4, 37])" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# make sure we are able to make forward pass\n", "resnet34_gn(next(iter(trn_loader))[0]).shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training using `PytorchLightning`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we use [PytorchLightning](https://github.com/PyTorchLightning/pytorch-lightning) for training the model. " ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "from pytorch_lightning import LightningModule, Trainer" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "class Model(LightningModule):\n", " def __init__(self, base):\n", " super().__init__()\n", " self.base = base\n", "\n", " def forward(self, x):\n", " return self.base(x)\n", "\n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=1e-3)\n", "\n", " def step(self, batch):\n", " x, y = batch\n", " y_hat = self(x)\n", " loss = nn.CrossEntropyLoss()(y_hat, y)\n", " return loss, y, y_hat\n", "\n", " def training_step(self, batch, batch_nb):\n", " loss, _, _ = self.step(batch)\n", " return {'loss': loss}\n", "\n", " def validation_step(self, batch, batch_nb):\n", " loss, y, y_hat = self.step(batch)\n", " return {'loss': loss, 'y': y.detach(), 'y_hat': y_hat.detach()}\n", "\n", " def validation_epoch_end(self, outputs):\n", " avg_loss = torch.stack([x['loss'] for x in outputs]).mean()\n", " acc = self.get_accuracy(outputs)\n", " print(f\"Epoch:{self.current_epoch} | Loss:{avg_loss} | Accuracy:{acc}\")\n", " return {'loss': avg_loss}\n", " \n", " def get_accuracy(self, outputs):\n", " from sklearn.metrics import accuracy_score\n", " y = torch.cat([x['y'] for x in outputs])\n", " y_hat = torch.cat([x['y_hat'] for x in outputs])\n", " preds = y_hat.argmax(1)\n", " return accuracy_score(y.cpu().numpy(), preds.cpu().numpy())" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "# define PL versions \n", "model_bn = Model(resnet34_bn)\n", "model_gn = Model(resnet34_gn)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "CUDA_VISIBLE_DEVICES: [0]\n" ] } ], "source": [ "debug = False\n", "gpus = torch.cuda.device_count()\n", "trainer = Trainer(gpus=gpus, max_epochs=50, \n", " num_sanity_val_steps=1 if debug else 0)" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### `batch_size=4`" ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "hidden": true, "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "--------------------------------\n", "0 | base | ResNet | 21 M \n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6cd9f36c8d9a4833989984fc388cab19", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:0 | Loss:3.638690710067749 | Accuracy:0.022372881355932205\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:1 | Loss:3.5767452716827393 | Accuracy:0.03728813559322034\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:2 | Loss:3.532081365585327 | Accuracy:0.05152542372881356\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:3 | Loss:3.497438907623291 | Accuracy:0.06033898305084746\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:4 | Loss:3.437784194946289 | Accuracy:0.07457627118644068\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:5 | Loss:3.3992772102355957 | Accuracy:0.07322033898305084\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:6 | Loss:3.3322556018829346 | Accuracy:0.08203389830508474\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:7 | Loss:3.278475761413574 | Accuracy:0.09220338983050848\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:8 | Loss:3.2041637897491455 | Accuracy:0.12\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:9 | Loss:3.1338086128234863 | Accuracy:0.13288135593220338\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:10 | Loss:2.9662578105926514 | Accuracy:0.15593220338983052\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:11 | Loss:2.9380886554718018 | Accuracy:0.16203389830508474\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:12 | Loss:2.7531585693359375 | Accuracy:0.21627118644067797\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:13 | Loss:2.7896103858947754 | Accuracy:0.2223728813559322\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:14 | Loss:2.5649585723876953 | Accuracy:0.26372881355932204\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:15 | Loss:2.5243453979492188 | Accuracy:0.3071186440677966\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:16 | Loss:2.453778028488159 | Accuracy:0.3220338983050847\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:17 | Loss:2.575655460357666 | Accuracy:0.33016949152542374\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:18 | Loss:2.723491668701172 | Accuracy:0.3193220338983051\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:19 | Loss:3.0088090896606445 | Accuracy:0.3369491525423729\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:20 | Loss:3.221853494644165 | Accuracy:0.3213559322033898\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:21 | Loss:3.3212766647338867 | Accuracy:0.34576271186440677\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:22 | Loss:3.6144063472747803 | Accuracy:0.3247457627118644\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:23 | Loss:3.542142868041992 | Accuracy:0.34440677966101696\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:24 | Loss:3.8027701377868652 | Accuracy:0.32610169491525426\n", "\n" ] }, { "data": { "text/plain": [ "1" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train model with `GroupNorm` with `bs=4` on the `Pets` dataset\n", "trainer = Trainer(gpus=gpus, max_epochs=25, \n", " num_sanity_val_steps=1 if debug else 0)\n", "trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "hidden": true, "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "--------------------------------\n", "0 | base | ResNet | 21 M \n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7ed8716b52e44ed2848966a0c9f50794", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:0 | Loss:4.403476715087891 | Accuracy:0.01966101694915254\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:1 | Loss:3.615051746368408 | Accuracy:0.03932203389830508\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:2 | Loss:3.6922903060913086 | Accuracy:0.05084745762711865\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:3 | Loss:3.4302172660827637 | Accuracy:0.062372881355932205\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:4 | Loss:3.351684331893921 | Accuracy:0.08271186440677966\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:5 | Loss:3.2836146354675293 | Accuracy:0.0935593220338983\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:6 | Loss:3.2269628047943115 | Accuracy:0.10915254237288136\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:7 | Loss:3.2704873085021973 | Accuracy:0.1023728813559322\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:8 | Loss:3.071798801422119 | Accuracy:0.1423728813559322\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:9 | Loss:3.0656063556671143 | Accuracy:0.15457627118644068\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:10 | Loss:3.0375216007232666 | Accuracy:0.17288135593220338\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:11 | Loss:2.8739380836486816 | Accuracy:0.2094915254237288\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:12 | Loss:2.7329418659210205 | Accuracy:0.23186440677966103\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:13 | Loss:2.737560510635376 | Accuracy:0.24813559322033898\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:14 | Loss:2.541532516479492 | Accuracy:0.27728813559322035\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:15 | Loss:2.540792226791382 | Accuracy:0.3064406779661017\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:16 | Loss:2.485729217529297 | Accuracy:0.3328813559322034\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:17 | Loss:2.7257814407348633 | Accuracy:0.31389830508474575\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:18 | Loss:3.07981276512146 | Accuracy:0.3247457627118644\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:19 | Loss:3.1801645755767822 | Accuracy:0.31661016949152543\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:20 | Loss:3.270585298538208 | Accuracy:0.3328813559322034\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:21 | Loss:3.355048656463623 | Accuracy:0.3376271186440678\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:22 | Loss:3.362093687057495 | Accuracy:0.29898305084745763\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:23 | Loss:3.470551013946533 | Accuracy:0.3389830508474576\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:24 | Loss:3.5411648750305176 | Accuracy:0.31254237288135595\n", "\n" ] }, { "data": { "text/plain": [ "1" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# train model with `BatchNorm` with `bs=4` on the `Pets` dataset\n", "trainer = Trainer(gpus=gpus, max_epochs=25, \n", " num_sanity_val_steps=1 if debug else 0)\n", "trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)" ] }, { "cell_type": "markdown", "metadata": { "heading_collapsed": true }, "source": [ "### `batch_size=64`" ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "hidden": true }, "outputs": [], "source": [ "trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=64, num_workers=4, shuffle=True)\n", "val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, num_workers=4, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "hidden": true }, "outputs": [], "source": [ "# redefine PL versions to remove trained weights\n", "model_bn = Model(resnet34_bn)\n", "model_gn = Model(resnet34_gn)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": { "hidden": true, "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "--------------------------------\n", "0 | base | ResNet | 21 M \n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:0 | Loss:4.571255683898926 | Accuracy:0.33152542372881355\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:1 | Loss:4.823599815368652 | Accuracy:0.33084745762711865\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:2 | Loss:4.738388538360596 | Accuracy:0.33152542372881355\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:3 | Loss:4.6921844482421875 | Accuracy:0.3383050847457627\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:4 | Loss:5.571420669555664 | Accuracy:0.3227118644067797\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:5 | Loss:4.973819255828857 | Accuracy:0.31864406779661014\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:6 | Loss:4.960039138793945 | Accuracy:0.31186440677966104\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:7 | Loss:4.72049617767334 | Accuracy:0.33152542372881355\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:8 | Loss:4.7438859939575195 | Accuracy:0.3410169491525424\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:9 | Loss:4.7650861740112305 | Accuracy:0.33220338983050846\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:10 | Loss:4.842560768127441 | Accuracy:0.33491525423728813\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:11 | Loss:5.002099514007568 | Accuracy:0.3410169491525424\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:12 | Loss:4.969579696655273 | Accuracy:0.3328813559322034\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:13 | Loss:4.797631740570068 | Accuracy:0.3328813559322034\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:14 | Loss:4.790388107299805 | Accuracy:0.33220338983050846\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:15 | Loss:4.84404993057251 | Accuracy:0.3464406779661017\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:16 | Loss:4.882577896118164 | Accuracy:0.3416949152542373\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:17 | Loss:4.831890106201172 | Accuracy:0.3403389830508475\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:18 | Loss:4.815413475036621 | Accuracy:0.34576271186440677\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:19 | Loss:4.880715370178223 | Accuracy:0.34779661016949154\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:20 | Loss:4.870474815368652 | Accuracy:0.34508474576271186\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:21 | Loss:4.8547258377075195 | Accuracy:0.3430508474576271\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:22 | Loss:4.814042568206787 | Accuracy:0.3505084745762712\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:23 | Loss:5.573678970336914 | Accuracy:0.29152542372881357\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:24 | Loss:4.861083030700684 | Accuracy:0.33220338983050846\n", "\n" ] }, { "data": { "text/plain": [ "1" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer = Trainer(gpus=gpus, max_epochs=25, \n", " num_sanity_val_steps=1 if debug else 0)\n", "trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": { "hidden": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "--------------------------------\n", "0 | base | ResNet | 21 M \n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:0 | Loss:4.338170051574707 | Accuracy:0.36135593220338985\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:1 | Loss:4.264873027801514 | Accuracy:0.3593220338983051\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:2 | Loss:4.475521564483643 | Accuracy:0.368135593220339\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:3 | Loss:4.5568928718566895 | Accuracy:0.37559322033898307\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:4 | Loss:4.563418865203857 | Accuracy:0.36610169491525424\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:5 | Loss:4.532094955444336 | Accuracy:0.36677966101694914\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:6 | Loss:4.709390163421631 | Accuracy:0.36474576271186443\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:7 | Loss:4.703502178192139 | Accuracy:0.34983050847457625\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:8 | Loss:4.687512397766113 | Accuracy:0.36135593220338985\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:9 | Loss:4.453052997589111 | Accuracy:0.37559322033898307\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:10 | Loss:4.729727745056152 | Accuracy:0.3423728813559322\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:11 | Loss:4.887462139129639 | Accuracy:0.34847457627118644\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:12 | Loss:4.761058807373047 | Accuracy:0.36\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:13 | Loss:4.628625869750977 | Accuracy:0.36610169491525424\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:14 | Loss:4.939492225646973 | Accuracy:0.3735593220338983\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:15 | Loss:4.9373321533203125 | Accuracy:0.36\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:16 | Loss:4.884154796600342 | Accuracy:0.3701694915254237\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:17 | Loss:5.015425682067871 | Accuracy:0.34576271186440677\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:18 | Loss:5.0034356117248535 | Accuracy:0.34372881355932206\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:19 | Loss:5.081662178039551 | Accuracy:0.34372881355932206\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:20 | Loss:5.115207195281982 | Accuracy:0.3403389830508475\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:21 | Loss:4.923257827758789 | Accuracy:0.368135593220339\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:22 | Loss:5.064967632293701 | Accuracy:0.3701694915254237\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:23 | Loss:4.966062545776367 | Accuracy:0.368135593220339\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:24 | Loss:5.010922431945801 | Accuracy:0.376271186440678\n", "\n" ] }, { "data": { "text/plain": [ "1" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer = Trainer(gpus=gpus, max_epochs=25, \n", " num_sanity_val_steps=1 if debug else 0)\n", "trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `batch_size=1`" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=1, num_workers=4, shuffle=True)\n", "val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=4, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "model_bn = Model(resnet34_bn)\n", "model_gn = Model(resnet34_gn)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "--------------------------------\n", "0 | base | ResNet | 21 M \n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7abca97fed0441029e17f0288b767a9d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:0 | Loss:3.6087236404418945 | Accuracy:0.04067796610169491\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:1 | Loss:3.8362090587615967 | Accuracy:0.025084745762711864\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:2 | Loss:3.6673178672790527 | Accuracy:0.03593220338983051\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:3 | Loss:3.7399044036865234 | Accuracy:0.03389830508474576\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:4 | Loss:4.054337501525879 | Accuracy:0.0488135593220339\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:5 | Loss:4.010653972625732 | Accuracy:0.04542372881355932\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:6 | Loss:4.764206886291504 | Accuracy:0.05288135593220339\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:7 | Loss:10.56059455871582 | Accuracy:0.04474576271186441\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:8 | Loss:5.048521041870117 | Accuracy:0.05830508474576271\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:9 | Loss:4.828557014465332 | Accuracy:0.06508474576271187\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:10 | Loss:7.225879192352295 | Accuracy:0.05694915254237288\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:11 | Loss:6.472527027130127 | Accuracy:0.06779661016949153\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:12 | Loss:9.755941390991211 | Accuracy:0.07050847457627119\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:13 | Loss:13.05939769744873 | Accuracy:0.059661016949152545\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:14 | Loss:18.591503143310547 | Accuracy:0.06508474576271187\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:15 | Loss:11.946345329284668 | Accuracy:0.06915254237288136\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:16 | Loss:16.744611740112305 | Accuracy:0.06983050847457627\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:17 | Loss:12.913531303405762 | Accuracy:0.07661016949152542\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:18 | Loss:23.76015281677246 | Accuracy:0.06508474576271187\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:19 | Loss:26.5297794342041 | Accuracy:0.06576271186440678\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:20 | Loss:35.212242126464844 | Accuracy:0.05898305084745763\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:21 | Loss:16.634546279907227 | Accuracy:0.06169491525423729\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:22 | Loss:21.815725326538086 | Accuracy:0.062372881355932205\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:23 | Loss:12.68907356262207 | Accuracy:0.0711864406779661\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:24 | Loss:19.639753341674805 | Accuracy:0.06779661016949153\n", "\n" ] }, { "data": { "text/plain": [ "1" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer = Trainer(gpus=gpus, max_epochs=25, \n", " num_sanity_val_steps=1 if debug else 0)\n", "trainer.fit(model_bn, train_dataloader=trn_loader, val_dataloaders=val_loader)" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "--------------------------------\n", "0 | base | ResNet | 21 M \n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6c67d4b333ab4eeb97bf38fc4dde3faa", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:0 | Loss:3.6178038120269775 | Accuracy:0.03593220338983051\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:1 | Loss:3.5887539386749268 | Accuracy:0.03864406779661017\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:2 | Loss:3.4937922954559326 | Accuracy:0.0576271186440678\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:3 | Loss:3.426539421081543 | Accuracy:0.06508474576271187\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:4 | Loss:3.4010708332061768 | Accuracy:0.06915254237288136\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:5 | Loss:3.352757453918457 | Accuracy:0.08949152542372882\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:6 | Loss:3.3006396293640137 | Accuracy:0.10033898305084746\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:7 | Loss:3.2513763904571533 | Accuracy:0.09966101694915254\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:8 | Loss:3.2186825275421143 | Accuracy:0.11254237288135593\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:9 | Loss:3.1824042797088623 | Accuracy:0.12067796610169491\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:10 | Loss:3.1842432022094727 | Accuracy:0.1152542372881356\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:11 | Loss:3.080850839614868 | Accuracy:0.1342372881355932\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:12 | Loss:3.1100575923919678 | Accuracy:0.1430508474576271\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:13 | Loss:3.085071563720703 | Accuracy:0.14508474576271185\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:14 | Loss:3.007901906967163 | Accuracy:0.17559322033898306\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:15 | Loss:3.1437573432922363 | Accuracy:0.1694915254237288\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:16 | Loss:3.110459089279175 | Accuracy:0.1864406779661017\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:17 | Loss:3.5012593269348145 | Accuracy:0.18847457627118644\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:18 | Loss:3.4454123973846436 | Accuracy:0.21084745762711865\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:19 | Loss:3.8177714347839355 | Accuracy:0.21152542372881356\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:20 | Loss:4.031371116638184 | Accuracy:0.1952542372881356\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:21 | Loss:4.404645919799805 | Accuracy:0.1769491525423729\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:22 | Loss:4.856805324554443 | Accuracy:0.1959322033898305\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch:23 | Loss:4.558755874633789 | Accuracy:0.21152542372881356\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "IOPub message rate exceeded.\n", "The notebook server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--NotebookApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "NotebookApp.rate_limit_window=3.0 (secs)\n", "\n", "IOPub message rate exceeded.\n", "The notebook server will temporarily stop sending output\n", "to the client in order to avoid crashing it.\n", "To change this limit, set the config variable\n", "`--NotebookApp.iopub_msg_rate_limit`.\n", "\n", "Current values:\n", "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", "NotebookApp.rate_limit_window=3.0 (secs)\n", "\n" ] } ], "source": [ "trainer = Trainer(gpus=gpus, max_epochs=25, \n", " num_sanity_val_steps=1 if debug else 0)\n", "trainer.fit(model_gn, train_dataloader=trn_loader, val_dataloaders=val_loader)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Conclusion" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Model with `GroupNorm` + `Standardised Weights` was able to achieve similar performance as `BatchNorm`. Thus, `GroupNorm` can be considered as an alternative to `BatchNorm`. \n", "\n", "- `GroupNorm` does not necessarily achieve better performance than `BatchNorm` with lower batch size as reported in the paper for `Pets` dataset. \n", "\n", "- The research paper uses `Imagenet` dataset whereas this experiment was run using the `Pets` dataset due to lack of compute required to train on `Imagenet`\n", "\n", "- For more details, refer to my blogpost\n", "\n", "- For `bs=1` `GroupNorm` performs significantly better than `BatchNorm`" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "165px" }, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }