In [None]:
# Execute this code block to install dependencies when running on colab
try:
 import torch
except:
 from os.path import exists
 from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
 platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
 cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
 accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'

 !pip install -q http://download.pytorch.org/whl/{accelerator}/torch-1.0.0-{platform}-linux_x86_64.whl torchvision

# Part 3: Variational Autoencoders (VAE)

For this part of the lab, you will implement a VAE. There is some code below which will help you run the learning procedure, however, you will need to complete the definition of the loss function. Start by implementing the encoder and decoder for the VAE. There is a diagram below illustrating what these architectures should look like.

Let's start by loading the Fashion-MNIST dataset again and transforming the data to a flattened tensor.

### Loading the Dataset

In [None]:
%matplotlib inline

import torchvision
import torchvision.transforms as transforms

batch_size = 256
image_dim = 784 #flattened

# dataset construction
transform = transforms.Compose([
 transforms.ToTensor(), # convert to tensor
 transforms.Lambda(lambda x: x.view(image_dim)) # flatten into vector
 ])

train_set = torchvision.datasets.FashionMNIST(
 root='./data/FashionMNIST'
 ,train=True
 ,download=True
 ,transform=transform
)

train_loader = torch.utils.data.DataLoader(
 train_set, batch_size=batch_size
)

## Build a Simple Variational Autoencoder



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tqdm.autonotebook import tqdm
from itertools import chain

import numpy as np

class Encoder(nn.Module):
 '''
 simple encoder with a single hidden dense layer (ReLU activation)
 and linear projections to the diag-Gauss parameters
 '''
 # YOUR CODE HERE
 raise NotImplementedError()

class Decoder(nn.Module):
 '''
 simple decoder: single dense hidden layer (ReLU activation) followed by 
 output layer with a sigmoid to squish values
 '''
 # YOUR CODE HERE
 raise NotImplementedError()

## You may complete the code below to test your implementation, or alternately, rewrite your own.

#### Once you've trained the network plot some reconstructions side-by-side with the original images and reflect on how good the reconstructions are (or aren't!). Also try generating some random images by sampling the prior and feeding the vectors to the decoder.

In [None]:
import matplotlib.gridspec as gridspec 
import os

# Sampling function (using the reparameterisation trick)
def sample(mu, log_sigma2):
 eps = torch.randn(mu.shape[0], mu.shape[1])
 return mu + torch.exp(log_sigma2 / 2) * eps


#parameters
batch_size = 256
embedding_dim = 2
enc_hidden_units = 512
dec_hidden_units = 512
nEpoch = 10

# construct the encoder, decoder and optimiser
enc = Encoder(image_dim, enc_hidden_units, embedding_dim)
dec = Decoder(embedding_dim, dec_hidden_units, image_dim)
optimizer = optim.Adam(chain(enc.parameters(), dec.parameters()), lr=1e-3)

# training loop
for epoch in range(nEpoch):
 losses = []
 trainloader = tqdm(train_loader)

 for i, data in enumerate(trainloader, 0):
 inputs, _ = data

 optimizer.zero_grad()

 mu, log_sigma2 = enc(inputs)
 z = sample(mu, log_sigma2)
 outputs = dec(z)

 # E[log P(X|z)] - as images are binary it makes most sense to use binary cross entropy
 # we need to be a little careful - by default torch averages over every observation 
 # (e.g. each pixel in each image of each batch), whereas we want the average over entire
 # images instead
 recon = F.binary_cross_entropy(outputs, inputs, reduction='sum') / inputs.shape[0]
 
 kl = 0 
 # kl = D_KL(Q(z|X) || P(z|X)) - calculate in closed form
 # Compute the term kl which is then added to the total loss
 # YOUR CODE HERE
 raise NotImplementedError()
 
 loss = recon + kl
 loss.backward()
 optimizer.step()

 # keep track of the loss and update the stats
 losses.append(loss.item())
 trainloader.set_postfix(loss=np.mean(losses), epoch=epoch)

 
 ## Please display some of the generated images in the submitted Notebook
 # YOUR CODE HERE
 raise NotImplementedError()