{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.6.8\n", "IPython 7.2.0\n", "\n", "torch 1.0.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Zoo -- Convolutional Variational Autoencoder" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A simple convolutional variational autoencoder that compresses 768-pixel MNIST images down to a 15-pixel latent vector representation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "import torch\n", "import torch.nn.functional as F\n", "from torch.utils.data import DataLoader\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Device: cuda:1\n", "Image batch dimensions: torch.Size([128, 1, 28, 28])\n", "Image label dimensions: torch.Size([128])\n" ] } ], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n", "print('Device:', device)\n", "\n", "# Hyperparameters\n", "random_seed = 0\n", "learning_rate = 0.001\n", "num_epochs = 50\n", "batch_size = 128\n", "\n", "# Architecture\n", "num_features = 784\n", "num_latent = 15\n", "\n", "\n", "##########################\n", "### MNIST DATASET\n", "##########################\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "train_dataset = datasets.MNIST(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=batch_size, \n", " shuffle=True)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=batch_size, \n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "## ### MODEL\n", "##########################\n", "\n", "class ConvVariationalAutoencoder(torch.nn.Module):\n", "\n", " def __init__(self, num_features, num_latent):\n", " super(ConvVariationalAutoencoder, self).__init__()\n", " \n", " ###############\n", " # ENCODER\n", " ##############\n", " \n", " # calculate same padding:\n", " # (w - k + 2*p)/s + 1 = o\n", " # => p = (s(o-1) - w + k)/2\n", "\n", " self.enc_conv_1 = torch.nn.Conv2d(in_channels=1,\n", " out_channels=16,\n", " kernel_size=(6, 6),\n", " stride=(2, 2),\n", " padding=0)\n", "\n", " self.enc_conv_2 = torch.nn.Conv2d(in_channels=16,\n", " out_channels=32,\n", " kernel_size=(4, 4),\n", " stride=(2, 2),\n", " padding=0) \n", " \n", " self.enc_conv_3 = torch.nn.Conv2d(in_channels=32,\n", " out_channels=64,\n", " kernel_size=(2, 2),\n", " stride=(2, 2),\n", " padding=0) \n", " \n", " self.z_mean = torch.nn.Linear(64*2*2, num_latent)\n", " # in the original paper (Kingma & Welling 2015, we use\n", " # have a z_mean and z_var, but the problem is that\n", " # the z_var can be negative, which would cause issues\n", " # in the log later. Hence we assume that latent vector\n", " # has a z_mean and z_log_var component, and when we need\n", " # the regular variance or std_dev, we simply use \n", " # an exponential function\n", " self.z_log_var = torch.nn.Linear(64*2*2, num_latent)\n", " \n", " \n", " \n", " ###############\n", " # DECODER\n", " ##############\n", " \n", " self.dec_linear_1 = torch.nn.Linear(num_latent, 64*2*2)\n", " \n", " self.dec_deconv_1 = torch.nn.ConvTranspose2d(in_channels=64,\n", " out_channels=32,\n", " kernel_size=(2, 2),\n", " stride=(2, 2),\n", " padding=0)\n", " \n", " self.dec_deconv_2 = torch.nn.ConvTranspose2d(in_channels=32,\n", " out_channels=16,\n", " kernel_size=(4, 4),\n", " stride=(3, 3),\n", " padding=1)\n", " \n", " self.dec_deconv_3 = torch.nn.ConvTranspose2d(in_channels=16,\n", " out_channels=1,\n", " kernel_size=(6, 6),\n", " stride=(3, 3),\n", " padding=4)\n", "\n", "\n", " def reparameterize(self, z_mu, z_log_var):\n", " # Sample epsilon from standard normal distribution\n", " eps = torch.randn(z_mu.size(0), z_mu.size(1)).to(device)\n", " # note that log(x^2) = 2*log(x); hence divide by 2 to get std_dev\n", " # i.e., std_dev = exp(log(std_dev^2)/2) = exp(log(var)/2)\n", " z = z_mu + eps * torch.exp(z_log_var/2.) \n", " return z\n", " \n", " def encoder(self, features):\n", " x = self.enc_conv_1(features)\n", " x = F.leaky_relu(x)\n", " #print('conv1 out:', x.size())\n", " \n", " x = self.enc_conv_2(x)\n", " x = F.leaky_relu(x)\n", " #print('conv2 out:', x.size())\n", " \n", " x = self.enc_conv_3(x)\n", " x = F.leaky_relu(x)\n", " #print('conv3 out:', x.size())\n", " \n", " z_mean = self.z_mean(x.view(-1, 64*2*2))\n", " z_log_var = self.z_log_var(x.view(-1, 64*2*2))\n", " encoded = self.reparameterize(z_mean, z_log_var)\n", " \n", " return z_mean, z_log_var, encoded\n", " \n", " def decoder(self, encoded):\n", " x = self.dec_linear_1(encoded)\n", " x = x.view(-1, 64, 2, 2)\n", " \n", " x = self.dec_deconv_1(x)\n", " x = F.leaky_relu(x)\n", " #print('deconv1 out:', x.size())\n", " \n", " x = self.dec_deconv_2(x)\n", " x = F.leaky_relu(x)\n", " #print('deconv2 out:', x.size())\n", " \n", " x = self.dec_deconv_3(x)\n", " x = F.leaky_relu(x)\n", " #print('deconv1 out:', x.size())\n", " \n", " decoded = torch.sigmoid(x)\n", " return decoded\n", "\n", " def forward(self, features):\n", " \n", " z_mean, z_log_var, encoded = self.encoder(features)\n", " decoded = self.decoder(encoded)\n", " \n", " return z_mean, z_log_var, encoded, decoded\n", "\n", " \n", "torch.manual_seed(random_seed)\n", "model = ConvVariationalAutoencoder(num_features,\n", " num_latent)\n", "model = model.to(device)\n", " \n", "\n", "##########################\n", "### COST AND OPTIMIZER\n", "##########################\n", "\n", "optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/050 | Batch 000/469 | Cost: 70508.8125\n", "Epoch: 001/050 | Batch 050/469 | Cost: 69360.6953\n", "Epoch: 001/050 | Batch 100/469 | Cost: 41214.9258\n", "Epoch: 001/050 | Batch 150/469 | Cost: 34571.9688\n", "Epoch: 001/050 | Batch 200/469 | Cost: 28148.7500\n", "Epoch: 001/050 | Batch 250/469 | Cost: 29157.5918\n", "Epoch: 001/050 | Batch 300/469 | Cost: 27405.8359\n", "Epoch: 001/050 | Batch 350/469 | Cost: 26417.7324\n", "Epoch: 001/050 | Batch 400/469 | Cost: 25685.3320\n", "Epoch: 001/050 | Batch 450/469 | Cost: 25538.5293\n", "Time elapsed: 0.14 min\n", "Epoch: 002/050 | Batch 000/469 | Cost: 25301.1914\n", "Epoch: 002/050 | Batch 050/469 | Cost: 24180.0254\n", "Epoch: 002/050 | Batch 100/469 | Cost: 24124.1309\n", "Epoch: 002/050 | Batch 150/469 | Cost: 24285.9629\n", "Epoch: 002/050 | Batch 200/469 | Cost: 24396.1914\n", "Epoch: 002/050 | Batch 250/469 | Cost: 23214.3711\n", "Epoch: 002/050 | Batch 300/469 | Cost: 23690.0957\n", "Epoch: 002/050 | Batch 350/469 | Cost: 23472.6250\n", "Epoch: 002/050 | Batch 400/469 | Cost: 22830.4551\n", "Epoch: 002/050 | Batch 450/469 | Cost: 23235.6602\n", "Time elapsed: 0.28 min\n", "Epoch: 003/050 | Batch 000/469 | Cost: 21816.1934\n", "Epoch: 003/050 | Batch 050/469 | Cost: 22230.9316\n", "Epoch: 003/050 | Batch 100/469 | Cost: 21450.3535\n", "Epoch: 003/050 | Batch 150/469 | Cost: 22165.5781\n", "Epoch: 003/050 | Batch 200/469 | Cost: 20737.0566\n", "Epoch: 003/050 | Batch 250/469 | Cost: 22005.0625\n", "Epoch: 003/050 | Batch 300/469 | Cost: 21199.4375\n", "Epoch: 003/050 | Batch 350/469 | Cost: 20788.6172\n", "Epoch: 003/050 | Batch 400/469 | Cost: 21091.8125\n", "Epoch: 003/050 | Batch 450/469 | Cost: 20739.7637\n", "Time elapsed: 0.43 min\n", "Epoch: 004/050 | Batch 000/469 | Cost: 20847.3320\n", "Epoch: 004/050 | Batch 050/469 | Cost: 21085.3945\n", "Epoch: 004/050 | Batch 100/469 | Cost: 19874.4863\n", "Epoch: 004/050 | Batch 150/469 | Cost: 20014.3594\n", "Epoch: 004/050 | Batch 200/469 | Cost: 19658.4062\n", "Epoch: 004/050 | Batch 250/469 | Cost: 19592.5801\n", "Epoch: 004/050 | Batch 300/469 | Cost: 20177.1680\n", "Epoch: 004/050 | Batch 350/469 | Cost: 19534.2344\n", "Epoch: 004/050 | Batch 400/469 | Cost: 19567.2852\n", "Epoch: 004/050 | Batch 450/469 | Cost: 19308.6367\n", "Time elapsed: 0.57 min\n", "Epoch: 005/050 | Batch 000/469 | Cost: 18299.0762\n", "Epoch: 005/050 | Batch 050/469 | Cost: 17929.8359\n", "Epoch: 005/050 | Batch 100/469 | Cost: 19014.4316\n", "Epoch: 005/050 | Batch 150/469 | Cost: 18907.9668\n", "Epoch: 005/050 | Batch 200/469 | Cost: 18992.6836\n", "Epoch: 005/050 | Batch 250/469 | Cost: 18611.7383\n", "Epoch: 005/050 | Batch 300/469 | Cost: 18453.7012\n", "Epoch: 005/050 | Batch 350/469 | Cost: 18959.7227\n", "Epoch: 005/050 | Batch 400/469 | Cost: 18798.1758\n", "Epoch: 005/050 | Batch 450/469 | Cost: 18019.3672\n", "Time elapsed: 0.71 min\n", "Epoch: 006/050 | Batch 000/469 | Cost: 18124.0820\n", "Epoch: 006/050 | Batch 050/469 | Cost: 18439.6680\n", "Epoch: 006/050 | Batch 100/469 | Cost: 17569.6094\n", "Epoch: 006/050 | Batch 150/469 | Cost: 18261.6934\n", "Epoch: 006/050 | Batch 200/469 | Cost: 17973.9492\n", "Epoch: 006/050 | Batch 250/469 | Cost: 16992.7305\n", "Epoch: 006/050 | Batch 300/469 | Cost: 18452.1992\n", "Epoch: 006/050 | Batch 350/469 | Cost: 17165.4297\n", "Epoch: 006/050 | Batch 400/469 | Cost: 18000.2754\n", "Epoch: 006/050 | Batch 450/469 | Cost: 16839.3262\n", "Time elapsed: 0.86 min\n", "Epoch: 007/050 | Batch 000/469 | Cost: 17863.0645\n", "Epoch: 007/050 | Batch 050/469 | Cost: 17572.0059\n", "Epoch: 007/050 | Batch 100/469 | Cost: 17348.5625\n", "Epoch: 007/050 | Batch 150/469 | Cost: 17124.4922\n", "Epoch: 007/050 | Batch 200/469 | Cost: 17443.2051\n", "Epoch: 007/050 | Batch 250/469 | Cost: 17221.6523\n", "Epoch: 007/050 | Batch 300/469 | Cost: 17059.4297\n", "Epoch: 007/050 | Batch 350/469 | Cost: 17353.8359\n", "Epoch: 007/050 | Batch 400/469 | Cost: 18116.3086\n", "Epoch: 007/050 | Batch 450/469 | Cost: 17090.7910\n", "Time elapsed: 1.00 min\n", "Epoch: 008/050 | Batch 000/469 | Cost: 17174.0098\n", "Epoch: 008/050 | Batch 050/469 | Cost: 16741.3477\n", "Epoch: 008/050 | Batch 100/469 | Cost: 16833.8691\n", "Epoch: 008/050 | Batch 150/469 | Cost: 17041.8145\n", "Epoch: 008/050 | Batch 200/469 | Cost: 16583.4785\n", "Epoch: 008/050 | Batch 250/469 | Cost: 17148.7363\n", "Epoch: 008/050 | Batch 300/469 | Cost: 16401.9492\n", "Epoch: 008/050 | Batch 350/469 | Cost: 16366.9717\n", "Epoch: 008/050 | Batch 400/469 | Cost: 16309.4883\n", "Epoch: 008/050 | Batch 450/469 | Cost: 16813.7383\n", "Time elapsed: 1.14 min\n", "Epoch: 009/050 | Batch 000/469 | Cost: 16475.1348\n", "Epoch: 009/050 | Batch 050/469 | Cost: 16717.6797\n", "Epoch: 009/050 | Batch 100/469 | Cost: 16681.8125\n", "Epoch: 009/050 | Batch 150/469 | Cost: 16367.4902\n", "Epoch: 009/050 | Batch 200/469 | Cost: 16425.5449\n", "Epoch: 009/050 | Batch 250/469 | Cost: 16841.6738\n", "Epoch: 009/050 | Batch 300/469 | Cost: 16003.4609\n", "Epoch: 009/050 | Batch 350/469 | Cost: 15953.0732\n", "Epoch: 009/050 | Batch 400/469 | Cost: 15981.5557\n", "Epoch: 009/050 | Batch 450/469 | Cost: 15866.3105\n", "Time elapsed: 1.28 min\n", "Epoch: 010/050 | Batch 000/469 | Cost: 16785.4121\n", "Epoch: 010/050 | Batch 050/469 | Cost: 16397.5430\n", "Epoch: 010/050 | Batch 100/469 | Cost: 16289.5566\n", "Epoch: 010/050 | Batch 150/469 | Cost: 16549.7812\n", "Epoch: 010/050 | Batch 200/469 | Cost: 16190.5586\n", "Epoch: 010/050 | Batch 250/469 | Cost: 15208.5176\n", "Epoch: 010/050 | Batch 300/469 | Cost: 15649.6221\n", "Epoch: 010/050 | Batch 350/469 | Cost: 15850.7285\n", "Epoch: 010/050 | Batch 400/469 | Cost: 15607.8145\n", "Epoch: 010/050 | Batch 450/469 | Cost: 16352.9707\n", "Time elapsed: 1.42 min\n", "Epoch: 011/050 | Batch 000/469 | Cost: 14833.6748\n", "Epoch: 011/050 | Batch 050/469 | Cost: 14793.8174\n", "Epoch: 011/050 | Batch 100/469 | Cost: 16031.2539\n", "Epoch: 011/050 | Batch 150/469 | Cost: 16403.2148\n", "Epoch: 011/050 | Batch 200/469 | Cost: 16180.4619\n", "Epoch: 011/050 | Batch 250/469 | Cost: 15964.9424\n", "Epoch: 011/050 | Batch 300/469 | Cost: 16027.1377\n", "Epoch: 011/050 | Batch 350/469 | Cost: 16350.3730\n", "Epoch: 011/050 | Batch 400/469 | Cost: 15546.7812\n", "Epoch: 011/050 | Batch 450/469 | Cost: 15494.3408\n", "Time elapsed: 1.56 min\n", "Epoch: 012/050 | Batch 000/469 | Cost: 15366.8662\n", "Epoch: 012/050 | Batch 050/469 | Cost: 15567.0410\n", "Epoch: 012/050 | Batch 100/469 | Cost: 15825.4131\n", "Epoch: 012/050 | Batch 150/469 | Cost: 15356.7363\n", "Epoch: 012/050 | Batch 200/469 | Cost: 16218.4111\n", "Epoch: 012/050 | Batch 250/469 | Cost: 15840.4629\n", "Epoch: 012/050 | Batch 300/469 | Cost: 15789.5957\n", "Epoch: 012/050 | Batch 350/469 | Cost: 16290.2920\n", "Epoch: 012/050 | Batch 400/469 | Cost: 16000.1152\n", "Epoch: 012/050 | Batch 450/469 | Cost: 15458.4883\n", "Time elapsed: 1.71 min\n", "Epoch: 013/050 | Batch 000/469 | Cost: 14845.1387\n", "Epoch: 013/050 | Batch 050/469 | Cost: 14813.1328\n", "Epoch: 013/050 | Batch 100/469 | Cost: 15130.9199\n", "Epoch: 013/050 | Batch 150/469 | Cost: 15422.9141\n", "Epoch: 013/050 | Batch 200/469 | Cost: 15566.4805\n", "Epoch: 013/050 | Batch 250/469 | Cost: 15794.4580\n", "Epoch: 013/050 | Batch 300/469 | Cost: 15083.1582\n", "Epoch: 013/050 | Batch 350/469 | Cost: 15447.7637\n", "Epoch: 013/050 | Batch 400/469 | Cost: 15675.3779\n", "Epoch: 013/050 | Batch 450/469 | Cost: 15165.6543\n", "Time elapsed: 1.85 min\n", "Epoch: 014/050 | Batch 000/469 | Cost: 15194.8164\n", "Epoch: 014/050 | Batch 050/469 | Cost: 15119.1504\n", "Epoch: 014/050 | Batch 100/469 | Cost: 15796.2129\n", "Epoch: 014/050 | Batch 150/469 | Cost: 14884.1680\n", "Epoch: 014/050 | Batch 200/469 | Cost: 15225.4922\n", "Epoch: 014/050 | Batch 250/469 | Cost: 15586.4531\n", "Epoch: 014/050 | Batch 300/469 | Cost: 14798.0352\n", "Epoch: 014/050 | Batch 350/469 | Cost: 15295.6680\n", "Epoch: 014/050 | Batch 400/469 | Cost: 15782.0469\n", "Epoch: 014/050 | Batch 450/469 | Cost: 15226.4424\n", "Time elapsed: 1.99 min\n", "Epoch: 015/050 | Batch 000/469 | Cost: 15213.7441\n", "Epoch: 015/050 | Batch 050/469 | Cost: 15049.6631\n", "Epoch: 015/050 | Batch 100/469 | Cost: 15464.3105\n", "Epoch: 015/050 | Batch 150/469 | Cost: 15114.6406\n", "Epoch: 015/050 | Batch 200/469 | Cost: 15309.3145\n", "Epoch: 015/050 | Batch 250/469 | Cost: 14940.2734\n", "Epoch: 015/050 | Batch 300/469 | Cost: 15016.6035\n", "Epoch: 015/050 | Batch 350/469 | Cost: 15046.3008\n", "Epoch: 015/050 | Batch 400/469 | Cost: 15167.2373\n", "Epoch: 015/050 | Batch 450/469 | Cost: 14859.8359\n", "Time elapsed: 2.13 min\n", "Epoch: 016/050 | Batch 000/469 | Cost: 15028.2578\n", "Epoch: 016/050 | Batch 050/469 | Cost: 14834.3887\n", "Epoch: 016/050 | Batch 100/469 | Cost: 15176.1133\n", "Epoch: 016/050 | Batch 150/469 | Cost: 15468.7812\n", "Epoch: 016/050 | Batch 200/469 | Cost: 15083.7363\n", "Epoch: 016/050 | Batch 250/469 | Cost: 14691.1562\n", "Epoch: 016/050 | Batch 300/469 | Cost: 15369.2461\n", "Epoch: 016/050 | Batch 350/469 | Cost: 14979.9854\n", "Epoch: 016/050 | Batch 400/469 | Cost: 14710.5820\n", "Epoch: 016/050 | Batch 450/469 | Cost: 15753.2812\n", "Time elapsed: 2.27 min\n", "Epoch: 017/050 | Batch 000/469 | Cost: 15676.1680\n", "Epoch: 017/050 | Batch 050/469 | Cost: 15123.8203\n", "Epoch: 017/050 | Batch 100/469 | Cost: 15131.5918\n", "Epoch: 017/050 | Batch 150/469 | Cost: 14856.3496\n", "Epoch: 017/050 | Batch 200/469 | Cost: 15176.2002\n", "Epoch: 017/050 | Batch 250/469 | Cost: 14768.6816\n", "Epoch: 017/050 | Batch 300/469 | Cost: 14871.7480\n", "Epoch: 017/050 | Batch 350/469 | Cost: 14418.3633\n", "Epoch: 017/050 | Batch 400/469 | Cost: 15398.9326\n", "Epoch: 017/050 | Batch 450/469 | Cost: 14675.2832\n", "Time elapsed: 2.40 min\n", "Epoch: 018/050 | Batch 000/469 | Cost: 15558.1592\n", "Epoch: 018/050 | Batch 050/469 | Cost: 14836.9766\n", "Epoch: 018/050 | Batch 100/469 | Cost: 14535.8203\n", "Epoch: 018/050 | Batch 150/469 | Cost: 15062.1992\n", "Epoch: 018/050 | Batch 200/469 | Cost: 15094.6914\n", "Epoch: 018/050 | Batch 250/469 | Cost: 15006.5684\n", "Epoch: 018/050 | Batch 300/469 | Cost: 14656.5703\n", "Epoch: 018/050 | Batch 350/469 | Cost: 15232.4990\n", "Epoch: 018/050 | Batch 400/469 | Cost: 15159.4854\n", "Epoch: 018/050 | Batch 450/469 | Cost: 15619.9785\n", "Time elapsed: 2.53 min\n", "Epoch: 019/050 | Batch 000/469 | Cost: 14647.2051\n", "Epoch: 019/050 | Batch 050/469 | Cost: 15262.9062\n", "Epoch: 019/050 | Batch 100/469 | Cost: 15305.6738\n", "Epoch: 019/050 | Batch 150/469 | Cost: 14550.4102\n", "Epoch: 019/050 | Batch 200/469 | Cost: 15431.4395\n", "Epoch: 019/050 | Batch 250/469 | Cost: 15205.6074\n", "Epoch: 019/050 | Batch 300/469 | Cost: 15149.4453\n", "Epoch: 019/050 | Batch 350/469 | Cost: 14836.1543\n", "Epoch: 019/050 | Batch 400/469 | Cost: 14699.8994\n", "Epoch: 019/050 | Batch 450/469 | Cost: 15564.8604\n", "Time elapsed: 2.67 min\n", "Epoch: 020/050 | Batch 000/469 | Cost: 15190.4043\n", "Epoch: 020/050 | Batch 050/469 | Cost: 15331.2246\n", "Epoch: 020/050 | Batch 100/469 | Cost: 14559.0176\n", "Epoch: 020/050 | Batch 150/469 | Cost: 14311.1699\n", "Epoch: 020/050 | Batch 200/469 | Cost: 14561.7070\n", "Epoch: 020/050 | Batch 250/469 | Cost: 15366.1982\n", "Epoch: 020/050 | Batch 300/469 | Cost: 14740.9365\n", "Epoch: 020/050 | Batch 350/469 | Cost: 14924.1406\n", "Epoch: 020/050 | Batch 400/469 | Cost: 14399.0762\n", "Epoch: 020/050 | Batch 450/469 | Cost: 15144.8867\n", "Time elapsed: 2.81 min\n", "Epoch: 021/050 | Batch 000/469 | Cost: 14497.8389\n", "Epoch: 021/050 | Batch 050/469 | Cost: 14999.7617\n", "Epoch: 021/050 | Batch 100/469 | Cost: 14503.3086\n", "Epoch: 021/050 | Batch 150/469 | Cost: 15366.3564\n", "Epoch: 021/050 | Batch 200/469 | Cost: 15190.8740\n", "Epoch: 021/050 | Batch 250/469 | Cost: 14832.3369\n", "Epoch: 021/050 | Batch 300/469 | Cost: 15091.1016\n", "Epoch: 021/050 | Batch 350/469 | Cost: 14928.2930\n", "Epoch: 021/050 | Batch 400/469 | Cost: 14790.8223\n", "Epoch: 021/050 | Batch 450/469 | Cost: 14803.0596\n", "Time elapsed: 2.95 min\n", "Epoch: 022/050 | Batch 000/469 | Cost: 14677.4326\n", "Epoch: 022/050 | Batch 050/469 | Cost: 14652.6543\n", "Epoch: 022/050 | Batch 100/469 | Cost: 15094.6904\n", "Epoch: 022/050 | Batch 150/469 | Cost: 14702.5977\n", "Epoch: 022/050 | Batch 200/469 | Cost: 15014.6758\n", "Epoch: 022/050 | Batch 250/469 | Cost: 14506.5420\n", "Epoch: 022/050 | Batch 300/469 | Cost: 14207.6309\n", "Epoch: 022/050 | Batch 350/469 | Cost: 14883.4453\n", "Epoch: 022/050 | Batch 400/469 | Cost: 14935.6797\n", "Epoch: 022/050 | Batch 450/469 | Cost: 14522.0771\n", "Time elapsed: 3.09 min\n", "Epoch: 023/050 | Batch 000/469 | Cost: 14545.0410\n", "Epoch: 023/050 | Batch 050/469 | Cost: 15465.3301\n", "Epoch: 023/050 | Batch 100/469 | Cost: 14911.1807\n", "Epoch: 023/050 | Batch 150/469 | Cost: 14108.9902\n", "Epoch: 023/050 | Batch 200/469 | Cost: 14171.8672\n", "Epoch: 023/050 | Batch 250/469 | Cost: 14510.0352\n", "Epoch: 023/050 | Batch 300/469 | Cost: 14746.7100\n", "Epoch: 023/050 | Batch 350/469 | Cost: 15409.6055\n", "Epoch: 023/050 | Batch 400/469 | Cost: 14423.5654\n", "Epoch: 023/050 | Batch 450/469 | Cost: 15278.3594\n", "Time elapsed: 3.23 min\n", "Epoch: 024/050 | Batch 000/469 | Cost: 14552.7031\n", "Epoch: 024/050 | Batch 050/469 | Cost: 14798.7969\n", "Epoch: 024/050 | Batch 100/469 | Cost: 14998.7012\n", "Epoch: 024/050 | Batch 150/469 | Cost: 14323.0811\n", "Epoch: 024/050 | Batch 200/469 | Cost: 13328.8086\n", "Epoch: 024/050 | Batch 250/469 | Cost: 15235.0488\n", "Epoch: 024/050 | Batch 300/469 | Cost: 14539.9482\n", "Epoch: 024/050 | Batch 350/469 | Cost: 13984.4404\n", "Epoch: 024/050 | Batch 400/469 | Cost: 14394.9082\n", "Epoch: 024/050 | Batch 450/469 | Cost: 14836.1758\n", "Time elapsed: 3.37 min\n", "Epoch: 025/050 | Batch 000/469 | Cost: 14210.6611\n", "Epoch: 025/050 | Batch 050/469 | Cost: 14331.7012\n", "Epoch: 025/050 | Batch 100/469 | Cost: 14440.1592\n", "Epoch: 025/050 | Batch 150/469 | Cost: 14585.4521\n", "Epoch: 025/050 | Batch 200/469 | Cost: 14941.8232\n", "Epoch: 025/050 | Batch 250/469 | Cost: 14408.6523\n", "Epoch: 025/050 | Batch 300/469 | Cost: 13879.6191\n", "Epoch: 025/050 | Batch 350/469 | Cost: 14163.3799\n", "Epoch: 025/050 | Batch 400/469 | Cost: 15489.8164\n", "Epoch: 025/050 | Batch 450/469 | Cost: 14584.5352\n", "Time elapsed: 3.51 min\n", "Epoch: 026/050 | Batch 000/469 | Cost: 14449.3213\n", "Epoch: 026/050 | Batch 050/469 | Cost: 14182.0420\n", "Epoch: 026/050 | Batch 100/469 | Cost: 14822.8936\n", "Epoch: 026/050 | Batch 150/469 | Cost: 15550.9629\n", "Epoch: 026/050 | Batch 200/469 | Cost: 14777.4414\n", "Epoch: 026/050 | Batch 250/469 | Cost: 14844.9375\n", "Epoch: 026/050 | Batch 300/469 | Cost: 14236.6016\n", "Epoch: 026/050 | Batch 350/469 | Cost: 14573.4326\n", "Epoch: 026/050 | Batch 400/469 | Cost: 14540.6592\n", "Epoch: 026/050 | Batch 450/469 | Cost: 15272.1367\n", "Time elapsed: 3.65 min\n", "Epoch: 027/050 | Batch 000/469 | Cost: 14737.4766\n", "Epoch: 027/050 | Batch 050/469 | Cost: 14636.1719\n", "Epoch: 027/050 | Batch 100/469 | Cost: 14763.8066\n", "Epoch: 027/050 | Batch 150/469 | Cost: 14228.8965\n", "Epoch: 027/050 | Batch 200/469 | Cost: 14508.6289\n", "Epoch: 027/050 | Batch 250/469 | Cost: 14433.5488\n", "Epoch: 027/050 | Batch 300/469 | Cost: 14199.0078\n", "Epoch: 027/050 | Batch 350/469 | Cost: 14910.3555\n", "Epoch: 027/050 | Batch 400/469 | Cost: 14825.3359\n", "Epoch: 027/050 | Batch 450/469 | Cost: 14556.9355\n", "Time elapsed: 3.79 min\n", "Epoch: 028/050 | Batch 000/469 | Cost: 14801.7754\n", "Epoch: 028/050 | Batch 050/469 | Cost: 14283.8076\n", "Epoch: 028/050 | Batch 100/469 | Cost: 14157.8916\n", "Epoch: 028/050 | Batch 150/469 | Cost: 14591.0586\n", "Epoch: 028/050 | Batch 200/469 | Cost: 14707.6934\n", "Epoch: 028/050 | Batch 250/469 | Cost: 14730.5000\n", "Epoch: 028/050 | Batch 300/469 | Cost: 14761.3613\n", "Epoch: 028/050 | Batch 350/469 | Cost: 15279.7812\n", "Epoch: 028/050 | Batch 400/469 | Cost: 14528.2744\n", "Epoch: 028/050 | Batch 450/469 | Cost: 14167.2188\n", "Time elapsed: 3.93 min\n", "Epoch: 029/050 | Batch 000/469 | Cost: 14382.7207\n", "Epoch: 029/050 | Batch 050/469 | Cost: 15143.0254\n", "Epoch: 029/050 | Batch 100/469 | Cost: 14207.4375\n", "Epoch: 029/050 | Batch 150/469 | Cost: 15312.8730\n", "Epoch: 029/050 | Batch 200/469 | Cost: 14714.6807\n", "Epoch: 029/050 | Batch 250/469 | Cost: 14761.9023\n", "Epoch: 029/050 | Batch 300/469 | Cost: 13909.5557\n", "Epoch: 029/050 | Batch 350/469 | Cost: 15295.2285\n", "Epoch: 029/050 | Batch 400/469 | Cost: 14590.0059\n", "Epoch: 029/050 | Batch 450/469 | Cost: 13771.6270\n", "Time elapsed: 4.07 min\n", "Epoch: 030/050 | Batch 000/469 | Cost: 14302.2412\n", "Epoch: 030/050 | Batch 050/469 | Cost: 14636.1582\n", "Epoch: 030/050 | Batch 100/469 | Cost: 14535.5391\n", "Epoch: 030/050 | Batch 150/469 | Cost: 14794.7129\n", "Epoch: 030/050 | Batch 200/469 | Cost: 14745.2432\n", "Epoch: 030/050 | Batch 250/469 | Cost: 14465.8652\n", "Epoch: 030/050 | Batch 300/469 | Cost: 14903.6123\n", "Epoch: 030/050 | Batch 350/469 | Cost: 14062.1025\n", "Epoch: 030/050 | Batch 400/469 | Cost: 14659.8281\n", "Epoch: 030/050 | Batch 450/469 | Cost: 14638.7471\n", "Time elapsed: 4.21 min\n", "Epoch: 031/050 | Batch 000/469 | Cost: 13900.5020\n", "Epoch: 031/050 | Batch 050/469 | Cost: 14276.7793\n", "Epoch: 031/050 | Batch 100/469 | Cost: 14385.0371\n", "Epoch: 031/050 | Batch 150/469 | Cost: 15063.9482\n", "Epoch: 031/050 | Batch 200/469 | Cost: 14061.3789\n", "Epoch: 031/050 | Batch 250/469 | Cost: 14794.1172\n", "Epoch: 031/050 | Batch 300/469 | Cost: 14461.4004\n", "Epoch: 031/050 | Batch 350/469 | Cost: 14760.6582\n", "Epoch: 031/050 | Batch 400/469 | Cost: 14211.6348\n", "Epoch: 031/050 | Batch 450/469 | Cost: 15117.2490\n", "Time elapsed: 4.36 min\n", "Epoch: 032/050 | Batch 000/469 | Cost: 14433.7568\n", "Epoch: 032/050 | Batch 050/469 | Cost: 14379.6641\n", "Epoch: 032/050 | Batch 100/469 | Cost: 14304.6387\n", "Epoch: 032/050 | Batch 150/469 | Cost: 13829.6826\n", "Epoch: 032/050 | Batch 200/469 | Cost: 14619.5654\n", "Epoch: 032/050 | Batch 250/469 | Cost: 14488.1992\n", "Epoch: 032/050 | Batch 300/469 | Cost: 14025.6309\n", "Epoch: 032/050 | Batch 350/469 | Cost: 14557.8555\n", "Epoch: 032/050 | Batch 400/469 | Cost: 14625.9219\n", "Epoch: 032/050 | Batch 450/469 | Cost: 14467.3330\n", "Time elapsed: 4.50 min\n", "Epoch: 033/050 | Batch 000/469 | Cost: 13708.5605\n", "Epoch: 033/050 | Batch 050/469 | Cost: 14030.7461\n", "Epoch: 033/050 | Batch 100/469 | Cost: 15058.2783\n", "Epoch: 033/050 | Batch 150/469 | Cost: 14089.2373\n", "Epoch: 033/050 | Batch 200/469 | Cost: 14830.2188\n", "Epoch: 033/050 | Batch 250/469 | Cost: 14473.9287\n", "Epoch: 033/050 | Batch 300/469 | Cost: 14349.3984\n", "Epoch: 033/050 | Batch 350/469 | Cost: 14528.9199\n", "Epoch: 033/050 | Batch 400/469 | Cost: 14033.7891\n", "Epoch: 033/050 | Batch 450/469 | Cost: 14026.8301\n", "Time elapsed: 4.64 min\n", "Epoch: 034/050 | Batch 000/469 | Cost: 15065.5000\n", "Epoch: 034/050 | Batch 050/469 | Cost: 14807.9961\n", "Epoch: 034/050 | Batch 100/469 | Cost: 14439.8008\n", "Epoch: 034/050 | Batch 150/469 | Cost: 14711.3418\n", "Epoch: 034/050 | Batch 200/469 | Cost: 14689.3828\n", "Epoch: 034/050 | Batch 250/469 | Cost: 13956.6719\n", "Epoch: 034/050 | Batch 300/469 | Cost: 14398.5410\n", "Epoch: 034/050 | Batch 350/469 | Cost: 14900.2051\n", "Epoch: 034/050 | Batch 400/469 | Cost: 14035.2871\n", "Epoch: 034/050 | Batch 450/469 | Cost: 14370.9922\n", "Time elapsed: 4.78 min\n", "Epoch: 035/050 | Batch 000/469 | Cost: 14394.5488\n", "Epoch: 035/050 | Batch 050/469 | Cost: 14367.2725\n", "Epoch: 035/050 | Batch 100/469 | Cost: 14434.9248\n", "Epoch: 035/050 | Batch 150/469 | Cost: 14409.7148\n", "Epoch: 035/050 | Batch 200/469 | Cost: 14353.3174\n", "Epoch: 035/050 | Batch 250/469 | Cost: 14548.2354\n", "Epoch: 035/050 | Batch 300/469 | Cost: 14818.1543\n", "Epoch: 035/050 | Batch 350/469 | Cost: 13898.6777\n", "Epoch: 035/050 | Batch 400/469 | Cost: 14176.9395\n", "Epoch: 035/050 | Batch 450/469 | Cost: 13999.2061\n", "Time elapsed: 4.91 min\n", "Epoch: 036/050 | Batch 000/469 | Cost: 14288.9336\n", "Epoch: 036/050 | Batch 050/469 | Cost: 14487.9365\n", "Epoch: 036/050 | Batch 100/469 | Cost: 14154.0234\n", "Epoch: 036/050 | Batch 150/469 | Cost: 14574.0762\n", "Epoch: 036/050 | Batch 200/469 | Cost: 14200.3008\n", "Epoch: 036/050 | Batch 250/469 | Cost: 14022.4297\n", "Epoch: 036/050 | Batch 300/469 | Cost: 14053.5713\n", "Epoch: 036/050 | Batch 350/469 | Cost: 14348.0186\n", "Epoch: 036/050 | Batch 400/469 | Cost: 14567.2314\n", "Epoch: 036/050 | Batch 450/469 | Cost: 14527.6348\n", "Time elapsed: 5.05 min\n", "Epoch: 037/050 | Batch 000/469 | Cost: 14948.3877\n", "Epoch: 037/050 | Batch 050/469 | Cost: 14357.0439\n", "Epoch: 037/050 | Batch 100/469 | Cost: 13578.9121\n", "Epoch: 037/050 | Batch 150/469 | Cost: 14657.7266\n", "Epoch: 037/050 | Batch 200/469 | Cost: 14293.0732\n", "Epoch: 037/050 | Batch 250/469 | Cost: 13609.5859\n", "Epoch: 037/050 | Batch 300/469 | Cost: 13738.5283\n", "Epoch: 037/050 | Batch 350/469 | Cost: 14079.2803\n", "Epoch: 037/050 | Batch 400/469 | Cost: 14029.6797\n", "Epoch: 037/050 | Batch 450/469 | Cost: 14522.6406\n", "Time elapsed: 5.18 min\n", "Epoch: 038/050 | Batch 000/469 | Cost: 14005.6035\n", "Epoch: 038/050 | Batch 050/469 | Cost: 13756.8330\n", "Epoch: 038/050 | Batch 100/469 | Cost: 15247.8760\n", "Epoch: 038/050 | Batch 150/469 | Cost: 14034.3789\n", "Epoch: 038/050 | Batch 200/469 | Cost: 14204.7061\n", "Epoch: 038/050 | Batch 250/469 | Cost: 14023.4863\n", "Epoch: 038/050 | Batch 300/469 | Cost: 13636.5508\n", "Epoch: 038/050 | Batch 350/469 | Cost: 14509.3711\n", "Epoch: 038/050 | Batch 400/469 | Cost: 14496.3965\n", "Epoch: 038/050 | Batch 450/469 | Cost: 14460.8896\n", "Time elapsed: 5.32 min\n", "Epoch: 039/050 | Batch 000/469 | Cost: 14317.6602\n", "Epoch: 039/050 | Batch 050/469 | Cost: 14440.6855\n", "Epoch: 039/050 | Batch 100/469 | Cost: 13772.3691\n", "Epoch: 039/050 | Batch 150/469 | Cost: 14023.2480\n", "Epoch: 039/050 | Batch 200/469 | Cost: 14576.5449\n", "Epoch: 039/050 | Batch 250/469 | Cost: 14164.7266\n", "Epoch: 039/050 | Batch 300/469 | Cost: 13657.8369\n", "Epoch: 039/050 | Batch 350/469 | Cost: 14456.4014\n", "Epoch: 039/050 | Batch 400/469 | Cost: 14202.3047\n", "Epoch: 039/050 | Batch 450/469 | Cost: 14564.9531\n", "Time elapsed: 5.45 min\n", "Epoch: 040/050 | Batch 000/469 | Cost: 14392.9277\n", "Epoch: 040/050 | Batch 050/469 | Cost: 13708.4375\n", "Epoch: 040/050 | Batch 100/469 | Cost: 14689.3535\n", "Epoch: 040/050 | Batch 150/469 | Cost: 13887.5840\n", "Epoch: 040/050 | Batch 200/469 | Cost: 14047.1543\n", "Epoch: 040/050 | Batch 250/469 | Cost: 14142.5859\n", "Epoch: 040/050 | Batch 300/469 | Cost: 14016.5820\n", "Epoch: 040/050 | Batch 350/469 | Cost: 14962.1387\n", "Epoch: 040/050 | Batch 400/469 | Cost: 14433.1416\n", "Epoch: 040/050 | Batch 450/469 | Cost: 14622.0762\n", "Time elapsed: 5.60 min\n", "Epoch: 041/050 | Batch 000/469 | Cost: 15024.6074\n", "Epoch: 041/050 | Batch 050/469 | Cost: 14015.1895\n", "Epoch: 041/050 | Batch 100/469 | Cost: 14236.3535\n", "Epoch: 041/050 | Batch 150/469 | Cost: 13553.2012\n", "Epoch: 041/050 | Batch 200/469 | Cost: 14393.0205\n", "Epoch: 041/050 | Batch 250/469 | Cost: 14220.9316\n", "Epoch: 041/050 | Batch 300/469 | Cost: 13906.4434\n", "Epoch: 041/050 | Batch 350/469 | Cost: 13650.9873\n", "Epoch: 041/050 | Batch 400/469 | Cost: 14031.2979\n", "Epoch: 041/050 | Batch 450/469 | Cost: 14202.7402\n", "Time elapsed: 5.74 min\n", "Epoch: 042/050 | Batch 000/469 | Cost: 13856.5684\n", "Epoch: 042/050 | Batch 050/469 | Cost: 14359.9023\n", "Epoch: 042/050 | Batch 100/469 | Cost: 14294.4902\n", "Epoch: 042/050 | Batch 150/469 | Cost: 14577.5811\n", "Epoch: 042/050 | Batch 200/469 | Cost: 14028.5820\n", "Epoch: 042/050 | Batch 250/469 | Cost: 13892.3926\n", "Epoch: 042/050 | Batch 300/469 | Cost: 13972.0322\n", "Epoch: 042/050 | Batch 350/469 | Cost: 14635.3506\n", "Epoch: 042/050 | Batch 400/469 | Cost: 13453.1562\n", "Epoch: 042/050 | Batch 450/469 | Cost: 14930.7197\n", "Time elapsed: 5.88 min\n", "Epoch: 043/050 | Batch 000/469 | Cost: 14080.6318\n", "Epoch: 043/050 | Batch 050/469 | Cost: 14356.2100\n", "Epoch: 043/050 | Batch 100/469 | Cost: 14747.7344\n", "Epoch: 043/050 | Batch 150/469 | Cost: 14025.0693\n", "Epoch: 043/050 | Batch 200/469 | Cost: 14294.0615\n", "Epoch: 043/050 | Batch 250/469 | Cost: 14147.0391\n", "Epoch: 043/050 | Batch 300/469 | Cost: 14254.3008\n", "Epoch: 043/050 | Batch 350/469 | Cost: 13503.6582\n", "Epoch: 043/050 | Batch 400/469 | Cost: 14689.1816\n", "Epoch: 043/050 | Batch 450/469 | Cost: 14308.2051\n", "Time elapsed: 6.02 min\n", "Epoch: 044/050 | Batch 000/469 | Cost: 13875.5928\n", "Epoch: 044/050 | Batch 050/469 | Cost: 14699.4229\n", "Epoch: 044/050 | Batch 100/469 | Cost: 14394.9424\n", "Epoch: 044/050 | Batch 150/469 | Cost: 14657.7197\n", "Epoch: 044/050 | Batch 200/469 | Cost: 14011.2949\n", "Epoch: 044/050 | Batch 250/469 | Cost: 13314.2246\n", "Epoch: 044/050 | Batch 300/469 | Cost: 14493.9434\n", "Epoch: 044/050 | Batch 350/469 | Cost: 13947.0000\n", "Epoch: 044/050 | Batch 400/469 | Cost: 14538.6055\n", "Epoch: 044/050 | Batch 450/469 | Cost: 13822.2129\n", "Time elapsed: 6.16 min\n", "Epoch: 045/050 | Batch 000/469 | Cost: 14430.2080\n", "Epoch: 045/050 | Batch 050/469 | Cost: 13560.6621\n", "Epoch: 045/050 | Batch 100/469 | Cost: 14101.0293\n", "Epoch: 045/050 | Batch 150/469 | Cost: 13972.5605\n", "Epoch: 045/050 | Batch 200/469 | Cost: 13934.4883\n", "Epoch: 045/050 | Batch 250/469 | Cost: 14146.7676\n", "Epoch: 045/050 | Batch 300/469 | Cost: 14229.7588\n", "Epoch: 045/050 | Batch 350/469 | Cost: 14473.1758\n", "Epoch: 045/050 | Batch 400/469 | Cost: 14182.4443\n", "Epoch: 045/050 | Batch 450/469 | Cost: 13847.8311\n", "Time elapsed: 6.30 min\n", "Epoch: 046/050 | Batch 000/469 | Cost: 13579.7725\n", "Epoch: 046/050 | Batch 050/469 | Cost: 14197.9629\n", "Epoch: 046/050 | Batch 100/469 | Cost: 14378.0156\n", "Epoch: 046/050 | Batch 150/469 | Cost: 13889.5391\n", "Epoch: 046/050 | Batch 200/469 | Cost: 14234.4473\n", "Epoch: 046/050 | Batch 250/469 | Cost: 14565.4922\n", "Epoch: 046/050 | Batch 300/469 | Cost: 14121.4434\n", "Epoch: 046/050 | Batch 350/469 | Cost: 13544.7070\n", "Epoch: 046/050 | Batch 400/469 | Cost: 13669.2461\n", "Epoch: 046/050 | Batch 450/469 | Cost: 14321.6992\n", "Time elapsed: 6.45 min\n", "Epoch: 047/050 | Batch 000/469 | Cost: 14563.6592\n", "Epoch: 047/050 | Batch 050/469 | Cost: 14157.8525\n", "Epoch: 047/050 | Batch 100/469 | Cost: 14169.9375\n", "Epoch: 047/050 | Batch 150/469 | Cost: 14047.9561\n", "Epoch: 047/050 | Batch 200/469 | Cost: 14237.7090\n", "Epoch: 047/050 | Batch 250/469 | Cost: 14265.3633\n", "Epoch: 047/050 | Batch 300/469 | Cost: 14120.1963\n", "Epoch: 047/050 | Batch 350/469 | Cost: 13613.9072\n", "Epoch: 047/050 | Batch 400/469 | Cost: 13844.0146\n", "Epoch: 047/050 | Batch 450/469 | Cost: 13815.9531\n", "Time elapsed: 6.59 min\n", "Epoch: 048/050 | Batch 000/469 | Cost: 14768.5332\n", "Epoch: 048/050 | Batch 050/469 | Cost: 13807.6055\n", "Epoch: 048/050 | Batch 100/469 | Cost: 14027.3555\n", "Epoch: 048/050 | Batch 150/469 | Cost: 14198.5234\n", "Epoch: 048/050 | Batch 200/469 | Cost: 14043.7871\n", "Epoch: 048/050 | Batch 250/469 | Cost: 14150.2158\n", "Epoch: 048/050 | Batch 300/469 | Cost: 14136.1113\n", "Epoch: 048/050 | Batch 350/469 | Cost: 13921.3516\n", "Epoch: 048/050 | Batch 400/469 | Cost: 14452.8145\n", "Epoch: 048/050 | Batch 450/469 | Cost: 13998.9541\n", "Time elapsed: 6.73 min\n", "Epoch: 049/050 | Batch 000/469 | Cost: 14730.7822\n", "Epoch: 049/050 | Batch 050/469 | Cost: 14744.3809\n", "Epoch: 049/050 | Batch 100/469 | Cost: 14377.9961\n", "Epoch: 049/050 | Batch 150/469 | Cost: 13894.9863\n", "Epoch: 049/050 | Batch 200/469 | Cost: 14319.2900\n", "Epoch: 049/050 | Batch 250/469 | Cost: 14335.9785\n", "Epoch: 049/050 | Batch 300/469 | Cost: 14045.4326\n", "Epoch: 049/050 | Batch 350/469 | Cost: 14342.3359\n", "Epoch: 049/050 | Batch 400/469 | Cost: 13990.9199\n", "Epoch: 049/050 | Batch 450/469 | Cost: 13979.7559\n", "Time elapsed: 6.87 min\n", "Epoch: 050/050 | Batch 000/469 | Cost: 13741.7539\n", "Epoch: 050/050 | Batch 050/469 | Cost: 14258.0557\n", "Epoch: 050/050 | Batch 100/469 | Cost: 14187.6738\n", "Epoch: 050/050 | Batch 150/469 | Cost: 14332.6895\n", "Epoch: 050/050 | Batch 200/469 | Cost: 14304.8984\n", "Epoch: 050/050 | Batch 250/469 | Cost: 13983.0000\n", "Epoch: 050/050 | Batch 300/469 | Cost: 14277.3750\n", "Epoch: 050/050 | Batch 350/469 | Cost: 13838.4023\n", "Epoch: 050/050 | Batch 400/469 | Cost: 13978.5732\n", "Epoch: 050/050 | Batch 450/469 | Cost: 13924.4717\n", "Time elapsed: 7.01 min\n", "Total Training Time: 7.01 min\n" ] } ], "source": [ "start_time = time.time()\n", "\n", "for epoch in range(num_epochs):\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", " \n", " # don't need labels, only the images (features)\n", " features = features.to(device)\n", "\n", " ### FORWARD AND BACK PROP\n", " z_mean, z_log_var, encoded, decoded = model(features)\n", "\n", " # cost = reconstruction loss + Kullback-Leibler divergence\n", " kl_divergence = (0.5 * (z_mean**2 + \n", " torch.exp(z_log_var) - z_log_var - 1)).sum()\n", " pixelwise_bce = F.binary_cross_entropy(decoded, features, reduction='sum')\n", " cost = kl_divergence + pixelwise_bce\n", " \n", " optimizer.zero_grad()\n", " cost.backward()\n", " \n", " ### UPDATE MODEL PARAMETERS\n", " optimizer.step()\n", " \n", " ### LOGGING\n", " if not batch_idx % 50:\n", " print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n", " %(epoch+1, num_epochs, batch_idx, \n", " len(train_loader), cost))\n", " \n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Evaluation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Reconstruction" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "##########################\n", "### VISUALIZATION\n", "##########################\n", "\n", "n_images = 15\n", "image_width = 28\n", "\n", "fig, axes = plt.subplots(nrows=2, ncols=n_images, \n", " sharex=True, sharey=True, figsize=(20, 2.5))\n", "orig_images = features[:n_images]\n", "decoded_images = decoded[:n_images]\n", "\n", "for i in range(n_images):\n", " for ax, img in zip(axes, [orig_images, decoded_images]):\n", " ax[i].imshow(img[i].detach().to(torch.device('cpu')).reshape((image_width, image_width)), cmap='binary')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate new images" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlMAAABSCAYAAABwglFkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJztnXmUVNXVt59LN900DdphEmgbWkFEEUFRIw5RNFFB45TPOITEZVQ0miw1k5poYmJi1GWMmtdoTKIhmqif8ZXPIc5DHHFAFByYRAURGaRBhp7rfn9cfrtuFzPVXXWr2M9avaq7uoaz7xnuOb+9zz5BGIY4juM4juM4W0enfBfAcRzHcRynkPHJlOM4juM4Thb4ZMpxHMdxHCcLfDLlOI7jOI6TBT6ZchzHcRzHyQKfTDmO4ziO42SBT6Ycx3Ecx3GyIKvJVBAERwVBMDMIgjlBEFzSXoVyHMdxHMcpFIKtTdoZBEEJMAv4GvAJ8DpwahiG77Vf8RzHcRzHcZJNNsrUfsCcMAznhmHYBNwDHNc+xXIcx3EcxykMSrN4bzUwP/b3J8CXM18UBMEEYAJAZWXlqKFDh2bxlfnjo48+YunSpcH6/lfsNhaLfQBTpkxZGoZh78zni8XGbbmdQvHbWCz2gfdF3MaCYGM2tiEMw636Af4P8NfY398G/mdj7xk1alRYqKwt+yavS7HbWMj2hWEYAm+ERWyjt9Ntx8ZCti8MvS+GbmNBsLk2ZuPmWwDUxP7ece1zjuM4juM42wzZTKZeB3YJgmCnIAjKgFOAB9unWI7jOI7jOIXBVsdMhWHYEgTB94HHgRLg9jAM3223kjmO4ziOs9mkUikAmpqaaGhoALDHNWvWsGBB5DxavXq1PQcwaNAgunfvDkBtbS0AnTp5GsotIZsAdMIw/A/wn3Yqi+M4juM4TsGR1WTKcdqLpqYmAGbNmgXAwIEDbdW0ePFiAGbMmAFAVVUV3bp1A6BHjx5AtLIqLc1Pcw7X5moLw5CWlhYgbY8oLS2ltbUVSK8KGxsbWblyJQBf+tKXANhuu+0AqKys7PiCO+1CGIYEwaY3+ySZVatWATBt2jSWLFkCwG677QZAv379TLVwkkMYhixatAiA999/H4B3342cQ0uXLuW1114DYMqUKQA0Nzfb2JM5PvXo0YPDDjsMgNNPPx2AcePGuTq1BfiVchzHcRzHyQJXpooAKR7Lly8HopXIp59+CsCHH34IRMrN4YcfDkQrTUiOTzwMQx5//HEgUmsgsqVz585AtKICmDlzJgBPP/20/U+rqR/84Ac5VaZSqRQfffQRAB9//DEAc+fONRVt9uzZQDomYciQIaa6qeyLFy+mvLwcwMq+5557AjB69GgOPPBAAFPhCpHm5ma7JmVlZQB06dIFiOxKuqKTqSZ+/vnnfP7550C6bj/99FOzY8iQIQBUVFQAMHjwYEpKSgASZavsevPNNwF45plnAHj11VetTeo1O+20ExdffDGQVoK3BaQ4J6neAGt/jz/+uI0pN954I5COj4J0/anfVVRU2NgjNLYuX76cadOmAVGbheTcH2RHfX09AF988QUAy5YtsxgxjSlr1qyhT58+APTuHaUw09ja0fVYcJMp3Wx1QcvKyqzR62LX1dWxbNmyNu9TUF3cRZRUZIeCBVetWmUyvBp/S0uLSbxyFemaTJkyhRdffLHNZ+24445sv/32ABxxxBFAesDPN01NTTb5GzFiBAArVqyw8mqSKKZMmWI3KE0QczWRUlurq6vj4YcfBtKD21133UVdXR2QHtR0jf/73//aZ8QHPLkFVa+6qT377LNceOGFABx99NFAelAsBJYuXQrA7bffbu1Tg9nOO+8MRBNH9UvVteo1n7S2tprb5IUXXgDSdTx37lxzN0+fPh2IJscqvxYqmhTvt99+9O3b136HyKWbTzvDMLSyaxHz1FNPAdENVC7nd955B4AlS5bY604++WR7XTGgm/E777zDgw9Gm9F181Zf79Gjh/VjucDy0Rd179NYMmvWLK677jogPc4rPKCkpIQLLrgAaDuJkE133HEHgN0nKyoqOOmkk4B1XYD5pKGhgUcffRSAt956C4DHHnsMiMqpcXPevHlAZL/627HHHgvAwQcfDET3ll69egEdM7Eqjh7hOI7jOI6TJwpGmdIMeuLEiQDcf//9QDq4DtKrjFQqZVK1ZvNizJgx3HDDDQDsvvvuHVvoLSAMQ5599lkgvRq+++67gUihuuyyy4C0AtO1a1dbQXbt2hVIB2i/9dZb62x9/eKLLxgwYACQlkSTwurVq2019MknnwBQXV1tq47PPvsMwAIqGxoarH616sj1SrG1tdVUFbXN5uZmUxyqqqqAtFuke/fuprBJqenWrZsF+0p+j7sA5UaUaqfvSyqpVMpcBZdccgkAL730kl0DrRj33ntvAObPn091dTUANTVR/t+RI0eaCqDHXKs4kyZNMmXqmmuuAdIB2qWlpaYmik6dOtnYI/T6qVOnctRRRwFpl9phhx3GTjvtBMAOO+xgn5ErmpqaePLJJwH4y1/+AqQVta985Su2apeS1tjYyNtvvw2k+9uOO+6Ys/K2F7oXLFu2jOeeew5Ihw5cddVV9jrVr1ScsrIyG1+kWp199tk5b5dqIxojunTpwoQJE4D0OCMXV9euXU2FUdmnT59uNkn1lso1cuRIuy9IOU4CkyZNsj743nvvAel6KSkpsfpQndXX15uKLGVVat2IESP42c9+BqTHILkC2wNXphzHcRzHcbKgYJQprVK1qtdsu7S01IIjtVIIw5DbbrsNSCszet/06dNZsWJF7gq+mUybNs1WCTfddBOQXoEcf/zxNpOW4lFaWsqcOXOAtI9bqlV8q7a22v/oRz9i0KBBQPICKufNm2dKk9SoRYsWWTm1+pBi19LSwi9+8QuAnNuk7+nduzcHHHAAkF4B7bHHHha7sNdeewHp679w4UJToaQazp4921ZN/fv3t88FGDt2rLVnxeMkFa34n3zySe655x4gveKtqqqyFfKYMWOAdLtuaWkxJVKrycbGRrsWUm+kvJaVlXVoPcfVpOuvvx5It8d4LEqm2l1ZWcn5558PrNsOFy1aZCtqteNevXrZ66QoqP3ngtbWVlvdK85y1KhRQFR+qYUqL8AHH3wApD0BqqMkx06pnhRjo40idXV1/O53vwPaxgfpmmgcjauNcZUKomuTa2VK3ycFqV+/fja+qE9pA0RVVZWVX+/r1q2btWc96n0DBgxgl112sdflGylNdXV11vb0nGyuqKiwelEcahAE62wcUN+aNWsWkydPBtIbs84666x2i7ctuMmUBi0Fy82bN4+ePXsC6Qs6b948C8yTq+zVV18FogE6Pkjkm4ULFwIwZ84crr32WiDdcb/85S8DcOaZZ7LrrrsC2I2pubnZ3D5///vfAXj99deBKHhbDUnX6ZBDDklcrhgNdtOmTbNBTZ27vr7eJsK6yclN1q1bN5tgyE2Sa4IgsHanNllVVWUuP3Vg2VBdXW03IA1yAwYMYODAgW0+VztpwjC0+srljXZLkB0KEL3uuussz40GqN122812JWYSz8ulvjtnzhybdKre5Z4eMGBAh06mdKN877337GYjG/W9ZWVl67gtjzrqKJtsyV0p13TPnj154403gLSb+sMPP+TSSy8F0u03l1mn6+vr7Xviuw4h6luZOdw+/fRT+7/qVxs/knDjXR+pVIrnn38eiHb/Avz5z38G2k6gNN6UlZWZLZpgxt256ot6bz4mkWqD++yzDxAt4hQCoLFF7a+xsdEWIdoU07NnTyv//PnzATjmmGOA6L6oBWAS0ASwsbHRxhLVlf4uLy83d7Oe69u3rwkOuhYKn5k7d66F+Pz85z8Hogm2FuTZktxlheM4juM4TgFQMMqUkBqgx6FDh66TdXrQoEE88sgjQNptoplu3759ExVgN3XqVCByw2l1vv/++wPw3e9+F4iygWuVqNVJp06dLNhXipS2Ozc2NtpKd/z48UAUbJ80957SAHzwwQemNEnCLikpsQDsX/7yl23e98tf/pJzzz03dwXdALqeWtEecMAB5rZSXYgwDC2IV7J1XJmSQiEVqr6+3oKCM3PDJAVteJAb5eWXX7Z+KSl+6NCh1q4z897Eg+y1mhw5cqRdT6kB+qyORuVcsmTJOtK/yjJ48GBTBvTcfvvtZ3Wl98lF+cQTT5iy9sorrwCRaqWt3OrrueybJSUlpqyo3anf7bvvvqZyqE/W1NRw+eWXA1H9ANx3330AnHHGGTkr95aiulDbks1BEFib0iaegQMHWjC62mc80FnpZJTXLh+nLaiNSE3s37+/ha8odEXjR9euXW0MktIWBAFz584F0m1XjwMGDDAlK0lUVlZaG1WqII0Pe++9N3vssQeQtqOmpsYUY10v/W/SpEnm3ZAbv6amxvpntuOsK1OO4ziO4zhZUHDKVCZBEJjqpNXCBx98YLEr//znP4F0fM6IESNs9ZxPFEsjlWLBggW2WtIqKB5Yrtdr9fzJJ59YEjOpcAp+7tOnj20BHT58OJAsdUM2KDB35syZlphSyR7XrFnD1VdfDaRjF7SKHzlyZKLs0Yq3urraYhEUC6dUHmvWrLFVo1aWTU1NthrUYzwQX78nyVZIxxEp0aNi9sIwtLLGt2Vnrp7V/6qqqiyuQQrdsGHD7Ew4PZercwpVvvnz55uKprFFjzvssIPVlVbFe+yxh8VYSc1Q34wHMeva1NbWWgySlPNcKlOVlZWmzkt9UuDy0KFDrSwaf9577z3L6C81UjEnqVQqkUHonTp1MsVBSpPaWn19PSeccAKQjgubMmWKtVONo6q7IUOG8PWvfx1IK475RG0zlUpZ4mYF12v8eeWVV0wl170vlUpZG/ztb38LpAPR6+vr7XOTUJ/xzObasKN2qfv8sGHDTHWSYlpbW2tKlvqixtGlS5eaMqc6nj9//jrxultL/q+a4ziO4zhOAVPwylQcqQGTJ0/miiuuANIzUKkap59+eiJ2R2XGA7W0tKyzOlVcxdSpU201rC3k8+bNs/QAmlnLrosuusgUqVzFm2wOWilo+7xUw5NPPtlWBVodNjU1mV1SJo4//niARO06gfSKqbKycp1dNYpXgLTCFj/vTe3goYceArDkrJ06dbIVYqZKkm+04lUMkFa+Xbp0MbuluKxcudJ2HGlFqTiVeNoEKSNHHnlkzneI6fr+5z//ASIVMTO1iMrUu3dv+5/ipBYsWGB9TzsRFX/T0tJiipdSDowZM8bacD4S6AZBYH1KKqHq4eOPP7bn1DZnzpxpdsl2qQX19fU5Uw63FMWNnnnmmW2e79y5sykfasNPPfWUtWONU1LtJk6caKkj8qkSa/yQ16W1tdXKoz6lx0GDBtm9Jb57Xb/Lc3HllVcCUSyxrlcSlPB4gmPdE3QPVLLtIAjsvq76KS8vtzaqelQM1aBBg8wrorpuaGhYJwnv1lLwk6kwDE161llpl156qXUWNS7duDXY5xsNzipfXV2dTfyUnVjuvqlTp5o9avxxl4oelT5h2LBhljMkSUHnd955J5DO8XHiiScC0cRBDVrX5bHHHrMbzQ9/+EMgnSpCAaVJo6KiwupA9fWvf/0LgPPOO89sU+deunSpDYyaOP76178GIvdDZnbsoUOH5sKMjRKGoW140OYJTSS6dOliN12xYMECs0OTiXiQup6TGyUfCx0NsHKZxDPZq13q71WrVplrSBtA+vfvb4c5y83y0ksvAZHrUPW37777ApFbUJmq89E/S0pKzF0ll6pypfXv399CDzRpfvrpp2080XN6nDFjhi3cknZ2ZOZpBIceeigQ5au75ZZbgPTiLX5epiaH2lI/fPjwvAScbwjZ1b17d8t1l3n+5WeffcYf/vAHIO3K+8c//mETZE2Gf/WrXwFwxRVX2GfIBZzPE0IUbH7PPffYpEjo/NbRo0db/sX1nTObuUGovLy8TcoFiCbM7dVu3c3nOI7jOI6TBcmZbm8hmq1+/PHH/PWvfwWwx1QqZasRnQSehMDBOFqZKnXB1VdfbSsIrfi1tbWhoWEd6bKsrGydrawHHXQQEMmZ65up55Np06bZivf2228H4JxzzgEi5UUJGhWAPnfuXLseQgnakuLuyiTuPtGqTqrMvffea8pjPMBXAZFCitbf/vY3awdyN+y66655Vxqbm5vNnS4pXuUrLy+3Fa/qsbq62tSmzGDr7t27M3r0aHtvvlC9SIVavXr1Okk7ZevixYttpaug7Lq6OnPlSWnUKr+0tNQ+Q+fa7bzzznl1pQRBYKpvPE0FwNtvv20qqRKPQvraKOBXbukBAwaYkpE0ZUpIVVLfvO+++0yFFEEQmJdAm0aGDRvW5v35JJVKWR3IJRs/my7z/rDDDjuYK1mvGTt2rHlv4kHsADfccIN9vsblCy+8MOdeAJVLp3vMnDnTyqXxQxs/Ro0atVnpHHRtmpubrZ+q/dfU1LTb2OPKlOM4juM4Thbkf8q9lcS31yvuQqvi7bbbzvzde+65J5Cs2CFIz4y/973vAVGZlRhPqwvFWnTp0sVm5ZpFt7S02ApSQc5aWcWDZPOFVhO33norEPmo//jHPwLpFaxW783Nzfa76Ny5s62C9VlSr5JMPBgd0kGT1dXV9j/Z09LSYgqOjjvSFufhw4db7JyCK+vr6/OeWG/BggWW0kL1qJVyU1OTrfzUFvv3728rfBEP8G3PU9u3FsX/3HvvvUDUHjMTJKotNjQ0WNyJVs/QNjEipFfYTU1Ndn3ibSLf6qqUNo2ZCr7v1q2bxaEoHqe+vt6UOsWC6ezFhx9+mG984xu5K/hWIGVJ8Xx77rkns2bNAtIpasrLyy1dgBTEJChtakfxmC6Vq1OnThtsR+s7Zqtv37620UNxq2+++SYQpbpQbOrvf/97IDq2Zty4cfZduSCukEJUP7JXGwmkTGmM2RAaZ6RCPvTQQ3bPjMedtZdtBTuZ0mQk7iLSTqorr7zSgnWTkDNjYygYecKECVZ+7eKLu0N0fpBk2VtuucXOntJAp9fk+7yshoYGOwdLg9XNN99sDVhB2nFXpGyVfTNmzLAgQdVhZnBzISC74oNefOOAgpk16ZesvnLlSnPzPvDAA0C0i1HBvvmaLE+cONEWKkKD1rBhw6yu4hntdeNW/alv1tbWJuKGpbLGz2zTokX1Ec/QrrP21O+6dOli41HmIaslJSVmo+q6qqoq74sdlUljjW7a48ePN1u14WPVqlV2wLEmWI8//jgQuWHk9tXuxqSha606ii9I9L+amhrbEJOEg8XVFnVtm5ub7V6RDeqfcs1KiHjggQdsY4jcuzNnzuSQQw4ByNm5rppM6XSMlpYWW4yozJmnmmwI3T8UXjJ16lTrs3KBamNFe5DsmYbjOI7jOE7CKVhlSiuKgw8+mMmTJwNpGXDkyJF5yd+SDd27d+fYY48F0u4vrQaCIFhnC/0JJ5zAo48+CqTVLa0y8h0wOXPmTDupXdvoR48ebSsKrWBVh0EQ2IpJLqTM4HNIZ8QuVDK3+MbTW6i9amv9559/btfktddeA6KAZwW257qO4xm+5eZSmdXuhgwZso7iEoahBb1Kbtd2/F122SXvCg20VbkhuraqF20UUdurqKgwN5jqoHv37m3y4kDaXQtp1Uffk4Q8PlKAlan9O9/5DhApNErJorqpr683l6aUg/jZdklw1W4OyqM0atQo2wSj63DaaaeZApJPNEZovI8H+gu58CoqKrZ6HMjMJt65c2e7fyoc45prrjG1LlfKlMIE4mebxhVdSKvGGxo7Mq+h3Pjz58+3/x111FFAun+3B65MOY7jOI7jZMEmp7VBENQA/wB2AELgtjAMbwyCoAdwL1ALfAR8MwzDuo4r6vopLS3lhhtuANLJHXVOT6GhANX1ZRSWn18r3oqKCpuZy/+rwLx8M2PGDAtwVBkHDx5sMVLxAEqI6lBB2cpIvHLlSlsNyTeulWWSWZ/6BJGyo9+1GlyxYoUpGYpd0ePzzz9vKoCu4YIFC9qc15hLpEykUikLhlWbVEqO+EpRtk6cONFipRSTcdhhhwHJiE2BdHs844wzgCjmUu1R11ur2xUrVtgpBOqnq1evNrVJaRbi54JJOdZnrVmzJmcr/Q2htqVyqjwHHnigKWiqz/Lyckuqes011wDp9nfAAQfkPUZzc4nXSTw+DtqeOJBPVEYpNNdffz0QtZnf/OY3QDpFzNChQ61tbUqt2RB6/cCBA3n33XcB2qQFUTvJRWqh1tZWG/80xsTtUfyt+ub6zoVMpVIWQy177rjjDntfZuLk9kzJsjmtpwX4URiGuwP7A+cHQbA7cAnwdBiGuwBPr/3bcRzHcRxnm2KTy9swDBcCC9f+vjIIgveBauA44NC1L5sIPAdc3CGl3AhXX321zVilTBUjirvQKl+75QDOOussgHbZ7dEe9O3b11QLbf1ftWqVKVNabcTPnlPcm/zcpaWl/PjHPwawrddJ2Pm1PqQuzZkzx+LBpMzV1UVibVVV1TpKU79+/SxGTEeT6NzCZcuWWTxHfLdZvrbUq/0pMS6kd9XE4xO1sn7kkUeAaGeQynzdddcB2A6hpCBVRnEp8Xam3VTxFXJm+43HTOk66X9VVVW2C0ntP98xjZAeK5S2Qmp+XPEWqVTKkstKGY8n8cx3uo7NRbvULr/88nV2XSodQlLITDGycOFCfvKTnwBwwQUXAFFiY4032g2s+MX4GXUbQzGtr776qu0Ol7LTp0+fvKnHsmvRokXWp6QI66zPxsZGi91TfS5fvtzSzCiNh+ICU6mUnVf47W9/G2jfXdFb1KuDIKgF9gJeBXZYO9EC+IzIDZgzdGFvvfVWTjnlFCD/KQE6ijAMbcIo6XLq1Kk2oCkgNin07NmTY445BsDOh1q0aJG5eeI5UyCSlXUWk2TXsWPHWqoIpXxIGrqB3njjjUCUq0cTJrl7JJmvXr3aJkcaHDp37twmsDv+WF5ezuDBg4HoIGiItq3ne0K53Xbb2QQ5s00OGjSIxx57DEgf9Nva2srFF0drrMzA5qQgF50G8AsuuMDOVNSGgLg7QeXX+1asWGFtWhNHTZxGjhxp7VjjUxIOH1dfVN+Kb3BRG4tnQldAsDaWaALao0ePREwON4b6lDLWL1myxBYA2gyRtBMyNEHVONrc3GwusJtuugmIJv0qv/qWXIDdu3dnv/32A9IB3EEQWF4xTYoVVvHCCy/w3HPP2esAxo0bt8lcTu1JSUmJTepPOOEEAG677Taze9KkSUB6UlxRUWG2ya5ly5ZZ2I82vOg+ecQRR9hCriMO595sJ3EQBN2A+4ELwzD8Iv6/MGqt4QbeNyEIgjeCIHijEPMEbQ7FbmOx2wduY7FQ7DYWu33gNhYL24KNcTZrSREEQWeiidQ/wzD837VPLwqCoF8YhguDIOgHLF7fe8MwvA24DWCfffZZ74RrS9AsU4kD6+rqmDBhQrYfmxXtbWMmy5cvtyR5mpWvWrXKVocdHcy6pfb17NnTVjRaAU+fPt1cWZmB5StXrrSVQnyL9kknnQTkxr23pTaGYWgZs+NBolKdMgPR10djY6PZpkdlKD722GPN9aAVZmVlZVaqTjbtVC6w0047jT/96U9A2h3wzjvvAFEKB2UCl8J40UUX8a1vfQtIr5A7kq2xUdd0zJgxQHSdpSbJpXD33XcDkaqhZIB6LCkpsfYrG6UUjBw50hQvJcHMJjVCe4012mYuF/r9998PRK4Q1bWU1w8//NCUKbmtdXLDKaec0u6u5/YeT+XF0A29a9euVl9KR5Oplnc0G7JRCqjqYJ999gEixUn3PKm/q1atMmXp5ZdfBtq6oOXKPfTQQ4FIfdO4JJX8yiuvtL/1Xrl+TzzxxKzSXmxNPSqhr7xNc+bM4cUXXwQi9QywpLn9+/e36yUlfMWKFdYvxYgRIwC47LLLOtSLs0llKoiu8N+A98MwvD72rweB09f+fjrw/9q/eI7jOI7jOMlmc5SpA4FvA9ODIHhr7XM/A64G/m8QBGcCHwPf7JgitkVBcoojOeiggyyRYbGhIObJkydbSnypIC0tLVx11VVA7hKqbS59+vRh/PjxQNqXfeedd5oPWzE3Wh316tXLYgOU3v+b3/xm3uODNsbSpUtt5ackeitXrrTAURFXnrSKkvp06qmn2tE6UvAUT1NdXZ0o+1Wu4447zsqlTRBSKyorK+1aSGE855xz2jUxXkei2J/Ro0fbClYrf/399ttvr6PEdOnSxdqClCldo2HDhtlzSUokLBukQkhdvPnmm01BU/985plnTBE/99xzATj77LOBSEFOOgqyvvbaa4Gon2amRkhCItU46m9jx44FIsVFY4UeFy1aZEHWuj8oBrO+vt7uH3qME0+YDFHf1Yaf888/H4jG5VzHN6oPKt3PT3/6U7snKPmv4vtmz57dJt0IRPdF9bPTTjsNiNRxgN13371D019szm6+F4ENXdHD27c4G0Y7ptSZJVeOHz8+p0FyuUDyus5ke/LJJ03GVaOpra1l//33B5IX0FtSUmKDgW5CAwcOtI4ePyAWIklabgNlIU5CzpeNsf3221tWbO0M0c0H2gZ9Ck0w4rsakx68m8muu+5qfTEzAH3YsGF2dptuyLlw7bU3QRDYgKyAVdXt4YcfboN5fMOLJk+ZOZo6d+6cuP4ZR+XVpGjMmDHccsstQHqC0dzczKmnngrAV7/6VSB9TlqS0XijcwU16W9tbV2nnhS4nTQ0PvTs2ZMDDzwQSI+pQRC02eASp7Gx0Q6vjoceZI6v+t+RRx5pfTUJY5La3qhRo6yf6fGJJ54AopAX9UtdkzFjxphrVBMyjdMdfU9J9h3LcRzHcRwn4eR/CrqZKPhOcqYCKJMknWdL5nlMOgNNAYeQnp2ff/75DB8+PLcF3AK0Cvja175mz2UGP4p85U7KhrKyMtvGqwDHeBBn0pW1raWiosIC4mXr6NGjgUiZSsr5kO2F6lGP22+/fWIyt7cHqidlse/du7eFDWgTQRiGHH545ISQIlUI7Vs53P79738DaRWmtbV1nTEoSS71DaFxcn0pgHQ/jPP973+/w8vU0VRUVNh9TvYDOHrTAAAEVklEQVTH3ZzySum52tpaU99y7bpNfo9wHMdxHMdJMAWxfFy8eLEl7lKwubb09urVa51stoWKVryKhVJCvYceesiSymnL9XnnnZe4oMlNofopRCVqYxSrXRtCioXUDKfwUZLIvfbay5RWtetUKlWQbVs2SdlQrGZzczNHH300kA7wLvR7RzGjtqd61Bm0ra2t6yjH+ST/JXAcx3EcxylgEq1MSXGqr6+32Ki77roLSM9Wa2trEzErbQ8yfcJ6nD17tilz2nVTaKqU4ziFQeZ4WoiqFKSPB7r88ssB7Fy2hoYGU62008spHKQiJi0uM1mlyUCTqZqaGs4880xg25RjS0tLC2IrsuM4TlJQoLbyujlOR1Icko7jOI7jOE6eCDbnDLF2+7IgWAKsBpbm7Eu3nl60LefAMAw3eVBREAQrgZkdVqr2ZYttLPA6hOK3cXPb6bZgo/fF5OB9cQNsIzYWdV+EHE+mAIIgeCMMw31y+qVbwdaWs1Dsg+K3MZtyuo3JodjbKRS/jd5OO+69uaTY2ylsfVndzec4juM4jpMFPplyHMdxHMfJgnxMpm7Lw3duDVtbzkKxD4rfxmzK6TYmh2Jvp1D8Nno77bj35pJib6ewlWXNecyU4ziO4zhOMeFuPsdxHMdxnCzI2WQqCIKjgiCYGQTBnCAILsnV926KIAhqgiB4NgiC94IgeDcIggvWPn9FEAQLgiB4a+3PuM34LLcxT7SXjUm1D4rfRm+nbmPG5xS1fWvf4zbmifa0EYiyjHf0D1ACfADsDJQBbwO75+K7N6Ns/YC91/7eHZgF7A5cAfzYbdx2bEyyfduCjd5O3cZtxT63sXhs1E+ulKn9gDlhGM4Nw7AJuAc4LkffvVHCMFwYhuGba39fCbwPVG/FR7mNeaSdbEysfVD8Nno73SKK3cZitw/cxrzSjjYCuXPzVQPzY39/QhaF7iiCIKgF9gJeXfvU94MgmBYEwe1BEHxpE293GxNCFjYWhH1Q/DZ6O93mbSx2+8BtTAxZ2gh4ALoRBEE34H7gwjAMvwBuAQYBI4GFwO/zWLx2wW10GwuBYrcP3EaKwMZitw/cRrbAxlxNphYANbG/d1z7XCIIgqAz0cX8ZxiG/wsQhuGiMAxbwzBMAX8hkis3htuYZ9rBxkTbB8Vvo7dTt3EtxW4fuI15p51sBHI3mXod2CUIgp2CICgDTgEezNF3b5QgCALgb8D7YRheH3u+X+xlJwDvbOKj3MY80k42JtY+KH4bvZ0abmPx2wduY15pRxsjtjRifWt/gHFE0fIfAD/P1fduRrkOAkJgGvDW2p9xwJ3A9LXPPwj0cxuL38ak2rct2Ojt1G3cluxzG4vHxjAMPQO64ziO4zhONngAuuM4juM4Thb4ZMpxHMdxHCcLfDLlOI7jOI6TBT6ZchzHcRzHyQKfTDmO4ziO42SBT6Ycx3Ecx3GywCdTjuM4juM4WeCTKcdxHMdxnCz4/7A7Q7JA7sAiAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for i in range(10):\n", "\n", " ##########################\n", " ### RANDOM SAMPLE\n", " ########################## \n", " \n", " n_images = 10\n", " rand_features = torch.randn(n_images, num_latent).to(device)\n", " new_images = model.decoder(rand_features)\n", "\n", " ##########################\n", " ### VISUALIZATION\n", " ##########################\n", "\n", " image_width = 28\n", "\n", " fig, axes = plt.subplots(nrows=1, ncols=n_images, figsize=(10, 2.5), sharey=True)\n", " decoded_images = new_images[:n_images]\n", "\n", " for ax, img in zip(axes, decoded_images):\n", " ax.imshow(img.detach().to(torch.device('cpu')).reshape((image_width, image_width)), cmap='binary')\n", " \n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "numpy 1.15.4\n", "torch 1.0.0\n", "\n" ] } ], "source": [ "%watermark -iv" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.1" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 2 }