{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generators for the resnets" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def convT_norm_relu(ch_in:int, ch_out:int, norm_layer:nn.Module, ks:int=3, stride:int=2, bias:bool=True):\n", " return [nn.ConvTranspose2d(ch_in, ch_out, kernel_size=ks, stride=stride, padding=1, output_padding=1, bias=bias),\n", " norm_layer(ch_out), nn.ReLU(True)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def pad_conv_norm_relu(ch_in:int, ch_out:int, pad_mode:str, norm_layer:nn.Module, ks:int=3, bias:bool=True, \n", " pad=1, stride:int=1, activ:bool=True, init:Callable=nn.init.kaiming_normal_)->List[nn.Module]:\n", " layers = []\n", " if pad_mode == 'reflection': layers.append(nn.ReflectionPad2d(pad))\n", " elif pad_mode == 'border': layers.append(nn.ReplicationPad2d(pad))\n", " p = pad if pad_mode == 'zeros' else 0\n", " conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=p, stride=stride, bias=bias)\n", " if init:\n", " init(conv.weight)\n", " if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)\n", " layers += [conv, norm_layer(ch_out)]\n", " if activ: layers.append(nn.ReLU(inplace=True))\n", " return layers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ResnetBlock(nn.Module):\n", " def __init__(self, dim:int, pad_mode:str='reflection', norm_layer:nn.Module=None, dropout:float=0., bias:bool=True):\n", " super().__init__()\n", " assert pad_mode in ['zeros', 'reflection', 'border'], f'padding {pad_mode} not implemented.'\n", " norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n", " layers = pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias)\n", " if dropout != 0: layers.append(nn.Dropout(dropout))\n", " layers += pad_conv_norm_relu(dim, dim, pad_mode, norm_layer, bias=bias, activ=False)\n", " self.conv_block = nn.Sequential(*layers)\n", "\n", " def forward(self, x): return x + self.conv_block(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def resnet_generator(ch_in:int, ch_out:int, n_ftrs:int=64, norm_layer:nn.Module=None, \n", " dropout:float=0., n_blocks:int=6, pad_mode:str='reflection')->nn.Module:\n", " norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n", " bias = (norm_layer == nn.InstanceNorm2d)\n", " layers = pad_conv_norm_relu(ch_in, n_ftrs, 'reflection', norm_layer, pad=3, ks=7, bias=bias)\n", " for i in range(2):\n", " layers += pad_conv_norm_relu(n_ftrs, n_ftrs *2, 'zeros', norm_layer, stride=2, bias=bias)\n", " n_ftrs *= 2\n", " layers += [ResnetBlock(n_ftrs, pad_mode, norm_layer, dropout, bias) for _ in range(n_blocks)]\n", " for i in range(2):\n", " layers += convT_norm_relu(n_ftrs, n_ftrs//2, norm_layer, bias=bias)\n", " n_ftrs //= 2\n", " layers += [nn.ReflectionPad2d(3), nn.Conv2d(n_ftrs, ch_out, kernel_size=7, padding=0), nn.Tanh()]\n", " return nn.Sequential(*layers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = resnet_generator(3, 3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def conv_norm_lr(ch_in:int, ch_out:int, norm_layer:nn.Module=None, ks:int=3, bias:bool=True, pad:int=1, stride:int=1, \n", " activ:bool=True, slope:float=0.2, init:Callable=nn.init.kaiming_normal_)->List[nn.Module]:\n", " conv = nn.Conv2d(ch_in, ch_out, kernel_size=ks, padding=pad, stride=stride, bias=bias)\n", " if init:\n", " init(conv.weight)\n", " if hasattr(conv, 'bias') and hasattr(conv.bias, 'data'): conv.bias.data.fill_(0.)\n", " layers = [conv]\n", " if norm_layer is not None: layers.append(norm_layer(ch_out))\n", " if activ: layers.append(nn.LeakyReLU(slope, inplace=True))\n", " return layers" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def discriminator(ch_in:int, n_ftrs:int=64, n_layers:int=3, norm_layer:nn.Module=None, sigmoid:bool=False)->nn.Module:\n", " norm_layer = ifnone(norm_layer, nn.InstanceNorm2d)\n", " bias = (norm_layer == nn.InstanceNorm2d)\n", " layers = conv_norm_lr(ch_in, n_ftrs, ks=4, stride=2, pad=1)\n", " for i in range(n_layers-1):\n", " new_ftrs = 2*n_ftrs if i <= 3 else n_ftrs\n", " layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=2, pad=1, bias=bias)\n", " n_ftrs = new_ftrs\n", " new_ftrs = 2*n_ftrs if n_layers <=3 else n_ftrs\n", " layers += conv_norm_lr(n_ftrs, new_ftrs, norm_layer, ks=4, stride=1, pad=1, bias=bias)\n", " layers.append(nn.Conv2d(new_ftrs, 1, kernel_size=4, stride=1, padding=1))\n", " if sigmoid: layers.append(nn.Sigmoid())\n", " return nn.Sequential(*layers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "discriminator(3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CycleGAN(nn.Module):\n", " \n", " def __init__(self, ch_in:int, ch_out:int, n_features:int=64, disc_layers:int=3, gen_blocks:int=6, lsgan:bool=True, \n", " drop:float=0., norm_layer:nn.Module=None):\n", " super().__init__()\n", " self.D_A = discriminator(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)\n", " self.D_B = discriminator(ch_in, n_features, disc_layers, norm_layer, sigmoid=not lsgan)\n", " self.G_A = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)\n", " self.G_B = resnet_generator(ch_in, ch_out, n_features, norm_layer, drop, gen_blocks)\n", " #G_A: takes real input B and generates fake input A\n", " #G_B: takes real input A and generates fake input B\n", " #D_A: trained to make the difference between real input A and fake input A\n", " #D_B: trained to make the difference between real input B and fake input B\n", " \n", " def forward(self, real_A, real_B):\n", " fake_A, fake_B = self.G_A(real_B), self.G_B(real_A)\n", " if not self.training: return torch.cat([fake_A[:,None],fake_B[:,None]], 1)\n", " idt_A, idt_B = self.G_A(real_A), self.G_B(real_B)\n", " return [fake_A, fake_B, idt_A, idt_B]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class AdaptiveLoss(nn.Module):\n", " def __init__(self, crit):\n", " super().__init__()\n", " self.crit = crit\n", " \n", " def forward(self, output, target:bool):\n", " targ = output.new_ones(*output.size()) if target else output.new_zeros(*output.size())\n", " return self.crit(output, targ)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class CycleGanLoss(nn.Module):\n", " \n", " def __init__(self, cgan:nn.Module, lambda_A:float=10., lambda_B:float=10, lambda_idt:float=0.5, lsgan:bool=True):\n", " super().__init__()\n", " self.cgan,self.l_A,self.l_B,self.l_idt = cgan,lambda_A,lambda_B,lambda_idt\n", " self.crit = AdaptiveLoss(F.mse_loss if lsgan else F.binary_cross_entropy)\n", " \n", " def set_input(self, input):\n", " self.real_A,self.real_B = input\n", " \n", " def forward(self, output, target):\n", " fake_A, fake_B, idt_A, idt_B = output\n", " #Generators should return identity on the datasets they try to convert to\n", " loss = self.l_idt * (self.l_B * F.l1_loss(idt_A, self.real_B) + self.l_A * F.l1_loss(idt_B, self.real_A))\n", " #Generators are trained to trick the discriminators so the following should be ones\n", " loss += self.crit(self.cgan.D_A(fake_A), True) + self.crit(self.cgan.D_B(fake_B), True)\n", " #Cycle loss\n", " loss += self.l_A * F.l1_loss(self.cgan.G_A(fake_B), self.real_A)\n", " loss += self.l_B * F.l1_loss(self.cgan.G_B(fake_A), self.real_B)\n", " return loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class CycleGANTrainer(LearnerCallback):\n", " \n", " def _set_trainable(self, D_A=False, D_B=False):\n", " gen = (not D_A) and (not D_B)\n", " requires_grad(self.learn.model.G_A, gen)\n", " requires_grad(self.learn.model.G_B, gen)\n", " requires_grad(self.learn.model.D_A, D_A)\n", " requires_grad(self.learn.model.D_B, D_B)\n", " if not gen:\n", " self.opt_D_A.lr, self.opt_D_A.mom = self.learn.opt.lr, self.learn.opt.mom\n", " self.opt_D_A.wd, self.opt_D_A.beta = self.learn.opt.wd, self.learn.opt.beta\n", " self.opt_D_B.lr, self.opt_D_B.mom = self.learn.opt.lr, self.learn.opt.mom\n", " self.opt_D_B.wd, self.opt_D_B.beta = self.learn.opt.wd, self.learn.opt.beta\n", " \n", " def on_train_begin(self, **kwargs):\n", " self.G_A,self.G_B = self.learn.model.G_A,self.learn.model.G_B\n", " self.D_A,self.D_B = self.learn.model.D_A,self.learn.model.D_B\n", " self.crit = self.learn.loss_func.crit\n", " self.opt_G = self.learn.opt.new([nn.Sequential(*flatten_model(self.G_A), *flatten_model(self.G_B))])\n", " self.opt_D_A = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_A))])\n", " self.opt_D_B = self.learn.opt.new([nn.Sequential(*flatten_model(self.D_B))])\n", " self.learn.opt.opt = self.opt_G.opt\n", " self._set_trainable()\n", " \n", " def on_batch_begin(self, last_input, **kwargs):\n", " self.learn.loss_func.set_input(last_input)\n", " \n", " def on_batch_end(self, last_input, last_output, **kwargs):\n", " self.G_A.zero_grad(); self.G_B.zero_grad()\n", " fake_A, fake_B = last_output[0].detach(), last_output[1].detach()\n", " real_A, real_B = last_input\n", " self._set_trainable(D_A=True)\n", " self.D_A.zero_grad()\n", " loss_D_A = 0.5 * (self.crit(self.D_A(real_A), True) + self.crit(self.D_A(fake_A), False))\n", " loss_D_A.backward()\n", " self.opt_D_A.step()\n", " self._set_trainable(D_B=True)\n", " self.D_B.zero_grad()\n", " loss_D_B = 0.5 * (self.crit(self.D_B(real_B), True) + self.crit(self.D_B(fake_B), False))\n", " loss_D_B.backward()\n", " self.opt_D_B.step()\n", " self._set_trainable()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = Path('data/horse2zebra')\n", "path.ls()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ImageTuple(ItemBase):\n", " def __init__(self, img1, img2):\n", " self.img1,self.img2 = img1,img2\n", " self.obj,self.data = (img1,img2),[-1+2*img1.data,-1+2*img2.data]\n", " \n", " def apply_tfms(self, tfms, **kwargs):\n", " self.img1 = self.img1.apply_tfms(tfms, **kwargs)\n", " self.img2 = self.img2.apply_tfms(tfms, **kwargs)\n", " return self\n", " \n", " def to_one(self): return Image(0.5+torch.cat(self.data,2)/2)\n", " \n", " def show_xys(self, xs, ys, figsize:Tuple[int,int]=(12,6), **kwargs):\n", " \"Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method.\"\n", " rows = int(math.sqrt(len(xs)))\n", " fig, axs = plt.subplots(rows,rows,figsize=figsize)\n", " for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):\n", " xs[i].to_one().show(ax=ax, **kwargs)\n", " plt.tight_layout()\n", "\n", " def show_xyzs(self, xs, ys, zs, figsize:Tuple[int,int]=None, **kwargs):\n", " \"\"\"Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`.\n", " `kwargs` are passed to the show method.\"\"\"\n", " figsize = ifnone(figsize, (12,3*len(xs)))\n", " fig,axs = plt.subplots(len(xs), 2, figsize=figsize)\n", " fig.suptitle('Ground truth / Predictions', weight='bold', size=14)\n", " for i,(x,z) in enumerate(zip(xs,zs)):\n", " x.to_one().show(ax=axs[i,0], **kwargs)\n", " z.to_one().show(ax=axs[i,1], **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class ImageTupleList(ImageItemList):\n", " _label_cls=TargetTupleList\n", " def __init__(self, items, itemsB=None, **kwargs):\n", " self.itemsB = itemsB\n", " super().__init__(items, **kwargs)\n", " \n", " def new(self, items, **kwargs):\n", " return super().new(items, itemsB=self.itemsB, **kwargs)\n", " \n", " def get(self, i):\n", " img1 = super().get(i)\n", " fn = self.itemsB[random.randint(0, len(self.itemsB)-1)]\n", " return ImageTuple(img1, open_image(fn))\n", " \n", " def reconstruct(self, t:Tensor): \n", " return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))\n", " \n", " @classmethod\n", " def from_folders(cls, path, folderA, folderB, **kwargs):\n", " itemsB = ImageItemList.from_folder(path/folderB).items\n", " res = super().from_folder(path/folderA, itemsB=itemsB, **kwargs)\n", " res.path = path\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TargetTupleList(ItemList):\n", " def reconstruct(self, t:Tensor): \n", " if len(t.size()) == 0: return t\n", " return ImageTuple(Image(t[0]/2+0.5),Image(t[1]/2+0.5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "itemsB = ImageItemList.from_folder(path/'trainB').items\n", "tst = (ImageTupleList.from_folder(path/'trainA', itemsB=itemsB)\n", " .split_by_idx([])\n", " .label_const(ImageTuple(Image(torch.ones(1,1,1)),Image(torch.ones(1,1,1)))))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = tst.transform(get_transforms(), size=128).databunch(bs=8)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = (ImageTupleList.from_folders(path, 'trainA', 'trainB')\n", " .split_by_idx([])\n", " .label_const(0.)\n", " .transform(get_transforms(), size=128)\n", " .databunch(bs=4))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch(rows=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cycle_gan = CycleGAN(3,3, gen_blocks=9)\n", "learn = Learner(data, cycle_gan, loss_func=CycleGanLoss(cycle_gan), opt_func=partial(optim.Adam, betas=(0.5,0.99)),\n", " callback_fns=[CycleGANTrainer])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.recorder.plot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#learn.load('20epoch')\n", "learn.fit_one_cycle(100,8e-4,moms=(0.5,0.5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('100epoch')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.load('100epocha')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.show_results(DatasetType.Train)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }