{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.1\n", "IPython 7.2.0\n", "\n", "torch 1.0.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Zoo -- Variational Autoencoder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A simple variational autoencoder that compresses 768-pixel MNIST images down to a 15-pixel latent vector representation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "import torch\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Device: cuda:0\n", "Image batch dimensions: torch.Size([128, 1, 28, 28])\n", "Image label dimensions: torch.Size([128])\n" ] } ], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print('Device:', device)\n", "\n", "# Hyperparameters\n", "random_seed = 0\n", "learning_rate = 0.001\n", "num_epochs = 50\n", "batch_size = 128\n", "\n", "# Architecture\n", "num_features = 784\n", "num_hidden_1 = 500\n", "num_latent = 15\n", "\n", "\n", "##########################\n", "### MNIST DATASET\n", "##########################\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "train_dataset = datasets.MNIST(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=batch_size, \n", " shuffle=True)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=batch_size, \n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "class VariationalAutoencoder(torch.nn.Module):\n", "\n", " def __init__(self, num_features, num_hidden_1, num_latent):\n", " super(VariationalAutoencoder, self).__init__()\n", " \n", " ### ENCODER\n", " self.hidden_1 = torch.nn.Linear(num_features, num_hidden_1)\n", " self.z_mean = torch.nn.Linear(num_hidden_1, num_latent)\n", " # in the original paper (Kingma & Welling 2015, we use\n", " # have a z_mean and z_var, but the problem is that\n", " # the z_var can be negative, which would cause issues\n", " # in the log later. Hence we assume that latent vector\n", " # has a z_mean and z_log_var component, and when we need\n", " # the regular variance or std_dev, we simply use \n", " # an exponential function\n", " self.z_log_var = torch.nn.Linear(num_hidden_1, num_latent)\n", " \n", " \n", " ### DECODER\n", " self.linear_3 = torch.nn.Linear(num_latent, num_hidden_1)\n", " self.linear_4 = torch.nn.Linear(num_hidden_1, num_features)\n", "\n", " def reparameterize(self, z_mu, z_log_var):\n", " # Sample epsilon from standard normal distribution\n", " eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)\n", " # note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev\n", " # i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)\n", " z = z_mu + eps * torch.exp(z_log_var/2.) \n", " return z\n", " \n", " def encoder(self, features):\n", " x = self.hidden_1(features)\n", " x = F.leaky_relu(x, negative_slope=0.0001)\n", " z_mean = self.z_mean(x)\n", " z_log_var = self.z_log_var(x)\n", " encoded = self.reparameterize(z_mean, z_log_var)\n", " return z_mean, z_log_var, encoded\n", " \n", " def decoder(self, encoded):\n", " x = self.linear_3(encoded)\n", " x = F.leaky_relu(x, negative_slope=0.0001)\n", " x = self.linear_4(x)\n", " decoded = torch.sigmoid(x)\n", " return decoded\n", "\n", " def forward(self, features):\n", " \n", " z_mean, z_log_var, encoded = self.encoder(features)\n", " decoded = self.decoder(encoded)\n", " \n", " return z_mean, z_log_var, encoded, decoded\n", "\n", " \n", "torch.manual_seed(random_seed)\n", "model = VariationalAutoencoder(num_features,\n", " num_hidden_1,\n", " num_latent)\n", "model = model.to(device)\n", " \n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/050 | Batch 000/469 | Cost: 70481.2422\n", "Epoch: 001/050 | Batch 050/469 | Cost: 27139.5547\n", "Epoch: 001/050 | Batch 100/469 | Cost: 22833.3730\n", "Epoch: 001/050 | Batch 150/469 | Cost: 19493.1523\n", "Epoch: 001/050 | Batch 200/469 | Cost: 18727.4688\n", "Epoch: 001/050 | Batch 250/469 | Cost: 18074.2676\n", "Epoch: 001/050 | Batch 300/469 | Cost: 16633.2852\n", "Epoch: 001/050 | Batch 350/469 | Cost: 17136.7852\n", "Epoch: 001/050 | Batch 400/469 | Cost: 16402.5293\n", "Epoch: 001/050 | Batch 450/469 | Cost: 16062.9814\n", "Epoch: 002/050 | Batch 000/469 | Cost: 16577.0840\n", "Epoch: 002/050 | Batch 050/469 | Cost: 15451.8242\n", "Epoch: 002/050 | Batch 100/469 | Cost: 15667.8535\n", "Epoch: 002/050 | Batch 150/469 | Cost: 15734.0801\n", "Epoch: 002/050 | Batch 200/469 | Cost: 15145.4365\n", "Epoch: 002/050 | Batch 250/469 | Cost: 15326.6953\n", "Epoch: 002/050 | Batch 300/469 | Cost: 15408.5801\n", "Epoch: 002/050 | Batch 350/469 | Cost: 15637.5430\n", "Epoch: 002/050 | Batch 400/469 | Cost: 14793.0332\n", "Epoch: 002/050 | Batch 450/469 | Cost: 15046.4414\n", "Epoch: 003/050 | Batch 000/469 | Cost: 14457.0537\n", "Epoch: 003/050 | Batch 050/469 | Cost: 14483.2910\n", "Epoch: 003/050 | Batch 100/469 | Cost: 14374.9258\n", "Epoch: 003/050 | Batch 150/469 | Cost: 13934.8672\n", "Epoch: 003/050 | Batch 200/469 | Cost: 15053.9336\n", "Epoch: 003/050 | Batch 250/469 | Cost: 14673.1025\n", "Epoch: 003/050 | Batch 300/469 | Cost: 14324.3916\n", "Epoch: 003/050 | Batch 350/469 | Cost: 14318.4229\n", "Epoch: 003/050 | Batch 400/469 | Cost: 14501.4912\n", "Epoch: 003/050 | Batch 450/469 | Cost: 13753.9082\n", "Epoch: 004/050 | Batch 000/469 | Cost: 15024.3789\n", "Epoch: 004/050 | Batch 050/469 | Cost: 14310.9219\n", "Epoch: 004/050 | Batch 100/469 | Cost: 14723.5176\n", "Epoch: 004/050 | Batch 150/469 | Cost: 15469.9473\n", "Epoch: 004/050 | Batch 200/469 | Cost: 14126.0586\n", "Epoch: 004/050 | Batch 250/469 | Cost: 14321.4062\n", "Epoch: 004/050 | Batch 300/469 | Cost: 13834.0576\n", "Epoch: 004/050 | Batch 350/469 | Cost: 14363.6494\n", "Epoch: 004/050 | Batch 400/469 | Cost: 14136.7422\n", "Epoch: 004/050 | Batch 450/469 | Cost: 13603.2012\n", "Epoch: 005/050 | Batch 000/469 | Cost: 14002.0479\n", "Epoch: 005/050 | Batch 050/469 | Cost: 14221.5488\n", "Epoch: 005/050 | Batch 100/469 | Cost: 13972.6787\n", "Epoch: 005/050 | Batch 150/469 | Cost: 13918.2402\n", "Epoch: 005/050 | Batch 200/469 | Cost: 13839.3809\n", "Epoch: 005/050 | Batch 250/469 | Cost: 14421.0020\n", "Epoch: 005/050 | Batch 300/469 | Cost: 14611.1816\n", "Epoch: 005/050 | Batch 350/469 | Cost: 13653.8027\n", "Epoch: 005/050 | Batch 400/469 | Cost: 13632.8047\n", "Epoch: 005/050 | Batch 450/469 | Cost: 13612.9375\n", "Epoch: 006/050 | Batch 000/469 | Cost: 13993.7344\n", "Epoch: 006/050 | Batch 050/469 | Cost: 13976.1006\n", "Epoch: 006/050 | Batch 100/469 | Cost: 14309.5527\n", "Epoch: 006/050 | Batch 150/469 | Cost: 13427.8916\n", "Epoch: 006/050 | Batch 200/469 | Cost: 13811.6260\n", "Epoch: 006/050 | Batch 250/469 | Cost: 14130.3496\n", "Epoch: 006/050 | Batch 300/469 | Cost: 12895.7324\n", "Epoch: 006/050 | Batch 350/469 | Cost: 13445.3213\n", "Epoch: 006/050 | Batch 400/469 | Cost: 13374.8242\n", "Epoch: 006/050 | Batch 450/469 | Cost: 13549.5098\n", "Epoch: 007/050 | Batch 000/469 | Cost: 13913.4043\n", "Epoch: 007/050 | Batch 050/469 | Cost: 13703.5654\n", "Epoch: 007/050 | Batch 100/469 | Cost: 14132.1758\n", "Epoch: 007/050 | Batch 150/469 | Cost: 14052.9814\n", "Epoch: 007/050 | Batch 200/469 | Cost: 13750.3535\n", "Epoch: 007/050 | Batch 250/469 | Cost: 14316.6953\n", "Epoch: 007/050 | Batch 300/469 | Cost: 13224.3281\n", "Epoch: 007/050 | Batch 350/469 | Cost: 14139.7979\n", "Epoch: 007/050 | Batch 400/469 | Cost: 13795.6016\n", "Epoch: 007/050 | Batch 450/469 | Cost: 13915.5020\n", "Epoch: 008/050 | Batch 000/469 | Cost: 13548.9512\n", "Epoch: 008/050 | Batch 050/469 | Cost: 13558.6338\n", "Epoch: 008/050 | Batch 100/469 | Cost: 13883.1074\n", "Epoch: 008/050 | Batch 150/469 | Cost: 13128.7617\n", "Epoch: 008/050 | Batch 200/469 | Cost: 13133.5879\n", "Epoch: 008/050 | Batch 250/469 | Cost: 13518.8672\n", "Epoch: 008/050 | Batch 300/469 | Cost: 13679.2324\n", "Epoch: 008/050 | Batch 350/469 | Cost: 13928.9824\n", "Epoch: 008/050 | Batch 400/469 | Cost: 14079.2256\n", "Epoch: 008/050 | Batch 450/469 | Cost: 13294.2021\n", "Epoch: 009/050 | Batch 000/469 | Cost: 13619.6504\n", "Epoch: 009/050 | Batch 050/469 | Cost: 13831.6201\n", "Epoch: 009/050 | Batch 100/469 | Cost: 13848.1406\n", "Epoch: 009/050 | Batch 150/469 | Cost: 14622.0889\n", "Epoch: 009/050 | Batch 200/469 | Cost: 13843.3887\n", "Epoch: 009/050 | Batch 250/469 | Cost: 13673.2441\n", "Epoch: 009/050 | Batch 300/469 | Cost: 13646.6543\n", "Epoch: 009/050 | Batch 350/469 | Cost: 13411.1816\n", "Epoch: 009/050 | Batch 400/469 | Cost: 14463.7988\n", "Epoch: 009/050 | Batch 450/469 | Cost: 13585.7891\n", "Epoch: 010/050 | Batch 000/469 | Cost: 13929.6816\n", "Epoch: 010/050 | Batch 050/469 | Cost: 13659.5176\n", "Epoch: 010/050 | Batch 100/469 | Cost: 13504.2568\n", "Epoch: 010/050 | Batch 150/469 | Cost: 13717.9434\n", "Epoch: 010/050 | Batch 200/469 | Cost: 13711.8818\n", "Epoch: 010/050 | Batch 250/469 | Cost: 13554.4062\n", "Epoch: 010/050 | Batch 300/469 | Cost: 13317.5156\n", "Epoch: 010/050 | Batch 350/469 | Cost: 13279.9912\n", "Epoch: 010/050 | Batch 400/469 | Cost: 13069.9648\n", "Epoch: 010/050 | Batch 450/469 | Cost: 13087.7695\n", "Epoch: 011/050 | Batch 000/469 | Cost: 13800.6113\n", "Epoch: 011/050 | Batch 050/469 | Cost: 13924.1973\n", "Epoch: 011/050 | Batch 100/469 | Cost: 13173.4414\n", "Epoch: 011/050 | Batch 150/469 | Cost: 13963.7402\n", "Epoch: 011/050 | Batch 200/469 | Cost: 13682.3281\n", "Epoch: 011/050 | Batch 250/469 | Cost: 13664.8027\n", "Epoch: 011/050 | Batch 300/469 | Cost: 14188.4707\n", "Epoch: 011/050 | Batch 350/469 | Cost: 13625.5840\n", "Epoch: 011/050 | Batch 400/469 | Cost: 13482.8643\n", "Epoch: 011/050 | Batch 450/469 | Cost: 13912.9238\n", "Epoch: 012/050 | Batch 000/469 | Cost: 13048.4648\n", "Epoch: 012/050 | Batch 050/469 | Cost: 13041.4395\n", "Epoch: 012/050 | Batch 100/469 | Cost: 13212.3662\n", "Epoch: 012/050 | Batch 150/469 | Cost: 13304.4463\n", "Epoch: 012/050 | Batch 200/469 | Cost: 13445.9287\n", "Epoch: 012/050 | Batch 250/469 | Cost: 13693.2676\n", "Epoch: 012/050 | Batch 300/469 | Cost: 13072.9004\n", "Epoch: 012/050 | Batch 350/469 | Cost: 13414.5361\n", "Epoch: 012/050 | Batch 400/469 | Cost: 13669.4121\n", "Epoch: 012/050 | Batch 450/469 | Cost: 13366.8633\n", "Epoch: 013/050 | Batch 000/469 | Cost: 13785.0518\n", "Epoch: 013/050 | Batch 050/469 | Cost: 13788.7734\n", "Epoch: 013/050 | Batch 100/469 | Cost: 13442.9023\n", "Epoch: 013/050 | Batch 150/469 | Cost: 13771.4902\n", "Epoch: 013/050 | Batch 200/469 | Cost: 13357.2217\n", "Epoch: 013/050 | Batch 250/469 | Cost: 13402.6758\n", "Epoch: 013/050 | Batch 300/469 | Cost: 13852.4033\n", "Epoch: 013/050 | Batch 350/469 | Cost: 13301.3457\n", "Epoch: 013/050 | Batch 400/469 | Cost: 13379.6172\n", "Epoch: 013/050 | Batch 450/469 | Cost: 14275.3047\n", "Epoch: 014/050 | Batch 000/469 | Cost: 13433.3906\n", "Epoch: 014/050 | Batch 050/469 | Cost: 13059.2354\n", "Epoch: 014/050 | Batch 100/469 | Cost: 14031.3721\n", "Epoch: 014/050 | Batch 150/469 | Cost: 13950.9883\n", "Epoch: 014/050 | Batch 200/469 | Cost: 13684.6611\n", "Epoch: 014/050 | Batch 250/469 | Cost: 13630.9336\n", "Epoch: 014/050 | Batch 300/469 | Cost: 13527.6230\n", "Epoch: 014/050 | Batch 350/469 | Cost: 13746.0527\n", "Epoch: 014/050 | Batch 400/469 | Cost: 13490.6982\n", "Epoch: 014/050 | Batch 450/469 | Cost: 13669.7402\n", "Epoch: 015/050 | Batch 000/469 | Cost: 13422.4238\n", "Epoch: 015/050 | Batch 050/469 | Cost: 13303.6855\n", "Epoch: 015/050 | Batch 100/469 | Cost: 13421.2900\n", "Epoch: 015/050 | Batch 150/469 | Cost: 13129.2764\n", "Epoch: 015/050 | Batch 200/469 | Cost: 13276.9336\n", "Epoch: 015/050 | Batch 250/469 | Cost: 13776.0889\n", "Epoch: 015/050 | Batch 300/469 | Cost: 13634.2188\n", "Epoch: 015/050 | Batch 350/469 | Cost: 13438.3828\n", "Epoch: 015/050 | Batch 400/469 | Cost: 13401.1045\n", "Epoch: 015/050 | Batch 450/469 | Cost: 13567.6631\n", "Epoch: 016/050 | Batch 000/469 | Cost: 13150.5625\n", "Epoch: 016/050 | Batch 050/469 | Cost: 13414.6553\n", "Epoch: 016/050 | Batch 100/469 | Cost: 12979.0029\n", "Epoch: 016/050 | Batch 150/469 | Cost: 13146.0801\n", "Epoch: 016/050 | Batch 200/469 | Cost: 13257.3301\n", "Epoch: 016/050 | Batch 250/469 | Cost: 14350.7471\n", "Epoch: 016/050 | Batch 300/469 | Cost: 13836.4316\n", "Epoch: 016/050 | Batch 350/469 | Cost: 13865.9902\n", "Epoch: 016/050 | Batch 400/469 | Cost: 13237.3877\n", "Epoch: 016/050 | Batch 450/469 | Cost: 13339.0303\n", "Epoch: 017/050 | Batch 000/469 | Cost: 13266.2793\n", "Epoch: 017/050 | Batch 050/469 | Cost: 13568.0957\n", "Epoch: 017/050 | Batch 100/469 | Cost: 12923.4482\n", "Epoch: 017/050 | Batch 150/469 | Cost: 14093.2070\n", "Epoch: 017/050 | Batch 200/469 | Cost: 13326.7510\n", "Epoch: 017/050 | Batch 250/469 | Cost: 13965.0625\n", "Epoch: 017/050 | Batch 300/469 | Cost: 13380.1445\n", "Epoch: 017/050 | Batch 350/469 | Cost: 13277.1875\n", "Epoch: 017/050 | Batch 400/469 | Cost: 13872.7607\n", "Epoch: 017/050 | Batch 450/469 | Cost: 13272.6797\n", "Epoch: 018/050 | Batch 000/469 | Cost: 13048.2461\n", "Epoch: 018/050 | Batch 050/469 | Cost: 13508.2314\n", "Epoch: 018/050 | Batch 100/469 | Cost: 12814.9834\n", "Epoch: 018/050 | Batch 150/469 | Cost: 13623.1924\n", "Epoch: 018/050 | Batch 200/469 | Cost: 13246.6113\n", "Epoch: 018/050 | Batch 250/469 | Cost: 13471.1328\n", "Epoch: 018/050 | Batch 300/469 | Cost: 13271.7930\n", "Epoch: 018/050 | Batch 350/469 | Cost: 13494.4883\n", "Epoch: 018/050 | Batch 400/469 | Cost: 13280.4316\n", "Epoch: 018/050 | Batch 450/469 | Cost: 13408.9775\n", "Epoch: 019/050 | Batch 000/469 | Cost: 13347.4941\n", "Epoch: 019/050 | Batch 050/469 | Cost: 13403.6006\n", "Epoch: 019/050 | Batch 100/469 | Cost: 12944.8574\n", "Epoch: 019/050 | Batch 150/469 | Cost: 13410.6201\n", "Epoch: 019/050 | Batch 200/469 | Cost: 13398.5342\n", "Epoch: 019/050 | Batch 250/469 | Cost: 13786.6992\n", "Epoch: 019/050 | Batch 300/469 | Cost: 12185.1465\n", "Epoch: 019/050 | Batch 350/469 | Cost: 13143.7744\n", "Epoch: 019/050 | Batch 400/469 | Cost: 13101.7451\n", "Epoch: 019/050 | Batch 450/469 | Cost: 13410.3252\n", "Epoch: 020/050 | Batch 000/469 | Cost: 13459.7676\n", "Epoch: 020/050 | Batch 050/469 | Cost: 13551.5127\n", "Epoch: 020/050 | Batch 100/469 | Cost: 13246.3486\n", "Epoch: 020/050 | Batch 150/469 | Cost: 13524.1133\n", "Epoch: 020/050 | Batch 200/469 | Cost: 13695.5605\n", "Epoch: 020/050 | Batch 250/469 | Cost: 13447.3887\n", "Epoch: 020/050 | Batch 300/469 | Cost: 13389.4941\n", "Epoch: 020/050 | Batch 350/469 | Cost: 13180.2422\n", "Epoch: 020/050 | Batch 400/469 | Cost: 13606.3457\n", "Epoch: 020/050 | Batch 450/469 | Cost: 13646.9355\n", "Epoch: 021/050 | Batch 000/469 | Cost: 13557.3623\n", "Epoch: 021/050 | Batch 050/469 | Cost: 13147.0098\n", "Epoch: 021/050 | Batch 100/469 | Cost: 13287.7227\n", "Epoch: 021/050 | Batch 150/469 | Cost: 12849.9639\n", "Epoch: 021/050 | Batch 200/469 | Cost: 13058.6406\n", "Epoch: 021/050 | Batch 250/469 | Cost: 13192.6367\n", "Epoch: 021/050 | Batch 300/469 | Cost: 13393.4082\n", "Epoch: 021/050 | Batch 350/469 | Cost: 13834.2705\n", "Epoch: 021/050 | Batch 400/469 | Cost: 13503.1680\n", "Epoch: 021/050 | Batch 450/469 | Cost: 13592.5518\n", "Epoch: 022/050 | Batch 000/469 | Cost: 13658.5986\n", "Epoch: 022/050 | Batch 050/469 | Cost: 13389.6855\n", "Epoch: 022/050 | Batch 100/469 | Cost: 13313.9707\n", "Epoch: 022/050 | Batch 150/469 | Cost: 13508.8438\n", "Epoch: 022/050 | Batch 200/469 | Cost: 12984.4082\n", "Epoch: 022/050 | Batch 250/469 | Cost: 13159.5137\n", "Epoch: 022/050 | Batch 300/469 | Cost: 13195.3516\n", "Epoch: 022/050 | Batch 350/469 | Cost: 13606.1777\n", "Epoch: 022/050 | Batch 400/469 | Cost: 12865.0508\n", "Epoch: 022/050 | Batch 450/469 | Cost: 13227.6514\n", "Epoch: 023/050 | Batch 000/469 | Cost: 13067.6016\n", "Epoch: 023/050 | Batch 050/469 | Cost: 13425.5498\n", "Epoch: 023/050 | Batch 100/469 | Cost: 13016.2773\n", "Epoch: 023/050 | Batch 150/469 | Cost: 13322.1260\n", "Epoch: 023/050 | Batch 200/469 | Cost: 12861.3926\n", "Epoch: 023/050 | Batch 250/469 | Cost: 13000.5967\n", "Epoch: 023/050 | Batch 300/469 | Cost: 13761.4629\n", "Epoch: 023/050 | Batch 350/469 | Cost: 13482.9814\n", "Epoch: 023/050 | Batch 400/469 | Cost: 12838.6201\n", "Epoch: 023/050 | Batch 450/469 | Cost: 13252.9746\n", "Epoch: 024/050 | Batch 000/469 | Cost: 13445.4590\n", "Epoch: 024/050 | Batch 050/469 | Cost: 13583.6416\n", "Epoch: 024/050 | Batch 100/469 | Cost: 13507.9443\n", "Epoch: 024/050 | Batch 150/469 | Cost: 13385.5938\n", "Epoch: 024/050 | Batch 200/469 | Cost: 13271.1357\n", "Epoch: 024/050 | Batch 250/469 | Cost: 13643.7109\n", "Epoch: 024/050 | Batch 300/469 | Cost: 13713.0889\n", "Epoch: 024/050 | Batch 350/469 | Cost: 12844.5703\n", "Epoch: 024/050 | Batch 400/469 | Cost: 12984.4746\n", "Epoch: 024/050 | Batch 450/469 | Cost: 13015.4365\n", "Epoch: 025/050 | Batch 000/469 | Cost: 13575.6875\n", "Epoch: 025/050 | Batch 050/469 | Cost: 13195.2832\n", "Epoch: 025/050 | Batch 100/469 | Cost: 13478.4746\n", "Epoch: 025/050 | Batch 150/469 | Cost: 13194.2852\n", "Epoch: 025/050 | Batch 200/469 | Cost: 12877.8242\n", "Epoch: 025/050 | Batch 250/469 | Cost: 13061.2148\n", "Epoch: 025/050 | Batch 300/469 | Cost: 13397.2266\n", "Epoch: 025/050 | Batch 350/469 | Cost: 12763.3711\n", "Epoch: 025/050 | Batch 400/469 | Cost: 13262.5332\n", "Epoch: 025/050 | Batch 450/469 | Cost: 13390.2393\n", "Epoch: 026/050 | Batch 000/469 | Cost: 13211.5508\n", "Epoch: 026/050 | Batch 050/469 | Cost: 13458.3652\n", "Epoch: 026/050 | Batch 100/469 | Cost: 12846.0557\n", "Epoch: 026/050 | Batch 150/469 | Cost: 12842.9570\n", "Epoch: 026/050 | Batch 200/469 | Cost: 13594.9395\n", "Epoch: 026/050 | Batch 250/469 | Cost: 13021.0605\n", "Epoch: 026/050 | Batch 300/469 | Cost: 13126.7686\n", "Epoch: 026/050 | Batch 350/469 | Cost: 12951.5898\n", "Epoch: 026/050 | Batch 400/469 | Cost: 13600.2119\n", "Epoch: 026/050 | Batch 450/469 | Cost: 13313.3535\n", "Epoch: 027/050 | Batch 000/469 | Cost: 12835.4717\n", "Epoch: 027/050 | Batch 050/469 | Cost: 12731.1875\n", "Epoch: 027/050 | Batch 100/469 | Cost: 13234.9297\n", "Epoch: 027/050 | Batch 150/469 | Cost: 13105.2148\n", "Epoch: 027/050 | Batch 200/469 | Cost: 13234.0684\n", "Epoch: 027/050 | Batch 250/469 | Cost: 13147.0801\n", "Epoch: 027/050 | Batch 300/469 | Cost: 13271.8262\n", "Epoch: 027/050 | Batch 350/469 | Cost: 12936.5947\n", "Epoch: 027/050 | Batch 400/469 | Cost: 13336.0293\n", "Epoch: 027/050 | Batch 450/469 | Cost: 13387.3662\n", "Epoch: 028/050 | Batch 000/469 | Cost: 13452.3438\n", "Epoch: 028/050 | Batch 050/469 | Cost: 13245.0342\n", "Epoch: 028/050 | Batch 100/469 | Cost: 13007.5234\n", "Epoch: 028/050 | Batch 150/469 | Cost: 13068.0166\n", "Epoch: 028/050 | Batch 200/469 | Cost: 12575.0166\n", "Epoch: 028/050 | Batch 250/469 | Cost: 13051.9434\n", "Epoch: 028/050 | Batch 300/469 | Cost: 13185.3330\n", "Epoch: 028/050 | Batch 350/469 | Cost: 13587.7715\n", "Epoch: 028/050 | Batch 400/469 | Cost: 12877.1436\n", "Epoch: 028/050 | Batch 450/469 | Cost: 13305.9297\n", "Epoch: 029/050 | Batch 000/469 | Cost: 13244.1865\n", "Epoch: 029/050 | Batch 050/469 | Cost: 13002.6309\n", "Epoch: 029/050 | Batch 100/469 | Cost: 13432.6504\n", "Epoch: 029/050 | Batch 150/469 | Cost: 13128.8027\n", "Epoch: 029/050 | Batch 200/469 | Cost: 12879.6543\n", "Epoch: 029/050 | Batch 250/469 | Cost: 13248.0068\n", "Epoch: 029/050 | Batch 300/469 | Cost: 13176.9912\n", "Epoch: 029/050 | Batch 350/469 | Cost: 13055.7490\n", "Epoch: 029/050 | Batch 400/469 | Cost: 13092.9580\n", "Epoch: 029/050 | Batch 450/469 | Cost: 13179.1875\n", "Epoch: 030/050 | Batch 000/469 | Cost: 13205.4668\n", "Epoch: 030/050 | Batch 050/469 | Cost: 13425.4883\n", "Epoch: 030/050 | Batch 100/469 | Cost: 12924.2070\n", "Epoch: 030/050 | Batch 150/469 | Cost: 13293.8105\n", "Epoch: 030/050 | Batch 200/469 | Cost: 12805.0674\n", "Epoch: 030/050 | Batch 250/469 | Cost: 12823.4629\n", "Epoch: 030/050 | Batch 300/469 | Cost: 12680.0322\n", "Epoch: 030/050 | Batch 350/469 | Cost: 13412.4023\n", "Epoch: 030/050 | Batch 400/469 | Cost: 13796.5479\n", "Epoch: 030/050 | Batch 450/469 | Cost: 13084.7051\n", "Epoch: 031/050 | Batch 000/469 | Cost: 13054.2988\n", "Epoch: 031/050 | Batch 050/469 | Cost: 13315.4570\n", "Epoch: 031/050 | Batch 100/469 | Cost: 13284.9463\n", "Epoch: 031/050 | Batch 150/469 | Cost: 13184.4668\n", "Epoch: 031/050 | Batch 200/469 | Cost: 13099.4189\n", "Epoch: 031/050 | Batch 250/469 | Cost: 13391.0918\n", "Epoch: 031/050 | Batch 300/469 | Cost: 13057.3223\n", "Epoch: 031/050 | Batch 350/469 | Cost: 13442.3750\n", "Epoch: 031/050 | Batch 400/469 | Cost: 13491.5635\n", "Epoch: 031/050 | Batch 450/469 | Cost: 13054.0693\n", "Epoch: 032/050 | Batch 000/469 | Cost: 13219.3789\n", "Epoch: 032/050 | Batch 050/469 | Cost: 12822.7051\n", "Epoch: 032/050 | Batch 100/469 | Cost: 13439.6436\n", "Epoch: 032/050 | Batch 150/469 | Cost: 12843.7061\n", "Epoch: 032/050 | Batch 200/469 | Cost: 13097.7012\n", "Epoch: 032/050 | Batch 250/469 | Cost: 12950.4707\n", "Epoch: 032/050 | Batch 300/469 | Cost: 13238.1094\n", "Epoch: 032/050 | Batch 350/469 | Cost: 13027.9121\n", "Epoch: 032/050 | Batch 400/469 | Cost: 13150.9277\n", "Epoch: 032/050 | Batch 450/469 | Cost: 13239.6348\n", "Epoch: 033/050 | Batch 000/469 | Cost: 12967.9863\n", "Epoch: 033/050 | Batch 050/469 | Cost: 13261.3467\n", "Epoch: 033/050 | Batch 100/469 | Cost: 13218.9023\n", "Epoch: 033/050 | Batch 150/469 | Cost: 13092.8994\n", "Epoch: 033/050 | Batch 200/469 | Cost: 12983.0459\n", "Epoch: 033/050 | Batch 250/469 | Cost: 13031.2188\n", "Epoch: 033/050 | Batch 300/469 | Cost: 12894.7129\n", "Epoch: 033/050 | Batch 350/469 | Cost: 13563.2578\n", "Epoch: 033/050 | Batch 400/469 | Cost: 13094.8340\n", "Epoch: 033/050 | Batch 450/469 | Cost: 13279.9639\n", "Epoch: 034/050 | Batch 000/469 | Cost: 12986.0615\n", "Epoch: 034/050 | Batch 050/469 | Cost: 12981.4004\n", "Epoch: 034/050 | Batch 100/469 | Cost: 13308.1504\n", "Epoch: 034/050 | Batch 150/469 | Cost: 13338.7227\n", "Epoch: 034/050 | Batch 200/469 | Cost: 13310.7227\n", "Epoch: 034/050 | Batch 250/469 | Cost: 13158.7334\n", "Epoch: 034/050 | Batch 300/469 | Cost: 13248.9336\n", "Epoch: 034/050 | Batch 350/469 | Cost: 13256.2227\n", "Epoch: 034/050 | Batch 400/469 | Cost: 12818.7148\n", "Epoch: 034/050 | Batch 450/469 | Cost: 12835.6738\n", "Epoch: 035/050 | Batch 000/469 | Cost: 12766.6123\n", "Epoch: 035/050 | Batch 050/469 | Cost: 12521.5166\n", "Epoch: 035/050 | Batch 100/469 | Cost: 12340.0430\n", "Epoch: 035/050 | Batch 150/469 | Cost: 12873.6191\n", "Epoch: 035/050 | Batch 200/469 | Cost: 13027.2266\n", "Epoch: 035/050 | Batch 250/469 | Cost: 13575.8379\n", "Epoch: 035/050 | Batch 300/469 | Cost: 13458.8867\n", "Epoch: 035/050 | Batch 350/469 | Cost: 12816.4980\n", "Epoch: 035/050 | Batch 400/469 | Cost: 12663.2207\n", "Epoch: 035/050 | Batch 450/469 | Cost: 12733.6777\n", "Epoch: 036/050 | Batch 000/469 | Cost: 13078.8682\n", "Epoch: 036/050 | Batch 050/469 | Cost: 13072.0742\n", "Epoch: 036/050 | Batch 100/469 | Cost: 12666.5215\n", "Epoch: 036/050 | Batch 150/469 | Cost: 13091.2852\n", "Epoch: 036/050 | Batch 200/469 | Cost: 13462.2529\n", "Epoch: 036/050 | Batch 250/469 | Cost: 12630.4287\n", "Epoch: 036/050 | Batch 300/469 | Cost: 13213.3223\n", "Epoch: 036/050 | Batch 350/469 | Cost: 13298.2490\n", "Epoch: 036/050 | Batch 400/469 | Cost: 12989.6328\n", "Epoch: 036/050 | Batch 450/469 | Cost: 12918.6348\n", "Epoch: 037/050 | Batch 000/469 | Cost: 12605.0732\n", "Epoch: 037/050 | Batch 050/469 | Cost: 13055.5742\n", "Epoch: 037/050 | Batch 100/469 | Cost: 12719.5420\n", "Epoch: 037/050 | Batch 150/469 | Cost: 12599.2461\n", "Epoch: 037/050 | Batch 200/469 | Cost: 12545.8223\n", "Epoch: 037/050 | Batch 250/469 | Cost: 12449.0918\n", "Epoch: 037/050 | Batch 300/469 | Cost: 13342.7930\n", "Epoch: 037/050 | Batch 350/469 | Cost: 13066.0029\n", "Epoch: 037/050 | Batch 400/469 | Cost: 13258.0957\n", "Epoch: 037/050 | Batch 450/469 | Cost: 13180.6914\n", "Epoch: 038/050 | Batch 000/469 | Cost: 12527.9854\n", "Epoch: 038/050 | Batch 050/469 | Cost: 13618.6875\n", "Epoch: 038/050 | Batch 100/469 | Cost: 13039.2627\n", "Epoch: 038/050 | Batch 150/469 | Cost: 13062.9453\n", "Epoch: 038/050 | Batch 200/469 | Cost: 13139.8945\n", "Epoch: 038/050 | Batch 250/469 | Cost: 13168.6621\n", "Epoch: 038/050 | Batch 300/469 | Cost: 12623.4629\n", "Epoch: 038/050 | Batch 350/469 | Cost: 12757.8447\n", "Epoch: 038/050 | Batch 400/469 | Cost: 12830.0762\n", "Epoch: 038/050 | Batch 450/469 | Cost: 12733.7969\n", "Epoch: 039/050 | Batch 000/469 | Cost: 12927.6729\n", "Epoch: 039/050 | Batch 050/469 | Cost: 13016.1133\n", "Epoch: 039/050 | Batch 100/469 | Cost: 12955.1621\n", "Epoch: 039/050 | Batch 150/469 | Cost: 12945.2852\n", "Epoch: 039/050 | Batch 200/469 | Cost: 12680.2188\n", "Epoch: 039/050 | Batch 250/469 | Cost: 12958.4688\n", "Epoch: 039/050 | Batch 300/469 | Cost: 13075.4912\n", "Epoch: 039/050 | Batch 350/469 | Cost: 12962.3750\n", "Epoch: 039/050 | Batch 400/469 | Cost: 12863.8867\n", "Epoch: 039/050 | Batch 450/469 | Cost: 13399.3818\n", "Epoch: 040/050 | Batch 000/469 | Cost: 12694.0283\n", "Epoch: 040/050 | Batch 050/469 | Cost: 13524.2754\n", "Epoch: 040/050 | Batch 100/469 | Cost: 12840.9316\n", "Epoch: 040/050 | Batch 150/469 | Cost: 12661.5918\n", "Epoch: 040/050 | Batch 200/469 | Cost: 13256.4902\n", "Epoch: 040/050 | Batch 250/469 | Cost: 13027.6816\n", "Epoch: 040/050 | Batch 300/469 | Cost: 12941.4727\n", "Epoch: 040/050 | Batch 350/469 | Cost: 12656.1348\n", "Epoch: 040/050 | Batch 400/469 | Cost: 12979.4785\n", "Epoch: 040/050 | Batch 450/469 | Cost: 12705.2158\n", "Epoch: 041/050 | Batch 000/469 | Cost: 12759.9707\n", "Epoch: 041/050 | Batch 050/469 | Cost: 12406.5781\n", "Epoch: 041/050 | Batch 100/469 | Cost: 12696.9307\n", "Epoch: 041/050 | Batch 150/469 | Cost: 13398.3613\n", "Epoch: 041/050 | Batch 200/469 | Cost: 12777.3418\n", "Epoch: 041/050 | Batch 250/469 | Cost: 12854.2783\n", "Epoch: 041/050 | Batch 300/469 | Cost: 13037.7236\n", "Epoch: 041/050 | Batch 350/469 | Cost: 13410.0801\n", "Epoch: 041/050 | Batch 400/469 | Cost: 13350.9121\n", "Epoch: 041/050 | Batch 450/469 | Cost: 12898.7432\n", "Epoch: 042/050 | Batch 000/469 | Cost: 12766.4316\n", "Epoch: 042/050 | Batch 050/469 | Cost: 13303.4766\n", "Epoch: 042/050 | Batch 100/469 | Cost: 13112.1465\n", "Epoch: 042/050 | Batch 150/469 | Cost: 12951.8428\n", "Epoch: 042/050 | Batch 200/469 | Cost: 13367.4814\n", "Epoch: 042/050 | Batch 250/469 | Cost: 13274.8955\n", "Epoch: 042/050 | Batch 300/469 | Cost: 12908.4805\n", "Epoch: 042/050 | Batch 350/469 | Cost: 12858.0723\n", "Epoch: 042/050 | Batch 400/469 | Cost: 13279.5410\n", "Epoch: 042/050 | Batch 450/469 | Cost: 12669.7422\n", "Epoch: 043/050 | Batch 000/469 | Cost: 13124.0225\n", "Epoch: 043/050 | Batch 050/469 | Cost: 12976.3857\n", "Epoch: 043/050 | Batch 100/469 | Cost: 12655.5703\n", "Epoch: 043/050 | Batch 150/469 | Cost: 12876.7061\n", "Epoch: 043/050 | Batch 200/469 | Cost: 13277.1592\n", "Epoch: 043/050 | Batch 250/469 | Cost: 12657.5117\n", "Epoch: 043/050 | Batch 300/469 | Cost: 12915.8867\n", "Epoch: 043/050 | Batch 350/469 | Cost: 13254.9941\n", "Epoch: 043/050 | Batch 400/469 | Cost: 12649.9102\n", "Epoch: 043/050 | Batch 450/469 | Cost: 13198.5771\n", "Epoch: 044/050 | Batch 000/469 | Cost: 13573.9121\n", "Epoch: 044/050 | Batch 050/469 | Cost: 12972.9453\n", "Epoch: 044/050 | Batch 100/469 | Cost: 12764.2188\n", "Epoch: 044/050 | Batch 150/469 | Cost: 13482.2910\n", "Epoch: 044/050 | Batch 200/469 | Cost: 13304.8975\n", "Epoch: 044/050 | Batch 250/469 | Cost: 13446.4141\n", "Epoch: 044/050 | Batch 300/469 | Cost: 13096.8887\n", "Epoch: 044/050 | Batch 350/469 | Cost: 13551.5537\n", "Epoch: 044/050 | Batch 400/469 | Cost: 12693.8301\n", "Epoch: 044/050 | Batch 450/469 | Cost: 12812.8682\n", "Epoch: 045/050 | Batch 000/469 | Cost: 12698.2207\n", "Epoch: 045/050 | Batch 050/469 | Cost: 12705.6787\n", "Epoch: 045/050 | Batch 100/469 | Cost: 13046.6162\n", "Epoch: 045/050 | Batch 150/469 | Cost: 13085.3457\n", "Epoch: 045/050 | Batch 200/469 | Cost: 12922.5312\n", "Epoch: 045/050 | Batch 250/469 | Cost: 13367.9189\n", "Epoch: 045/050 | Batch 300/469 | Cost: 12917.3457\n", "Epoch: 045/050 | Batch 350/469 | Cost: 12896.9463\n", "Epoch: 045/050 | Batch 400/469 | Cost: 13378.9902\n", "Epoch: 045/050 | Batch 450/469 | Cost: 12873.3105\n", "Epoch: 046/050 | Batch 000/469 | Cost: 12739.6260\n", "Epoch: 046/050 | Batch 050/469 | Cost: 13021.6465\n", "Epoch: 046/050 | Batch 100/469 | Cost: 13027.7256\n", "Epoch: 046/050 | Batch 150/469 | Cost: 12995.7490\n", "Epoch: 046/050 | Batch 200/469 | Cost: 12588.5645\n", "Epoch: 046/050 | Batch 250/469 | Cost: 13288.1494\n", "Epoch: 046/050 | Batch 300/469 | Cost: 12766.4707\n", "Epoch: 046/050 | Batch 350/469 | Cost: 12326.2334\n", "Epoch: 046/050 | Batch 400/469 | Cost: 13174.7734\n", "Epoch: 046/050 | Batch 450/469 | Cost: 12531.1074\n", "Epoch: 047/050 | Batch 000/469 | Cost: 13050.0781\n", "Epoch: 047/050 | Batch 050/469 | Cost: 12920.9609\n", "Epoch: 047/050 | Batch 100/469 | Cost: 12954.7383\n", "Epoch: 047/050 | Batch 150/469 | Cost: 12598.8203\n", "Epoch: 047/050 | Batch 200/469 | Cost: 12969.3066\n", "Epoch: 047/050 | Batch 250/469 | Cost: 12893.0693\n", "Epoch: 047/050 | Batch 300/469 | Cost: 12975.1309\n", "Epoch: 047/050 | Batch 350/469 | Cost: 13452.9775\n", "Epoch: 047/050 | Batch 400/469 | Cost: 13320.0781\n", "Epoch: 047/050 | Batch 450/469 | Cost: 13015.5547\n", "Epoch: 048/050 | Batch 000/469 | Cost: 12687.9102\n", "Epoch: 048/050 | Batch 050/469 | Cost: 12957.4678\n", "Epoch: 048/050 | Batch 100/469 | Cost: 13027.3281\n", "Epoch: 048/050 | Batch 150/469 | Cost: 13472.4619\n", "Epoch: 048/050 | Batch 200/469 | Cost: 12705.8525\n", "Epoch: 048/050 | Batch 250/469 | Cost: 13201.4590\n", "Epoch: 048/050 | Batch 300/469 | Cost: 13181.9707\n", "Epoch: 048/050 | Batch 350/469 | Cost: 13372.9746\n", "Epoch: 048/050 | Batch 400/469 | Cost: 12831.2305\n", "Epoch: 048/050 | Batch 450/469 | Cost: 12963.0156\n", "Epoch: 049/050 | Batch 000/469 | Cost: 12832.8057\n", "Epoch: 049/050 | Batch 050/469 | Cost: 12771.7109\n", "Epoch: 049/050 | Batch 100/469 | Cost: 13143.8457\n", "Epoch: 049/050 | Batch 150/469 | Cost: 12863.8740\n", "Epoch: 049/050 | Batch 200/469 | Cost: 12841.9248\n", "Epoch: 049/050 | Batch 250/469 | Cost: 13240.2529\n", "Epoch: 049/050 | Batch 300/469 | Cost: 12889.4521\n", "Epoch: 049/050 | Batch 350/469 | Cost: 13022.1709\n", "Epoch: 049/050 | Batch 400/469 | Cost: 12963.3125\n", "Epoch: 049/050 | Batch 450/469 | Cost: 12778.5381\n", "Epoch: 050/050 | Batch 000/469 | Cost: 12820.4805\n", "Epoch: 050/050 | Batch 050/469 | Cost: 12769.4893\n", "Epoch: 050/050 | Batch 100/469 | Cost: 12682.0566\n", "Epoch: 050/050 | Batch 150/469 | Cost: 13312.6172\n", "Epoch: 050/050 | Batch 200/469 | Cost: 13716.5430\n", "Epoch: 050/050 | Batch 250/469 | Cost: 12201.1973\n", "Epoch: 050/050 | Batch 300/469 | Cost: 13228.4199\n", "Epoch: 050/050 | Batch 350/469 | Cost: 12986.6201\n", "Epoch: 050/050 | Batch 400/469 | Cost: 13097.1562\n", "Epoch: 050/050 | Batch 450/469 | Cost: 13272.6777\n" ] } ], "source": [ "start_time = time.time()\n", "for epoch in range(num_epochs):\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " # don't need labels, only the images (features)\n", " features = features.view(-1, 28*28).to(device)\n", "\n", " ### FORWARD AND BACK PROP\n", " z_mean, z_log_var, encoded, decoded = model(features)\n", "\n", " # cost = reconstruction loss + Kullback-Leibler divergence\n", " kl_divergence = (0.5 * (z_mean**2 + \n", " torch.exp(z_log_var) - z_log_var - 1)).sum()\n", " pixelwise_bce = F.binary_cross_entropy(decoded, features, reduction='sum')\n", " cost = kl_divergence + pixelwise_bce\n", " \n", " optimizer.zero_grad()\n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 50:\n", " print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n", " %(epoch+1, num_epochs, batch_idx, \n", " len(train_loader), cost))\n", " \n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reconstruction" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "##########################\n", "### VISUALIZATION\n", "##########################\n", "\n", "n_images = 15\n", "image_width = 28\n", "\n", "fig, axes = plt.subplots(nrows=2, ncols=n_images, \n", " sharex=True, sharey=True, figsize=(20, 2.5))\n", "orig_images = features[:n_images]\n", "decoded_images = decoded[:n_images]\n", "\n", "for i in range(n_images):\n", " for ax, img in zip(axes, [orig_images, decoded_images]):\n", " curr_img = img[i].detach().to(torch.device('cpu'))\n", " ax[i].imshow(curr_img.view((image_width, image_width)), cmap='binary')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate new images" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for i in range(10):\n", "\n", " ##########################\n", " ### RANDOM SAMPLE\n", " ########################## \n", " \n", " n_images = 10\n", " rand_features = torch.randn(n_images, num_latent).to(device)\n", " new_images = model.decoder(rand_features)\n", "\n", " ##########################\n", " ### VISUALIZATION\n", " ##########################\n", "\n", " image_width = 28\n", "\n", " fig, axes = plt.subplots(nrows=1, ncols=n_images, figsize=(10, 2.5), sharey=True)\n", " decoded_images = new_images[:n_images]\n", "\n", " for ax, img in zip(axes, decoded_images):\n", " curr_img = img.detach().to(torch.device('cpu'))\n", " ax.imshow(curr_img.view((image_width, image_width)), cmap='binary')\n", " \n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy 1.15.4\n", "torch 1.0.0\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }