{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# DCGan with skorch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Code is adapted from [pytorch examples](https://github.com/pytorch/examples/tree/master/dcgan)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Populating the interactive namespace from numpy and matplotlib\n" ] } ], "source": [ "import matplotlib.pyplot as plt\n", "%pylab inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.parallel\n", "import torch.optim as optim\n", "import torch.utils.data\n", "import torchvision.datasets as dset\n", "import torchvision.transforms as transforms\n", "import torchvision.utils as vutils" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from skorch import NeuralNet\n", "from skorch.utils import to_tensor\n", "from skorch.callbacks import PassthroughScoring" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parameters" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [], "source": [ "torch.manual_seed(0)\n", "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "nz = 100 # size of the latent z vector\n", "ngf = 32 # units of generator\n", "ndf = 32 # units of discriminator\n", "nc = 1 # number of channels\n", "batch_size = 64\n", "lr = 0.0002\n", "beta1 = 0.5 # for adam\n", "max_epochs = 5\n", "ngpu = 1\n", "img_size = 32 # 32 is easier than 28 to work with\n", "workers = 2 # for dataloader" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "path = './mnist'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "dataset = dset.MNIST(\n", " root=path,\n", " download=True,\n", " transform=transforms.Compose([\n", " transforms.Resize(img_size),\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.5,), (0.5,)),\n", " ]),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Custom code" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# custom weights initialization called on generator and discriminator\n", "def weights_init(m):\n", " classname = m.__class__.__name__\n", " if classname.find('Conv') != -1:\n", " m.weight.data.normal_(0.0, 0.02)\n", " elif classname.find('BatchNorm') != -1:\n", " m.weight.data.normal_(1.0, 0.02)\n", " m.bias.data.fill_(0)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "class Generator(nn.Module):\n", " def __init__(self, nz, ngf, ngpu):\n", " super().__init__()\n", " self.nz = nz\n", " self.ngf = ngf\n", " self.ngpu = ngpu\n", "\n", " self.main = nn.Sequential(\n", " # input is Z, going into a convolution\n", " nn.ConvTranspose2d( nz, ngf * 4, 4, 1, 0, bias=False),\n", " nn.BatchNorm2d(ngf * 4),\n", " nn.ReLU(True),\n", " # state size. (ngf*4) x 4 x 4\n", " nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),\n", " nn.BatchNorm2d(ngf * 2),\n", " nn.ReLU(True),\n", " # state size. (ngf*2) x 8 x 8\n", " nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),\n", " nn.BatchNorm2d(ngf),\n", " nn.ReLU(True),\n", " # state size. (ngf) x 16 x 16\n", " nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),\n", " nn.Tanh(),\n", " # state size. (nc) x 32 x 32\n", " )\n", "\n", " def forward(self, input):\n", " if input.is_cuda and self.ngpu > 1:\n", " output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))\n", " else:\n", " output = self.main(input)\n", " return output" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "class Discriminator(nn.Module):\n", " def __init__(self, nc, ndf, ngpu):\n", " super().__init__()\n", " self.nc = nc\n", " self.ndf = ndf\n", " self.ngpu = ngpu\n", "\n", " self.main = nn.Sequential(\n", " # state size. (ndf) x 32 x 32\n", " nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),\n", " nn.LeakyReLU(0.2, inplace=True),\n", " # state size. (ndf) x 16 x 16\n", " nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),\n", " nn.BatchNorm2d(ndf * 2),\n", " nn.LeakyReLU(0.2, inplace=True),\n", " # state size. (ndf*2) x 8 x 8\n", " nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),\n", " nn.BatchNorm2d(ndf * 4),\n", " nn.LeakyReLU(0.2, inplace=True),\n", " # state size. (ndf*4) x 4 x 4\n", " nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),\n", " nn.Sigmoid()\n", " )\n", "\n", " def forward(self, input):\n", " if input.is_cuda and self.ngpu > 1:\n", " output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))\n", " else:\n", " output = self.main(input)\n", "\n", " return output.view(-1, 1).squeeze(1)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class Dcgan(nn.Module):\n", " def __init__(self, nc, nz, ndf, ngf, ngpu):\n", " super().__init__()\n", " \n", " self.nc = nc\n", " self.nz = nz\n", " self.ndf = ndf\n", " self.ngf = ngf\n", " self.ngpu = ngpu\n", "\n", " self.discriminator = Discriminator(\n", " nc=self.nc,\n", " ndf=self.ndf,\n", " ngpu=self.ngpu,\n", " )\n", " self.discriminator.apply(weights_init)\n", " self.generator = Generator(\n", " nz=self.nz,\n", " ngf=self.ngf,\n", " ngpu=self.ngpu,\n", " )\n", " self.discriminator.apply(weights_init)\n", " \n", " def forward(self, X, y=None):\n", " # general forward method just returns fake images\n", " return self.generator(X)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "class DcganNet(NeuralNet):\n", " def __init__(self, *args, optimizer_gen, optimizer_dis, **kwargs):\n", " self.optimizer_gen = optimizer_gen\n", " self.optimizer_dis = optimizer_dis\n", " super().__init__(*args, **kwargs)\n", "\n", " def initialize_optimizer(self, *_, **__):\n", " args, kwargs = self.get_params_for_optimizer(\n", " 'optimizer_gen', self.module_.generator.named_parameters())\n", " self.optimizer_gen_ = self.optimizer_gen(*args, **kwargs)\n", "\n", " args, kwargs = self.get_params_for_optimizer(\n", " 'optimizer_dis', self.module_.discriminator.named_parameters())\n", " self.optimizer_dis_ = self.optimizer_dis(*args, **kwargs)\n", "\n", " return self\n", " \n", " def validation_step(self, Xi, yi, **fit_params):\n", " raise NotImplementedError\n", " \n", " def train_step(self, Xi, yi=None, **fit_params):\n", " Xi = to_tensor(Xi, device=self.device)\n", " discriminator = self.module_.discriminator\n", " generator = self.module_.generator\n", " label_real = torch.ones((len(Xi),),device=self.device)\n", " label_fake = torch.zeros((len(Xi),),device=self.device)\n", "\n", " # (1) Update discriminator: maximize log(D(x)) + log(1 - D(G(z)))\n", " discriminator.zero_grad()\n", " output_real = discriminator(Xi)\n", " loss_real = self.criterion_(output_real, label_real)\n", " loss_real.backward()\n", "\n", " noise = torch.randn(Xi.shape[0], self.module_.nz, 1, 1, device=self.device)\n", " fake = generator(noise)\n", " output_fake = discriminator(fake.detach())\n", " loss_fake = self.criterion_(output_fake, label_fake)\n", " loss_fake.backward()\n", " \n", " self.optimizer_dis_.step()\n", " \n", " # (2) Update generator: maximize log(D(G(z)))\n", " generator.zero_grad()\n", " output_fake = discriminator(fake)\n", " loss_gen = self.criterion_(output_fake, label_real)\n", " loss_gen.backward()\n", " self.optimizer_gen_.step()\n", " \n", " loss_dis = loss_real + loss_fake\n", " \n", " self.history.record_batch('loss_dis', loss_dis.item())\n", " self.history.record_batch('loss_gen', loss_gen.item())\n", " \n", " return {\n", " 'y_pred': fake,\n", " 'loss': loss_dis + loss_gen,\n", " }" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "net = DcganNet(\n", " Dcgan,\n", " module__nz=nz,\n", " module__ndf=ndf,\n", " module__ngf=ngf,\n", " module__nc=nc,\n", " module__ngpu=ngpu,\n", " \n", " criterion=nn.BCELoss,\n", "\n", " optimizer_gen=optim.Adam,\n", " optimizer_gen__lr=0.0002,\n", " optimizer_gen__betas=(beta1, 0.999),\n", "\n", " optimizer_dis=optim.Adam,\n", " optimizer_dis__lr=0.00002,\n", " optimizer_dis__betas=(beta1, 0.999),\n", "\n", " batch_size=batch_size,\n", " max_epochs=max_epochs,\n", "\n", " train_split=False, # not implemented\n", " iterator_train__shuffle=True,\n", " iterator_train__num_workers=workers,\n", " iterator_valid__num_workers=workers,\n", "\n", " callbacks=[\n", " PassthroughScoring('loss_dis', on_train=True),\n", " PassthroughScoring('loss_gen', on_train=True),\n", " ],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train net" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "[uninitialized](\n", " module=,\n", " module_=Dcgan(\n", " (discriminator): Discriminator(\n", " (main): Sequential(\n", " (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (1): LeakyReLU(negative_slope=0.2, inplace=True)\n", " (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (4): LeakyReLU(negative_slope=0.2, inplace=True)\n", " (5): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (7): LeakyReLU(negative_slope=0.2, inplace=True)\n", " (8): Conv2d(128, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", " (9): Sigmoid()\n", " )\n", " )\n", " (generator): Generator(\n", " (main): Sequential(\n", " (0): ConvTranspose2d(100, 128, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU(inplace=True)\n", " (3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (5): ReLU(inplace=True)\n", " (6): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (8): ReLU(inplace=True)\n", " (9): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (10): Tanh()\n", " )\n", " )\n", " ),\n", " module__nc=1,\n", " module__ndf=32,\n", " module__ngf=32,\n", " module__ngpu=1,\n", " module__nz=100,\n", ")" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "net.set_params(optimizer_gen__lr=0.0003, optimizer_dis__lr=0.0004)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " epoch loss_dis loss_gen train_loss dur\n", "------- ---------- ---------- ------------ --------\n", " 1 \u001b[36m0.4713\u001b[0m \u001b[32m3.2007\u001b[0m \u001b[35m3.6720\u001b[0m 288.9744\n", " 2 0.4876 \u001b[32m2.8905\u001b[0m \u001b[35m3.3781\u001b[0m 363.5930\n", " 3 \u001b[36m0.3492\u001b[0m 3.6009 3.9501 300.6626\n", " 4 \u001b[36m0.3198\u001b[0m 3.7600 4.0798 297.8958\n", " 5 \u001b[36m0.1154\u001b[0m 5.0819 5.1972 293.7685\n" ] }, { "data": { "text/plain": [ "[initialized](\n", " module_=Dcgan(\n", " (discriminator): Discriminator(\n", " (main): Sequential(\n", " (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (1): LeakyReLU(negative_slope=0.2, inplace=True)\n", " (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (4): LeakyReLU(negative_slope=0.2, inplace=True)\n", " (5): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (6): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (7): LeakyReLU(negative_slope=0.2, inplace=True)\n", " (8): Conv2d(128, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", " (9): Sigmoid()\n", " )\n", " )\n", " (generator): Generator(\n", " (main): Sequential(\n", " (0): ConvTranspose2d(100, 128, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", " (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): ReLU(inplace=True)\n", " (3): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (5): ReLU(inplace=True)\n", " (6): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (8): ReLU(inplace=True)\n", " (9): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (10): Tanh()\n", " )\n", " )\n", " ),\n", ")" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#net.fit(torch.utils.data.Subset(dataset, torch.arange(0, 500)))\n", "net.fit(dataset)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(net.history[:, 'loss_dis'], label='loss discriminator')\n", "plt.plot(net.history[:, 'loss_gen'], label='loss generator')\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inspect images" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "noise = torch.randn(10, nz, 1, 1)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "fakes = net.predict(noise)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "_, axes = plt.subplots(10, 1, figsize=(8, 30))\n", "for i in range(10):\n", " axes[i].imshow(fakes[i][0], cmap='gray')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.2" } }, "nbformat": 4, "nbformat_minor": 2 }