{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LSun data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "PATH = Path('../data/lsun')\n", "IMG_PATH = PATH/'bedroom'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def create_csv_file(sample=False):\n", " files = PATH.glob('bedroom/**/*.jpg')\n", " with (PATH/'files.csv').open('w') as fo:\n", " for f in files: \n", " if not sample or random.random() < 0.1: fo.write(f'{f},0\\n')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#create_csv_file(sample=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(PATH/'files.csv', header=None)\n", "fns, ys = np.array(df[0]), np.array(df[1])\n", "train_ds = ImageDataset(fns, ys)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "size = 64\n", "train_tds = DatasetTfm(train_ds, tfms = [crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))], size=size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "norm, denorm = normalize_funcs(mean = torch.tensor([0.5,0.5,0.5]), std = torch.tensor([0.5,0.5,0.5]))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = DataBunch.create(train_tds, valid_ds=None, path=PATH, bs=128, tfms=[norm])\n", "data.valid_dl = None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Models" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def conv_layer1(ni:int, nf:int, ks:int=3, stride:int=1, padding:int=None, bias:bool=False, bn:bool=True, \n", " leaky:bool=False, slope:float=0.1, transpose:bool=False):\n", " if padding is None: padding = (ks-1)//2 if not transpose else 0\n", " conv_func = nn.ConvTranspose2d if transpose else nn.Conv2d\n", " activ = nn.LeakyReLU(inplace=True, negative_slope=slope) if leaky else nn.ReLU(inplace=True) \n", " layers = [conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), activ]\n", " if bn: layers.append(nn.BatchNorm2d(nf))\n", " return nn.Sequential(*layers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def AvgFlatten():\n", " return Lambda(lambda x: x.mean(0).view(1))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def discriminator(in_size, n_channels, n_features, n_extra_layers=0):\n", " layers = [conv_layer1(n_channels, n_features, 4, 2, 1, bn=False, leaky=True, slope=0.2)]\n", " cur_size, cur_ftrs = in_size//2, n_features\n", " layers.append(nn.Sequential(*[conv_layer1(cur_ftrs, cur_ftrs, 3, 1, leaky=True, slope=0.2) for _ in range(n_extra_layers)]))\n", " while cur_size > 4:\n", " layers.append(conv_layer1(cur_ftrs, cur_ftrs*2, 4, 2, 1, leaky=True, slope=0.2))\n", " cur_ftrs *= 2 ; cur_size //= 2\n", " layers += [conv2d(cur_ftrs, 1, 4, padding=0), AvgFlatten()]\n", " return nn.Sequential(*layers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def generator(in_size, noise_sz, n_channels, n_features, n_extra_layers=0):\n", " cur_size, cur_ftrs = 4, n_features//2\n", " while cur_size < in_size: cur_size *= 2; cur_ftrs *= 2\n", " layers = [conv_layer1(noise_sz, cur_ftrs, 4, 1, transpose=True)]\n", " cur_size = 4\n", " while cur_size < in_size // 2:\n", " layers.append(conv_layer1(cur_ftrs, cur_ftrs//2, 4, 2, 1, transpose=True))\n", " cur_ftrs //= 2; cur_size *= 2\n", " layers += [conv_layer1(cur_ftrs, cur_ftrs, 3, 1, 1, transpose=True) for _ in range(n_extra_layers)]\n", " layers += [conv2d_trans(cur_ftrs, n_channels, 4, 2, 1, bias=False), nn.Tanh()]\n", " return nn.Sequential(*layers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "generator(64, 100, 3, 64, 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "discriminator(64, 3, 64, 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class BasicGAN(nn.Module):\n", " \n", " def __init__(self, in_size, noise_sz, n_channels, n_features, n_extra_layers=0):\n", " super().__init__()\n", " self.discriminator = discriminator(in_size, n_channels, n_features, n_extra_layers)\n", " self.generator = generator(in_size, noise_sz, n_channels, n_features, n_extra_layers)\n", " \n", " def forward(self, x, gen=False):\n", " return self.generator(x) if gen else self.discriminator(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def first_disc_iter(gen_iter):\n", " return 100 if (gen_iter < 25 or gen_iter%500 == 0) else 5\n", "\n", "def standard_disc_iter(gen_iter):\n", " return 100 if gen_iter%500 == 0 else 5" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "noise_sz = 100\n", "def create_noise(x, b, grad=True): return x.new(b, noise_sz, 1, 1).normal_(0, 1).requires_grad_(grad)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class WasserteinLoss(nn.Module):\n", " \n", " def forward(self, real, fake): return real[0] - fake[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class GANTrainer(LearnerCallback):\n", " loss_fn:LossFunction = WasserteinLoss()\n", " n_disc_iter:Callable = standard_disc_iter\n", " clip:float = 0.01\n", " bs:int = 64\n", " \n", " def _set_trainable(self, gen=False):\n", " requires_grad(self.learn.model.generator, gen)\n", " requires_grad(self.learn.model.discriminator, not gen)\n", " if gen:\n", " self.opt_gen.lr, self.opt_gen.mom = self.learn.opt.lr, self.learn.opt.mom\n", " self.opt_gen.wd, self.opt_gen.beta = self.learn.opt.wd, self.learn.opt.beta\n", " \n", " def on_train_begin(self, **kwargs):\n", " opt_fn = self.learn.opt_fn\n", " lr, wd, true_wd, bn_wd = self.learn.opt.lr, self.learn.opt.wd, self.learn.opt.true_wd, self.learn.opt.bn_wd\n", " self.opt_gen = OptimWrapper.create(opt_fn, lr, \n", " [nn.Sequential(*flatten_model(self.learn.model.generator))], \n", " wd=wd, true_wd=true_wd, bn_wd=bn_wd)\n", " self.opt_disc = OptimWrapper.create(opt_fn, lr, \n", " [nn.Sequential(*flatten_model(self.learn.model.discriminator))],\n", " wd=wd, true_wd=true_wd, bn_wd=bn_wd)\n", " self.learn.opt.opt = self.opt_disc.opt\n", " self.disc_iters, self.gen_iters = 0, 0\n", " self._set_trainable()\n", " self.dlosses,self.glosses = [],[]\n", " \n", " def on_batch_begin(self, **kwargs):\n", " for p in self.learn.model.discriminator.parameters(): \n", " p.data.clamp_(-self.clip, self.clip)\n", " \n", " def on_backward_begin(self, last_output, last_input, **kwargs):\n", " fake = self.learn.model(create_noise(last_input, last_input.size(0), False), gen=True)\n", " fake.requires_grad_(True)\n", " loss = self.loss_fn(last_output, self.learn.model(fake))\n", " self.dlosses.append(loss.detach().cpu())\n", " return loss\n", " \n", " def on_batch_end(self, last_input, **kwargs):\n", " self.disc_iters += 1\n", " if self.disc_iters == self.n_disc_iter(self.gen_iters):\n", " self.disc_iters = 0\n", " self._set_trainable(True)\n", " loss = self.learn.model(self.learn.model(create_noise(last_input,self.bs), gen=True)).mean().view(1)[0]\n", " self.glosses.append(loss.detach().cpu())\n", " self.learn.model.generator.zero_grad()\n", " loss.backward()\n", " self.opt_gen.step()\n", " self.gen_iters += 1\n", " self._set_trainable()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class NoopLoss(nn.Module):\n", " \n", " def forward(self, output, target): return output[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "wgan = BasicGAN(64, 100, 3, 64, 1)\n", "learn = Learner(data, wgan, loss_fn=NoopLoss(), opt_fn=optim.RMSprop, wd=0.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cb = GANTrainer(learn, bs=128, n_disc_iter=first_disc_iter)\n", "learn.callbacks.append(cb)\n", "learn.fit(1, 1e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x,y = next(iter(learn.data.train_dl))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = learn.model(create_noise(x,64,False), gen=True)\n", "imgs = denorm(tst.cpu()).numpy().clip(0,1)\n", "fig,axs = plt.subplots(5,5,figsize=(8,8))\n", "for i,ax in enumerate(axs.flatten()):\n", " ax.imshow(imgs[i].transpose(1,2,0))\n", " ax.axis('off')\n", "plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.save('temp')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cb = GANTrainer(learn, bs=128, n_disc_iter=standard_disc_iter)\n", "learn.callbacks.append(cb)\n", "learn.fit(1, 1e-5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tst = learn.model(create_noise(x,64,False), gen=True)\n", "imgs = denorm(tst.cpu()).numpy().clip(0,1)\n", "fig,axs = plt.subplots(5,5,figsize=(8,8))\n", "for i,ax in enumerate(axs.flatten()):\n", " ax.imshow(imgs[i].transpose(1,2,0))\n", " ax.axis('off')\n", "plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }