{ "cells": [ { "cell_type": "markdown", "source": [ "# Part 1: Import Libraries and config Dataset" ], "metadata": { "id": "xXLDcU4S9FEx" } }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "ZLDpr0NC4HG2" }, "outputs": [], "source": [ "#@title import libraries\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torchvision import transforms\n", "import torchvision.datasets as datasets\n", "from torch.autograd import Variable\n", "from torch.utils.data import DataLoader\n", "\n", "import numpy as np\n", "from PIL import Image\n", "\n", "from tqdm import tqdm\n", "\n", "import os" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_fJ2P54F4m6F" }, "outputs": [], "source": [ "#@title load dataset\n", "\n", "batch_size = 100\n", "#TODO: define transform that turns images to torch tensors and normalizes them to (-1, 1)\n", "#Hint: use transforms.ToTensor() and transforms.Normalize()\n", "transform = None\n", "\n", "\n", "mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n", "mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n", "\n", "\n", "trainloader = DataLoader(mnist_trainset, batch_size = batch_size, num_workers = 0, shuffle=True)" ] }, { "cell_type": "markdown", "source": [ "# Part 2: Implement models" ], "metadata": { "id": "-ZVs6usI9WsB" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U9KcNnjs49Di" }, "outputs": [], "source": [ "#@title Define generator model\n", "class Generator(nn.Module):\n", " def __init__(self, noise_dim, out_dim):\n", " super(Generator, self).__init__()\n", " #noise_dim: dimension of input noise vector\n", " #out_dim: dimenstion of output image in our case 28 * 28\n", "\n", " #TODO: define fully connected network with dims: noise_dim -> 256 -> 512 -> 512 -> out_dim\n", " #Use ReLU as activation functions of hidden layers\n", " #Use Tanh as activation function of the final layers\n", "\n", "\n", " def forward(self, x):\n", " #TODO: implement the forward function of the generator\n", " pass\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dD-bNFse7kVI" }, "outputs": [], "source": [ "G = Generator(100, 28 * 28)\n", "print(G)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sl8GjZVd7G5s" }, "outputs": [], "source": [ "#@title Define discriminator model\n", "class Discriminator(nn.Module):\n", " def __init__(self, image_dim):\n", " super(Discriminator, self).__init__()\n", " #image_dim: dimension of input image. in our case 28 * 28\n", " #TODO define linear layers with dims image_dim -> 256 -> 128 -> 64 -> 1\n", " #Use LeakyReLU as activation functions of hidden layers\n", " #Use Sigmoid as activation function of the final layers\n", "\n", " def forward(self, x):\n", " #TODO: implement the forward function of the discriminator\n", " pass\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sKKJ4Ngq7vA8" }, "outputs": [], "source": [ "D = Discriminator(28 * 28)\n", "print(D)" ] }, { "cell_type": "markdown", "source": [ "# Part 3: Training" ], "metadata": { "id": "-yxq7aFh-snD" } }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "SrVMkV7p710c" }, "outputs": [], "source": [ "#@title Training\n", "\n", "lr = 0.0002 #learning rate\n", "nepochs = 20 #number of training epochs\n", "noise_dim = 100 #dimension of input noise vector\n", "\n", "class Trainer:\n", " def __init__(self):\n", " self.G = Generator(noise_dim = noise_dim, out_dim = 28 * 28).to('cuda:0')\n", " self.D = Discriminator(image_dim = 28 * 28).to('cuda:0')\n", "\n", " #TODO: define optimizers. one for generator and one for discriminator\n", " #Hint: use torch.optim.Adam()\n", "\n", " self.G_optimizer = None\n", " self.D_optimizer = None\n", "\n", "\n", " #define loss function\n", " self.criterion = nn.BCELoss()\n", "\n", " self.eval_freq = 1\n", " self.fig_dir = './figs'\n", " self.checkpoint_dir = './checkpoints'\n", "\n", " os.makedirs(self.fig_dir, exist_ok = True)\n", " os.makedirs(self.checkpoint_dir, exist_ok = True)\n", "\n", " def run(self):\n", " for e in range(1, nepochs + 1):\n", " if e % self.eval_freq == 0:\n", " self.eval_step(e)\n", " self.save_step(e)\n", " self.train_step(e)\n", "\n", " def train_step(self, epoch):\n", " self.G.train()\n", " self.D.train()\n", " pbar = tqdm(trainloader)\n", " for i, data in enumerate(pbar):\n", " real_data, _ = data\n", " real_data = real_data.cuda()\n", "\n", " D_loss = self.train_D(real_data)\n", " G_loss = self.train_G()\n", "\n", " pbar.set_description(\"Epoch: {}, G_loss = {:.4f}, D_loss = {:.4f}\".format(epoch, G_loss, D_loss))\n", "\n", " def train_D(self, real_data):\n", " self.D_optimizer.zero_grad()\n", " D_loss = 0.\n", " #TODO: train discriminator\n", " #real data: a batch of real data with shape(batch_size, 1, 28, 28)\n", " #1. feed real data to D\n", " #2. generate labels for real data (shoud be all ones). Hint: use torch.ones()\n", " #3. compute loss for real data\n", " #4. generate noise. Hint: use torch.randn()\n", " #5. feed noise to G to get fake data\n", " #6. feed fake data to D\n", " #7. generate labels for fake data (shoud be all zeros). Hint: use torch.zeros()\n", " #8. compute loss for fake data\n", " #9. add losses and optimize D\n", " return D_loss\n", "\n", " def train_G(self):\n", " self.G_optimizer.zero_grad()\n", " G_loss = 0.\n", " #TODO: train generator\n", " #1. generate noise. Hint: use torch.randn()\n", " #2. feed noise to G to get fake data\n", " #3. feed fake data to D\n", " #4. generate labels for fake data (shoud be all ones) (why?). Hint: use torch.zerooness()\n", " #5. compute loss for fake data\n", " #6. optimize generator\n", " return G_loss\n", "\n", "\n", " def eval_step(self, epoch):\n", " self.G.eval()\n", " noise = torch.randn((1, noise_dim)).cuda()\n", " image = self.G(noise).resize(28, 28)\n", " image = image.clamp(-1, 1).detach().cpu().numpy()\n", " image = ((image + 1) * 127.5).astype('uint8')\n", " Image.fromarray(image).save(os.path.join(self.fig_dir, 'fig_{}.png'.format(epoch)))\n", "\n", " def save_step(self, epoch):\n", " torch.save(self.G.state_dict(), os.path.join(self.checkpoint_dir, 'gen_weights_{}.pth'.format(epoch)))\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DMUGbZb295nX" }, "outputs": [], "source": [ "trainer = Trainer()\n", "trainer.run()" ] }, { "cell_type": "markdown", "source": [ "# Evaluation" ], "metadata": { "id": "sfZ5gTKH_JME" } }, { "cell_type": "code", "source": [ "#@title Load pretrained weights\n", "!wget https://github.com/arash-mham/visual-computing-II/blob/main/labs/mnist_gan/gen_weights.pth?raw=true -O gen_weights.pth" ], "metadata": { "id": "rM7JV1Xh_Ntz" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DB3bjbIc-PuC" }, "outputs": [], "source": [ "#@title Generate samples using trained generator\n", "\n", "G = Generator(noise_dim, 28 * 28).to('cuda')\n", "#TODO: load weights into model from gen_weights.pth\n", "\n", "#TODO: generate 8 fake samples and plot them" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 0 }