{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch implementation of the StyleGAN Generator and eiscriminator\n", "*by Piotr Bialecki and Thomas Viehmann*\n", "\n", "This is a hacky addition of the Discriminator to [our StyleGAN Generator notebook](pytorch_style_gan.ipynb) that is much nicer and annotated. This here is even less ready for consumption.\n", "We also don't have losses or training.\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "#import os\n", "#os.environ['CUDA_LAUNCH_BLOCKING'] = \"1\"\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from collections import OrderedDict\n", "import pickle\n", "import numpy as np\n", "\n", "%matplotlib inline\n", "from matplotlib import pyplot\n", "\n", "import IPython\n", "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n", "device = 'cpu'" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class MyLinear(nn.Module):\n", " \"\"\"Linear layer with equalized learning rate and custom learning rate multiplier.\"\"\"\n", " def __init__(self, input_size, output_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True):\n", " super().__init__()\n", " he_std = gain * input_size**(-0.5) # He init\n", " # Equalized learning rate and custom learning rate multiplier.\n", " if use_wscale:\n", " init_std = 1.0 / lrmul\n", " self.w_mul = he_std * lrmul\n", " else:\n", " init_std = he_std / lrmul\n", " self.w_mul = lrmul\n", " self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std)\n", " if bias:\n", " self.bias = torch.nn.Parameter(torch.zeros(output_size))\n", " self.b_mul = lrmul\n", " else:\n", " self.bias = None\n", "\n", " def forward(self, x):\n", " bias = self.bias\n", " if bias is not None:\n", " bias = bias * self.b_mul\n", " return F.linear(x, self.weight * self.w_mul, bias)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class MyConv2d(nn.Module):\n", " \"\"\"Conv layer with equalized learning rate and custom learning rate multiplier.\"\"\"\n", " def __init__(self, input_channels, output_channels, kernel_size, stride=1, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True,\n", " intermediate=None, upscale=False, downscale=False):\n", " super().__init__()\n", " if upscale:\n", " self.upscale = Upscale2d()\n", " else:\n", " self.upscale = None\n", " if downscale:\n", " self.downscale = Downscale2d()\n", " else:\n", " self.downscale = None\n", " he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init\n", " self.kernel_size = kernel_size\n", " if use_wscale:\n", " init_std = 1.0 / lrmul\n", " self.w_mul = he_std * lrmul\n", " else:\n", " init_std = he_std / lrmul\n", " self.w_mul = lrmul\n", " self.weight = torch.nn.Parameter(torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std)\n", " if bias:\n", " self.bias = torch.nn.Parameter(torch.zeros(output_channels))\n", " self.b_mul = lrmul\n", " else:\n", " self.bias = None\n", " self.intermediate = intermediate\n", "\n", " def forward(self, x):\n", " bias = self.bias\n", " if bias is not None:\n", " bias = bias * self.b_mul\n", " \n", " have_convolution = False\n", " if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:\n", " # this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way\n", " # this really needs to be cleaned up and go into the conv...\n", " w = self.weight * self.w_mul\n", " w = w.permute(1, 0, 2, 3)\n", " # probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!\n", " w = F.pad(w, (1,1,1,1))\n", " w = w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]\n", " x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1)-1)//2)\n", " have_convolution = True\n", " elif self.upscale is not None:\n", " x = self.upscale(x)\n", " \n", " downscale = self.downscale\n", " intermediate = self.intermediate\n", " if downscale is not None and min(x.shape[2:]) >= 128:\n", " w = self.weight * self.w_mul\n", " w = F.pad(w, (1,1,1,1))\n", " # in contrast to upscale, this is a mean...\n", " w = (w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1])*0.25 # avg_pool?\n", " x = F.conv2d(x, w, stride=2, padding=(w.size(-1)-1)//2)\n", " have_convolution = True\n", " downscale = None\n", " elif downscale is not None:\n", " assert intermediate is None\n", " intermediate = downscale\n", " \n", " if not have_convolution and intermediate is None:\n", " return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size//2)\n", " elif not have_convolution:\n", " x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size//2)\n", "\n", " if intermediate is not None:\n", " x = intermediate(x)\n", "\n", " if bias is not None:\n", " x = x + bias.view(1, -1, 1, 1)\n", " return x\n", " " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "class NoiseLayer(nn.Module):\n", " \"\"\"adds noise. noise is per pixel (constant over channels) with per-channel weight\"\"\"\n", " def __init__(self, channels):\n", " super().__init__()\n", " self.weight = nn.Parameter(torch.zeros(channels))\n", " self.noise = None\n", " \n", " def forward(self, x, noise=None):\n", " if noise is None and self.noise is None:\n", " noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)\n", " elif noise is None:\n", " # here is a little trick: if you get all the noiselayers and set each\n", " # modules .noise attribute, you can have pre-defined noise.\n", " # Very useful for analysis\n", " noise = self.noise\n", " x = x + self.weight.view(1, -1, 1, 1) * noise\n", " return x " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class StyleMod(nn.Module):\n", " def __init__(self, latent_size, channels, use_wscale):\n", " super(StyleMod, self).__init__()\n", " self.lin = MyLinear(latent_size,\n", " channels * 2,\n", " gain=1.0, use_wscale=use_wscale)\n", " \n", " def forward(self, x, latent):\n", " style = self.lin(latent) # style => [batch_size, n_channels*2]\n", " shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1]\n", " style = style.view(shape) # [batch_size, 2, n_channels, ...]\n", " x = x * (style[:, 0] + 1.) + style[:, 1]\n", " return x" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "class PixelNormLayer(nn.Module):\n", " def __init__(self, epsilon=1e-8):\n", " super().__init__()\n", " self.epsilon = epsilon\n", " def forward(self, x):\n", " return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Upscale and blur layers\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "class BlurLayer(nn.Module):\n", " def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1):\n", " super(BlurLayer, self).__init__()\n", " kernel = torch.tensor(kernel, dtype=torch.float32)\n", " kernel = kernel[:, None] * kernel[None, :]\n", " kernel = kernel[None, None]\n", " if normalize:\n", " kernel = kernel / kernel.sum()\n", " if flip:\n", " kernel = kernel[:, :, ::-1, ::-1]\n", " self.register_buffer('kernel', kernel)\n", " self.stride = stride\n", " \n", " def forward(self, x):\n", " # expand kernel channels\n", " kernel = self.kernel.expand(x.size(1), -1, -1, -1)\n", " x = F.conv2d(\n", " x,\n", " kernel,\n", " stride=self.stride,\n", " padding=int((self.kernel.size(2)-1)/2),\n", " groups=x.size(1)\n", " )\n", " return x\n", "\n", "def upscale2d(x, factor=2, gain=1):\n", " assert x.dim() == 4\n", " if gain != 1:\n", " x = x * gain\n", " if factor != 1:\n", " shape = x.shape\n", " x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor)\n", " x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3])\n", " return x\n", "\n", "class Upscale2d(nn.Module):\n", " def __init__(self, factor=2, gain=1):\n", " super().__init__()\n", " assert isinstance(factor, int) and factor >= 1\n", " self.gain = gain\n", " self.factor = factor\n", " def forward(self, x):\n", " return upscale2d(x, factor=self.factor, gain=self.gain)\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "class G_mapping(nn.Sequential):\n", " def __init__(self, nonlinearity='lrelu', use_wscale=True):\n", " act, gain = {'relu': (torch.relu, np.sqrt(2)),\n", " 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]\n", " layers = [\n", " ('pixel_norm', PixelNormLayer()),\n", " ('dense0', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense0_act', act),\n", " ('dense1', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense1_act', act),\n", " ('dense2', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense2_act', act),\n", " ('dense3', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense3_act', act),\n", " ('dense4', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense4_act', act),\n", " ('dense5', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense5_act', act),\n", " ('dense6', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense6_act', act),\n", " ('dense7', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),\n", " ('dense7_act', act)\n", " ]\n", " super().__init__(OrderedDict(layers))\n", " \n", " def forward(self, x):\n", " x = super().forward(x)\n", " # Broadcast\n", " x = x.unsqueeze(1).expand(-1, 18, -1)\n", " return x\n", "\n", "class Truncation(nn.Module):\n", " def __init__(self, avg_latent, max_layer=8, threshold=0.7):\n", " super().__init__()\n", " self.max_layer = max_layer\n", " self.threshold = threshold\n", " self.register_buffer('avg_latent', avg_latent)\n", " def forward(self, x):\n", " assert x.dim() == 3\n", " interp = torch.lerp(self.avg_latent, x, self.threshold)\n", " do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1)\n", " return torch.where(do_trunc, interp, x)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class LayerEpilogue(nn.Module):\n", " \"\"\"Things to do at the end of each layer.\"\"\"\n", " def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):\n", " super().__init__()\n", " layers = []\n", " if use_noise:\n", " layers.append(('noise', NoiseLayer(channels)))\n", " layers.append(('activation', activation_layer))\n", " if use_pixel_norm:\n", " layers.append(('pixel_norm', PixelNorm()))\n", " if use_instance_norm:\n", " layers.append(('instance_norm', nn.InstanceNorm2d(channels)))\n", " self.top_epi = nn.Sequential(OrderedDict(layers))\n", " if use_styles:\n", " self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale)\n", " else:\n", " self.style_mod = None\n", " def forward(self, x, dlatents_in_slice=None):\n", " x = self.top_epi(x)\n", " if self.style_mod is not None:\n", " x = self.style_mod(x, dlatents_in_slice)\n", " else:\n", " assert dlatents_in_slice is None\n", " return x\n", "\n", "\n", "class InputBlock(nn.Module):\n", " def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):\n", " super().__init__()\n", " self.const_input_layer = const_input_layer\n", " self.nf = nf\n", " if self.const_input_layer:\n", " # called 'const' in tf\n", " self.const = nn.Parameter(torch.ones(1, nf, 4, 4))\n", " self.bias = nn.Parameter(torch.ones(nf))\n", " else:\n", " self.dense = MyLinear(dlatent_size, nf*16, gain=gain/4, use_wscale=use_wscale) # tweak gain to match the official implementation of Progressing GAN\n", " self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)\n", " self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale)\n", " self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)\n", " \n", " def forward(self, dlatents_in_range):\n", " batch_size = dlatents_in_range.size(0)\n", " if self.const_input_layer:\n", " x = self.const.expand(batch_size, -1, -1, -1)\n", " x = x + self.bias.view(1, -1, 1, 1)\n", " else:\n", " x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4)\n", " x = self.epi1(x, dlatents_in_range[:, 0])\n", " x = self.conv(x)\n", " x = self.epi2(x, dlatents_in_range[:, 1])\n", " return x\n", "\n", "\n", "class GSynthesisBlock(nn.Module):\n", " def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):\n", " # 2**res x 2**res # res = 3..resolution_log2\n", " super().__init__()\n", " if blur_filter:\n", " blur = BlurLayer(blur_filter)\n", " else:\n", " blur = None\n", " self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale,\n", " intermediate=blur, upscale=True)\n", " self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)\n", " self.conv1 = MyConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)\n", " self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)\n", " \n", " def forward(self, x, dlatents_in_range):\n", " x = self.conv0_up(x)\n", " x = self.epi1(x, dlatents_in_range[:, 0])\n", " x = self.conv1(x)\n", " x = self.epi2(x, dlatents_in_range[:, 1])\n", " return x\n", " " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class G_synthesis(nn.Module):\n", " def __init__(self,\n", " dlatent_size = 512, # Disentangled latent (W) dimensionality.\n", " num_channels = 3, # Number of output color channels.\n", " resolution = 1024, # Output resolution.\n", " fmap_base = 8192, # Overall multiplier for the number of feature maps.\n", " fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.\n", " fmap_max = 512, # Maximum number of feature maps in any layer.\n", " use_styles = True, # Enable style inputs?\n", " const_input_layer = True, # First layer is a learned constant?\n", " use_noise = True, # Enable noise inputs?\n", " randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.\n", " nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu'\n", " use_wscale = True, # Enable equalized learning rate?\n", " use_pixel_norm = False, # Enable pixelwise feature vector normalization?\n", " use_instance_norm = True, # Enable instance normalization?\n", " dtype = torch.float32, # Data type to use for activations and outputs.\n", " fused_scale = 'auto', # True = fused convolution + scaling, False = separate ops, 'auto' = decide automatically.\n", " blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering.\n", " structure = 'auto', # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically.\n", " is_template_graph = False, # True = template graph constructed by the Network class, False = actual evaluation.\n", " force_clean_graph = False, # True = construct a clean graph that looks nice in TensorBoard, False = default behavior.\n", " ):\n", " \n", " super().__init__()\n", " def nf(stage):\n", " return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)\n", " self.dlatent_size = dlatent_size\n", " resolution_log2 = int(np.log2(resolution))\n", " assert resolution == 2**resolution_log2 and resolution >= 4\n", " if is_template_graph: force_clean_graph = True\n", " if force_clean_graph: randomize_noise = False\n", " if structure == 'auto': structure = 'linear' if force_clean_graph else 'recursive'\n", "\n", " act, gain = {'relu': (torch.relu, np.sqrt(2)),\n", " 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]\n", " num_layers = resolution_log2 * 2 - 2\n", " num_styles = num_layers if use_styles else 1\n", " torgbs = []\n", " blocks = []\n", " for res in range(2, resolution_log2 + 1):\n", " channels = nf(res-1)\n", " name = '{s}x{s}'.format(s=2**res)\n", " if res == 2:\n", " blocks.append((name,\n", " InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale,\n", " use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))\n", " \n", " else:\n", " blocks.append((name,\n", " GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))\n", " last_channels = channels\n", " self.torgb = MyConv2d(channels, num_channels, 1, gain=1, use_wscale=use_wscale)\n", " self.blocks = nn.ModuleDict(OrderedDict(blocks))\n", " \n", " def forward(self, dlatents_in):\n", " # Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].\n", " # lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype)\n", " batch_size = dlatents_in.size(0) \n", " for i, m in enumerate(self.blocks.values()):\n", " if i == 0:\n", " x = m(dlatents_in[:, 2*i:2*i+2])\n", " else:\n", " x = m(x, dlatents_in[:, 2*i:2*i+2])\n", " rgb = self.torgb(x)\n", " return rgb" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## All done, let's define the model!" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "g_all = nn.Sequential(OrderedDict([\n", " ('g_mapping', G_mapping()),\n", " #('truncation', Truncation(avg_latent)),\n", " ('g_synthesis', G_synthesis()) \n", "]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### But we need weights. Can we use the pretrained ones?\n", "\n", "Yes, we can! The following can be used to convert them from author's format. We have already done this for you, and you can get the weights from \n", "[here](https://github.com/lernapparat/lernapparat/releases/download/v2019-02-01/karras2019stylegan-ffhq-1024x1024.for_g_all.pt).\n", "\n", "Note that the weights are taken from [the reference implementation](https://github.com/NVlabs/stylegan) distributed by NVidia Corporation as Licensed under the CC-BY-NC 4.0 license. As such, the same applies here.\n", "\n", "For completeness, our conversion is below, but you can skip it if you download the PyTorch-ready weights." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "if 0:\n", " # this can be run to get the weights, but you need the reference implementation and weights\n", " import dnnlib, dnnlib.tflib, pickle, torch, collections\n", " dnnlib.tflib.init_tf()\n", " weights = pickle.load(open('./karras2019stylegan-ffhq-1024x1024.pkl','rb'))\n", " weights_pt = [collections.OrderedDict([(k, torch.from_numpy(v.value().eval())) for k,v in w.trainables.items()]) for w in weights]\n", " torch.save(weights_pt, './karras2019stylegan-ffhq-1024x1024.pt')\n", "if 1:\n", " # then on the PyTorch side run\n", " state_G, state_D, state_Gs = torch.load('../karras2019stylegan-ffhq-1024x1024.pt')\n", " def key_translate(k):\n", " k = k.lower().split('/')\n", " if k[0] == 'g_synthesis':\n", " if not k[1].startswith('torgb'):\n", " k.insert(1, 'blocks')\n", " k = '.'.join(k)\n", " k = (k.replace('const.const','const').replace('const.bias','bias').replace('const.stylemod','epi1.style_mod.lin')\n", " .replace('const.noise.weight','epi1.top_epi.noise.weight')\n", " .replace('conv.noise.weight','epi2.top_epi.noise.weight')\n", " .replace('conv.stylemod','epi2.style_mod.lin')\n", " .replace('conv0_up.noise.weight', 'epi1.top_epi.noise.weight')\n", " .replace('conv0_up.stylemod','epi1.style_mod.lin')\n", " .replace('conv1.noise.weight', 'epi2.top_epi.noise.weight')\n", " .replace('conv1.stylemod','epi2.style_mod.lin')\n", " .replace('torgb_lod0','torgb')\n", " .replace('fromrgb_lod0','fromrgb'))\n", " if 'torgb_lod' in k or 'fromrgb_lod' in k: # we don't want the lower layers to/from RGB\n", " k = None\n", " return k\n", "\n", " def weight_translate(k, w):\n", " k = key_translate(k)\n", " if k.endswith('.weight'):\n", " if w.dim() == 2:\n", " w = w.t()\n", " elif w.dim() == 1:\n", " pass\n", " else:\n", " assert w.dim() == 4\n", " w = w.permute(3, 2, 0, 1)\n", " return w\n", "\n", "if 0:\n", " param_dict = {key_translate(k) : weight_translate(k, v) for k,v in state_Gs.items() if key_translate(k) is not None}\n", " if 1:\n", " sd_shapes = {k : v.shape for k,v in g_all.state_dict().items()}\n", " param_shapes = {k : v.shape for k,v in param_dict.items() }\n", "\n", " for k in list(sd_shapes)+list(param_shapes):\n", " pds = param_shapes.get(k)\n", " sds = sd_shapes.get(k)\n", " if pds is None:\n", " print (\"sd only\", k, sds)\n", " elif sds is None:\n", " print (\"pd only\", k, pds)\n", " elif sds != pds:\n", " print (\"mismatch!\", k, pds, sds)\n", "\n", " g_all.load_state_dict(param_dict, strict=False) # needed for the blur kernels\n", " torch.save(g_all.state_dict(), './karras2019stylegan-ffhq-1024x1024.for_g_all.pt')\n", " \n", "if 0:\n", " param_dict = {key_translate(k) : weight_translate(k, v) for k,v in state_D.items() if key_translate(k) is not None}\n", " if 1:\n", " sd_shapes = {k : v.shape for k,v in d_basic.state_dict().items()}\n", " param_shapes = {k : v.shape for k,v in param_dict.items() }\n", "\n", " for k in list(sd_shapes)+list(param_shapes):\n", " pds = param_shapes.get(k)\n", " sds = sd_shapes.get(k)\n", " if pds is None:\n", " print (\"sd only\", k, sds)\n", " elif sds is None:\n", " print (\"pd only\", k, pds)\n", " elif sds != pds:\n", " print (\"mismatch!\", k, pds, sds)\n", "\n", " d_basic.load_state_dict(param_dict, strict=False) # needed for the blur kernels\n", " torch.save(d_basic.state_dict(), '../karras2019stylegan-ffhq-1024x1024.for_d_basic.pt')\n", " " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's load the weights." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "g_all.load_state_dict(torch.load('../karras2019stylegan-ffhq-1024x1024.for_g_all.pt', map_location=device))\n", "g_all.eval()\n", "g_all.to(device);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we're all set! Let's generate faces!" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "if 0:\n", " %matplotlib inline\n", " from matplotlib import pyplot\n", " import torchvision\n", " g_all.eval()\n", " g_all.to(device)\n", "\n", " torch.manual_seed(500)\n", " nb_rows = 2\n", " nb_cols = 8\n", " nb_samples = nb_rows * nb_cols\n", " latents = torch.randn(nb_samples, 512, device=device)\n", " with torch.no_grad():\n", " imgs = g_all(latents)\n", " imgs = (imgs.clamp(-1, 1) + 1) / 2.0 # normalization to 0..1 range\n", " imgs = imgs.cpu()\n", "\n", " imgs = torchvision.utils.make_grid(imgs, nrow=nb_cols)\n", "\n", " pyplot.figure(figsize=(15, 6))\n", " pyplot.imshow(imgs.permute(1, 2, 0).detach().numpy())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Discriminator" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [], "source": [ "class StddevLayer(nn.Module):\n", " def __init__(self, group_size=4, num_new_features=1):\n", " super().__init__()\n", " self.group_size = 4\n", " self.num_new_features = 1\n", " def forward(self, x):\n", " b, c, h, w = x.shape\n", " group_size = min(self.group_size, b)\n", " y = x.reshape([group_size, -1, self.num_new_features,\n", " c // self.num_new_features, h, w])\n", " y = y - y.mean(0, keepdim=True)\n", " y = (y**2).mean(0, keepdim=True)\n", " y = (y + 1e-8)**0.5\n", " y = y.mean([3, 4, 5], keepdim=True).squeeze(3) # don't keep the meaned-out channels\n", " y = y.expand(group_size, -1, -1, h, w).clone().reshape(b, self.num_new_features, h, w)\n", " z = torch.cat([x, y], dim=1)\n", " return z\n", "\n", "class Downscale2d(nn.Module):\n", " def __init__(self, factor=2, gain=1):\n", " super().__init__()\n", " assert isinstance(factor, int) and factor >= 1\n", " self.factor = factor\n", " self.gain = gain\n", " if factor == 2:\n", " f = [np.sqrt(gain) / factor] * factor\n", " self.blur = BlurLayer(kernel=f, normalize=False, stride=factor)\n", " else:\n", " self.blur = None\n", "\n", " def forward(self, x):\n", " assert x.dim()==4\n", " # 2x2, float32 => downscale using _blur2d().\n", " if self.blur is not None and x.dtype == torch.float32:\n", " return self.blur(x)\n", "\n", " # Apply gain.\n", " if self.gain != 1:\n", " x = x * self.gain\n", "\n", " # No-op => early exit.\n", " if factor == 1:\n", " return x\n", "\n", " # Large factor => downscale using tf.nn.avg_pool().\n", " # NOTE: Requires tf_config['graph_options.place_pruned_graph']=True to work.\n", " return F.avg_pool2d(x, self.factor)\n", "\n", "class DiscriminatorBlock(nn.Sequential):\n", " def __init__(self, in_channels, out_channels, gain, use_wscale, activation_layer):\n", " super().__init__(OrderedDict([\n", " ('conv0', MyConv2d(in_channels, in_channels, 3, gain=gain, use_wscale=use_wscale)), # out channels nf(res-1)\n", " ('act0', activation_layer),\n", " ('blur', BlurLayer()),\n", " ('conv1_down', MyConv2d(in_channels, out_channels, 3, gain=gain, use_wscale=use_wscale, downscale=True)),\n", " ('act1', activation_layer)]))\n", "\n", "class View(nn.Module):\n", " def __init__(self, *shape):\n", " super().__init__()\n", " self.shape = shape\n", " def forward(self, x):\n", " return x.view(x.size(0), *self.shape)\n", "\n", "class DiscriminatorTop(nn.Sequential):\n", " def __init__(self, mbstd_group_size, mbstd_num_features, in_channels, intermediate_channels, gain, use_wscale, activation_layer, resolution=4, in_channels2=None, output_features=1, last_gain=1):\n", " layers = []\n", " if mbstd_group_size > 1:\n", " layers.append(('stddev_layer', StddevLayer(mbstd_group_size, mbstd_num_features)))\n", " if in_channels2 is None:\n", " in_channels2 = in_channels\n", " layers.append(('conv', MyConv2d(in_channels + mbstd_num_features, in_channels2, 3, gain=gain, use_wscale=use_wscale)))\n", " layers.append(('act0', activation_layer))\n", " layers.append(('view', View(-1)))\n", " layers.append(('dense0', MyLinear(in_channels2*resolution*resolution, intermediate_channels, gain=gain, use_wscale=use_wscale)))\n", " layers.append(('act1', activation_layer))\n", " layers.append(('dense1', MyLinear(intermediate_channels, output_features, gain=last_gain, use_wscale=use_wscale)))\n", " super().__init__(OrderedDict(layers))\n", " " ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "class D_basic(nn.Sequential):\n", " \n", " def __init__(self,\n", " #images_in, # First input: Images [minibatch, channel, height, width].\n", " #labels_in, # Second input: Labels [minibatch, label_size].\n", " num_channels = 3, # Number of input color channels. Overridden based on dataset.\n", " resolution = 1024, # Input resolution. Overridden based on dataset.\n", " fmap_base = 8192, # Overall multiplier for the number of feature maps.\n", " fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.\n", " fmap_max = 512, # Maximum number of feature maps in any layer.\n", " nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu',\n", " use_wscale = True, # Enable equalized learning rate?\n", " mbstd_group_size = 4, # Group size for the minibatch standard deviation layer, 0 = disable.\n", " mbstd_num_features = 1, # Number of features for the minibatch standard deviation layer.\n", " #blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering.\n", " ):\n", " self.mbstd_group_size = 4\n", " self.mbstd_num_features = 1\n", " resolution_log2 = int(np.log2(resolution))\n", " assert resolution == 2**resolution_log2 and resolution >= 4\n", " def nf(stage):\n", " return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)\n", "\n", " act, gain = {'relu': (torch.relu, np.sqrt(2)),\n", " 'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]\n", " self.gain = gain\n", " self.use_wscale = use_wscale\n", " super().__init__(OrderedDict([\n", " ('fromrgb', MyConv2d(num_channels, nf(resolution_log2-1), 1, gain=gain, use_wscale=use_wscale)),\n", " ('act', act)]\n", " +[('{s}x{s}'.format(s=2**res), DiscriminatorBlock(nf(res-1), nf(res-2), gain=gain, use_wscale=use_wscale, activation_layer=act)) for res in range(resolution_log2, 2, -1)]\n", " +[('4x4', DiscriminatorTop(mbstd_group_size, mbstd_num_features, nf(2), nf(2), gain=gain, use_wscale=use_wscale, activation_layer=act))]))\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "if 1:\n", " d_basic = D_basic()\n", " d_basic.load_state_dict(torch.load('../karras2019stylegan-ffhq-1024x1024.for_d_basic.pt', map_location=device))\n", " d_basic.to(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3rc1" } }, "nbformat": 4, "nbformat_minor": 2 }