{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "iOyOci4Ki_xk" }, "source": [ "# Training a Variational Auto-Encoder\n", "\n", "This guide will give a quick guide on training a variational auto-encoder (VAE) in torchbearer. We will use the VAE example from the pytorch examples [here](https://github.com/pytorch/examples/tree/master/vae). \n", "\n", "We will compare the implementations of a standard VAE and one that uses torchbearers persistant state.\n", "\n", "**Note**: The easiest way to use this tutorial is as a colab notebook, which allows you to dive in with no setup. We recommend you enable a free GPU with\n", "\n", "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n", "\n", "## Install Torchbearer\n", "\n", "First we install torchbearer if needed." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.3.2\n" ] } ], "source": [ "try:\n", " import torchbearer\n", "except:\n", " !pip install -q torchbearer\n", " import torchbearer\n", " \n", "print(torchbearer.__version__)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ZvXLMbXYj8Pd" }, "source": [ "## Defining the Models\n", "\n", "First, we define the standard PyTorch VAE. " ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": {}, "colab_type": "code", "id": "UJUVUXLHjN-4" }, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class VAE(nn.Module):\n", " def __init__(self):\n", " super(VAE, self).__init__()\n", "\n", " self.fc1 = nn.Linear(784, 400)\n", " self.fc21 = nn.Linear(400, 20)\n", " self.fc22 = nn.Linear(400, 20)\n", " self.fc3 = nn.Linear(20, 400)\n", " self.fc4 = nn.Linear(400, 784)\n", "\n", " def encode(self, x):\n", " h1 = F.relu(self.fc1(x))\n", " return self.fc21(h1), self.fc22(h1)\n", "\n", " def reparameterize(self, mu, logvar):\n", " if self.training:\n", " std = torch.exp(0.5*logvar)\n", " eps = torch.randn_like(std)\n", " return eps.mul(std).add_(mu)\n", " else:\n", " return mu\n", "\n", " def decode(self, z):\n", " h3 = F.relu(self.fc3(z))\n", " return torch.sigmoid(self.fc4(h3)).view(-1, 1, 28, 28)\n", "\n", " def forward(self, x):\n", " mu, logvar = self.encode(x.view(-1, 784))\n", " z = self.reparameterize(mu, logvar)\n", " return self.decode(z), mu, logvar" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "HJZY0Z_4jsZs" }, "source": [ "Now lets modify this to use torchbearers state by overriding the forward method. Here, we define some state keys with the [`state_key` method](https://torchbearer.readthedocs.io/en/latest/code/main.html#torchbearer.state.state_key) which will store our `MU` and `LOGVAR`." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": {}, "colab_type": "code", "id": "MlsxPtlujwiy" }, "outputs": [], "source": [ "import torchbearer\n", "\n", "# Define state keys for storing things in torchbearers state\n", "MU, LOGVAR = torchbearer.state_key('mu'), torchbearer.state_key('logvar')\n", "\n", "\n", "class TorchbearerVAE(VAE):\n", " def forward(self, x, state):\n", " mu, logvar = self.encode(x.view(-1, 784))\n", " z = self.reparameterize(mu, logvar)\n", " state[MU], state[LOGVAR] = mu, logvar\n", " return self.decode(z)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "JZevN3eqk_pS" }, "source": [ "There is very llittle difference between these models except that for torchbearers VAE we store the mean and log-variance in state instead of outputing them. This allows us to access them from within callbacks as well as for the loss. \n", "\n", "## Defining the Loss Functions\n", "\n", "Lets now look at loss functions for these models. First we see the standard VAE loss function. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": {}, "colab_type": "code", "id": "X3FRGq-llg7a" }, "outputs": [], "source": [ "def binary_cross_entropy(y_pred, y_true):\n", " BCE = F.binary_cross_entropy(y_pred.view(-1, 784), y_true.view(-1, 784), reduction='sum')\n", " return BCE\n", " \n", "def kld(mu, logvar):\n", " KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n", " return KLD\n", "\n", "def loss_function(y_pred, y_true):\n", " recon_x, mu, logvar = y_pred\n", " x = y_true\n", "\n", " BCE = bce_loss(recon_x, x)\n", "\n", " KLD = kld_Loss(mu, logvar)\n", "\n", " return BCE + KLD" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "AOqmSTCqly0C" }, "source": [ "In Torchbearer we have a couple options for how to define this loss. Since Torchbearer loss functions can either be a funciton of (y_pred, y_true) or (state), we could actually use the standard loss function (taking state) and grabbing the mean and log-variance from state. \n", "Instead we shall showcase the \"add_to_loss\" callback decorator to add the KL loss, alongside a base reconstruciton loss. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": {}, "colab_type": "code", "id": "4Pc86tbWmw53" }, "outputs": [], "source": [ "main_loss = binary_cross_entropy\n", "\n", "@torchbearer.callbacks.add_to_loss\n", "def add_kld_loss_callback(state):\n", " KLD = kld(state[MU], state[LOGVAR])\n", " return KLD" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "HED6rD8znOtE" }, "source": [ "## Data\n", "\n", "Both these models need data in the same format, so lets define out data now. We create a simple dataset class to wrap the PyTorch MNIST dataset so that we can replace the target (usually a clas label) with the input image. As in the [quickstart example](https://torchbearer.readthedocs.io/en/latest/examples/notebooks.html#notebooks-list), we use the [`DatasetValidationSplitter`](https://torchbearer.readthedocs.io/en/latest/code/main.html#torchbearer.cv_utils.DatasetValidationSplitter) here to obtain a validation set." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 269 }, "colab_type": "code", "id": "0KcT1A1tnL3-", "outputId": "d069c20d-9561-45d1-af3d-53fab145fa82" }, "outputs": [], "source": [ "import torch\n", "from torch.utils.data.dataset import Dataset\n", "import torchvision\n", "from torchvision import transforms\n", "\n", "from torchbearer.cv_utils import DatasetValidationSplitter\n", "\n", "class AutoEncoderMNIST(Dataset):\n", " def __init__(self, mnist_dataset):\n", " super().__init__()\n", " self.mnist_dataset = mnist_dataset\n", "\n", " def __getitem__(self, index):\n", " character, label = self.mnist_dataset.__getitem__(index)\n", " return character, character\n", "\n", " def __len__(self):\n", " return len(self.mnist_dataset)\n", "\n", "\n", "BATCH_SIZE = 128\n", "\n", "transform = transforms.Compose([transforms.ToTensor()])\n", "\n", "# Define standard classification mnist dataset with random validation set\n", "\n", "dataset = torchvision.datasets.MNIST('./data/mnist', train=True, download=True, transform=transform)\n", "splitter = DatasetValidationSplitter(len(dataset), 0.1)\n", "basetrainset = splitter.get_train_dataset(dataset)\n", "basevalset = splitter.get_val_dataset(dataset)\n", "basetestset = torchvision.datasets.MNIST('./data/mnist', train=False, download=True, transform=transform)\n", "\n", "# Wrap base classification mnist dataset to return the image as the target\n", "\n", "trainset = AutoEncoderMNIST(basetrainset)\n", "\n", "valset = AutoEncoderMNIST(basevalset)\n", "\n", "testset = AutoEncoderMNIST(basetestset)\n", "\n", "traingen = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)\n", "\n", "valgen = torch.utils.data.DataLoader(valset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)\n", "\n", "testgen = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "dQ4P8re5onL0" }, "source": [ "## Visualising the Model\n", "\n", "For auto-encoding problems it is often useful to visualise the reconstructions. We can do this in torchbearer by using the [`MakeGrid` callback](https://torchbearer.readthedocs.io/en/latest/code/callbacks.html#torchbearer.callbacks.imaging.imaging.MakeGrid), from the [`imaging` sub-package](https://torchbearer.readthedocs.io/en/latest/code/callbacks.html#module-torchbearer.callbacks.imaging). This is an [`ImagingCallback`](https://torchbearer.readthedocs.io/en/latest/code/callbacks.html#torchbearer.callbacks.imaging.imaging.ImagingCallback) which uses torchvisions [save_image](https://pytorch.org/docs/stable/torchvision/utils.html?highlight=save#torchvision.utils.save_image) to make a grid of images." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": {}, "colab_type": "code", "id": "1t8Ks-l_oasX" }, "outputs": [], "source": [ "from torchbearer.callbacks import imaging\n", "\n", "targets = imaging.MakeGrid(torchbearer.TARGET, num_images=64, nrow=8)\n", "targets = targets.on_test().to_pyplot().to_file('targets.png')\n", "\n", "predictions = imaging.MakeGrid(torchbearer.PREDICTION, num_images=64, nrow=8)\n", "predictions = predictions.on_test().to_pyplot().to_file('predictions.png')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Q5Ohha9SpzbI" }, "source": [ "In the above code we create two callbacks, one which makes a grid of target images, another which makes a grid of predictions. These will be saved to a file and plotted with pyplot.\n", "\n", "## Training the Model\n", "\n", "Now lets train the model. We shall skip training the standard PyTorch model since that is covered in the PyTorch examples. To train our Torchbearer model we first create a `Trial` and then call `run` on it. Along the way we add some metrics to be displayed and add our visualisation callback. **Note**: We set `verbose=1` here to mean that the progress bar should only tick for each epoch (rather than each batch, we creates a lot of output), this can be set at a trial level or for each call to `run` or `evaluate`." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1917, "resources": { "http://localhost:8080/nbextensions/google.colab/colabwidgets/controls.css": { "data": "", "headers": [ [ "content-type", "text/css" ] ], "ok": true, "status": 200, "status_text": "" } } }, "colab_type": "code", "id": "gFxV7zVHpgwl", "outputId": "02be4d59-ddd5-4e94-b66e-2a5e96e85d19" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "44c14d513b2b4267829eef6680c343c1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, max=10), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "import torch.optim as optim\n", "from torchbearer import Trial\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "model = TorchbearerVAE()\n", "optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)\n", "loss = binary_cross_entropy\n", "\n", "trial = Trial(model, optimizer, main_loss, metrics=['acc', 'loss'],\n", " callbacks=[add_kld_loss_callback, predictions, targets]).to(device)\n", "trial.with_generators(train_generator=traingen, val_generator=valgen, test_generator=testgen)\n", "_ = trial.run(epochs=10, verbose=1)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "dz-yoPgjq4KP" }, "source": [ "We now evaluate on the test data, and see how well our model performed" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 79 }, "colab_type": "code", "id": "9S7Mbsmnqh3k", "outputId": "6c11dc9e-2042-4f2b-ffd5-175f8f1dcf59" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4e92f17308d648bf8ce744b7b98309df", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(IntProgress(value=0, description='0/1(e)', max=79), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n" ] }, { "data": { "text/plain": [ "{'test_binary_acc': 0.9711461663246155, 'test_loss': 12219.810546875}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trial.evaluate(data_key=torchbearer.TEST_DATA)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "dI7EufU8rH-7" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "name": "VAE.ipynb", "provenance": [], "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 1 }