""" Pytorch implementation of VGG models. Reference: [1] . Simonyan and A. Zisserman. Very deep convolutional networks for large-scale image recognition. In ICLR, 2015. """ import torch import torch.nn as nn from net.spectral_normalization.spectral_norm_conv_inplace import spectral_norm_conv from net.spectral_normalization.spectral_norm_fc import spectral_norm_fc cfg_cifar = { "VGG11": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "VGG13": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "VGG16": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M",], "VGG19": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M",], } inp_size_cifar = { "VGG11": [32, 16, 16, 8, 8, 8, 4, 4, 4, 2, 2, 2, 1], "VGG13": [32, 32, 16, 16, 16, 8, 8, 8, 4, 4, 4, 2, 2, 2, 1], "VGG16": [32, 32, 16, 16, 16, 8, 8, 8, 8, 4, 4, 4, 4, 2, 2, 2, 2, 1], "VGG19": [32, 32, 16, 16, 16, 8, 8, 8, 8, 8, 4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 1], } cfg_mnist = { "VGG11": [64, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "VGG13": [64, 64, 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], "VGG16": [64, 64, 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M",], "VGG19": [64, 64, 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M",], } inp_size_mnist = { "VGG11": [28, 28, 14, 14, 14, 7, 7, 7, 3, 3, 3, 1], "VGG13": [28, 28, 28, 28, 14, 14, 14, 7, 7, 7, 3, 3, 3, 1], "VGG16": [28, 28, 28, 28, 14, 14, 14, 14, 7, 7, 7, 7, 3, 3, 3, 3, 1], "VGG19": [28, 28, 28, 28, 14, 14, 14, 14, 14, 7, 7, 7, 7, 7, 3, 3, 3, 3, 3, 1], } class VGG(nn.Module): def __init__( self, vgg_name, num_classes=10, temp=1.0, spectral_normalization=True, mod=True, coeff=3, n_power_iterations=1, mnist=False, ): """ If the "mod" parameter is set to True, the architecture uses 2 modifications: 1. LeakyReLU instead of normal ReLU 2. Average Pooling on the residual connections. """ super(VGG, self).__init__() self.temp = temp self.mod = mod def wrapped_conv(input_size, in_c, out_c, kernel_size, stride): padding = 1 if kernel_size == 3 else 0 conv = nn.Conv2d(in_c, out_c, kernel_size, stride, padding, bias=False) if not spectral_normalization: return conv # NOTE: Google uses the spectral_norm_fc in all cases if kernel_size == 1: # use spectral norm fc, because bound are tight for 1x1 convolutions wrapped_conv = spectral_norm_fc(conv, coeff, n_power_iterations) else: # Otherwise use spectral norm conv, with loose bound shapes = (in_c, input_size, input_size) wrapped_conv = spectral_norm_conv(conv, coeff, shapes, n_power_iterations) return wrapped_conv self.wrapped_conv = wrapped_conv self.mnist = mnist if mnist: self.inp_sizes = inp_size_mnist[vgg_name] self.features = self._make_layers(cfg_mnist[vgg_name]) else: self.inp_sizes = inp_size_cifar[vgg_name] self.features = self._make_layers(cfg_cifar[vgg_name]) self.classifier = nn.Linear(512, num_classes) self.feature = None def forward(self, x): out = self.features(x) out = out.view(out.size(0), -1) self.feature = out.clone().detach() out = self.classifier(out) / self.temp return out def _make_layers(self, cfg): layers = [] in_channels = 1 if self.mnist else 3 for i, x in enumerate(cfg): if x == "M": layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: layers += [ self.wrapped_conv(self.inp_sizes[i], in_channels, x, kernel_size=3, stride=1), nn.BatchNorm2d(x), nn.LeakyReLU(inplace=True) if self.mod else nn.ReLU(inplace=True), ] in_channels = x layers += [nn.AvgPool2d(kernel_size=1, stride=1)] return nn.Sequential(*layers) def vgg11(spectral_normalization=True, mod=True, temp=1.0, mnist=False, **kwargs): model = VGG("VGG11", spectral_normalization=spectral_normalization, mod=mod, temp=temp, mnist=mnist, **kwargs) return model def vgg13(spectral_normalization=True, mod=True, temp=1.0, mnist=False, **kwargs): model = VGG("VGG13", spectral_normalization=spectral_normalization, mod=mod, temp=temp, mnist=mnist, **kwargs) return model def vgg16(spectral_normalization=True, mod=True, temp=1.0, mnist=False, **kwargs): model = VGG("VGG16", spectral_normalization=spectral_normalization, mod=mod, temp=temp, mnist=mnist, **kwargs) return model def vgg19(spectral_normalization=True, mod=True, temp=1.0, mnist=False, **kwargs): model = VGG("VGG19", spectral_normalization=spectral_normalization, mod=mod, temp=temp, mnist=mnist, **kwargs) return model