{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This demo shows how to train an MNIST classifier\n", "# on top of the latent codes of a pre-trained model.\n", "#\n", "# In this example, the classifier greatly overfits\n", "# after only a few seconds of training!\n", "# It achieves 100% training accuracy, but only\n", "# 93% test accuracy.\n", "#\n", "# Can you figure out how to improve it?" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from IPython.display import display\n", "from PIL import Image\n", "\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import torch.optim as optim\n", "from torchvision import datasets, transforms\n", "\n", "from vq_draw import MSELoss, Encoder, MNISTRefiner" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load pre-trained encoder model\n", "encoder = Encoder(shape=(1, 28, 28),\n", " options=64,\n", " refiner=MNISTRefiner(64, 10),\n", " loss_fn=MSELoss(),\n", " num_stages=10)\n", "encoder.load_state_dict(torch.load('pretrained/mnist_model.pt', map_location='cpu'))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Create raw MNIST datasets.\n", "ENCODE_BATCH = 512\n", "transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,))\n", "])\n", "raw_train = datasets.MNIST('mnist_data', train=True, download=True, transform=transform)\n", "raw_test = datasets.MNIST('mnist_data', train=False, download=True, transform=transform)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def encode_dataset(raw_dataset):\n", " \"\"\"Convert a dataset into the encoder's latent space.\"\"\"\n", " loader = torch.utils.data.DataLoader(raw_dataset, batch_size=1000)\n", " latents = []\n", " labels = []\n", " for inputs, targets in loader:\n", " labels.append(targets)\n", " with torch.no_grad():\n", " latents.append(encoder(inputs)[0])\n", " print('encoded %d samples' % (len(latents) * 1000))\n", " return torch.utils.data.TensorDataset(torch.cat(latents), torch.cat(labels))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "encoded 1000 samples\n", "encoded 2000 samples\n", "encoded 3000 samples\n", "encoded 4000 samples\n", "encoded 5000 samples\n", "encoded 6000 samples\n", "encoded 7000 samples\n", "encoded 8000 samples\n", "encoded 9000 samples\n", "encoded 10000 samples\n", "encoded 11000 samples\n", "encoded 12000 samples\n", "encoded 13000 samples\n", "encoded 14000 samples\n", "encoded 15000 samples\n", "encoded 16000 samples\n", "encoded 17000 samples\n", "encoded 18000 samples\n", "encoded 19000 samples\n", "encoded 20000 samples\n", "encoded 21000 samples\n", "encoded 22000 samples\n", "encoded 23000 samples\n", "encoded 24000 samples\n", "encoded 25000 samples\n", "encoded 26000 samples\n", "encoded 27000 samples\n", "encoded 28000 samples\n", "encoded 29000 samples\n", "encoded 30000 samples\n", "encoded 31000 samples\n", "encoded 32000 samples\n", "encoded 33000 samples\n", "encoded 34000 samples\n", "encoded 35000 samples\n", "encoded 36000 samples\n", "encoded 37000 samples\n", "encoded 38000 samples\n", "encoded 39000 samples\n", "encoded 40000 samples\n", "encoded 41000 samples\n", "encoded 42000 samples\n", "encoded 43000 samples\n", "encoded 44000 samples\n", "encoded 45000 samples\n", "encoded 46000 samples\n", "encoded 47000 samples\n", "encoded 48000 samples\n", "encoded 49000 samples\n", "encoded 50000 samples\n", "encoded 51000 samples\n", "encoded 52000 samples\n", "encoded 53000 samples\n", "encoded 54000 samples\n", "encoded 55000 samples\n", "encoded 56000 samples\n", "encoded 57000 samples\n", "encoded 58000 samples\n", "encoded 59000 samples\n", "encoded 60000 samples\n" ] } ], "source": [ "# Encoding the training dataset may take 10-20\n", "# minutes on a laptop CPU :(\n", "train_dataset = encode_dataset(raw_train)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "encoded 1000 samples\n", "encoded 2000 samples\n", "encoded 3000 samples\n", "encoded 4000 samples\n", "encoded 5000 samples\n", "encoded 6000 samples\n", "encoded 7000 samples\n", "encoded 8000 samples\n", "encoded 9000 samples\n", "encoded 10000 samples\n" ] } ], "source": [ "# Encoding the test dataset may take a few\n", "# minutes on a laptop CPU, but it's faster\n", "# than the training set!\n", "test_dataset = encode_dataset(raw_test)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100)\n", "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-------- Epoch 0 ---------\n", "train accuracy: 0.869517\n", "test accuracy: 0.924400\n", "-------- Epoch 1 ---------\n", "train accuracy: 0.934900\n", "test accuracy: 0.927000\n", "-------- Epoch 2 ---------\n", "train accuracy: 0.948783\n", "test accuracy: 0.927900\n", "-------- Epoch 3 ---------\n", "train accuracy: 0.962150\n", "test accuracy: 0.927100\n", "-------- Epoch 4 ---------\n", "train accuracy: 0.975883\n", "test accuracy: 0.927400\n", "-------- Epoch 5 ---------\n", "train accuracy: 0.987733\n", "test accuracy: 0.927000\n", "-------- Epoch 6 ---------\n", "train accuracy: 0.995667\n", "test accuracy: 0.927300\n", "-------- Epoch 7 ---------\n", "train accuracy: 0.999283\n", "test accuracy: 0.928200\n", "-------- Epoch 8 ---------\n", "train accuracy: 0.999950\n", "test accuracy: 0.932600\n", "-------- Epoch 9 ---------\n", "train accuracy: 1.000000\n", "test accuracy: 0.932400\n" ] } ], "source": [ "def latents_to_vec(latents):\n", " \"\"\"Convert a Long latent Tensor into one-hots.\"\"\"\n", " ones = torch.ones_like(latents[..., None]).float()\n", " result = torch.zeros(*latents.shape, 64, device=latents.device)\n", " result.scatter_(-1, latents[..., None], 1)\n", " result = result.view(result.shape[0], -1)\n", " # Scale result to have second moment of 1.\n", " return result * 8\n", "\n", "# Create our model, which is just a small MLP.\n", "classifier = nn.Sequential(\n", " # The input size expects 10 one-hots of size 64\n", " # each, packed together in a single vector.\n", " nn.Linear(64 * 10, 128),\n", " nn.ReLU(),\n", " nn.Linear(128, 10),\n", " nn.LogSoftmax(dim=-1),\n", ")\n", "\n", "optimizer = optim.Adam(classifier.parameters(), lr=1e-3)\n", "\n", "# Run the training loop!\n", "for epoch in range(10):\n", " print('-------- Epoch %d ---------' % epoch)\n", "\n", " total_train_correct = 0\n", " total_train_count = 0\n", " classifier.train()\n", " for inputs, targets in train_loader:\n", " preds = classifier(latents_to_vec(inputs))\n", " pred_labels = torch.argmax(preds, dim=-1)\n", " loss = F.nll_loss(preds, targets)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " total_train_correct += (pred_labels == targets).sum().item()\n", " total_train_count += preds.shape[0]\n", " print('train accuracy: %f' % (total_train_correct / total_train_count))\n", "\n", " # Evaluate test loss.\n", " total_test_correct = 0\n", " total_test_count = 0\n", " classifier.eval()\n", " for inputs, targets in test_loader:\n", " with torch.no_grad():\n", " preds = classifier(latents_to_vec(inputs))\n", " pred_labels = torch.argmax(preds, dim=-1)\n", " total_test_correct += (pred_labels == targets).sum().item()\n", " total_test_count += preds.shape[0]\n", " print('test accuracy: %f' % (total_test_correct / total_test_count))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.3" } }, "nbformat": 4, "nbformat_minor": 2 }