In [None]:
# This demo shows how to train an MNIST classifier
# on top of the latent codes of a pre-trained model.
#
# In this example, the classifier greatly overfits
# after only a few seconds of training!
# It achieves 100% training accuracy, but only
# 93% test accuracy.
#
# Can you figure out how to improve it?

In [1]:
from IPython.display import display
from PIL import Image

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from vq_draw import MSELoss, Encoder, MNISTRefiner

In [2]:
# Load pre-trained encoder model
encoder = Encoder(shape=(1, 28, 28),
 options=64,
 refiner=MNISTRefiner(64, 10),
 loss_fn=MSELoss(),
 num_stages=10)
encoder.load_state_dict(torch.load('pretrained/mnist_model.pt', map_location='cpu'))



In [4]:
# Create raw MNIST datasets.
ENCODE_BATCH = 512
transform = transforms.Compose([
 transforms.ToTensor(),
 transforms.Normalize((0.1307,), (0.3081,))
])
raw_train = datasets.MNIST('mnist_data', train=True, download=True, transform=transform)
raw_test = datasets.MNIST('mnist_data', train=False, download=True, transform=transform)

In [5]:
def encode_dataset(raw_dataset):
 """Convert a dataset into the encoder's latent space."""
 loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1000)
 latents = []
 labels = []
 for inputs, targets in loader:
 labels.append(targets)
 with torch.no_grad():
 latents.append(encoder(inputs)[0])
 print('encoded %d samples' % (len(latents) * 1000))
 return torch.utils.data.TensorDataset(torch.cat(latents), torch.cat(labels))

In [6]:
# Encoding the training dataset may take 10-20
# minutes on a laptop CPU :(
train_dataset = encode_dataset(raw_train)

encoded 1000 samples
encoded 2000 samples
encoded 3000 samples
encoded 4000 samples
encoded 5000 samples
encoded 6000 samples
encoded 7000 samples
encoded 8000 samples
encoded 9000 samples
encoded 10000 samples
encoded 11000 samples
encoded 12000 samples
encoded 13000 samples
encoded 14000 samples
encoded 15000 samples
encoded 16000 samples
encoded 17000 samples
encoded 18000 samples
encoded 19000 samples
encoded 20000 samples
encoded 21000 samples
encoded 22000 samples
encoded 23000 samples
encoded 24000 samples
encoded 25000 samples
encoded 26000 samples
encoded 27000 samples
encoded 28000 samples
encoded 29000 samples
encoded 30000 samples
encoded 31000 samples
encoded 32000 samples
encoded 33000 samples
encoded 34000 samples
encoded 35000 samples
encoded 36000 samples
encoded 37000 samples
encoded 38000 samples
encoded 39000 samples
encoded 40000 samples
encoded 41000 samples
encoded 42000 samples
encoded 43000 samples
encoded 44000 samples
encoded 45000 samples
encoded 46000 sampl

In [7]:
# Encoding the test dataset may take a few
# minutes on a laptop CPU, but it's faster
# than the training set!
test_dataset = encode_dataset(raw_test)

encoded 1000 samples
encoded 2000 samples
encoded 3000 samples
encoded 4000 samples
encoded 5000 samples
encoded 6000 samples
encoded 7000 samples
encoded 8000 samples
encoded 9000 samples
encoded 10000 samples


In [8]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100)

In [9]:
def latents_to_vec(latents):
 """Convert a Long latent Tensor into one-hots."""
 ones = torch.ones_like(latents[..., None]).float()
 result = torch.zeros(*latents.shape, 64, device=latents.device)
 result.scatter_(-1, latents[..., None], 1)
 result = result.view(result.shape[0], -1)
 # Scale result to have second moment of 1.
 return result * 8

# Create our model, which is just a small MLP.
classifier = nn.Sequential(
 # The input size expects 10 one-hots of size 64
 # each, packed together in a single vector.
 nn.Linear(64 * 10, 128),
 nn.ReLU(),
 nn.Linear(128, 10),
 nn.LogSoftmax(dim=-1),
)

optimizer = optim.Adam(classifier.parameters(), lr=1e-3)

# Run the training loop!
for epoch in range(10):
 print('-------- Epoch %d ---------' % epoch)

 total_train_correct = 0
 total_train_count = 0
 classifier.train()
 for inputs, targets in train_loader:
 preds = classifier(latents_to_vec(inputs))
 pred_labels = torch.argmax(preds, dim=-1)
 loss = F.nll_loss(preds, targets)
 optimizer.zero_grad()
 loss.backward()
 optimizer.step()
 
 total_train_correct += (pred_labels == targets).sum().item()
 total_train_count += preds.shape[0]
 print('train accuracy: %f' % (total_train_correct / total_train_count))

 # Evaluate test loss.
 total_test_correct = 0
 total_test_count = 0
 classifier.eval()
 for inputs, targets in test_loader:
 with torch.no_grad():
 preds = classifier(latents_to_vec(inputs))
 pred_labels = torch.argmax(preds, dim=-1)
 total_test_correct += (pred_labels == targets).sum().item()
 total_test_count += preds.shape[0]
 print('test accuracy: %f' % (total_test_correct / total_test_count))

-------- Epoch 0 ---------
train accuracy: 0.869517
test accuracy: 0.924400
-------- Epoch 1 ---------
train accuracy: 0.934900
test accuracy: 0.927000
-------- Epoch 2 ---------
train accuracy: 0.948783
test accuracy: 0.927900
-------- Epoch 3 ---------
train accuracy: 0.962150
test accuracy: 0.927100
-------- Epoch 4 ---------
train accuracy: 0.975883
test accuracy: 0.927400
-------- Epoch 5 ---------
train accuracy: 0.987733
test accuracy: 0.927000
-------- Epoch 6 ---------
train accuracy: 0.995667
test accuracy: 0.927300
-------- Epoch 7 ---------
train accuracy: 0.999283
test accuracy: 0.928200
-------- Epoch 8 ---------
train accuracy: 0.999950
test accuracy: 0.932600
-------- Epoch 9 ---------
train accuracy: 1.000000
test accuracy: 0.932400
