{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.1\n", "IPython 7.2.0\n", "\n", "torch 1.0.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Zoo -- Conditional Variational Autoencoder \n", "\n", "## (without labels in reconstruction loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A simple conditional variational autoencoder that compresses 768-pixel MNIST images down to a 35-pixel latent vector representation.\n", "\n", "\n", "This implementation DOES NOT concatenate the inputs with the class labels when computing the reconstruction loss in contrast to how it is commonly done in non-convolutional conditional variational autoencoders. Not considering class-labels in the reconstruction loss leads to slightly worse results compared to the implementation that does concatenate the labels with the inputs to compute the reconstruction loss. For reference, see the implementation [./autoencoder-cvae.ipynb](./autoencoder-cvae.ipynb)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import torch\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Device: cuda:1\n", "Image batch dimensions: torch.Size([128, 1, 28, 28])\n", "Image label dimensions: torch.Size([128])\n" ] } ], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n", "print('Device:', device)\n", "\n", "# Hyperparameters\n", "random_seed = 0\n", "learning_rate = 0.001\n", "num_epochs = 50\n", "batch_size = 128\n", "\n", "# Architecture\n", "num_classes = 10\n", "num_features = 784\n", "num_hidden_1 = 500\n", "num_latent = 35\n", "\n", "\n", "##########################\n", "### MNIST DATASET\n", "##########################\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "train_dataset = datasets.MNIST(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=batch_size, \n", " shuffle=True)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=batch_size, \n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "def to_onehot(labels, num_classes, device):\n", "\n", " labels_onehot = torch.zeros(labels.size()[0], num_classes).to(device)\n", " labels_onehot.scatter_(1, labels.view(-1, 1), 1)\n", "\n", " return labels_onehot\n", "\n", "\n", "class ConditionalVariationalAutoencoder(torch.nn.Module):\n", "\n", " def __init__(self, num_features, num_hidden_1, num_latent, num_classes):\n", " super(ConditionalVariationalAutoencoder, self).__init__()\n", " \n", " self.num_classes = num_classes\n", " \n", " ### ENCODER\n", " self.hidden_1 = torch.nn.Linear(num_features+num_classes, num_hidden_1)\n", " self.z_mean = torch.nn.Linear(num_hidden_1, num_latent)\n", " # in the original paper (Kingma & Welling 2015, we use\n", " # have a z_mean and z_var, but the problem is that\n", " # the z_var can be negative, which would cause issues\n", " # in the log later. Hence we assume that latent vector\n", " # has a z_mean and z_log_var component, and when we need\n", " # the regular variance or std_dev, we simply use \n", " # an exponential function\n", " self.z_log_var = torch.nn.Linear(num_hidden_1, num_latent)\n", " \n", " \n", " ### DECODER\n", " self.linear_3 = torch.nn.Linear(num_latent+num_classes, num_hidden_1)\n", " # don't output labels in resulting image as it yields worse results\n", " #self.linear_4 = torch.nn.Linear(num_hidden_1, num_features+num_classes)\n", " self.linear_4 = torch.nn.Linear(num_hidden_1, num_features)\n", "\n", " def reparameterize(self, z_mu, z_log_var):\n", " # Sample epsilon from standard normal distribution\n", " eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)\n", " # note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev\n", " # i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)\n", " z = z_mu + eps * torch.exp(z_log_var/2.) \n", " return z\n", " \n", " def encoder(self, features, targets):\n", " ### Add condition\n", " onehot_targets = to_onehot(targets, self.num_classes, device)\n", " x = torch.cat((features, onehot_targets), dim=1)\n", "\n", " ### ENCODER\n", " x = self.hidden_1(x)\n", " x = F.leaky_relu(x)\n", " z_mean = self.z_mean(x)\n", " z_log_var = self.z_log_var(x)\n", " encoded = self.reparameterize(z_mean, z_log_var)\n", " return z_mean, z_log_var, encoded\n", " \n", " def decoder(self, encoded, targets):\n", " ### Add condition\n", " onehot_targets = to_onehot(targets, self.num_classes, device)\n", " encoded = torch.cat((encoded, onehot_targets), dim=1) \n", " \n", " ### DECODER\n", " x = self.linear_3(encoded)\n", " x = F.leaky_relu(x)\n", " x = self.linear_4(x)\n", " decoded = torch.sigmoid(x)\n", " return decoded\n", "\n", " def forward(self, features, targets):\n", " \n", " z_mean, z_log_var, encoded = self.encoder(features, targets)\n", " decoded = self.decoder(encoded, targets)\n", " \n", " return z_mean, z_log_var, encoded, decoded\n", "\n", " \n", "torch.manual_seed(random_seed)\n", "model = ConditionalVariationalAutoencoder(num_features,\n", " num_hidden_1,\n", " num_latent,\n", " num_classes)\n", "model = model.to(device)\n", " \n", "\n", "##########################\n", "### COST AND OPTIMIZER\n", "##########################\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/050 | Batch 000/469 | Cost: 70192.2188\n", "Epoch: 001/050 | Batch 050/469 | Cost: 25749.1211\n", "Epoch: 001/050 | Batch 100/469 | Cost: 21997.0703\n", "Epoch: 001/050 | Batch 150/469 | Cost: 19867.7832\n", "Epoch: 001/050 | Batch 200/469 | Cost: 19859.0391\n", "Epoch: 001/050 | Batch 250/469 | Cost: 18637.0820\n", "Epoch: 001/050 | Batch 300/469 | Cost: 17863.7227\n", "Epoch: 001/050 | Batch 350/469 | Cost: 16946.2188\n", "Epoch: 001/050 | Batch 400/469 | Cost: 16602.0566\n", "Epoch: 001/050 | Batch 450/469 | Cost: 16563.0742\n", "Epoch: 002/050 | Batch 000/469 | Cost: 15998.6934\n", "Epoch: 002/050 | Batch 050/469 | Cost: 15980.9590\n", "Epoch: 002/050 | Batch 100/469 | Cost: 15842.0508\n", "Epoch: 002/050 | Batch 150/469 | Cost: 15789.3096\n", "Epoch: 002/050 | Batch 200/469 | Cost: 15471.0068\n", "Epoch: 002/050 | Batch 250/469 | Cost: 15047.9609\n", "Epoch: 002/050 | Batch 300/469 | Cost: 14812.4609\n", "Epoch: 002/050 | Batch 350/469 | Cost: 15064.4570\n", "Epoch: 002/050 | Batch 400/469 | Cost: 14684.2832\n", "Epoch: 002/050 | Batch 450/469 | Cost: 14621.3662\n", "Epoch: 003/050 | Batch 000/469 | Cost: 14662.3740\n", "Epoch: 003/050 | Batch 050/469 | Cost: 14373.9258\n", "Epoch: 003/050 | Batch 100/469 | Cost: 14580.2539\n", "Epoch: 003/050 | Batch 150/469 | Cost: 14757.9639\n", "Epoch: 003/050 | Batch 200/469 | Cost: 14678.1953\n", "Epoch: 003/050 | Batch 250/469 | Cost: 14471.2031\n", "Epoch: 003/050 | Batch 300/469 | Cost: 14082.8926\n", "Epoch: 003/050 | Batch 350/469 | Cost: 14371.0566\n", "Epoch: 003/050 | Batch 400/469 | Cost: 13371.3496\n", "Epoch: 003/050 | Batch 450/469 | Cost: 14689.0518\n", "Epoch: 004/050 | Batch 000/469 | Cost: 13579.7764\n", "Epoch: 004/050 | Batch 050/469 | Cost: 14012.3291\n", "Epoch: 004/050 | Batch 100/469 | Cost: 13688.7236\n", "Epoch: 004/050 | Batch 150/469 | Cost: 13726.9590\n", "Epoch: 004/050 | Batch 200/469 | Cost: 13938.3027\n", "Epoch: 004/050 | Batch 250/469 | Cost: 14040.9287\n", "Epoch: 004/050 | Batch 300/469 | Cost: 13934.2998\n", "Epoch: 004/050 | Batch 350/469 | Cost: 13701.2197\n", "Epoch: 004/050 | Batch 400/469 | Cost: 13695.0098\n", "Epoch: 004/050 | Batch 450/469 | Cost: 13170.8828\n", "Epoch: 005/050 | Batch 000/469 | Cost: 13500.4805\n", "Epoch: 005/050 | Batch 050/469 | Cost: 13655.4971\n", "Epoch: 005/050 | Batch 100/469 | Cost: 13458.8867\n", "Epoch: 005/050 | Batch 150/469 | Cost: 13754.9385\n", "Epoch: 005/050 | Batch 200/469 | Cost: 13421.4209\n", "Epoch: 005/050 | Batch 250/469 | Cost: 13213.7803\n", "Epoch: 005/050 | Batch 300/469 | Cost: 12693.9590\n", "Epoch: 005/050 | Batch 350/469 | Cost: 13030.9766\n", "Epoch: 005/050 | Batch 400/469 | Cost: 13811.0107\n", "Epoch: 005/050 | Batch 450/469 | Cost: 14092.8613\n", "Epoch: 006/050 | Batch 000/469 | Cost: 13308.8340\n", "Epoch: 006/050 | Batch 050/469 | Cost: 13082.6172\n", "Epoch: 006/050 | Batch 100/469 | Cost: 13904.7197\n", "Epoch: 006/050 | Batch 150/469 | Cost: 13171.1230\n", "Epoch: 006/050 | Batch 200/469 | Cost: 13504.8125\n", "Epoch: 006/050 | Batch 250/469 | Cost: 13535.9785\n", "Epoch: 006/050 | Batch 300/469 | Cost: 13284.6123\n", "Epoch: 006/050 | Batch 350/469 | Cost: 13442.9844\n", "Epoch: 006/050 | Batch 400/469 | Cost: 13444.6738\n", "Epoch: 006/050 | Batch 450/469 | Cost: 13511.4160\n", "Epoch: 007/050 | Batch 000/469 | Cost: 13908.5977\n", "Epoch: 007/050 | Batch 050/469 | Cost: 13451.8223\n", "Epoch: 007/050 | Batch 100/469 | Cost: 13165.7402\n", "Epoch: 007/050 | Batch 150/469 | Cost: 13328.9355\n", "Epoch: 007/050 | Batch 200/469 | Cost: 12998.3633\n", "Epoch: 007/050 | Batch 250/469 | Cost: 13683.5605\n", "Epoch: 007/050 | Batch 300/469 | Cost: 13152.0830\n", "Epoch: 007/050 | Batch 350/469 | Cost: 13329.2920\n", "Epoch: 007/050 | Batch 400/469 | Cost: 13330.4443\n", "Epoch: 007/050 | Batch 450/469 | Cost: 13523.7051\n", "Epoch: 008/050 | Batch 000/469 | Cost: 13326.9102\n", "Epoch: 008/050 | Batch 050/469 | Cost: 12950.0498\n", "Epoch: 008/050 | Batch 100/469 | Cost: 13583.9219\n", "Epoch: 008/050 | Batch 150/469 | Cost: 12776.9805\n", "Epoch: 008/050 | Batch 200/469 | Cost: 12999.6387\n", "Epoch: 008/050 | Batch 250/469 | Cost: 13324.9883\n", "Epoch: 008/050 | Batch 300/469 | Cost: 13418.3408\n", "Epoch: 008/050 | Batch 350/469 | Cost: 13043.9551\n", "Epoch: 008/050 | Batch 400/469 | Cost: 13222.0293\n", "Epoch: 008/050 | Batch 450/469 | Cost: 12615.9102\n", "Epoch: 009/050 | Batch 000/469 | Cost: 13419.1953\n", "Epoch: 009/050 | Batch 050/469 | Cost: 13427.7383\n", "Epoch: 009/050 | Batch 100/469 | Cost: 12853.0498\n", "Epoch: 009/050 | Batch 150/469 | Cost: 13082.6865\n", "Epoch: 009/050 | Batch 200/469 | Cost: 12926.8877\n", "Epoch: 009/050 | Batch 250/469 | Cost: 13223.6982\n", "Epoch: 009/050 | Batch 300/469 | Cost: 12966.2803\n", "Epoch: 009/050 | Batch 350/469 | Cost: 12672.2607\n", "Epoch: 009/050 | Batch 400/469 | Cost: 13285.1992\n", "Epoch: 009/050 | Batch 450/469 | Cost: 12638.7812\n", "Epoch: 010/050 | Batch 000/469 | Cost: 13139.2168\n", "Epoch: 010/050 | Batch 050/469 | Cost: 12674.6816\n", "Epoch: 010/050 | Batch 100/469 | Cost: 13080.3828\n", "Epoch: 010/050 | Batch 150/469 | Cost: 12448.9199\n", "Epoch: 010/050 | Batch 200/469 | Cost: 12761.3613\n", "Epoch: 010/050 | Batch 250/469 | Cost: 13139.7520\n", "Epoch: 010/050 | Batch 300/469 | Cost: 12969.9932\n", "Epoch: 010/050 | Batch 350/469 | Cost: 12518.5615\n", "Epoch: 010/050 | Batch 400/469 | Cost: 13042.4551\n", "Epoch: 010/050 | Batch 450/469 | Cost: 13296.8926\n", "Epoch: 011/050 | Batch 000/469 | Cost: 13233.5322\n", "Epoch: 011/050 | Batch 050/469 | Cost: 13085.0918\n", "Epoch: 011/050 | Batch 100/469 | Cost: 12664.7422\n", "Epoch: 011/050 | Batch 150/469 | Cost: 13344.7686\n", "Epoch: 011/050 | Batch 200/469 | Cost: 12498.5938\n", "Epoch: 011/050 | Batch 250/469 | Cost: 13314.7920\n", "Epoch: 011/050 | Batch 300/469 | Cost: 13175.9463\n", "Epoch: 011/050 | Batch 350/469 | Cost: 13034.4180\n", "Epoch: 011/050 | Batch 400/469 | Cost: 12425.1221\n", "Epoch: 011/050 | Batch 450/469 | Cost: 12548.4668\n", "Epoch: 012/050 | Batch 000/469 | Cost: 12779.4053\n", "Epoch: 012/050 | Batch 050/469 | Cost: 13129.1328\n", "Epoch: 012/050 | Batch 100/469 | Cost: 12274.7422\n", "Epoch: 012/050 | Batch 150/469 | Cost: 13289.4688\n", "Epoch: 012/050 | Batch 200/469 | Cost: 13256.5312\n", "Epoch: 012/050 | Batch 250/469 | Cost: 12437.4629\n", "Epoch: 012/050 | Batch 300/469 | Cost: 12500.7627\n", "Epoch: 012/050 | Batch 350/469 | Cost: 13362.0430\n", "Epoch: 012/050 | Batch 400/469 | Cost: 13271.1768\n", "Epoch: 012/050 | Batch 450/469 | Cost: 13070.1992\n", "Epoch: 013/050 | Batch 000/469 | Cost: 12979.0723\n", "Epoch: 013/050 | Batch 050/469 | Cost: 12714.0527\n", "Epoch: 013/050 | Batch 100/469 | Cost: 12925.5879\n", "Epoch: 013/050 | Batch 150/469 | Cost: 13068.3555\n", "Epoch: 013/050 | Batch 200/469 | Cost: 12462.0791\n", "Epoch: 013/050 | Batch 250/469 | Cost: 12443.1250\n", "Epoch: 013/050 | Batch 300/469 | Cost: 12773.1631\n", "Epoch: 013/050 | Batch 350/469 | Cost: 12435.6836\n", "Epoch: 013/050 | Batch 400/469 | Cost: 12659.2441\n", "Epoch: 013/050 | Batch 450/469 | Cost: 12680.4297\n", "Epoch: 014/050 | Batch 000/469 | Cost: 12963.3291\n", "Epoch: 014/050 | Batch 050/469 | Cost: 12406.1680\n", "Epoch: 014/050 | Batch 100/469 | Cost: 13342.7998\n", "Epoch: 014/050 | Batch 150/469 | Cost: 13050.4004\n", "Epoch: 014/050 | Batch 200/469 | Cost: 12695.7129\n", "Epoch: 014/050 | Batch 250/469 | Cost: 12899.9678\n", "Epoch: 014/050 | Batch 300/469 | Cost: 12568.9746\n", "Epoch: 014/050 | Batch 350/469 | Cost: 12800.3164\n", "Epoch: 014/050 | Batch 400/469 | Cost: 12908.6758\n", "Epoch: 014/050 | Batch 450/469 | Cost: 13055.2197\n", "Epoch: 015/050 | Batch 000/469 | Cost: 12697.0527\n", "Epoch: 015/050 | Batch 050/469 | Cost: 13206.3340\n", "Epoch: 015/050 | Batch 100/469 | Cost: 12505.6865\n", "Epoch: 015/050 | Batch 150/469 | Cost: 12765.6504\n", "Epoch: 015/050 | Batch 200/469 | Cost: 12692.5625\n", "Epoch: 015/050 | Batch 250/469 | Cost: 12564.1904\n", "Epoch: 015/050 | Batch 300/469 | Cost: 12480.1055\n", "Epoch: 015/050 | Batch 350/469 | Cost: 12703.9590\n", "Epoch: 015/050 | Batch 400/469 | Cost: 12782.6943\n", "Epoch: 015/050 | Batch 450/469 | Cost: 12501.6982\n", "Epoch: 016/050 | Batch 000/469 | Cost: 12316.8369\n", "Epoch: 016/050 | Batch 050/469 | Cost: 12879.1367\n", "Epoch: 016/050 | Batch 100/469 | Cost: 12799.4814\n", "Epoch: 016/050 | Batch 150/469 | Cost: 13116.8818\n", "Epoch: 016/050 | Batch 200/469 | Cost: 12788.3652\n", "Epoch: 016/050 | Batch 250/469 | Cost: 12618.8379\n", "Epoch: 016/050 | Batch 300/469 | Cost: 13378.3730\n", "Epoch: 016/050 | Batch 350/469 | Cost: 12751.9121\n", "Epoch: 016/050 | Batch 400/469 | Cost: 12654.6123\n", "Epoch: 016/050 | Batch 450/469 | Cost: 12693.1211\n", "Epoch: 017/050 | Batch 000/469 | Cost: 13261.9746\n", "Epoch: 017/050 | Batch 050/469 | Cost: 13040.6025\n", "Epoch: 017/050 | Batch 100/469 | Cost: 12892.2832\n", "Epoch: 017/050 | Batch 150/469 | Cost: 12776.0957\n", "Epoch: 017/050 | Batch 200/469 | Cost: 12676.0645\n", "Epoch: 017/050 | Batch 250/469 | Cost: 13100.1250\n", "Epoch: 017/050 | Batch 300/469 | Cost: 12229.5742\n", "Epoch: 017/050 | Batch 350/469 | Cost: 12896.2207\n", "Epoch: 017/050 | Batch 400/469 | Cost: 12986.7246\n", "Epoch: 017/050 | Batch 450/469 | Cost: 12528.6777\n", "Epoch: 018/050 | Batch 000/469 | Cost: 12395.8604\n", "Epoch: 018/050 | Batch 050/469 | Cost: 12674.4678\n", "Epoch: 018/050 | Batch 100/469 | Cost: 12528.5469\n", "Epoch: 018/050 | Batch 150/469 | Cost: 13454.7070\n", "Epoch: 018/050 | Batch 200/469 | Cost: 12878.5322\n", "Epoch: 018/050 | Batch 250/469 | Cost: 12682.8457\n", "Epoch: 018/050 | Batch 300/469 | Cost: 12604.6943\n", "Epoch: 018/050 | Batch 350/469 | Cost: 13185.8828\n", "Epoch: 018/050 | Batch 400/469 | Cost: 12933.7246\n", "Epoch: 018/050 | Batch 450/469 | Cost: 12973.7314\n", "Epoch: 019/050 | Batch 000/469 | Cost: 12347.1924\n", "Epoch: 019/050 | Batch 050/469 | Cost: 12655.2314\n", "Epoch: 019/050 | Batch 100/469 | Cost: 12840.0889\n", "Epoch: 019/050 | Batch 150/469 | Cost: 12790.1152\n", "Epoch: 019/050 | Batch 200/469 | Cost: 12546.3301\n", "Epoch: 019/050 | Batch 250/469 | Cost: 12630.3662\n", "Epoch: 019/050 | Batch 300/469 | Cost: 12877.1553\n", "Epoch: 019/050 | Batch 350/469 | Cost: 12754.5049\n", "Epoch: 019/050 | Batch 400/469 | Cost: 12562.9287\n", "Epoch: 019/050 | Batch 450/469 | Cost: 12670.7939\n", "Epoch: 020/050 | Batch 000/469 | Cost: 13078.5391\n", "Epoch: 020/050 | Batch 050/469 | Cost: 13251.5137\n", "Epoch: 020/050 | Batch 100/469 | Cost: 12222.6816\n", "Epoch: 020/050 | Batch 150/469 | Cost: 13020.2549\n", "Epoch: 020/050 | Batch 200/469 | Cost: 12660.7695\n", "Epoch: 020/050 | Batch 250/469 | Cost: 12797.1309\n", "Epoch: 020/050 | Batch 300/469 | Cost: 12559.7441\n", "Epoch: 020/050 | Batch 350/469 | Cost: 12983.9473\n", "Epoch: 020/050 | Batch 400/469 | Cost: 12665.8516\n", "Epoch: 020/050 | Batch 450/469 | Cost: 12557.4512\n", "Epoch: 021/050 | Batch 000/469 | Cost: 12259.2539\n", "Epoch: 021/050 | Batch 050/469 | Cost: 12225.6787\n", "Epoch: 021/050 | Batch 100/469 | Cost: 13265.6328\n", "Epoch: 021/050 | Batch 150/469 | Cost: 12958.9795\n", "Epoch: 021/050 | Batch 200/469 | Cost: 13201.1504\n", "Epoch: 021/050 | Batch 250/469 | Cost: 12173.3027\n", "Epoch: 021/050 | Batch 300/469 | Cost: 11880.8125\n", "Epoch: 021/050 | Batch 350/469 | Cost: 12684.7500\n", "Epoch: 021/050 | Batch 400/469 | Cost: 12973.6250\n", "Epoch: 021/050 | Batch 450/469 | Cost: 12326.9854\n", "Epoch: 022/050 | Batch 000/469 | Cost: 12506.0596\n", "Epoch: 022/050 | Batch 050/469 | Cost: 12992.8047\n", "Epoch: 022/050 | Batch 100/469 | Cost: 12908.5557\n", "Epoch: 022/050 | Batch 150/469 | Cost: 12658.6768\n", "Epoch: 022/050 | Batch 200/469 | Cost: 13097.6426\n", "Epoch: 022/050 | Batch 250/469 | Cost: 12514.5166\n", "Epoch: 022/050 | Batch 300/469 | Cost: 13067.9795\n", "Epoch: 022/050 | Batch 350/469 | Cost: 13335.4814\n", "Epoch: 022/050 | Batch 400/469 | Cost: 12482.6094\n", "Epoch: 022/050 | Batch 450/469 | Cost: 12887.1328\n", "Epoch: 023/050 | Batch 000/469 | Cost: 12895.0732\n", "Epoch: 023/050 | Batch 050/469 | Cost: 12596.9219\n", "Epoch: 023/050 | Batch 100/469 | Cost: 12961.1699\n", "Epoch: 023/050 | Batch 150/469 | Cost: 12497.6240\n", "Epoch: 023/050 | Batch 200/469 | Cost: 12390.3174\n", "Epoch: 023/050 | Batch 250/469 | Cost: 12916.2070\n", "Epoch: 023/050 | Batch 300/469 | Cost: 12608.6494\n", "Epoch: 023/050 | Batch 350/469 | Cost: 12270.3037\n", "Epoch: 023/050 | Batch 400/469 | Cost: 12774.8906\n", "Epoch: 023/050 | Batch 450/469 | Cost: 12438.0068\n", "Epoch: 024/050 | Batch 000/469 | Cost: 12060.3467\n", "Epoch: 024/050 | Batch 050/469 | Cost: 12482.3770\n", "Epoch: 024/050 | Batch 100/469 | Cost: 12389.7715\n", "Epoch: 024/050 | Batch 150/469 | Cost: 13020.0859\n", "Epoch: 024/050 | Batch 200/469 | Cost: 12233.1670\n", "Epoch: 024/050 | Batch 250/469 | Cost: 12507.4473\n", "Epoch: 024/050 | Batch 300/469 | Cost: 12403.1035\n", "Epoch: 024/050 | Batch 350/469 | Cost: 12475.9551\n", "Epoch: 024/050 | Batch 400/469 | Cost: 12369.6104\n", "Epoch: 024/050 | Batch 450/469 | Cost: 12104.8066\n", "Epoch: 025/050 | Batch 000/469 | Cost: 12380.4355\n", "Epoch: 025/050 | Batch 050/469 | Cost: 12826.3662\n", "Epoch: 025/050 | Batch 100/469 | Cost: 12431.5898\n", "Epoch: 025/050 | Batch 150/469 | Cost: 12982.6113\n", "Epoch: 025/050 | Batch 200/469 | Cost: 12823.1465\n", "Epoch: 025/050 | Batch 250/469 | Cost: 12800.5156\n", "Epoch: 025/050 | Batch 300/469 | Cost: 13140.7812\n", "Epoch: 025/050 | Batch 350/469 | Cost: 12483.5723\n", "Epoch: 025/050 | Batch 400/469 | Cost: 12694.3594\n", "Epoch: 025/050 | Batch 450/469 | Cost: 12767.1543\n", "Epoch: 026/050 | Batch 000/469 | Cost: 11855.4678\n", "Epoch: 026/050 | Batch 050/469 | Cost: 12363.9590\n", "Epoch: 026/050 | Batch 100/469 | Cost: 13079.2793\n", "Epoch: 026/050 | Batch 150/469 | Cost: 12977.3594\n", "Epoch: 026/050 | Batch 200/469 | Cost: 12642.0938\n", "Epoch: 026/050 | Batch 250/469 | Cost: 12530.8447\n", "Epoch: 026/050 | Batch 300/469 | Cost: 12514.3311\n", "Epoch: 026/050 | Batch 350/469 | Cost: 12100.2314\n", "Epoch: 026/050 | Batch 400/469 | Cost: 12814.0479\n", "Epoch: 026/050 | Batch 450/469 | Cost: 12364.0166\n", "Epoch: 027/050 | Batch 000/469 | Cost: 12499.8721\n", "Epoch: 027/050 | Batch 050/469 | Cost: 12678.9111\n", "Epoch: 027/050 | Batch 100/469 | Cost: 12261.5918\n", "Epoch: 027/050 | Batch 150/469 | Cost: 12901.1641\n", "Epoch: 027/050 | Batch 200/469 | Cost: 12548.0469\n", "Epoch: 027/050 | Batch 250/469 | Cost: 12211.9111\n", "Epoch: 027/050 | Batch 300/469 | Cost: 13003.7646\n", "Epoch: 027/050 | Batch 350/469 | Cost: 12214.5781\n", "Epoch: 027/050 | Batch 400/469 | Cost: 12604.0361\n", "Epoch: 027/050 | Batch 450/469 | Cost: 12504.3213\n", "Epoch: 028/050 | Batch 000/469 | Cost: 12680.8613\n", "Epoch: 028/050 | Batch 050/469 | Cost: 13018.3525\n", "Epoch: 028/050 | Batch 100/469 | Cost: 13040.8760\n", "Epoch: 028/050 | Batch 150/469 | Cost: 12745.8643\n", "Epoch: 028/050 | Batch 200/469 | Cost: 12417.4248\n", "Epoch: 028/050 | Batch 250/469 | Cost: 12684.0645\n", "Epoch: 028/050 | Batch 300/469 | Cost: 12119.3633\n", "Epoch: 028/050 | Batch 350/469 | Cost: 12281.8008\n", "Epoch: 028/050 | Batch 400/469 | Cost: 12434.8438\n", "Epoch: 028/050 | Batch 450/469 | Cost: 12379.5928\n", "Epoch: 029/050 | Batch 000/469 | Cost: 12527.9355\n", "Epoch: 029/050 | Batch 050/469 | Cost: 12694.2578\n", "Epoch: 029/050 | Batch 100/469 | Cost: 12318.5742\n", "Epoch: 029/050 | Batch 150/469 | Cost: 12357.7070\n", "Epoch: 029/050 | Batch 200/469 | Cost: 12823.2246\n", "Epoch: 029/050 | Batch 250/469 | Cost: 12532.8555\n", "Epoch: 029/050 | Batch 300/469 | Cost: 12343.1777\n", "Epoch: 029/050 | Batch 350/469 | Cost: 12207.8662\n", "Epoch: 029/050 | Batch 400/469 | Cost: 12553.4434\n", "Epoch: 029/050 | Batch 450/469 | Cost: 12426.8096\n", "Epoch: 030/050 | Batch 000/469 | Cost: 12391.7988\n", "Epoch: 030/050 | Batch 050/469 | Cost: 12414.6650\n", "Epoch: 030/050 | Batch 100/469 | Cost: 12213.8281\n", "Epoch: 030/050 | Batch 150/469 | Cost: 12527.5752\n", "Epoch: 030/050 | Batch 200/469 | Cost: 12135.3281\n", "Epoch: 030/050 | Batch 250/469 | Cost: 12099.4062\n", "Epoch: 030/050 | Batch 300/469 | Cost: 12891.4102\n", "Epoch: 030/050 | Batch 350/469 | Cost: 12546.6768\n", "Epoch: 030/050 | Batch 400/469 | Cost: 12653.6172\n", "Epoch: 030/050 | Batch 450/469 | Cost: 12576.2285\n", "Epoch: 031/050 | Batch 000/469 | Cost: 12499.4316\n", "Epoch: 031/050 | Batch 050/469 | Cost: 12517.8770\n", "Epoch: 031/050 | Batch 100/469 | Cost: 12340.2480\n", "Epoch: 031/050 | Batch 150/469 | Cost: 12368.0469\n", "Epoch: 031/050 | Batch 200/469 | Cost: 12331.4121\n", "Epoch: 031/050 | Batch 250/469 | Cost: 12736.1953\n", "Epoch: 031/050 | Batch 300/469 | Cost: 12985.6914\n", "Epoch: 031/050 | Batch 350/469 | Cost: 12383.8086\n", "Epoch: 031/050 | Batch 400/469 | Cost: 12270.4277\n", "Epoch: 031/050 | Batch 450/469 | Cost: 12418.8633\n", "Epoch: 032/050 | Batch 000/469 | Cost: 12244.7559\n", "Epoch: 032/050 | Batch 050/469 | Cost: 12531.9453\n", "Epoch: 032/050 | Batch 100/469 | Cost: 12477.5752\n", "Epoch: 032/050 | Batch 150/469 | Cost: 12838.6650\n", "Epoch: 032/050 | Batch 200/469 | Cost: 12590.4707\n", "Epoch: 032/050 | Batch 250/469 | Cost: 12658.5674\n", "Epoch: 032/050 | Batch 300/469 | Cost: 12619.9316\n", "Epoch: 032/050 | Batch 350/469 | Cost: 12790.5488\n", "Epoch: 032/050 | Batch 400/469 | Cost: 12336.5918\n", "Epoch: 032/050 | Batch 450/469 | Cost: 11956.5361\n", "Epoch: 033/050 | Batch 000/469 | Cost: 12257.5645\n", "Epoch: 033/050 | Batch 050/469 | Cost: 12238.9277\n", "Epoch: 033/050 | Batch 100/469 | Cost: 12166.1533\n", "Epoch: 033/050 | Batch 150/469 | Cost: 12442.1953\n", "Epoch: 033/050 | Batch 200/469 | Cost: 12383.0957\n", "Epoch: 033/050 | Batch 250/469 | Cost: 12242.8730\n", "Epoch: 033/050 | Batch 300/469 | Cost: 12493.3262\n", "Epoch: 033/050 | Batch 350/469 | Cost: 12194.9941\n", "Epoch: 033/050 | Batch 400/469 | Cost: 12441.2207\n", "Epoch: 033/050 | Batch 450/469 | Cost: 12835.3838\n", "Epoch: 034/050 | Batch 000/469 | Cost: 12413.8838\n", "Epoch: 034/050 | Batch 050/469 | Cost: 12801.7031\n", "Epoch: 034/050 | Batch 100/469 | Cost: 12464.5234\n", "Epoch: 034/050 | Batch 150/469 | Cost: 12432.2822\n", "Epoch: 034/050 | Batch 200/469 | Cost: 12561.4375\n", "Epoch: 034/050 | Batch 250/469 | Cost: 12854.5889\n", "Epoch: 034/050 | Batch 300/469 | Cost: 12125.7393\n", "Epoch: 034/050 | Batch 350/469 | Cost: 12752.3701\n", "Epoch: 034/050 | Batch 400/469 | Cost: 12496.3652\n", "Epoch: 034/050 | Batch 450/469 | Cost: 12751.6465\n", "Epoch: 035/050 | Batch 000/469 | Cost: 12277.0820\n", "Epoch: 035/050 | Batch 050/469 | Cost: 12367.7256\n", "Epoch: 035/050 | Batch 100/469 | Cost: 12402.5156\n", "Epoch: 035/050 | Batch 150/469 | Cost: 12334.3750\n", "Epoch: 035/050 | Batch 200/469 | Cost: 12532.5967\n", "Epoch: 035/050 | Batch 250/469 | Cost: 12294.9727\n", "Epoch: 035/050 | Batch 300/469 | Cost: 12221.8359\n", "Epoch: 035/050 | Batch 350/469 | Cost: 12979.2939\n", "Epoch: 035/050 | Batch 400/469 | Cost: 12789.4639\n", "Epoch: 035/050 | Batch 450/469 | Cost: 12396.4160\n", "Epoch: 036/050 | Batch 000/469 | Cost: 12536.0049\n", "Epoch: 036/050 | Batch 050/469 | Cost: 12159.3613\n", "Epoch: 036/050 | Batch 100/469 | Cost: 12361.6260\n", "Epoch: 036/050 | Batch 150/469 | Cost: 12638.1709\n", "Epoch: 036/050 | Batch 200/469 | Cost: 12634.9355\n", "Epoch: 036/050 | Batch 250/469 | Cost: 12643.7432\n", "Epoch: 036/050 | Batch 300/469 | Cost: 12563.5137\n", "Epoch: 036/050 | Batch 350/469 | Cost: 12375.0566\n", "Epoch: 036/050 | Batch 400/469 | Cost: 12551.1367\n", "Epoch: 036/050 | Batch 450/469 | Cost: 12317.5762\n", "Epoch: 037/050 | Batch 000/469 | Cost: 12063.4453\n", "Epoch: 037/050 | Batch 050/469 | Cost: 11987.3984\n", "Epoch: 037/050 | Batch 100/469 | Cost: 12577.7441\n", "Epoch: 037/050 | Batch 150/469 | Cost: 12403.6309\n", "Epoch: 037/050 | Batch 200/469 | Cost: 12922.1729\n", "Epoch: 037/050 | Batch 250/469 | Cost: 12302.4805\n", "Epoch: 037/050 | Batch 300/469 | Cost: 12353.8057\n", "Epoch: 037/050 | Batch 350/469 | Cost: 12627.0859\n", "Epoch: 037/050 | Batch 400/469 | Cost: 12517.3809\n", "Epoch: 037/050 | Batch 450/469 | Cost: 11899.2090\n", "Epoch: 038/050 | Batch 000/469 | Cost: 11766.3467\n", "Epoch: 038/050 | Batch 050/469 | Cost: 12509.6875\n", "Epoch: 038/050 | Batch 100/469 | Cost: 12706.8721\n", "Epoch: 038/050 | Batch 150/469 | Cost: 12288.3730\n", "Epoch: 038/050 | Batch 200/469 | Cost: 12531.9883\n", "Epoch: 038/050 | Batch 250/469 | Cost: 12904.4297\n", "Epoch: 038/050 | Batch 300/469 | Cost: 12279.0957\n", "Epoch: 038/050 | Batch 350/469 | Cost: 13053.5732\n", "Epoch: 038/050 | Batch 400/469 | Cost: 12317.9678\n", "Epoch: 038/050 | Batch 450/469 | Cost: 12069.1924\n", "Epoch: 039/050 | Batch 000/469 | Cost: 12420.7734\n", "Epoch: 039/050 | Batch 050/469 | Cost: 12101.2764\n", "Epoch: 039/050 | Batch 100/469 | Cost: 12663.4492\n", "Epoch: 039/050 | Batch 150/469 | Cost: 12434.3320\n", "Epoch: 039/050 | Batch 200/469 | Cost: 12394.2676\n", "Epoch: 039/050 | Batch 250/469 | Cost: 12588.5234\n", "Epoch: 039/050 | Batch 300/469 | Cost: 12016.5742\n", "Epoch: 039/050 | Batch 350/469 | Cost: 11895.7480\n", "Epoch: 039/050 | Batch 400/469 | Cost: 12270.1885\n", "Epoch: 039/050 | Batch 450/469 | Cost: 12623.2764\n", "Epoch: 040/050 | Batch 000/469 | Cost: 12347.0195\n", "Epoch: 040/050 | Batch 050/469 | Cost: 12172.0439\n", "Epoch: 040/050 | Batch 100/469 | Cost: 12112.8770\n", "Epoch: 040/050 | Batch 150/469 | Cost: 12661.4824\n", "Epoch: 040/050 | Batch 200/469 | Cost: 12516.9434\n", "Epoch: 040/050 | Batch 250/469 | Cost: 11665.0059\n", "Epoch: 040/050 | Batch 300/469 | Cost: 12424.7168\n", "Epoch: 040/050 | Batch 350/469 | Cost: 12546.3516\n", "Epoch: 040/050 | Batch 400/469 | Cost: 12085.0430\n", "Epoch: 040/050 | Batch 450/469 | Cost: 12052.1777\n", "Epoch: 041/050 | Batch 000/469 | Cost: 12553.8594\n", "Epoch: 041/050 | Batch 050/469 | Cost: 12719.8916\n", "Epoch: 041/050 | Batch 100/469 | Cost: 12318.2598\n", "Epoch: 041/050 | Batch 150/469 | Cost: 12868.4424\n", "Epoch: 041/050 | Batch 200/469 | Cost: 12110.4648\n", "Epoch: 041/050 | Batch 250/469 | Cost: 12877.4014\n", "Epoch: 041/050 | Batch 300/469 | Cost: 12044.2422\n", "Epoch: 041/050 | Batch 350/469 | Cost: 12094.7090\n", "Epoch: 041/050 | Batch 400/469 | Cost: 12124.3301\n", "Epoch: 041/050 | Batch 450/469 | Cost: 12671.7217\n", "Epoch: 042/050 | Batch 000/469 | Cost: 12054.0957\n", "Epoch: 042/050 | Batch 050/469 | Cost: 12345.2227\n", "Epoch: 042/050 | Batch 100/469 | Cost: 12810.0957\n", "Epoch: 042/050 | Batch 150/469 | Cost: 11998.7207\n", "Epoch: 042/050 | Batch 200/469 | Cost: 12693.5879\n", "Epoch: 042/050 | Batch 250/469 | Cost: 11996.5615\n", "Epoch: 042/050 | Batch 300/469 | Cost: 12084.2832\n", "Epoch: 042/050 | Batch 350/469 | Cost: 12159.6025\n", "Epoch: 042/050 | Batch 400/469 | Cost: 12514.6943\n", "Epoch: 042/050 | Batch 450/469 | Cost: 12273.8809\n", "Epoch: 043/050 | Batch 000/469 | Cost: 12472.9395\n", "Epoch: 043/050 | Batch 050/469 | Cost: 12462.2734\n", "Epoch: 043/050 | Batch 100/469 | Cost: 12303.0898\n", "Epoch: 043/050 | Batch 150/469 | Cost: 12641.2676\n", "Epoch: 043/050 | Batch 200/469 | Cost: 11870.0820\n", "Epoch: 043/050 | Batch 250/469 | Cost: 12087.6504\n", "Epoch: 043/050 | Batch 300/469 | Cost: 12615.6992\n", "Epoch: 043/050 | Batch 350/469 | Cost: 12327.5391\n", "Epoch: 043/050 | Batch 400/469 | Cost: 12761.4795\n", "Epoch: 043/050 | Batch 450/469 | Cost: 12429.0576\n", "Epoch: 044/050 | Batch 000/469 | Cost: 12172.6055\n", "Epoch: 044/050 | Batch 050/469 | Cost: 12338.0742\n", "Epoch: 044/050 | Batch 100/469 | Cost: 12473.4297\n", "Epoch: 044/050 | Batch 150/469 | Cost: 12260.2695\n", "Epoch: 044/050 | Batch 200/469 | Cost: 12475.7871\n", "Epoch: 044/050 | Batch 250/469 | Cost: 12570.5645\n", "Epoch: 044/050 | Batch 300/469 | Cost: 12297.6982\n", "Epoch: 044/050 | Batch 350/469 | Cost: 12525.9111\n", "Epoch: 044/050 | Batch 400/469 | Cost: 12596.0791\n", "Epoch: 044/050 | Batch 450/469 | Cost: 11957.3623\n", "Epoch: 045/050 | Batch 000/469 | Cost: 12849.4238\n", "Epoch: 045/050 | Batch 050/469 | Cost: 12080.3203\n", "Epoch: 045/050 | Batch 100/469 | Cost: 12260.8994\n", "Epoch: 045/050 | Batch 150/469 | Cost: 12638.3770\n", "Epoch: 045/050 | Batch 200/469 | Cost: 12635.9248\n", "Epoch: 045/050 | Batch 250/469 | Cost: 12265.3184\n", "Epoch: 045/050 | Batch 300/469 | Cost: 12359.3242\n", "Epoch: 045/050 | Batch 350/469 | Cost: 12409.3135\n", "Epoch: 045/050 | Batch 400/469 | Cost: 12485.5879\n", "Epoch: 045/050 | Batch 450/469 | Cost: 12399.2988\n", "Epoch: 046/050 | Batch 000/469 | Cost: 12027.0762\n", "Epoch: 046/050 | Batch 050/469 | Cost: 12070.3789\n", "Epoch: 046/050 | Batch 100/469 | Cost: 12531.2441\n", "Epoch: 046/050 | Batch 150/469 | Cost: 12265.9395\n", "Epoch: 046/050 | Batch 200/469 | Cost: 12452.6680\n", "Epoch: 046/050 | Batch 250/469 | Cost: 13118.8711\n", "Epoch: 046/050 | Batch 300/469 | Cost: 12208.3818\n", "Epoch: 046/050 | Batch 350/469 | Cost: 12624.9814\n", "Epoch: 046/050 | Batch 400/469 | Cost: 12488.5791\n", "Epoch: 046/050 | Batch 450/469 | Cost: 12633.9775\n", "Epoch: 047/050 | Batch 000/469 | Cost: 12152.1914\n", "Epoch: 047/050 | Batch 050/469 | Cost: 12525.3857\n", "Epoch: 047/050 | Batch 100/469 | Cost: 12195.7227\n", "Epoch: 047/050 | Batch 150/469 | Cost: 12642.2949\n", "Epoch: 047/050 | Batch 200/469 | Cost: 12667.8174\n", "Epoch: 047/050 | Batch 250/469 | Cost: 12729.5176\n", "Epoch: 047/050 | Batch 300/469 | Cost: 12052.5898\n", "Epoch: 047/050 | Batch 350/469 | Cost: 12097.2480\n", "Epoch: 047/050 | Batch 400/469 | Cost: 12530.8574\n", "Epoch: 047/050 | Batch 450/469 | Cost: 12496.5098\n", "Epoch: 048/050 | Batch 000/469 | Cost: 12613.0137\n", "Epoch: 048/050 | Batch 050/469 | Cost: 12692.5273\n", "Epoch: 048/050 | Batch 100/469 | Cost: 12363.4863\n", "Epoch: 048/050 | Batch 150/469 | Cost: 11625.2861\n", "Epoch: 048/050 | Batch 200/469 | Cost: 12005.9697\n", "Epoch: 048/050 | Batch 250/469 | Cost: 12227.3750\n", "Epoch: 048/050 | Batch 300/469 | Cost: 12684.3359\n", "Epoch: 048/050 | Batch 350/469 | Cost: 12430.2783\n", "Epoch: 048/050 | Batch 400/469 | Cost: 12213.2578\n", "Epoch: 048/050 | Batch 450/469 | Cost: 13208.1133\n", "Epoch: 049/050 | Batch 000/469 | Cost: 12118.3057\n", "Epoch: 049/050 | Batch 050/469 | Cost: 12340.2715\n", "Epoch: 049/050 | Batch 100/469 | Cost: 12029.6094\n", "Epoch: 049/050 | Batch 150/469 | Cost: 12366.4453\n", "Epoch: 049/050 | Batch 200/469 | Cost: 12537.2998\n", "Epoch: 049/050 | Batch 250/469 | Cost: 12324.0312\n", "Epoch: 049/050 | Batch 300/469 | Cost: 12378.3457\n", "Epoch: 049/050 | Batch 350/469 | Cost: 12218.6914\n", "Epoch: 049/050 | Batch 400/469 | Cost: 12550.4785\n", "Epoch: 049/050 | Batch 450/469 | Cost: 12444.4463\n", "Epoch: 050/050 | Batch 000/469 | Cost: 12246.0020\n", "Epoch: 050/050 | Batch 050/469 | Cost: 12554.1836\n", "Epoch: 050/050 | Batch 100/469 | Cost: 12373.2930\n", "Epoch: 050/050 | Batch 150/469 | Cost: 12895.8096\n", "Epoch: 050/050 | Batch 200/469 | Cost: 12233.0605\n", "Epoch: 050/050 | Batch 250/469 | Cost: 12621.2920\n", "Epoch: 050/050 | Batch 300/469 | Cost: 12492.7812\n", "Epoch: 050/050 | Batch 350/469 | Cost: 12525.1934\n", "Epoch: 050/050 | Batch 400/469 | Cost: 13032.4062\n", "Epoch: 050/050 | Batch 450/469 | Cost: 12618.7773\n" ] } ], "source": [ "start_time = time.time()\n", "for epoch in range(num_epochs):\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " features = features.view(-1, 28*28).to(device)\n", " targets = targets.to(device)\n", "\n", " ### FORWARD AND BACK PROP\n", " z_mean, z_log_var, encoded, decoded = model(features, targets)\n", "\n", " # cost = reconstruction loss + Kullback-Leibler divergence\n", " kl_divergence = (0.5 * (z_mean**2 + \n", " torch.exp(z_log_var) - z_log_var - 1)).sum()\n", " \n", " ### Add condition\n", " # Disabled for reconstruction loss as it gives poor results\n", " # x_con = torch.cat((features, to_onehot(targets, num_classes, device)), dim=1)\n", " \n", " ### Compute loss\n", " # pixelwise_bce = F.binary_cross_entropy(decoded, x_con, reduction='sum')\n", " pixelwise_bce = F.binary_cross_entropy(decoded, features, reduction='sum')\n", " cost = kl_divergence + pixelwise_bce\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.zero_grad()\n", " cost.backward()\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 50:\n", " print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n", " %(epoch+1, num_epochs, batch_idx, \n", " len(train_loader), cost))\n", " \n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reconstruction" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "##########################\n", "### VISUALIZATION\n", "##########################\n", "\n", "n_images = 15\n", "image_width = 28\n", "\n", "fig, axes = plt.subplots(nrows=2, ncols=n_images, \n", " sharex=True, sharey=True, figsize=(20, 2.5))\n", "orig_images = features[:n_images]\n", "decoded_images = decoded[:n_images]\n", "\n", "for i in range(n_images):\n", " for ax, img in zip(axes, [orig_images, decoded_images]):\n", " curr_img = img[i].detach().to(torch.device('cpu'))\n", " ax[i].imshow(curr_img.view((image_width, image_width)), cmap='binary')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### New random-conditional images" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Class Label 0\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Class Label 1\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Class Label 2\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Class Label 3\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Class Label 4\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Class Label 5\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Class Label 6\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Class Label 7\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Class Label 8\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Class Label 9\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for i in range(10):\n", "\n", " ##########################\n", " ### RANDOM SAMPLE\n", " ########################## \n", " \n", " labels = torch.tensor([i]*10).to(device)\n", " n_images = labels.size()[0]\n", " rand_features = torch.randn(n_images, num_latent).to(device)\n", " new_images = model.decoder(rand_features, labels)\n", "\n", " ##########################\n", " ### VISUALIZATION\n", " ##########################\n", "\n", " image_width = 28\n", "\n", " fig, axes = plt.subplots(nrows=1, ncols=n_images, figsize=(10, 2.5), sharey=True)\n", " decoded_images = new_images[:n_images]\n", "\n", " print('Class Label %d' % i)\n", "\n", " for ax, img in zip(axes, decoded_images):\n", " curr_img = img.detach().to(torch.device('cpu'))\n", " ax.imshow(curr_img.view((image_width, image_width)), cmap='binary')\n", " \n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy 1.15.4\n", "torch 1.0.0\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }