# Part 1: Import Libraries and config Dataset

In [1]:
#@title import libraries

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
from torch.utils.data import DataLoader

import numpy as np
from PIL import Image

from tqdm import tqdm

import os

In [None]:
#@title load dataset

batch_size = 100
#TODO: define transform that turns images to torch tensors and normalizes them to (-1, 1)
#Hint: use transforms.ToTensor() and transforms.Normalize()
transform = None


mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)


trainloader = DataLoader(mnist_trainset, batch_size = batch_size, num_workers = 0, shuffle=True)

# Part 2: Implement models

In [None]:
#@title Define generator model
class Generator(nn.Module):
 def __init__(self, noise_dim, out_dim):
 super(Generator, self).__init__()
 #noise_dim: dimension of input noise vector
 #out_dim: dimenstion of output image in our case 28 * 28

 #TODO: define fully connected network with dims: noise_dim -> 256 -> 512 -> 512 -> out_dim
 #Use ReLU as activation functions of hidden layers
 #Use Tanh as activation function of the final layers


 def forward(self, x):
 #TODO: implement the forward function of the generator
 pass



In [None]:
G = Generator(100, 28 * 28)
print(G)

In [None]:
#@title Define discriminator model
class Discriminator(nn.Module):
 def __init__(self, image_dim):
 super(Discriminator, self).__init__()
 #image_dim: dimension of input image. in our case 28 * 28
 #TODO define linear layers with dims image_dim -> 256 -> 128 -> 64 -> 1
 #Use LeakyReLU as activation functions of hidden layers
 #Use Sigmoid as activation function of the final layers

 def forward(self, x):
 #TODO: implement the forward function of the discriminator
 pass


In [None]:
D = Discriminator(28 * 28)
print(D)

# Part 3: Training

In [4]:
#@title Training

lr = 0.0002 #learning rate
nepochs = 20 #number of training epochs
noise_dim = 100 #dimension of input noise vector

class Trainer:
 def __init__(self):
 self.G = Generator(noise_dim = noise_dim, out_dim = 28 * 28).to('cuda:0')
 self.D = Discriminator(image_dim = 28 * 28).to('cuda:0')

 #TODO: define optimizers. one for generator and one for discriminator
 #Hint: use torch.optim.Adam()

 self.G_optimizer = None
 self.D_optimizer = None


 #define loss function
 self.criterion = nn.BCELoss()

 self.eval_freq = 1
 self.fig_dir = './figs'
 self.checkpoint_dir = './checkpoints'

 os.makedirs(self.fig_dir, exist_ok = True)
 os.makedirs(self.checkpoint_dir, exist_ok = True)

 def run(self):
 for e in range(1, nepochs + 1):
 if e % self.eval_freq == 0:
 self.eval_step(e)
 self.save_step(e)
 self.train_step(e)

 def train_step(self, epoch):
 self.G.train()
 self.D.train()
 pbar = tqdm(trainloader)
 for i, data in enumerate(pbar):
 real_data, _ = data
 real_data = real_data.cuda()

 D_loss = self.train_D(real_data)
 G_loss = self.train_G()

 pbar.set_description("Epoch: {}, G_loss = {:.4f}, D_loss = {:.4f}".format(epoch, G_loss, D_loss))

 def train_D(self, real_data):
 self.D_optimizer.zero_grad()
 D_loss = 0.
 #TODO: train discriminator
 #real data: a batch of real data with shape(batch_size, 1, 28, 28)
 #1. feed real data to D
 #2. generate labels for real data (shoud be all ones). Hint: use torch.ones()
 #3. compute loss for real data
 #4. generate noise. Hint: use torch.randn()
 #5. feed noise to G to get fake data
 #6. feed fake data to D
 #7. generate labels for fake data (shoud be all zeros). Hint: use torch.zeros()
 #8. compute loss for fake data
 #9. add losses and optimize D
 return D_loss

 def train_G(self):
 self.G_optimizer.zero_grad()
 G_loss = 0.
 #TODO: train generator
 #1. generate noise. Hint: use torch.randn()
 #2. feed noise to G to get fake data
 #3. feed fake data to D
 #4. generate labels for fake data (shoud be all ones) (why?). Hint: use torch.zerooness()
 #5. compute loss for fake data
 #6. optimize generator
 return G_loss


 def eval_step(self, epoch):
 self.G.eval()
 noise = torch.randn((1, noise_dim)).cuda()
 image = self.G(noise).resize(28, 28)
 image = image.clamp(-1, 1).detach().cpu().numpy()
 image = ((image + 1) * 127.5).astype('uint8')
 Image.fromarray(image).save(os.path.join(self.fig_dir, 'fig_{}.png'.format(epoch)))

 def save_step(self, epoch):
 torch.save(self.G.state_dict(), os.path.join(self.checkpoint_dir, 'gen_weights_{}.pth'.format(epoch)))


In [None]:
trainer = Trainer()
trainer.run()

# Evaluation

In [None]:
#@title Load pretrained weights
!wget https://github.com/arash-mham/visual-computing-II/blob/main/labs/mnist_gan/gen_weights.pth?raw=true -O gen_weights.pth

In [None]:
#@title Generate samples using trained generator

G = Generator(noise_dim, 28 * 28).to('cuda')
#TODO: load weights into model from gen_weights.pth

#TODO: generate 8 fake samples and plot them