{"cells":[{"cell_type":"markdown","metadata":{},"source":["# Generative Adversarial Networks (GANs)"]},{"cell_type":"markdown","metadata":{},"source":["## Loading data"]},{"cell_type":"code","execution_count":null,"metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","trusted":true},"outputs":[],"source":["import os\n","import requests\n","import zipfile\n","directory = \"./tmp/all-dogs\" \n","zip_url = \"https://static-1300131294.cos.ap-shanghai.myqcloud.com/data/deep-learning/Gan/all-dogs.zip\"\n","os.makedirs(directory, exist_ok=True)\n","\n","response = requests.get(zip_url)\n","zip_filename = os.path.join(directory, \"all-dogs.zip\")\n","\n","with open(zip_filename, \"wb\") as file:\n"," file.write(response.content)\n","print(\"ZIP File successfully downloaded\")\n","\n","with zipfile.ZipFile(zip_filename, \"r\") as zip_ref:\n"," zip_ref.extractall(directory)\n","\n","print(\"ZIP File successfully unzipped\")\n","os.remove(zip_filename)\n"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["print(os.listdir(\"./tmp\"))"]},{"cell_type":"markdown","metadata":{},"source":["## Importing the libraries"]},{"cell_type":"code","execution_count":null,"metadata":{"_kg_hide-input":true,"trusted":true},"outputs":[],"source":["from __future__ import print_function\n","import time\n","import torch\n","import torch.nn as nn\n","import torch.nn.parallel\n","import torch.optim as optim\n","import torch.utils.data\n","import torchvision.datasets as dset\n","import torchvision.transforms as transforms\n","import torchvision.utils as vutils\n","from torch.autograd import Variable\n","import matplotlib.pyplot as plt\n","import numpy as np\n","from torch import nn, optim\n","import torch.nn.functional as F\n","from torchvision import datasets, transforms\n","from torchvision.utils import save_image\n","import matplotlib.pyplot as plt\n","import matplotlib.image as mpimg\n","from tqdm import tqdm_notebook as tqdm"]},{"cell_type":"markdown","metadata":{},"source":["## Some dogs\n","\n","The Stanford Dogs dataset contains images of 120 breeds of dogs from around the world."]},{"cell_type":"code","execution_count":null,"metadata":{"_kg_hide-input":true,"trusted":true},"outputs":[],"source":["PATH = './tmp/all-dogs/dogs/'\n","images = os.listdir(PATH)\n","print(f'There are {len(os.listdir(PATH))} pictures of dogs.')\n","\n","fig, axes = plt.subplots(nrows=3, ncols=3, figsize=(12, 10))\n","\n","for indx, axis in enumerate(axes.flatten()):\n"," rnd_indx = np.random.randint(0, len(os.listdir(PATH)))\n"," img = plt.imread(PATH + images[rnd_indx])\n"," imgplot = axis.imshow(img)\n"," axis.set_title(images[rnd_indx])\n"," axis.set_axis_off()\n","plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n"]},{"cell_type":"markdown","metadata":{},"source":["## Image Preprocessing\n"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["batch_size = 32\n","image_size = 64\n","\n","random_transforms = [transforms.ColorJitter(), transforms.RandomRotation(degrees=20)]\n","transform = transforms.Compose([transforms.Resize(64),\n"," transforms.CenterCrop(64),\n"," transforms.RandomHorizontalFlip(p=0.5),\n"," transforms.RandomApply(random_transforms, p=0.2),\n"," transforms.ToTensor(),\n"," transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n","\n","train_data = datasets.ImageFolder('./tmp', transform=transform)\n","train_loader = torch.utils.data.DataLoader(train_data, shuffle=True,\n"," batch_size=batch_size)\n"," \n","imgs, label = next(iter(train_loader))\n","imgs = imgs.numpy().transpose(0, 2, 3, 1)"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["for i in range(5):\n"," plt.imshow(imgs[i])\n"," plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## Weights"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["def weights_init(m):\n"," classname = m.__class__.__name__\n"," if classname.find('Conv') != -1:\n"," m.weight.data.normal_(0.0, 0.02)\n"," elif classname.find('BatchNorm') != -1:\n"," m.weight.data.normal_(1.0, 0.02)\n"," m.bias.data.fill_(0)"]},{"cell_type":"markdown","metadata":{},"source":["## Generator"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["class G(nn.Module):\n"," def __init__(self):\n"," \n"," super(G, self).__init__()\n"," \n"," self.main = nn.Sequential(\n"," nn.ConvTranspose2d(100, 512, 4, stride=1, padding=0, bias=False),\n"," nn.BatchNorm2d(512),\n"," nn.ReLU(True),\n"," nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, bias=False),\n"," nn.BatchNorm2d(256),\n"," nn.ReLU(True),\n"," nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),\n"," nn.BatchNorm2d(128),\n"," nn.ReLU(True),\n"," nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False),\n"," nn.BatchNorm2d(64),\n"," nn.ReLU(True),\n"," nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1, bias=False),\n"," nn.Tanh()\n"," )\n"," \n"," def forward(self, input):\n"," output = self.main(input)\n"," return output\n","\n","\n","netG = G()\n","netG.apply(weights_init)"]},{"cell_type":"markdown","metadata":{},"source":["## Discriminator"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["\n","class D(nn.Module):\n"," def __init__(self):\n"," super(D, self).__init__()\n"," self.main = nn.Sequential(\n"," nn.Conv2d(3, 64, 4, stride=2, padding=1, bias=False),\n"," nn.LeakyReLU(negative_slope=0.2, inplace=True),\n"," nn.Conv2d(64, 128, 4, stride=2, padding=1, bias=False),\n"," nn.BatchNorm2d(128),\n"," nn.LeakyReLU(negative_slope=0.2, inplace=True),\n"," nn.Conv2d(128, 256, 4, stride=2, padding=1, bias=False),\n"," nn.BatchNorm2d(256),\n"," nn.LeakyReLU(negative_slope=0.2, inplace=True),\n"," nn.Conv2d(256, 512, 4, stride=2, padding=1, bias=False),\n"," nn.BatchNorm2d(512),\n"," nn.LeakyReLU(negative_slope=0.2, inplace=True),\n"," nn.Conv2d(512, 1, 4, stride=1, padding=0, bias=False),\n"," nn.Sigmoid()\n"," )\n"," def forward(self, input):\n"," output = self.main(input)\n"," \n"," return output.view(-1)\n","netD = D()\n","netD.apply(weights_init)\n"]},{"cell_type":"markdown","metadata":{},"source":["## Another setup"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["class Generator(nn.Module):\n"," def __init__(self, nz=128, channels=3):\n"," super(Generator, self).__init__()\n"," \n"," self.nz = nz\n"," self.channels = channels\n"," \n"," def convlayer(n_input, n_output, k_size=4, stride=2, padding=0):\n"," block = [\n"," nn.ConvTranspose2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False),\n"," nn.BatchNorm2d(n_output),\n"," nn.ReLU(inplace=True),\n"," ]\n"," return block\n","\n"," self.model = nn.Sequential(\n"," *convlayer(self.nz, 1024, 4, 1, 0), \n"," *convlayer(1024, 512, 4, 2, 1),\n"," *convlayer(512, 256, 4, 2, 1),\n"," *convlayer(256, 128, 4, 2, 1),\n"," *convlayer(128, 64, 4, 2, 1),\n"," nn.ConvTranspose2d(64, self.channels, 3, 1, 1),\n"," nn.Tanh()\n"," )\n","\n"," def forward(self, z):\n"," z = z.view(-1, self.nz, 1, 1)\n"," img = self.model(z)\n"," return img\n","\n"," \n","class Discriminator(nn.Module):\n"," def __init__(self, channels=3):\n"," super(Discriminator, self).__init__()\n"," \n"," self.channels = channels\n","\n"," def convlayer(n_input, n_output, k_size=4, stride=2, padding=0, bn=False):\n"," block = [nn.Conv2d(n_input, n_output, kernel_size=k_size, stride=stride, padding=padding, bias=False)]\n"," if bn:\n"," block.append(nn.BatchNorm2d(n_output))\n"," block.append(nn.LeakyReLU(0.2, inplace=True))\n"," return block\n","\n"," self.model = nn.Sequential(\n"," *convlayer(self.channels, 32, 4, 2, 1),\n"," *convlayer(32, 64, 4, 2, 1),\n"," *convlayer(64, 128, 4, 2, 1, bn=True),\n"," *convlayer(128, 256, 4, 2, 1, bn=True),\n"," nn.Conv2d(256, 1, 4, 1, 0, bias=False), \n"," )\n","\n"," def forward(self, imgs):\n"," logits = self.model(imgs)\n"," out = torch.sigmoid(logits)\n"," \n"," return out.view(-1, 1)"]},{"cell_type":"markdown","metadata":{},"source":["## Training"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["EPOCH = 0\n","LR = 0.001\n","criterion = nn.BCELoss()\n","optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(0.5, 0.999))\n","optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(0.5, 0.999))"]},{"cell_type":"code","execution_count":null,"metadata":{"_kg_hide-output":true,"trusted":true},"outputs":[],"source":["for epoch in range(EPOCH):\n"," for i, data in enumerate(dataloader, 0):\n"," # 1st Step: Updating the weights of the neural network of the discriminator\n"," netD.zero_grad()\n"," \n"," # Training the discriminator with a real image of the dataset\n"," real,_ = data\n"," input = Variable(real)\n"," target = Variable(torch.ones(input.size()[0]))\n"," output = netD(input)\n"," errD_real = criterion(output, target)\n"," \n"," # Training the discriminator with a fake image generated by the generator\n"," noise = Variable(torch.randn(input.size()[0], 100, 1, 1))\n"," fake = netG(noise)\n"," target = Variable(torch.zeros(input.size()[0]))\n"," output = netD(fake.detach())\n"," errD_fake = criterion(output, target)\n"," \n"," # Backpropagating the total error\n"," errD = errD_real + errD_fake\n"," errD.backward()\n"," optimizerD.step()\n"," \n"," # 2nd Step: Updating the weights of the neural network of the generator\n"," netG.zero_grad()\n"," target = Variable(torch.ones(input.size()[0]))\n"," output = netD(fake)\n"," errG = criterion(output, target)\n"," errG.backward()\n"," optimizerG.step()\n"," \n"," # 3rd Step: Printing the losses and saving the real images and the generated images of the minibatch every 100 steps\n"," print('[%d/%d][%d/%d] Loss_D: %.4f; Loss_G: %.4f' % (epoch, EPOCH, i, len(dataloader), errD.item(), errG.item()))\n"," if i % 100 == 0:\n"," vutils.save_image(real, '%s/real_samples.png' % \"./results\", normalize=True)\n"," fake = netG(noise)\n"," vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % (\"./results\", epoch), normalize=True)"]},{"cell_type":"markdown","metadata":{},"source":["## Best public training"]},{"cell_type":"markdown","metadata":{},"source":["### Parameters"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["batch_size = 32\n","LR_G = 0.001\n","LR_D = 0.0005\n","\n","beta1 = 0.5\n","epochs = 100\n","\n","real_label = 0.9\n","fake_label = 0\n","nz = 128\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"]},{"cell_type":"markdown","metadata":{},"source":["### Initialize models and optimizers"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["netG = Generator(nz).to(device)\n","netD = Discriminator().to(device)\n","\n","criterion = nn.BCELoss()\n","\n","optimizerD = optim.Adam(netD.parameters(), lr=LR_D, betas=(beta1, 0.999))\n","optimizerG = optim.Adam(netG.parameters(), lr=LR_G, betas=(beta1, 0.999))\n","\n","fixed_noise = torch.randn(25, nz, 1, 1, device=device)\n","\n","G_losses = []\n","D_losses = []\n","epoch_time = []"]},{"cell_type":"code","execution_count":null,"metadata":{"_kg_hide-input":true,"_kg_hide-output":false,"trusted":true},"outputs":[],"source":["def plot_loss (G_losses, D_losses, epoch):\n"," plt.figure(figsize=(10,5))\n"," plt.title(\"Generator and Discriminator Loss - EPOCH \"+ str(epoch))\n"," plt.plot(G_losses,label=\"G\")\n"," plt.plot(D_losses,label=\"D\")\n"," plt.xlabel(\"iterations\")\n"," plt.ylabel(\"Loss\")\n"," plt.legend()\n"," plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## Show generated images"]},{"cell_type":"code","execution_count":null,"metadata":{"_kg_hide-input":true,"_kg_hide-output":false,"trusted":true},"outputs":[],"source":["def show_generated_img(n_images=5):\n"," sample = []\n"," for _ in range(n_images):\n"," noise = torch.randn(1, nz, 1, 1, device=device)\n"," gen_image = netG(noise).to(\"cpu\").clone().detach().squeeze(0)\n"," gen_image = gen_image.numpy().transpose(1, 2, 0)\n"," sample.append(gen_image)\n"," \n"," figure, axes = plt.subplots(1, len(sample), figsize = (64,64))\n"," for index, axis in enumerate(axes):\n"," axis.axis('off')\n"," image_array = sample[index]\n"," axis.imshow(image_array)\n"," \n"," plt.show()\n"," plt.close()"]},{"cell_type":"markdown","metadata":{},"source":["## Training Loop"]},{"cell_type":"code","execution_count":null,"metadata":{"_kg_hide-input":true,"_kg_hide-output":false,"trusted":true},"outputs":[],"source":["for epoch in range(epochs):\n"," \n"," start = time.time()\n"," for ii, (real_images, train_labels) in tqdm(enumerate(train_loader), total=len(train_loader)):\n"," ############################\n"," # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))\n"," ###########################\n"," # train with real\n"," netD.zero_grad()\n"," real_images = real_images.to(device)\n"," batch_size = real_images.size(0)\n"," labels = torch.full((batch_size, 1), real_label, device=device)\n","\n"," output = netD(real_images)\n"," errD_real = criterion(output, labels)\n"," errD_real.backward()\n"," D_x = output.mean().item()\n","\n"," # train with fake\n"," noise = torch.randn(batch_size, nz, 1, 1, device=device)\n"," fake = netG(noise)\n"," labels.fill_(fake_label)\n"," output = netD(fake.detach())\n"," errD_fake = criterion(output, labels)\n"," errD_fake.backward()\n"," D_G_z1 = output.mean().item()\n"," errD = errD_real + errD_fake\n"," optimizerD.step()\n","\n"," ############################\n"," # (2) Update G network: maximize log(D(G(z)))\n"," ###########################\n"," netG.zero_grad()\n"," labels.fill_(real_label) # fake labels are real for generator cost\n"," output = netD(fake)\n"," errG = criterion(output, labels)\n"," errG.backward()\n"," D_G_z2 = output.mean().item()\n"," optimizerG.step()\n"," \n"," # Save Losses for plotting later\n"," G_losses.append(errG.item())\n"," D_losses.append(errD.item())\n"," \n"," if (ii+1) % (len(train_loader)//2) == 0:\n"," print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'\n"," % (epoch + 1, epochs, ii+1, len(train_loader),\n"," errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))\n"," \n"," plot_loss (G_losses, D_losses, epoch)\n"," G_losses = []\n"," D_losses = []\n"," if epoch % 10 == 0:\n"," show_generated_img()\n","\n"," epoch_time.append(time.time()- start)"]},{"cell_type":"code","execution_count":null,"metadata":{"_kg_hide-output":true,"trusted":true},"outputs":[],"source":["print (\">> average EPOCH duration = \", np.mean(epoch_time))"]},{"cell_type":"markdown","metadata":{},"source":["## Generation example"]},{"cell_type":"code","execution_count":null,"metadata":{"_kg_hide-input":true,"trusted":true},"outputs":[],"source":["show_generated_img(7)"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["if not os.path.exists('../output_images'):\n"," os.mkdir('../output_images')\n"," \n","im_batch_size = 50\n","n_images=10000\n","\n","for i_batch in tqdm(range(0, n_images, im_batch_size)):\n"," gen_z = torch.randn(im_batch_size, nz, 1, 1, device=device)\n"," gen_images = netG(gen_z)\n"," images = gen_images.to(\"cpu\").clone().detach()\n"," images = images.numpy().transpose(0, 2, 3, 1)\n"," for i_image in range(gen_images.size(0)):\n"," save_image(gen_images[i_image, :, :, :], os.path.join('../output_images', f'image_{i_batch+i_image:05d}.png'))"]},{"cell_type":"code","execution_count":null,"metadata":{"_kg_hide-input":true,"trusted":true},"outputs":[],"source":["fig = plt.figure(figsize=(25, 16))\n","# display 10 images from each class\n","for i, j in enumerate(images[:32]):\n"," ax = fig.add_subplot(4, 8, i + 1, xticks=[], yticks=[])\n"," plt.imshow(j)"]},{"cell_type":"markdown","metadata":{},"source":["### Save models"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["torch.save(netG.state_dict(), 'generator.pth')\n","torch.save(netD.state_dict(), 'discriminator.pth')"]},{"cell_type":"markdown","metadata":{},"source":["## Acknowledgement\n","Thanks to [jesucristo](https://www.kaggle.com/jesucristo) for creating [GAN Introduction](https://www.kaggle.com/code/jesucristo/gan-introduction/notebook). It inspired the majority of the content in this article."]}],"metadata":{"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.9"}},"nbformat":4,"nbformat_minor":4}