{ "cells": [ { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.quantization import fuse_modules\n", "import os\n", "\n", "__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',\n", " 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',\n", " 'wide_resnet50_2', 'wide_resnet101_2']\n", "\n", "\n", "model_urls = {\n", " 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n", " 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n", " 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n", " 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n", " 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',\n", " 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',\n", " 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',\n", " 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',\n", " 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',\n", "}\n", "\n", "\n", "def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):\n", " \"\"\"3x3 convolution with padding\"\"\"\n", " return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n", " padding=dilation, groups=groups, bias=False, dilation=dilation)\n", "\n", "\n", "def conv1x1(in_planes, out_planes, stride=1):\n", " \"\"\"1x1 convolution\"\"\"\n", " return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n", "\n", "\n", "class BasicBlock(nn.Module):\n", " expansion = 1\n", " __constants__ = ['downsample']\n", "\n", " def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n", " base_width=64, dilation=1, norm_layer=None):\n", " super(BasicBlock, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = nn.BatchNorm2d\n", " if groups != 1 or base_width != 64:\n", " raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n", " if dilation > 1:\n", " raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n", " # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n", " self.conv1 = conv3x3(inplanes, planes, stride)\n", " self.bn1 = norm_layer(planes)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.conv2 = conv3x3(planes, planes)\n", " self.bn2 = norm_layer(planes)\n", " self.downsample = downsample\n", " self.stride = stride\n", "\n", " def forward(self, x):\n", " identity = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", "\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out += identity\n", " out = self.relu(out)\n", "\n", " return out\n", "\n", "\n", "class Bottleneck(nn.Module):\n", " expansion = 4\n", " __constants__ = ['downsample']\n", "\n", " def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,\n", " base_width=64, dilation=1, norm_layer=None):\n", " super(Bottleneck, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = nn.BatchNorm2d\n", " width = int(planes * (base_width / 64.)) * groups\n", " # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n", " self.conv1 = conv1x1(inplanes, width)\n", " self.bn1 = norm_layer(width)\n", " self.conv2 = conv3x3(width, width, stride, groups, dilation)\n", " self.bn2 = norm_layer(width)\n", " self.conv3 = conv1x1(width, planes * self.expansion)\n", " self.bn3 = norm_layer(planes * self.expansion)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.downsample = downsample\n", " self.stride = stride\n", "\n", " def forward(self, x):\n", " identity = x\n", "\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", " out = self.relu(out)\n", "\n", " out = self.conv3(out)\n", " out = self.bn3(out)\n", "\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", "\n", " out += identity\n", " out = self.relu(out)\n", "\n", " return out\n", "\n", "\n", "class ResNet(nn.Module):\n", "\n", " def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,\n", " groups=1, width_per_group=64, replace_stride_with_dilation=None,\n", " norm_layer=None):\n", " super(ResNet, self).__init__()\n", " if norm_layer is None:\n", " norm_layer = nn.BatchNorm2d\n", " self._norm_layer = norm_layer\n", "\n", " self.inplanes = 64\n", " self.dilation = 1\n", " if replace_stride_with_dilation is None:\n", " # each element in the tuple indicates if we should replace\n", " # the 2x2 stride with a dilated convolution instead\n", " replace_stride_with_dilation = [False, False, False]\n", " if len(replace_stride_with_dilation) != 3:\n", " raise ValueError(\"replace_stride_with_dilation should be None \"\n", " \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n", " self.groups = groups\n", " self.base_width = width_per_group\n", " self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,\n", " bias=False)\n", " self.bn1 = norm_layer(self.inplanes)\n", " self.relu = nn.ReLU(inplace=True)\n", " self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n", " self.layer1 = self._make_layer(block, 64, layers[0])\n", " self.layer2 = self._make_layer(block, 128, layers[1], stride=2,\n", " dilate=replace_stride_with_dilation[0])\n", " self.layer3 = self._make_layer(block, 256, layers[2], stride=2,\n", " dilate=replace_stride_with_dilation[1])\n", " self.layer4 = self._make_layer(block, 512, layers[3], stride=2,\n", " dilate=replace_stride_with_dilation[2])\n", " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n", " self.fc = nn.Linear(512 * block.expansion, num_classes)\n", "\n", " for m in self.modules():\n", " if isinstance(m, nn.Conv2d):\n", " nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n", " elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n", " nn.init.constant_(m.weight, 1)\n", " nn.init.constant_(m.bias, 0)\n", "\n", " # Zero-initialize the last BN in each residual branch,\n", " # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n", " # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n", " if zero_init_residual:\n", " for m in self.modules():\n", " if isinstance(m, Bottleneck):\n", " nn.init.constant_(m.bn3.weight, 0)\n", " elif isinstance(m, BasicBlock):\n", " nn.init.constant_(m.bn2.weight, 0)\n", "\n", " def _make_layer(self, block, planes, blocks, stride=1, dilate=False):\n", " norm_layer = self._norm_layer\n", " downsample = None\n", " previous_dilation = self.dilation\n", " if dilate:\n", " self.dilation *= stride\n", " stride = 1\n", " if stride != 1 or self.inplanes != planes * block.expansion:\n", " downsample = nn.Sequential(\n", " conv1x1(self.inplanes, planes * block.expansion, stride),\n", " norm_layer(planes * block.expansion),\n", " )\n", "\n", " layers = []\n", " layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n", " self.base_width, previous_dilation, norm_layer))\n", " self.inplanes = planes * block.expansion\n", " for _ in range(1, blocks):\n", " layers.append(block(self.inplanes, planes, groups=self.groups,\n", " base_width=self.base_width, dilation=self.dilation,\n", " norm_layer=norm_layer))\n", "\n", " return nn.Sequential(*layers)\n", "\n", " def _forward(self, x):\n", " x = self.conv1(x)\n", " x = self.bn1(x)\n", " x = self.relu(x)\n", " x = self.maxpool(x)\n", "\n", " x = self.layer1(x)\n", " x = self.layer2(x)\n", " x = self.layer3(x)\n", " x = self.layer4(x)\n", "\n", " x = self.avgpool(x)\n", " x = torch.flatten(x, 1)\n", " x = self.fc(x)\n", "\n", " return x\n", "\n", " # Allow for accessing forward method in a inherited class\n", " forward = _forward\n", " \n", "def _replace_relu(module):\n", " reassign = {}\n", " for name, mod in module.named_children():\n", " _replace_relu(mod)\n", " # Checking for explicit type instead of instance\n", " # as we only want to replace modules of the exact type\n", " # not inherited classes\n", " if type(mod) == nn.ReLU or type(mod) == nn.ReLU6:\n", " reassign[name] = nn.ReLU(inplace=False)\n", "\n", " for key, value in reassign.items(): \n", " module._modules[key] = value\n", " \n", " \n", "class QuantizableBottleneck(Bottleneck):\n", " def __init__(self, *args, **kwargs):\n", " super(QuantizableBottleneck, self).__init__(*args, **kwargs)\n", " self.skip_add_relu = nn.quantized.FloatFunctional()\n", " self.relu1 = nn.ReLU(inplace=False)\n", " self.relu2 = nn.ReLU(inplace=False)\n", "\n", " def forward(self, x):\n", " identity = x\n", " out = self.conv1(x)\n", " out = self.bn1(out)\n", " out = self.relu1(out)\n", " out = self.conv2(out)\n", " out = self.bn2(out)\n", " out = self.relu2(out)\n", "\n", " out = self.conv3(out)\n", " out = self.bn3(out)\n", "\n", " if self.downsample is not None:\n", " identity = self.downsample(x)\n", " out = self.skip_add_relu.add_relu(out, identity)\n", "\n", " return out\n", "\n", " def fuse_model(self):\n", " fuse_modules(self, [['conv1', 'bn1', 'relu1'],\n", " ['conv2', 'bn2', 'relu2'],\n", " ['conv3', 'bn3']], inplace=True)\n", " if self.downsample:\n", " torch.quantization.fuse_modules(self.downsample, ['0', '1'], inplace=True)\n", "\n", "\n", "class QuantizableResNet(ResNet):\n", "\n", " def __init__(self, *args, **kwargs):\n", " super(QuantizableResNet, self).__init__(*args, **kwargs)\n", "\n", " self.quant = torch.quantization.QuantStub()\n", " self.dequant = torch.quantization.DeQuantStub()\n", "\n", " def forward(self, x):\n", " x = self.quant(x)\n", " # Ensure scriptability\n", " # super(QuantizableResNet,self).forward(x)\n", " # is not scriptable\n", " x = self._forward(x)\n", " x = self.dequant(x)\n", " return x\n", "\n", " def fuse_model(self):\n", " r\"\"\"Fuse conv/bn/relu modules in resnet models\n", " Fuse conv+bn+relu/ Conv+relu/conv+Bn modules to prepare for quantization.\n", " Model is modified in place. Note that this operation does not change numerics\n", " and the model after modification is in floating point\n", " \"\"\"\n", "\n", " fuse_modules(self, ['conv1', 'bn1', 'relu'], inplace=True)\n", " for m in self.modules():\n", " if type(m) == QuantizableBottleneck:\n", " m.fuse_model()\n", "\n", "def print_size_of_model(model):\n", " torch.save(model.state_dict(), \"temp.p\")\n", " print('Size (MB):', os.path.getsize(\"temp.p\")/1e6)\n", " os.remove('temp.p')" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Size (MB): 178.742957\n", "Size (MB): 172.638462\n" ] } ], "source": [ "model = QuantizableResNet(QuantizableBottleneck, [3, 4, 23, 3])\n", "_replace_relu(model)\n", "model.fuse_model()\n", "print_size_of_model(model)\n", "model.qconfig = torch.quantization.default_qconfig\n", "torch.quantization.prepare(model, inplace=True)\n", "quantized_net = torch.quantization.convert(model)\n", "print_size_of_model(quantized_net)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "ename": "AssertionError", "evalue": "Conv and BN both must be in the same mode (train or eval).", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0m_replace_relu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfuse_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mprint_size_of_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mqconfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mquantization\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdefault_qconfig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m\u001b[0m in \u001b[0;36mfuse_model\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 286\u001b[0m \"\"\"\n\u001b[1;32m 287\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 288\u001b[0;31m \u001b[0mfuse_modules\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'conv1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'bn1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'relu'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 289\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 290\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mQuantizableBottleneck\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/test/lib/python3.7/site-packages/torch/quantization/fuse_modules.py\u001b[0m in \u001b[0;36mfuse_modules\u001b[0;34m(model, modules_to_fuse, inplace, fuser_func)\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule_element\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule_element\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules_to_fuse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;31m# Handle case of modules_to_fuse being a list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0m_fuse_modules\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodules_to_fuse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfuser_func\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;31m# Handle case of modules_to_fuse being a list of lists\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/test/lib/python3.7/site-packages/torch/quantization/fuse_modules.py\u001b[0m in \u001b[0;36m_fuse_modules\u001b[0;34m(model, modules_to_fuse, fuser_func)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;31m# Fuse list of modules\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mnew_mod_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfuser_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmod_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;31m# Replace original module list with fused module list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/test/lib/python3.7/site-packages/torch/quantization/fuse_modules.py\u001b[0m in \u001b[0;36mfuse_known_modules\u001b[0;34m(mod_list)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Cannot fuse modules: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtypes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0mnew_mod\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmod_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m \u001b[0mnew_mod\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfuser_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mmod_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmod_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/anaconda3/envs/test/lib/python3.7/site-packages/torch/quantization/fuse_modules.py\u001b[0m in \u001b[0;36mfuse_conv_bn_relu\u001b[0;34m(conv, bn, relu)\u001b[0m\n\u001b[1;32m 45\u001b[0m \"\"\"\n\u001b[1;32m 46\u001b[0m \u001b[0;32massert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mbn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mrelu\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m \u001b[0;34m\"Conv and BN both must be in the same mode (train or eval).\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mAssertionError\u001b[0m: Conv and BN both must be in the same mode (train or eval)." ] } ], "source": [ "model = QuantizableResNet(QuantizableBottleneck, [3, 4, 6, 3])\n", "model.eval()\n", "_replace_relu(model)\n", "model.fuse_model()\n", "print_size_of_model(model)\n", "model.qconfig = torch.quantization.default_qconfig\n", "torch.quantization.prepare(model, inplace=True)\n", "quantized_net = torch.quantization.convert(model)\n", "print_size_of_model(quantized_net)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" } }, "nbformat": 4, "nbformat_minor": 2 }