import os import torch import torch.nn as nn import torchvision.transforms as transforms import torchvision.models as models from .matcher import Matcher from collections import OrderedDict from torchvision.models.vgg import model_urls from torchvision.models import vgg19 from torch.autograd import Variable from .vgg_modified import VGGModified def get_pretrained_net(name): """Loads pretrained network""" if name == 'alexnet_caffe': if not os.path.exists('alexnet-torch_py3.pth'): print('Downloading AlexNet') os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download') return torch.load('alexnet-torch_py3.pth') elif name == 'vgg19_caffe': if not os.path.exists('vgg19-caffe-py3.pth'): print('Downloading VGG-19') os.system('wget -O vgg19-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/HPcOFQTjXxbmp4X/download') vgg = get_vgg19_caffe() return vgg elif name == 'vgg16_caffe': if not os.path.exists('vgg16-caffe-py3.pth'): print('Downloading VGG-16') os.system('wget -O vgg16-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/TUZ62HnPKWdxyLr/download') vgg = get_vgg16_caffe() return vgg elif name == 'vgg19_pytorch_modified': # os.system('wget -O data/feature_inversion/vgg19-caffe.pth --no-check-certificate -nc https://www.dropbox.com/s/xlbdo688dy4keyk/vgg19-caffe.pth?dl=1') model = VGGModified(vgg19(pretrained=False), 0.2) model.load_state_dict(torch.load('vgg_pytorch_modified.pkl')['state_dict']) return model else: assert False class PerceputalLoss(nn.modules.loss._Loss): """ Assumes input image is in range [0,1] if `input_range` is 'sigmoid', [-1, 1] if 'tanh' """ def __init__(self, input_range='sigmoid', net_type = 'vgg_torch', input_preprocessing='corresponding', match=[{'layers':[11,20,29],'what':'features'}]): if input_range not in ['sigmoid', 'tanh']: assert False self.net = get_pretrained_net(net_type).cuda() self.matchers = [get_matcher(self.net, match_opts) for match_opts in match] preprocessing_correspondence = { 'vgg19_torch': vgg_preprocess_caffe, 'vgg16_torch': vgg_preprocess_caffe, 'vgg19_pytorch': vgg_preprocess_pytorch, 'vgg19_pytorch_modified': vgg_preprocess_pytorch, } if input_preprocessing == 'corresponding': self.preprocess_input = preprocessing_correspondence[net_type] else: self.preprocessing = preprocessing_correspondence[input_preprocessing] def preprocess_input(self, x): if self.input_range == 'tanh': x = (x + 1.) / 2. return self.preprocess(x) def __call__(self, x, y): # for self.matcher_content.mode = 'store' self.net(self.preprocess_input(y)); self.matcher_content.mode = 'match' self.net(self.preprocess_input(x)); return sum([sum(matcher.losses.values()) for matcher in self.matchers]) def get_vgg19_caffe(): model = vgg19() model.classifier = nn.Sequential(View(), *model.classifier._modules.values()) vgg = model.features vgg_classifier = model.classifier names = ['conv1_1','relu1_1','conv1_2','relu1_2','pool1', 'conv2_1','relu2_1','conv2_2','relu2_2','pool2', 'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','conv3_4','relu3_4','pool3', 'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','conv4_4','relu4_4','pool4', 'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','conv5_4','relu5_4','pool5', 'torch_view','fc6','relu6','drop6','fc7','relu7','drop7','fc8'] model = nn.Sequential() for n, m in zip(names, list(vgg) + list(vgg_classifier)): model.add_module(n, m) model.load_state_dict(torch.load('vgg19-caffe-py3.pth')) return model def get_vgg16_caffe(): vgg = torch.load('vgg16-caffe-py3.pth') names = ['conv1_1','relu1_1','conv1_2','relu1_2','pool1', 'conv2_1','relu2_1','conv2_2','relu2_2','pool2', 'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','pool3', 'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','pool4', 'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','pool5', 'torch_view','fc6','relu6','drop6','fc7','relu7','fc8'] model = nn.Sequential() for n, m in zip(names, list(vgg)): model.add_module(n, m) # model.load_state_dict(torch.load('vgg19-caffe-py3.pth')) return model class View(nn.Module): def __init__(self): super(View, self).__init__() def forward(self, x): return x.view(x.size(0), -1) def get_matcher(vgg, opt): # idxs = [int(x) for x in opt['layers'].split(',')] matcher = Matcher(opt['what'], 'mse', opt['map_idx']) def hook(module, input, output): matcher(module, output) for layer_name in opt['layers']: vgg._modules[layer_name].register_forward_hook(hook) return matcher def get_vgg(cut_idx=-1, vgg_type='pytorch'): f = get_vanilla_vgg_features(cut_idx, vgg_type) keys = [x for x in cnn._modules.keys()] max_idx = max(keys.index(x) for x in opt_content['layers'].split(',')) for k in keys[max_idx+1:]: cnn._modules.pop(k) return f vgg_mean = torch.FloatTensor([103.939, 116.779, 123.680]).view(3, 1, 1) def vgg_preprocess_caffe(var): (r, g, b) = torch.chunk(var, 3, dim=1) bgr = torch.cat((b, g, r), 1) out = bgr * 255 - torch.autograd.Variable(vgg_mean).type(var.type()) return out mean_pytorch = Variable(torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)) std_pytorch = Variable(torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)) def vgg_preprocess_pytorch(var): return (var - mean_pytorch.type_as(var))/std_pytorch.type_as(var) def get_preprocessor(imsize): def vgg_preprocess(tensor): (r, g, b) = torch.chunk(tensor, 3, dim=0) bgr = torch.cat((b, g, r), 0) out = bgr * 255 - vgg_mean.type(tensor.type()).expand_as(bgr) return out preprocess = transforms.Compose([ transforms.Resize(imsize), transforms.ToTensor(), transforms.Lambda(vgg_preprocess) ]) return preprocess def get_deprocessor(): def vgg_deprocess(tensor): bgr = (tensor + vgg_mean.expand_as(tensor)) / 255.0 (b, g, r) = torch.chunk(bgr, 3, dim=0) rgb = torch.cat((r, g, b), 0) return rgb deprocess = transforms.Compose([ transforms.Lambda(vgg_deprocess), transforms.Lambda(lambda x: torch.clamp(x, 0, 1)), transforms.ToPILImage() ]) return deprocess