{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.6.1\n", "\n", "torch 1.2.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Zoo -- Generative Adversarial Networks (GAN)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation of a standard GAN." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "import torch\n", "import torch.nn.functional as F\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Settings and Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image batch dimensions: torch.Size([128, 1, 28, 28])\n", "Image label dimensions: torch.Size([128])\n" ] } ], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Hyperparameters\n", "random_seed = 123\n", "generator_learning_rate = 0.001\n", "discriminator_learning_rate = 0.001\n", "NUM_EPOCHS = 100\n", "BATCH_SIZE = 128\n", "LATENT_DIM = 75\n", "IMG_SHAPE = (1, 28, 28)\n", "IMG_SIZE = 1\n", "for x in IMG_SHAPE:\n", " IMG_SIZE *= x\n", "\n", "\n", "\n", "##########################\n", "### MNIST DATASET\n", "##########################\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "train_dataset = datasets.MNIST(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=BATCH_SIZE, \n", " shuffle=True)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=BATCH_SIZE, \n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "\n", "class GAN(torch.nn.Module):\n", "\n", " def __init__(self):\n", " super(GAN, self).__init__()\n", " \n", " \n", " self.generator = nn.Sequential(\n", " nn.Linear(LATENT_DIM, 128),\n", " nn.LeakyReLU(inplace=True),\n", " nn.Dropout(p=0.5),\n", " nn.Linear(128, IMG_SIZE),\n", " nn.Tanh()\n", " )\n", " \n", " self.discriminator = nn.Sequential(\n", " nn.Linear(IMG_SIZE, 128),\n", " nn.LeakyReLU(inplace=True),\n", " nn.Dropout(p=0.5),\n", " nn.Linear(128, 1),\n", " nn.Sigmoid()\n", " )\n", "\n", " \n", " def generator_forward(self, z):\n", " img = self.generator(z)\n", " return img\n", " \n", " def discriminator_forward(self, img):\n", " pred = model.discriminator(img)\n", " return pred.view(-1)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(random_seed)\n", "\n", "model = GAN()\n", "model = model.to(device)\n", "\n", "optim_gener = torch.optim.Adam(model.generator.parameters(), lr=generator_learning_rate)\n", "optim_discr = torch.optim.Adam(model.discriminator.parameters(), lr=discriminator_learning_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/100 | Batch 000/469 | Gen/Dis Loss: 0.6576/0.7134\n", "Epoch: 001/100 | Batch 100/469 | Gen/Dis Loss: 5.1797/0.0280\n", "Epoch: 001/100 | Batch 200/469 | Gen/Dis Loss: 1.8944/0.0933\n", "Epoch: 001/100 | Batch 300/469 | Gen/Dis Loss: 1.5018/0.1451\n", "Epoch: 001/100 | Batch 400/469 | Gen/Dis Loss: 2.0884/0.1026\n", "Time elapsed: 0.27 min\n", "Epoch: 002/100 | Batch 000/469 | Gen/Dis Loss: 2.8803/0.0496\n", "Epoch: 002/100 | Batch 100/469 | Gen/Dis Loss: 3.4923/0.0483\n", "Epoch: 002/100 | Batch 200/469 | Gen/Dis Loss: 2.9812/0.1615\n", "Epoch: 002/100 | Batch 300/469 | Gen/Dis Loss: 2.2371/0.1658\n", "Epoch: 002/100 | Batch 400/469 | Gen/Dis Loss: 1.7027/0.2905\n", "Time elapsed: 0.51 min\n", "Epoch: 003/100 | Batch 000/469 | Gen/Dis Loss: 1.2188/0.3533\n", "Epoch: 003/100 | Batch 100/469 | Gen/Dis Loss: 1.8254/0.2083\n", "Epoch: 003/100 | Batch 200/469 | Gen/Dis Loss: 1.9774/0.2238\n", "Epoch: 003/100 | Batch 300/469 | Gen/Dis Loss: 1.9323/0.2806\n", "Epoch: 003/100 | Batch 400/469 | Gen/Dis Loss: 1.9518/0.2712\n", "Time elapsed: 0.77 min\n", "Epoch: 004/100 | Batch 000/469 | Gen/Dis Loss: 1.2785/0.3455\n", "Epoch: 004/100 | Batch 100/469 | Gen/Dis Loss: 1.3979/0.3208\n", "Epoch: 004/100 | Batch 200/469 | Gen/Dis Loss: 1.4295/0.3638\n", "Epoch: 004/100 | Batch 300/469 | Gen/Dis Loss: 1.2798/0.3620\n", "Epoch: 004/100 | Batch 400/469 | Gen/Dis Loss: 1.1321/0.4751\n", "Time elapsed: 1.04 min\n", "Epoch: 005/100 | Batch 000/469 | Gen/Dis Loss: 1.1786/0.3932\n", "Epoch: 005/100 | Batch 100/469 | Gen/Dis Loss: 1.1437/0.4343\n", "Epoch: 005/100 | Batch 200/469 | Gen/Dis Loss: 1.0105/0.4453\n", "Epoch: 005/100 | Batch 300/469 | Gen/Dis Loss: 1.3987/0.4194\n", "Epoch: 005/100 | Batch 400/469 | Gen/Dis Loss: 1.3960/0.4005\n", "Time elapsed: 1.28 min\n", "Epoch: 006/100 | Batch 000/469 | Gen/Dis Loss: 1.3119/0.4792\n", "Epoch: 006/100 | Batch 100/469 | Gen/Dis Loss: 1.6029/0.4045\n", "Epoch: 006/100 | Batch 200/469 | Gen/Dis Loss: 1.6302/0.3768\n", "Epoch: 006/100 | Batch 300/469 | Gen/Dis Loss: 0.9141/0.4838\n", "Epoch: 006/100 | Batch 400/469 | Gen/Dis Loss: 0.9891/0.4810\n", "Time elapsed: 1.56 min\n", "Epoch: 007/100 | Batch 000/469 | Gen/Dis Loss: 1.3198/0.4820\n", "Epoch: 007/100 | Batch 100/469 | Gen/Dis Loss: 1.1527/0.4620\n", "Epoch: 007/100 | Batch 200/469 | Gen/Dis Loss: 1.3668/0.3967\n", "Epoch: 007/100 | Batch 300/469 | Gen/Dis Loss: 1.6183/0.4676\n", "Epoch: 007/100 | Batch 400/469 | Gen/Dis Loss: 1.0077/0.4841\n", "Time elapsed: 1.85 min\n", "Epoch: 008/100 | Batch 000/469 | Gen/Dis Loss: 1.2245/0.5437\n", "Epoch: 008/100 | Batch 100/469 | Gen/Dis Loss: 1.0142/0.4928\n", "Epoch: 008/100 | Batch 200/469 | Gen/Dis Loss: 0.8817/0.4939\n", "Epoch: 008/100 | Batch 300/469 | Gen/Dis Loss: 1.0748/0.4967\n", "Epoch: 008/100 | Batch 400/469 | Gen/Dis Loss: 2.1265/0.4329\n", "Time elapsed: 2.11 min\n", "Epoch: 009/100 | Batch 000/469 | Gen/Dis Loss: 0.9277/0.4871\n", "Epoch: 009/100 | Batch 100/469 | Gen/Dis Loss: 1.1624/0.4473\n", "Epoch: 009/100 | Batch 200/469 | Gen/Dis Loss: 1.1869/0.4800\n", "Epoch: 009/100 | Batch 300/469 | Gen/Dis Loss: 1.9998/0.4295\n", "Epoch: 009/100 | Batch 400/469 | Gen/Dis Loss: 1.6921/0.5037\n", "Time elapsed: 2.34 min\n", "Epoch: 010/100 | Batch 000/469 | Gen/Dis Loss: 1.3091/0.4358\n", "Epoch: 010/100 | Batch 100/469 | Gen/Dis Loss: 1.2604/0.5375\n", "Epoch: 010/100 | Batch 200/469 | Gen/Dis Loss: 1.1491/0.4537\n", "Epoch: 010/100 | Batch 300/469 | Gen/Dis Loss: 1.3843/0.5068\n", "Epoch: 010/100 | Batch 400/469 | Gen/Dis Loss: 1.3413/0.5051\n", "Time elapsed: 2.60 min\n", "Epoch: 011/100 | Batch 000/469 | Gen/Dis Loss: 1.2368/0.5161\n", "Epoch: 011/100 | Batch 100/469 | Gen/Dis Loss: 1.3715/0.4692\n", "Epoch: 011/100 | Batch 200/469 | Gen/Dis Loss: 1.1182/0.5274\n", "Epoch: 011/100 | Batch 300/469 | Gen/Dis Loss: 1.2770/0.4649\n", "Epoch: 011/100 | Batch 400/469 | Gen/Dis Loss: 1.1847/0.5504\n", "Time elapsed: 2.84 min\n", "Epoch: 012/100 | Batch 000/469 | Gen/Dis Loss: 0.9930/0.5509\n", "Epoch: 012/100 | Batch 100/469 | Gen/Dis Loss: 1.1921/0.5310\n", "Epoch: 012/100 | Batch 200/469 | Gen/Dis Loss: 0.9925/0.6062\n", "Epoch: 012/100 | Batch 300/469 | Gen/Dis Loss: 1.1246/0.5170\n", "Epoch: 012/100 | Batch 400/469 | Gen/Dis Loss: 1.0432/0.4437\n", "Time elapsed: 3.07 min\n", "Epoch: 013/100 | Batch 000/469 | Gen/Dis Loss: 1.1419/0.5287\n", "Epoch: 013/100 | Batch 100/469 | Gen/Dis Loss: 1.0053/0.5152\n", "Epoch: 013/100 | Batch 200/469 | Gen/Dis Loss: 1.1308/0.5384\n", "Epoch: 013/100 | Batch 300/469 | Gen/Dis Loss: 1.1822/0.5124\n", "Epoch: 013/100 | Batch 400/469 | Gen/Dis Loss: 1.4501/0.5495\n", "Time elapsed: 3.32 min\n", "Epoch: 014/100 | Batch 000/469 | Gen/Dis Loss: 1.1417/0.5364\n", "Epoch: 014/100 | Batch 100/469 | Gen/Dis Loss: 0.9595/0.5884\n", "Epoch: 014/100 | Batch 200/469 | Gen/Dis Loss: 0.9887/0.5216\n", "Epoch: 014/100 | Batch 300/469 | Gen/Dis Loss: 1.0332/0.5686\n", "Epoch: 014/100 | Batch 400/469 | Gen/Dis Loss: 1.5268/0.4554\n", "Time elapsed: 3.60 min\n", "Epoch: 015/100 | Batch 000/469 | Gen/Dis Loss: 1.1181/0.4960\n", "Epoch: 015/100 | Batch 100/469 | Gen/Dis Loss: 1.2722/0.4632\n", "Epoch: 015/100 | Batch 200/469 | Gen/Dis Loss: 0.9523/0.6012\n", "Epoch: 015/100 | Batch 300/469 | Gen/Dis Loss: 0.9905/0.5274\n", "Epoch: 015/100 | Batch 400/469 | Gen/Dis Loss: 1.0448/0.5855\n", "Time elapsed: 3.82 min\n", "Epoch: 016/100 | Batch 000/469 | Gen/Dis Loss: 1.0641/0.5432\n", "Epoch: 016/100 | Batch 100/469 | Gen/Dis Loss: 0.9587/0.5636\n", "Epoch: 016/100 | Batch 200/469 | Gen/Dis Loss: 1.3602/0.5691\n", "Epoch: 016/100 | Batch 300/469 | Gen/Dis Loss: 1.1294/0.5564\n", "Epoch: 016/100 | Batch 400/469 | Gen/Dis Loss: 1.0727/0.5042\n", "Time elapsed: 4.04 min\n", "Epoch: 017/100 | Batch 000/469 | Gen/Dis Loss: 0.9285/0.6045\n", "Epoch: 017/100 | Batch 100/469 | Gen/Dis Loss: 1.0024/0.6384\n", "Epoch: 017/100 | Batch 200/469 | Gen/Dis Loss: 1.5662/0.4652\n", "Epoch: 017/100 | Batch 300/469 | Gen/Dis Loss: 1.3644/0.4632\n", "Epoch: 017/100 | Batch 400/469 | Gen/Dis Loss: 1.2681/0.5238\n", "Time elapsed: 4.22 min\n", "Epoch: 018/100 | Batch 000/469 | Gen/Dis Loss: 1.2578/0.5151\n", "Epoch: 018/100 | Batch 100/469 | Gen/Dis Loss: 1.6475/0.4929\n", "Epoch: 018/100 | Batch 200/469 | Gen/Dis Loss: 1.0610/0.5496\n", "Epoch: 018/100 | Batch 300/469 | Gen/Dis Loss: 1.0613/0.5634\n", "Epoch: 018/100 | Batch 400/469 | Gen/Dis Loss: 1.4675/0.4589\n", "Time elapsed: 4.40 min\n", "Epoch: 019/100 | Batch 000/469 | Gen/Dis Loss: 1.1211/0.5027\n", "Epoch: 019/100 | Batch 100/469 | Gen/Dis Loss: 1.1444/0.5655\n", "Epoch: 019/100 | Batch 200/469 | Gen/Dis Loss: 1.2471/0.5716\n", "Epoch: 019/100 | Batch 300/469 | Gen/Dis Loss: 1.0223/0.5106\n", "Epoch: 019/100 | Batch 400/469 | Gen/Dis Loss: 1.0361/0.5805\n", "Time elapsed: 4.58 min\n", "Epoch: 020/100 | Batch 000/469 | Gen/Dis Loss: 0.9195/0.5428\n", "Epoch: 020/100 | Batch 100/469 | Gen/Dis Loss: 1.3110/0.4955\n", "Epoch: 020/100 | Batch 200/469 | Gen/Dis Loss: 1.2449/0.4973\n", "Epoch: 020/100 | Batch 300/469 | Gen/Dis Loss: 1.3258/0.4992\n", "Epoch: 020/100 | Batch 400/469 | Gen/Dis Loss: 1.2196/0.5279\n", "Time elapsed: 4.77 min\n", "Epoch: 021/100 | Batch 000/469 | Gen/Dis Loss: 1.5621/0.5584\n", "Epoch: 021/100 | Batch 100/469 | Gen/Dis Loss: 1.1148/0.5888\n", "Epoch: 021/100 | Batch 200/469 | Gen/Dis Loss: 1.5108/0.4636\n", "Epoch: 021/100 | Batch 300/469 | Gen/Dis Loss: 1.0957/0.4912\n", "Epoch: 021/100 | Batch 400/469 | Gen/Dis Loss: 1.0342/0.5184\n", "Time elapsed: 4.92 min\n", "Epoch: 022/100 | Batch 000/469 | Gen/Dis Loss: 1.9312/0.4366\n", "Epoch: 022/100 | Batch 100/469 | Gen/Dis Loss: 1.2312/0.5260\n", "Epoch: 022/100 | Batch 200/469 | Gen/Dis Loss: 1.1939/0.5075\n", "Epoch: 022/100 | Batch 300/469 | Gen/Dis Loss: 1.1393/0.5692\n", "Epoch: 022/100 | Batch 400/469 | Gen/Dis Loss: 1.0390/0.5261\n", "Time elapsed: 5.05 min\n", "Epoch: 023/100 | Batch 000/469 | Gen/Dis Loss: 1.3148/0.4902\n", "Epoch: 023/100 | Batch 100/469 | Gen/Dis Loss: 1.2077/0.6129\n", "Epoch: 023/100 | Batch 200/469 | Gen/Dis Loss: 1.0886/0.5545\n", "Epoch: 023/100 | Batch 300/469 | Gen/Dis Loss: 1.0762/0.4948\n", "Epoch: 023/100 | Batch 400/469 | Gen/Dis Loss: 1.5361/0.5476\n", "Time elapsed: 5.17 min\n", "Epoch: 024/100 | Batch 000/469 | Gen/Dis Loss: 1.1752/0.5881\n", "Epoch: 024/100 | Batch 100/469 | Gen/Dis Loss: 1.3408/0.5339\n", "Epoch: 024/100 | Batch 200/469 | Gen/Dis Loss: 1.2613/0.4555\n", "Epoch: 024/100 | Batch 300/469 | Gen/Dis Loss: 1.0707/0.5099\n", "Epoch: 024/100 | Batch 400/469 | Gen/Dis Loss: 1.1063/0.5695\n", "Time elapsed: 5.32 min\n", "Epoch: 025/100 | Batch 000/469 | Gen/Dis Loss: 1.2911/0.5084\n", "Epoch: 025/100 | Batch 100/469 | Gen/Dis Loss: 1.1280/0.5151\n", "Epoch: 025/100 | Batch 200/469 | Gen/Dis Loss: 1.3799/0.5784\n", "Epoch: 025/100 | Batch 300/469 | Gen/Dis Loss: 1.1675/0.6001\n", "Epoch: 025/100 | Batch 400/469 | Gen/Dis Loss: 0.9834/0.6158\n", "Time elapsed: 5.48 min\n", "Epoch: 026/100 | Batch 000/469 | Gen/Dis Loss: 1.2713/0.5475\n", "Epoch: 026/100 | Batch 100/469 | Gen/Dis Loss: 1.3814/0.5652\n", "Epoch: 026/100 | Batch 200/469 | Gen/Dis Loss: 1.1782/0.4850\n", "Epoch: 026/100 | Batch 300/469 | Gen/Dis Loss: 0.9917/0.5888\n", "Epoch: 026/100 | Batch 400/469 | Gen/Dis Loss: 1.0909/0.5825\n", "Time elapsed: 5.64 min\n", "Epoch: 027/100 | Batch 000/469 | Gen/Dis Loss: 1.0873/0.5579\n", "Epoch: 027/100 | Batch 100/469 | Gen/Dis Loss: 0.9639/0.5860\n", "Epoch: 027/100 | Batch 200/469 | Gen/Dis Loss: 1.0458/0.5526\n", "Epoch: 027/100 | Batch 300/469 | Gen/Dis Loss: 1.3373/0.5140\n", "Epoch: 027/100 | Batch 400/469 | Gen/Dis Loss: 1.2790/0.5223\n", "Time elapsed: 5.79 min\n", "Epoch: 028/100 | Batch 000/469 | Gen/Dis Loss: 0.9300/0.5869\n", "Epoch: 028/100 | Batch 100/469 | Gen/Dis Loss: 1.0022/0.6056\n", "Epoch: 028/100 | Batch 200/469 | Gen/Dis Loss: 1.0688/0.5447\n", "Epoch: 028/100 | Batch 300/469 | Gen/Dis Loss: 1.0161/0.5702\n", "Epoch: 028/100 | Batch 400/469 | Gen/Dis Loss: 0.8731/0.5543\n", "Time elapsed: 5.92 min\n", "Epoch: 029/100 | Batch 000/469 | Gen/Dis Loss: 0.8719/0.5524\n", "Epoch: 029/100 | Batch 100/469 | Gen/Dis Loss: 1.3005/0.5179\n", "Epoch: 029/100 | Batch 200/469 | Gen/Dis Loss: 1.2986/0.5312\n", "Epoch: 029/100 | Batch 300/469 | Gen/Dis Loss: 1.1084/0.5207\n", "Epoch: 029/100 | Batch 400/469 | Gen/Dis Loss: 1.0591/0.5577\n", "Time elapsed: 6.07 min\n", "Epoch: 030/100 | Batch 000/469 | Gen/Dis Loss: 1.0231/0.6170\n", "Epoch: 030/100 | Batch 100/469 | Gen/Dis Loss: 0.9142/0.6046\n", "Epoch: 030/100 | Batch 200/469 | Gen/Dis Loss: 1.2140/0.5290\n", "Epoch: 030/100 | Batch 300/469 | Gen/Dis Loss: 0.8784/0.5804\n", "Epoch: 030/100 | Batch 400/469 | Gen/Dis Loss: 1.1178/0.5165\n", "Time elapsed: 6.20 min\n", "Epoch: 031/100 | Batch 000/469 | Gen/Dis Loss: 0.9555/0.5921\n", "Epoch: 031/100 | Batch 100/469 | Gen/Dis Loss: 0.9644/0.5432\n", "Epoch: 031/100 | Batch 200/469 | Gen/Dis Loss: 0.9531/0.5465\n", "Epoch: 031/100 | Batch 300/469 | Gen/Dis Loss: 1.3496/0.5550\n", "Epoch: 031/100 | Batch 400/469 | Gen/Dis Loss: 1.2137/0.5672\n", "Time elapsed: 6.32 min\n", "Epoch: 032/100 | Batch 000/469 | Gen/Dis Loss: 1.0849/0.5020\n", "Epoch: 032/100 | Batch 100/469 | Gen/Dis Loss: 0.9098/0.5481\n", "Epoch: 032/100 | Batch 200/469 | Gen/Dis Loss: 1.2349/0.5024\n", "Epoch: 032/100 | Batch 300/469 | Gen/Dis Loss: 0.9468/0.5599\n", "Epoch: 032/100 | Batch 400/469 | Gen/Dis Loss: 1.4531/0.4928\n", "Time elapsed: 6.45 min\n", "Epoch: 033/100 | Batch 000/469 | Gen/Dis Loss: 1.3397/0.5521\n", "Epoch: 033/100 | Batch 100/469 | Gen/Dis Loss: 1.0106/0.5472\n", "Epoch: 033/100 | Batch 200/469 | Gen/Dis Loss: 0.9787/0.5606\n", "Epoch: 033/100 | Batch 300/469 | Gen/Dis Loss: 1.1434/0.5388\n", "Epoch: 033/100 | Batch 400/469 | Gen/Dis Loss: 1.0476/0.5259\n", "Time elapsed: 6.57 min\n", "Epoch: 034/100 | Batch 000/469 | Gen/Dis Loss: 1.3847/0.5294\n", "Epoch: 034/100 | Batch 100/469 | Gen/Dis Loss: 0.8550/0.5800\n", "Epoch: 034/100 | Batch 200/469 | Gen/Dis Loss: 1.0220/0.5527\n", "Epoch: 034/100 | Batch 300/469 | Gen/Dis Loss: 0.9255/0.5751\n", "Epoch: 034/100 | Batch 400/469 | Gen/Dis Loss: 1.0400/0.5554\n", "Time elapsed: 6.72 min\n", "Epoch: 035/100 | Batch 000/469 | Gen/Dis Loss: 0.9723/0.5789\n", "Epoch: 035/100 | Batch 100/469 | Gen/Dis Loss: 1.4414/0.4769\n", "Epoch: 035/100 | Batch 200/469 | Gen/Dis Loss: 0.9431/0.5898\n", "Epoch: 035/100 | Batch 300/469 | Gen/Dis Loss: 0.8252/0.6573\n", "Epoch: 035/100 | Batch 400/469 | Gen/Dis Loss: 0.9694/0.5427\n", "Time elapsed: 6.84 min\n", "Epoch: 036/100 | Batch 000/469 | Gen/Dis Loss: 1.3664/0.5839\n", "Epoch: 036/100 | Batch 100/469 | Gen/Dis Loss: 1.0854/0.5739\n", "Epoch: 036/100 | Batch 200/469 | Gen/Dis Loss: 1.0429/0.5457\n", "Epoch: 036/100 | Batch 300/469 | Gen/Dis Loss: 0.8601/0.6151\n", "Epoch: 036/100 | Batch 400/469 | Gen/Dis Loss: 1.2785/0.5850\n", "Time elapsed: 6.97 min\n", "Epoch: 037/100 | Batch 000/469 | Gen/Dis Loss: 1.0251/0.5933\n", "Epoch: 037/100 | Batch 100/469 | Gen/Dis Loss: 1.2177/0.5053\n", "Epoch: 037/100 | Batch 200/469 | Gen/Dis Loss: 0.8804/0.5925\n", "Epoch: 037/100 | Batch 300/469 | Gen/Dis Loss: 1.2797/0.6173\n", "Epoch: 037/100 | Batch 400/469 | Gen/Dis Loss: 0.9189/0.6238\n", "Time elapsed: 7.10 min\n", "Epoch: 038/100 | Batch 000/469 | Gen/Dis Loss: 1.3463/0.5419\n", "Epoch: 038/100 | Batch 100/469 | Gen/Dis Loss: 1.0166/0.6045\n", "Epoch: 038/100 | Batch 200/469 | Gen/Dis Loss: 0.9895/0.6320\n", "Epoch: 038/100 | Batch 300/469 | Gen/Dis Loss: 0.9749/0.5621\n", "Epoch: 038/100 | Batch 400/469 | Gen/Dis Loss: 1.0448/0.5945\n", "Time elapsed: 7.24 min\n", "Epoch: 039/100 | Batch 000/469 | Gen/Dis Loss: 0.9662/0.5669\n", "Epoch: 039/100 | Batch 100/469 | Gen/Dis Loss: 1.1476/0.5462\n", "Epoch: 039/100 | Batch 200/469 | Gen/Dis Loss: 0.9662/0.5554\n", "Epoch: 039/100 | Batch 300/469 | Gen/Dis Loss: 1.0850/0.6031\n", "Epoch: 039/100 | Batch 400/469 | Gen/Dis Loss: 1.1491/0.6014\n", "Time elapsed: 7.41 min\n", "Epoch: 040/100 | Batch 000/469 | Gen/Dis Loss: 0.9942/0.5999\n", "Epoch: 040/100 | Batch 100/469 | Gen/Dis Loss: 0.9034/0.5979\n", "Epoch: 040/100 | Batch 200/469 | Gen/Dis Loss: 1.1880/0.5693\n", "Epoch: 040/100 | Batch 300/469 | Gen/Dis Loss: 1.0893/0.5933\n", "Epoch: 040/100 | Batch 400/469 | Gen/Dis Loss: 1.0711/0.5501\n", "Time elapsed: 7.59 min\n", "Epoch: 041/100 | Batch 000/469 | Gen/Dis Loss: 0.9100/0.5957\n", "Epoch: 041/100 | Batch 100/469 | Gen/Dis Loss: 0.7538/0.5947\n", "Epoch: 041/100 | Batch 200/469 | Gen/Dis Loss: 0.9743/0.5999\n", "Epoch: 041/100 | Batch 300/469 | Gen/Dis Loss: 0.8305/0.6395\n", "Epoch: 041/100 | Batch 400/469 | Gen/Dis Loss: 1.1106/0.6419\n", "Time elapsed: 7.73 min\n", "Epoch: 042/100 | Batch 000/469 | Gen/Dis Loss: 1.1241/0.5890\n", "Epoch: 042/100 | Batch 100/469 | Gen/Dis Loss: 0.8509/0.6164\n", "Epoch: 042/100 | Batch 200/469 | Gen/Dis Loss: 1.2024/0.5684\n", "Epoch: 042/100 | Batch 300/469 | Gen/Dis Loss: 0.9708/0.6378\n", "Epoch: 042/100 | Batch 400/469 | Gen/Dis Loss: 1.1171/0.5501\n", "Time elapsed: 7.85 min\n", "Epoch: 043/100 | Batch 000/469 | Gen/Dis Loss: 1.0931/0.5653\n", "Epoch: 043/100 | Batch 100/469 | Gen/Dis Loss: 1.0468/0.5782\n", "Epoch: 043/100 | Batch 200/469 | Gen/Dis Loss: 1.0359/0.6329\n", "Epoch: 043/100 | Batch 300/469 | Gen/Dis Loss: 1.1976/0.6114\n", "Epoch: 043/100 | Batch 400/469 | Gen/Dis Loss: 0.8817/0.6200\n", "Time elapsed: 7.98 min\n", "Epoch: 044/100 | Batch 000/469 | Gen/Dis Loss: 0.9911/0.6061\n", "Epoch: 044/100 | Batch 100/469 | Gen/Dis Loss: 1.0196/0.6435\n", "Epoch: 044/100 | Batch 200/469 | Gen/Dis Loss: 1.0005/0.6266\n", "Epoch: 044/100 | Batch 300/469 | Gen/Dis Loss: 0.8342/0.6092\n", "Epoch: 044/100 | Batch 400/469 | Gen/Dis Loss: 0.8342/0.5589\n", "Time elapsed: 8.10 min\n", "Epoch: 045/100 | Batch 000/469 | Gen/Dis Loss: 0.7638/0.6289\n", "Epoch: 045/100 | Batch 100/469 | Gen/Dis Loss: 0.9049/0.5920\n", "Epoch: 045/100 | Batch 200/469 | Gen/Dis Loss: 1.0077/0.5975\n", "Epoch: 045/100 | Batch 300/469 | Gen/Dis Loss: 0.9315/0.6066\n", "Epoch: 045/100 | Batch 400/469 | Gen/Dis Loss: 0.7719/0.6624\n", "Time elapsed: 8.23 min\n", "Epoch: 046/100 | Batch 000/469 | Gen/Dis Loss: 1.0064/0.5672\n", "Epoch: 046/100 | Batch 100/469 | Gen/Dis Loss: 0.8730/0.6217\n", "Epoch: 046/100 | Batch 200/469 | Gen/Dis Loss: 1.2217/0.5859\n", "Epoch: 046/100 | Batch 300/469 | Gen/Dis Loss: 1.1649/0.5878\n", "Epoch: 046/100 | Batch 400/469 | Gen/Dis Loss: 0.9912/0.5882\n", "Time elapsed: 8.35 min\n", "Epoch: 047/100 | Batch 000/469 | Gen/Dis Loss: 0.8579/0.6209\n", "Epoch: 047/100 | Batch 100/469 | Gen/Dis Loss: 1.0072/0.5908\n", "Epoch: 047/100 | Batch 200/469 | Gen/Dis Loss: 0.8694/0.6285\n", "Epoch: 047/100 | Batch 300/469 | Gen/Dis Loss: 0.9354/0.6087\n", "Epoch: 047/100 | Batch 400/469 | Gen/Dis Loss: 0.8800/0.6521\n", "Time elapsed: 8.48 min\n", "Epoch: 048/100 | Batch 000/469 | Gen/Dis Loss: 0.8513/0.6051\n", "Epoch: 048/100 | Batch 100/469 | Gen/Dis Loss: 0.8803/0.6090\n", "Epoch: 048/100 | Batch 200/469 | Gen/Dis Loss: 1.0930/0.6115\n", "Epoch: 048/100 | Batch 300/469 | Gen/Dis Loss: 0.7406/0.6692\n", "Epoch: 048/100 | Batch 400/469 | Gen/Dis Loss: 0.8551/0.6188\n", "Time elapsed: 8.62 min\n", "Epoch: 049/100 | Batch 000/469 | Gen/Dis Loss: 0.8792/0.5986\n", "Epoch: 049/100 | Batch 100/469 | Gen/Dis Loss: 0.8424/0.6277\n", "Epoch: 049/100 | Batch 200/469 | Gen/Dis Loss: 0.7973/0.6320\n", "Epoch: 049/100 | Batch 300/469 | Gen/Dis Loss: 0.9188/0.5828\n", "Epoch: 049/100 | Batch 400/469 | Gen/Dis Loss: 0.9253/0.6013\n", "Time elapsed: 8.80 min\n", "Epoch: 050/100 | Batch 000/469 | Gen/Dis Loss: 1.3241/0.5689\n", "Epoch: 050/100 | Batch 100/469 | Gen/Dis Loss: 1.0220/0.5922\n", "Epoch: 050/100 | Batch 200/469 | Gen/Dis Loss: 0.9210/0.6024\n", "Epoch: 050/100 | Batch 300/469 | Gen/Dis Loss: 0.8139/0.6578\n", "Epoch: 050/100 | Batch 400/469 | Gen/Dis Loss: 1.0371/0.5987\n", "Time elapsed: 8.93 min\n", "Epoch: 051/100 | Batch 000/469 | Gen/Dis Loss: 0.9253/0.6002\n", "Epoch: 051/100 | Batch 100/469 | Gen/Dis Loss: 0.8154/0.5774\n", "Epoch: 051/100 | Batch 200/469 | Gen/Dis Loss: 0.9697/0.6240\n", "Epoch: 051/100 | Batch 300/469 | Gen/Dis Loss: 1.1185/0.5541\n", "Epoch: 051/100 | Batch 400/469 | Gen/Dis Loss: 0.8016/0.6642\n", "Time elapsed: 9.06 min\n", "Epoch: 052/100 | Batch 000/469 | Gen/Dis Loss: 0.8716/0.6364\n", "Epoch: 052/100 | Batch 100/469 | Gen/Dis Loss: 0.9636/0.5944\n", "Epoch: 052/100 | Batch 200/469 | Gen/Dis Loss: 0.9511/0.6204\n", "Epoch: 052/100 | Batch 300/469 | Gen/Dis Loss: 0.9293/0.5901\n", "Epoch: 052/100 | Batch 400/469 | Gen/Dis Loss: 1.1139/0.5535\n", "Time elapsed: 9.18 min\n", "Epoch: 053/100 | Batch 000/469 | Gen/Dis Loss: 0.8345/0.6399\n", "Epoch: 053/100 | Batch 100/469 | Gen/Dis Loss: 1.0420/0.5847\n", "Epoch: 053/100 | Batch 200/469 | Gen/Dis Loss: 0.8887/0.6183\n", "Epoch: 053/100 | Batch 300/469 | Gen/Dis Loss: 1.1280/0.5869\n", "Epoch: 053/100 | Batch 400/469 | Gen/Dis Loss: 0.8391/0.6031\n", "Time elapsed: 9.30 min\n", "Epoch: 054/100 | Batch 000/469 | Gen/Dis Loss: 1.0584/0.5659\n", "Epoch: 054/100 | Batch 100/469 | Gen/Dis Loss: 0.8722/0.5991\n", "Epoch: 054/100 | Batch 200/469 | Gen/Dis Loss: 0.8416/0.6067\n", "Epoch: 054/100 | Batch 300/469 | Gen/Dis Loss: 0.9295/0.5910\n", "Epoch: 054/100 | Batch 400/469 | Gen/Dis Loss: 0.7705/0.6145\n", "Time elapsed: 9.43 min\n", "Epoch: 055/100 | Batch 000/469 | Gen/Dis Loss: 0.9697/0.6207\n", "Epoch: 055/100 | Batch 100/469 | Gen/Dis Loss: 1.3702/0.5782\n", "Epoch: 055/100 | Batch 200/469 | Gen/Dis Loss: 0.8874/0.6034\n", "Epoch: 055/100 | Batch 300/469 | Gen/Dis Loss: 0.9273/0.6095\n", "Epoch: 055/100 | Batch 400/469 | Gen/Dis Loss: 1.0736/0.5893\n", "Time elapsed: 9.57 min\n", "Epoch: 056/100 | Batch 000/469 | Gen/Dis Loss: 0.9631/0.5959\n", "Epoch: 056/100 | Batch 100/469 | Gen/Dis Loss: 0.8657/0.6398\n", "Epoch: 056/100 | Batch 200/469 | Gen/Dis Loss: 0.8120/0.6027\n", "Epoch: 056/100 | Batch 300/469 | Gen/Dis Loss: 1.1529/0.6493\n", "Epoch: 056/100 | Batch 400/469 | Gen/Dis Loss: 0.9172/0.5788\n", "Time elapsed: 9.77 min\n", "Epoch: 057/100 | Batch 000/469 | Gen/Dis Loss: 0.9197/0.6090\n", "Epoch: 057/100 | Batch 100/469 | Gen/Dis Loss: 0.9413/0.6255\n", "Epoch: 057/100 | Batch 200/469 | Gen/Dis Loss: 0.9020/0.5870\n", "Epoch: 057/100 | Batch 300/469 | Gen/Dis Loss: 0.9947/0.5586\n", "Epoch: 057/100 | Batch 400/469 | Gen/Dis Loss: 0.9077/0.6454\n", "Time elapsed: 10.03 min\n", "Epoch: 058/100 | Batch 000/469 | Gen/Dis Loss: 0.8899/0.6106\n", "Epoch: 058/100 | Batch 100/469 | Gen/Dis Loss: 0.8154/0.6554\n", "Epoch: 058/100 | Batch 200/469 | Gen/Dis Loss: 0.9307/0.5997\n", "Epoch: 058/100 | Batch 300/469 | Gen/Dis Loss: 0.8293/0.5881\n", "Epoch: 058/100 | Batch 400/469 | Gen/Dis Loss: 0.9434/0.6448\n", "Time elapsed: 10.31 min\n", "Epoch: 059/100 | Batch 000/469 | Gen/Dis Loss: 0.9638/0.6325\n", "Epoch: 059/100 | Batch 100/469 | Gen/Dis Loss: 0.9374/0.6304\n", "Epoch: 059/100 | Batch 200/469 | Gen/Dis Loss: 0.8452/0.6464\n", "Epoch: 059/100 | Batch 300/469 | Gen/Dis Loss: 1.0170/0.6210\n", "Epoch: 059/100 | Batch 400/469 | Gen/Dis Loss: 0.8808/0.5950\n", "Time elapsed: 10.56 min\n", "Epoch: 060/100 | Batch 000/469 | Gen/Dis Loss: 0.9076/0.5969\n", "Epoch: 060/100 | Batch 100/469 | Gen/Dis Loss: 1.1195/0.6040\n", "Epoch: 060/100 | Batch 200/469 | Gen/Dis Loss: 0.9015/0.6149\n", "Epoch: 060/100 | Batch 300/469 | Gen/Dis Loss: 0.8414/0.5804\n", "Epoch: 060/100 | Batch 400/469 | Gen/Dis Loss: 0.8220/0.6557\n", "Time elapsed: 10.83 min\n", "Epoch: 061/100 | Batch 000/469 | Gen/Dis Loss: 0.8411/0.6360\n", "Epoch: 061/100 | Batch 100/469 | Gen/Dis Loss: 0.8431/0.6304\n", "Epoch: 061/100 | Batch 200/469 | Gen/Dis Loss: 0.7740/0.6395\n", "Epoch: 061/100 | Batch 300/469 | Gen/Dis Loss: 0.8840/0.5987\n", "Epoch: 061/100 | Batch 400/469 | Gen/Dis Loss: 0.8510/0.6232\n", "Time elapsed: 11.07 min\n", "Epoch: 062/100 | Batch 000/469 | Gen/Dis Loss: 1.0286/0.6151\n", "Epoch: 062/100 | Batch 100/469 | Gen/Dis Loss: 1.0516/0.5767\n", "Epoch: 062/100 | Batch 200/469 | Gen/Dis Loss: 0.8182/0.5654\n", "Epoch: 062/100 | Batch 300/469 | Gen/Dis Loss: 0.8658/0.6156\n", "Epoch: 062/100 | Batch 400/469 | Gen/Dis Loss: 0.9674/0.6434\n", "Time elapsed: 11.33 min\n", "Epoch: 063/100 | Batch 000/469 | Gen/Dis Loss: 0.6952/0.6601\n", "Epoch: 063/100 | Batch 100/469 | Gen/Dis Loss: 0.8180/0.6041\n", "Epoch: 063/100 | Batch 200/469 | Gen/Dis Loss: 0.8224/0.6683\n", "Epoch: 063/100 | Batch 300/469 | Gen/Dis Loss: 0.9604/0.5938\n", "Epoch: 063/100 | Batch 400/469 | Gen/Dis Loss: 0.7969/0.6561\n", "Time elapsed: 11.54 min\n", "Epoch: 064/100 | Batch 000/469 | Gen/Dis Loss: 0.8544/0.6290\n", "Epoch: 064/100 | Batch 100/469 | Gen/Dis Loss: 0.8685/0.5925\n", "Epoch: 064/100 | Batch 200/469 | Gen/Dis Loss: 1.4746/0.5992\n", "Epoch: 064/100 | Batch 300/469 | Gen/Dis Loss: 0.8570/0.6417\n", "Epoch: 064/100 | Batch 400/469 | Gen/Dis Loss: 0.8588/0.6461\n", "Time elapsed: 11.78 min\n", "Epoch: 065/100 | Batch 000/469 | Gen/Dis Loss: 0.8579/0.6151\n", "Epoch: 065/100 | Batch 100/469 | Gen/Dis Loss: 0.9720/0.5867\n", "Epoch: 065/100 | Batch 200/469 | Gen/Dis Loss: 0.8870/0.6215\n", "Epoch: 065/100 | Batch 300/469 | Gen/Dis Loss: 0.8184/0.6506\n", "Epoch: 065/100 | Batch 400/469 | Gen/Dis Loss: 0.9247/0.6219\n", "Time elapsed: 12.03 min\n", "Epoch: 066/100 | Batch 000/469 | Gen/Dis Loss: 0.9073/0.6157\n", "Epoch: 066/100 | Batch 100/469 | Gen/Dis Loss: 0.8459/0.6364\n", "Epoch: 066/100 | Batch 200/469 | Gen/Dis Loss: 1.0687/0.5647\n", "Epoch: 066/100 | Batch 300/469 | Gen/Dis Loss: 0.9213/0.6136\n", "Epoch: 066/100 | Batch 400/469 | Gen/Dis Loss: 0.7895/0.6409\n", "Time elapsed: 12.30 min\n", "Epoch: 067/100 | Batch 000/469 | Gen/Dis Loss: 0.8258/0.6246\n", "Epoch: 067/100 | Batch 100/469 | Gen/Dis Loss: 0.9616/0.5776\n", "Epoch: 067/100 | Batch 200/469 | Gen/Dis Loss: 0.9039/0.6012\n", "Epoch: 067/100 | Batch 300/469 | Gen/Dis Loss: 0.9857/0.5949\n", "Epoch: 067/100 | Batch 400/469 | Gen/Dis Loss: 1.1779/0.5773\n", "Time elapsed: 12.58 min\n", "Epoch: 068/100 | Batch 000/469 | Gen/Dis Loss: 0.9631/0.6006\n", "Epoch: 068/100 | Batch 100/469 | Gen/Dis Loss: 0.7157/0.6103\n", "Epoch: 068/100 | Batch 200/469 | Gen/Dis Loss: 0.8400/0.6223\n", "Epoch: 068/100 | Batch 300/469 | Gen/Dis Loss: 1.0586/0.5840\n", "Epoch: 068/100 | Batch 400/469 | Gen/Dis Loss: 0.9487/0.6224\n", "Time elapsed: 12.84 min\n", "Epoch: 069/100 | Batch 000/469 | Gen/Dis Loss: 1.0124/0.5248\n", "Epoch: 069/100 | Batch 100/469 | Gen/Dis Loss: 0.8849/0.6481\n", "Epoch: 069/100 | Batch 200/469 | Gen/Dis Loss: 0.9250/0.6130\n", "Epoch: 069/100 | Batch 300/469 | Gen/Dis Loss: 0.9207/0.6420\n", "Epoch: 069/100 | Batch 400/469 | Gen/Dis Loss: 0.8661/0.6100\n", "Time elapsed: 13.11 min\n", "Epoch: 070/100 | Batch 000/469 | Gen/Dis Loss: 1.0647/0.6247\n", "Epoch: 070/100 | Batch 100/469 | Gen/Dis Loss: 0.8877/0.6254\n", "Epoch: 070/100 | Batch 200/469 | Gen/Dis Loss: 0.8151/0.6462\n", "Epoch: 070/100 | Batch 300/469 | Gen/Dis Loss: 0.8807/0.6079\n", "Epoch: 070/100 | Batch 400/469 | Gen/Dis Loss: 0.9690/0.6432\n", "Time elapsed: 13.34 min\n", "Epoch: 071/100 | Batch 000/469 | Gen/Dis Loss: 0.8764/0.6338\n", "Epoch: 071/100 | Batch 100/469 | Gen/Dis Loss: 0.9052/0.5937\n", "Epoch: 071/100 | Batch 200/469 | Gen/Dis Loss: 1.0023/0.5866\n", "Epoch: 071/100 | Batch 300/469 | Gen/Dis Loss: 0.7945/0.6066\n", "Epoch: 071/100 | Batch 400/469 | Gen/Dis Loss: 0.8566/0.6092\n", "Time elapsed: 13.57 min\n", "Epoch: 072/100 | Batch 000/469 | Gen/Dis Loss: 1.0826/0.5474\n", "Epoch: 072/100 | Batch 100/469 | Gen/Dis Loss: 0.9077/0.6232\n", "Epoch: 072/100 | Batch 200/469 | Gen/Dis Loss: 1.0860/0.6291\n", "Epoch: 072/100 | Batch 300/469 | Gen/Dis Loss: 0.9009/0.6444\n", "Epoch: 072/100 | Batch 400/469 | Gen/Dis Loss: 0.9546/0.6265\n", "Time elapsed: 13.82 min\n", "Epoch: 073/100 | Batch 000/469 | Gen/Dis Loss: 0.9126/0.5977\n", "Epoch: 073/100 | Batch 100/469 | Gen/Dis Loss: 1.0169/0.6357\n", "Epoch: 073/100 | Batch 200/469 | Gen/Dis Loss: 0.8760/0.6333\n", "Epoch: 073/100 | Batch 300/469 | Gen/Dis Loss: 0.8972/0.5929\n", "Epoch: 073/100 | Batch 400/469 | Gen/Dis Loss: 0.9535/0.6609\n", "Time elapsed: 14.05 min\n", "Epoch: 074/100 | Batch 000/469 | Gen/Dis Loss: 0.8905/0.6017\n", "Epoch: 074/100 | Batch 100/469 | Gen/Dis Loss: 0.9040/0.6458\n", "Epoch: 074/100 | Batch 200/469 | Gen/Dis Loss: 0.8277/0.6424\n", "Epoch: 074/100 | Batch 300/469 | Gen/Dis Loss: 1.6138/0.5738\n", "Epoch: 074/100 | Batch 400/469 | Gen/Dis Loss: 0.9943/0.6718\n", "Time elapsed: 14.31 min\n", "Epoch: 075/100 | Batch 000/469 | Gen/Dis Loss: 1.0839/0.6357\n", "Epoch: 075/100 | Batch 100/469 | Gen/Dis Loss: 0.8858/0.6300\n", "Epoch: 075/100 | Batch 200/469 | Gen/Dis Loss: 0.9034/0.6045\n", "Epoch: 075/100 | Batch 300/469 | Gen/Dis Loss: 0.8336/0.5991\n", "Epoch: 075/100 | Batch 400/469 | Gen/Dis Loss: 0.8414/0.6642\n", "Time elapsed: 14.54 min\n", "Epoch: 076/100 | Batch 000/469 | Gen/Dis Loss: 0.8422/0.6506\n", "Epoch: 076/100 | Batch 100/469 | Gen/Dis Loss: 0.8560/0.5884\n", "Epoch: 076/100 | Batch 200/469 | Gen/Dis Loss: 0.8066/0.6215\n", "Epoch: 076/100 | Batch 300/469 | Gen/Dis Loss: 0.7987/0.6537\n", "Epoch: 076/100 | Batch 400/469 | Gen/Dis Loss: 0.8784/0.5854\n", "Time elapsed: 14.82 min\n", "Epoch: 077/100 | Batch 000/469 | Gen/Dis Loss: 0.9845/0.6067\n", "Epoch: 077/100 | Batch 100/469 | Gen/Dis Loss: 0.8514/0.6269\n", "Epoch: 077/100 | Batch 200/469 | Gen/Dis Loss: 1.0448/0.6637\n", "Epoch: 077/100 | Batch 300/469 | Gen/Dis Loss: 0.9325/0.5811\n", "Epoch: 077/100 | Batch 400/469 | Gen/Dis Loss: 0.9169/0.5837\n", "Time elapsed: 15.08 min\n", "Epoch: 078/100 | Batch 000/469 | Gen/Dis Loss: 0.9746/0.6398\n", "Epoch: 078/100 | Batch 100/469 | Gen/Dis Loss: 0.8518/0.6321\n", "Epoch: 078/100 | Batch 200/469 | Gen/Dis Loss: 0.9485/0.5925\n", "Epoch: 078/100 | Batch 300/469 | Gen/Dis Loss: 0.8646/0.6530\n", "Epoch: 078/100 | Batch 400/469 | Gen/Dis Loss: 0.8851/0.6056\n", "Time elapsed: 15.33 min\n", "Epoch: 079/100 | Batch 000/469 | Gen/Dis Loss: 0.9215/0.6184\n", "Epoch: 079/100 | Batch 100/469 | Gen/Dis Loss: 0.8766/0.5987\n", "Epoch: 079/100 | Batch 200/469 | Gen/Dis Loss: 0.9273/0.6339\n", "Epoch: 079/100 | Batch 300/469 | Gen/Dis Loss: 1.0428/0.6016\n", "Epoch: 079/100 | Batch 400/469 | Gen/Dis Loss: 0.8676/0.6156\n", "Time elapsed: 15.63 min\n", "Epoch: 080/100 | Batch 000/469 | Gen/Dis Loss: 0.8753/0.6354\n", "Epoch: 080/100 | Batch 100/469 | Gen/Dis Loss: 0.7689/0.6156\n", "Epoch: 080/100 | Batch 200/469 | Gen/Dis Loss: 0.9524/0.5874\n", "Epoch: 080/100 | Batch 300/469 | Gen/Dis Loss: 1.1452/0.5870\n", "Epoch: 080/100 | Batch 400/469 | Gen/Dis Loss: 0.9418/0.5921\n", "Time elapsed: 15.87 min\n", "Epoch: 081/100 | Batch 000/469 | Gen/Dis Loss: 0.9341/0.5982\n", "Epoch: 081/100 | Batch 100/469 | Gen/Dis Loss: 0.9412/0.6336\n", "Epoch: 081/100 | Batch 200/469 | Gen/Dis Loss: 0.8976/0.6561\n", "Epoch: 081/100 | Batch 300/469 | Gen/Dis Loss: 0.8531/0.6544\n", "Epoch: 081/100 | Batch 400/469 | Gen/Dis Loss: 0.8658/0.6275\n", "Time elapsed: 16.14 min\n", "Epoch: 082/100 | Batch 000/469 | Gen/Dis Loss: 0.8624/0.6454\n", "Epoch: 082/100 | Batch 100/469 | Gen/Dis Loss: 0.8182/0.5911\n", "Epoch: 082/100 | Batch 200/469 | Gen/Dis Loss: 0.8794/0.6080\n", "Epoch: 082/100 | Batch 300/469 | Gen/Dis Loss: 0.9631/0.6111\n", "Epoch: 082/100 | Batch 400/469 | Gen/Dis Loss: 1.0426/0.6404\n", "Time elapsed: 16.39 min\n", "Epoch: 083/100 | Batch 000/469 | Gen/Dis Loss: 1.0449/0.6439\n", "Epoch: 083/100 | Batch 100/469 | Gen/Dis Loss: 0.9290/0.6319\n", "Epoch: 083/100 | Batch 200/469 | Gen/Dis Loss: 0.8768/0.6186\n", "Epoch: 083/100 | Batch 300/469 | Gen/Dis Loss: 0.8202/0.6050\n", "Epoch: 083/100 | Batch 400/469 | Gen/Dis Loss: 0.8840/0.6135\n", "Time elapsed: 16.63 min\n", "Epoch: 084/100 | Batch 000/469 | Gen/Dis Loss: 1.0632/0.6157\n", "Epoch: 084/100 | Batch 100/469 | Gen/Dis Loss: 0.8863/0.5954\n", "Epoch: 084/100 | Batch 200/469 | Gen/Dis Loss: 1.0618/0.6428\n", "Epoch: 084/100 | Batch 300/469 | Gen/Dis Loss: 1.0627/0.5874\n", "Epoch: 084/100 | Batch 400/469 | Gen/Dis Loss: 0.9114/0.6118\n", "Time elapsed: 16.90 min\n", "Epoch: 085/100 | Batch 000/469 | Gen/Dis Loss: 0.8453/0.6248\n", "Epoch: 085/100 | Batch 100/469 | Gen/Dis Loss: 1.0609/0.6182\n", "Epoch: 085/100 | Batch 200/469 | Gen/Dis Loss: 0.8899/0.6170\n", "Epoch: 085/100 | Batch 300/469 | Gen/Dis Loss: 0.9211/0.6023\n", "Epoch: 085/100 | Batch 400/469 | Gen/Dis Loss: 0.8161/0.6840\n", "Time elapsed: 17.21 min\n", "Epoch: 086/100 | Batch 000/469 | Gen/Dis Loss: 0.9190/0.5845\n", "Epoch: 086/100 | Batch 100/469 | Gen/Dis Loss: 1.0762/0.6450\n", "Epoch: 086/100 | Batch 200/469 | Gen/Dis Loss: 1.0070/0.6302\n", "Epoch: 086/100 | Batch 300/469 | Gen/Dis Loss: 0.8805/0.6313\n", "Epoch: 086/100 | Batch 400/469 | Gen/Dis Loss: 0.8568/0.6320\n", "Time elapsed: 17.47 min\n", "Epoch: 087/100 | Batch 000/469 | Gen/Dis Loss: 0.9597/0.6527\n", "Epoch: 087/100 | Batch 100/469 | Gen/Dis Loss: 0.8664/0.6339\n", "Epoch: 087/100 | Batch 200/469 | Gen/Dis Loss: 1.0466/0.6181\n", "Epoch: 087/100 | Batch 300/469 | Gen/Dis Loss: 0.8645/0.6272\n", "Epoch: 087/100 | Batch 400/469 | Gen/Dis Loss: 0.8296/0.6125\n", "Time elapsed: 17.71 min\n", "Epoch: 088/100 | Batch 000/469 | Gen/Dis Loss: 0.8497/0.6134\n", "Epoch: 088/100 | Batch 100/469 | Gen/Dis Loss: 0.7984/0.6551\n", "Epoch: 088/100 | Batch 200/469 | Gen/Dis Loss: 0.7777/0.6737\n", "Epoch: 088/100 | Batch 300/469 | Gen/Dis Loss: 0.8157/0.6250\n", "Epoch: 088/100 | Batch 400/469 | Gen/Dis Loss: 0.7993/0.6446\n", "Time elapsed: 17.96 min\n", "Epoch: 089/100 | Batch 000/469 | Gen/Dis Loss: 0.8526/0.6219\n", "Epoch: 089/100 | Batch 100/469 | Gen/Dis Loss: 0.9565/0.6241\n", "Epoch: 089/100 | Batch 200/469 | Gen/Dis Loss: 1.0437/0.6488\n", "Epoch: 089/100 | Batch 300/469 | Gen/Dis Loss: 0.8082/0.6521\n", "Epoch: 089/100 | Batch 400/469 | Gen/Dis Loss: 0.9082/0.6187\n", "Time elapsed: 18.20 min\n", "Epoch: 090/100 | Batch 000/469 | Gen/Dis Loss: 0.8507/0.6127\n", "Epoch: 090/100 | Batch 100/469 | Gen/Dis Loss: 0.8370/0.6160\n", "Epoch: 090/100 | Batch 200/469 | Gen/Dis Loss: 0.8270/0.6310\n", "Epoch: 090/100 | Batch 300/469 | Gen/Dis Loss: 0.9313/0.6230\n", "Epoch: 090/100 | Batch 400/469 | Gen/Dis Loss: 0.9462/0.6391\n", "Time elapsed: 18.46 min\n", "Epoch: 091/100 | Batch 000/469 | Gen/Dis Loss: 0.9294/0.6189\n", "Epoch: 091/100 | Batch 100/469 | Gen/Dis Loss: 1.0533/0.6279\n", "Epoch: 091/100 | Batch 200/469 | Gen/Dis Loss: 0.9623/0.6491\n", "Epoch: 091/100 | Batch 300/469 | Gen/Dis Loss: 0.8521/0.6031\n", "Epoch: 091/100 | Batch 400/469 | Gen/Dis Loss: 0.8233/0.6487\n", "Time elapsed: 18.70 min\n", "Epoch: 092/100 | Batch 000/469 | Gen/Dis Loss: 0.9691/0.6357\n", "Epoch: 092/100 | Batch 100/469 | Gen/Dis Loss: 0.8876/0.6303\n", "Epoch: 092/100 | Batch 200/469 | Gen/Dis Loss: 0.9333/0.6201\n", "Epoch: 092/100 | Batch 300/469 | Gen/Dis Loss: 0.8813/0.5981\n", "Epoch: 092/100 | Batch 400/469 | Gen/Dis Loss: 0.9026/0.6128\n", "Time elapsed: 18.94 min\n", "Epoch: 093/100 | Batch 000/469 | Gen/Dis Loss: 0.8874/0.6373\n", "Epoch: 093/100 | Batch 100/469 | Gen/Dis Loss: 0.8537/0.6204\n", "Epoch: 093/100 | Batch 200/469 | Gen/Dis Loss: 0.7982/0.6342\n", "Epoch: 093/100 | Batch 300/469 | Gen/Dis Loss: 0.9005/0.6010\n", "Epoch: 093/100 | Batch 400/469 | Gen/Dis Loss: 1.0532/0.6091\n", "Time elapsed: 19.20 min\n", "Epoch: 094/100 | Batch 000/469 | Gen/Dis Loss: 0.9877/0.6426\n", "Epoch: 094/100 | Batch 100/469 | Gen/Dis Loss: 0.8308/0.6501\n", "Epoch: 094/100 | Batch 200/469 | Gen/Dis Loss: 0.9217/0.6269\n", "Epoch: 094/100 | Batch 300/469 | Gen/Dis Loss: 0.9183/0.6632\n", "Epoch: 094/100 | Batch 400/469 | Gen/Dis Loss: 0.8859/0.6128\n", "Time elapsed: 19.46 min\n", "Epoch: 095/100 | Batch 000/469 | Gen/Dis Loss: 0.9032/0.6331\n", "Epoch: 095/100 | Batch 100/469 | Gen/Dis Loss: 0.8298/0.6976\n", "Epoch: 095/100 | Batch 200/469 | Gen/Dis Loss: 1.0004/0.6347\n", "Epoch: 095/100 | Batch 300/469 | Gen/Dis Loss: 0.9161/0.6169\n", "Epoch: 095/100 | Batch 400/469 | Gen/Dis Loss: 0.7622/0.6884\n", "Time elapsed: 19.71 min\n", "Epoch: 096/100 | Batch 000/469 | Gen/Dis Loss: 0.8816/0.5997\n", "Epoch: 096/100 | Batch 100/469 | Gen/Dis Loss: 0.9499/0.5969\n", "Epoch: 096/100 | Batch 200/469 | Gen/Dis Loss: 0.8974/0.6214\n", "Epoch: 096/100 | Batch 300/469 | Gen/Dis Loss: 0.8853/0.6259\n", "Epoch: 096/100 | Batch 400/469 | Gen/Dis Loss: 0.8107/0.6027\n", "Time elapsed: 19.95 min\n", "Epoch: 097/100 | Batch 000/469 | Gen/Dis Loss: 0.9242/0.6189\n", "Epoch: 097/100 | Batch 100/469 | Gen/Dis Loss: 0.8917/0.6491\n", "Epoch: 097/100 | Batch 200/469 | Gen/Dis Loss: 0.8729/0.6375\n", "Epoch: 097/100 | Batch 300/469 | Gen/Dis Loss: 0.8848/0.5950\n", "Epoch: 097/100 | Batch 400/469 | Gen/Dis Loss: 0.8502/0.6296\n", "Time elapsed: 20.21 min\n", "Epoch: 098/100 | Batch 000/469 | Gen/Dis Loss: 0.9020/0.6453\n", "Epoch: 098/100 | Batch 100/469 | Gen/Dis Loss: 1.1077/0.5882\n", "Epoch: 098/100 | Batch 200/469 | Gen/Dis Loss: 0.9468/0.6364\n", "Epoch: 098/100 | Batch 300/469 | Gen/Dis Loss: 0.8636/0.6313\n", "Epoch: 098/100 | Batch 400/469 | Gen/Dis Loss: 0.9089/0.6911\n", "Time elapsed: 20.45 min\n", "Epoch: 099/100 | Batch 000/469 | Gen/Dis Loss: 0.9101/0.6386\n", "Epoch: 099/100 | Batch 100/469 | Gen/Dis Loss: 0.8036/0.6396\n", "Epoch: 099/100 | Batch 200/469 | Gen/Dis Loss: 0.9393/0.6060\n", "Epoch: 099/100 | Batch 300/469 | Gen/Dis Loss: 0.8776/0.6242\n", "Epoch: 099/100 | Batch 400/469 | Gen/Dis Loss: 0.8244/0.6278\n", "Time elapsed: 20.68 min\n", "Epoch: 100/100 | Batch 000/469 | Gen/Dis Loss: 0.8623/0.6496\n", "Epoch: 100/100 | Batch 100/469 | Gen/Dis Loss: 0.9965/0.5964\n", "Epoch: 100/100 | Batch 200/469 | Gen/Dis Loss: 0.8666/0.6306\n", "Epoch: 100/100 | Batch 300/469 | Gen/Dis Loss: 1.1555/0.6634\n", "Epoch: 100/100 | Batch 400/469 | Gen/Dis Loss: 0.9071/0.6545\n", "Time elapsed: 20.94 min\n", "Total Training Time: 20.94 min\n" ] } ], "source": [ "start_time = time.time() \n", "\n", "discr_costs = []\n", "gener_costs = []\n", "for epoch in range(NUM_EPOCHS):\n", " model = model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", "\n", " \n", " \n", " features = (features - 0.5)*2.\n", " features = features.view(-1, IMG_SIZE).to(device) \n", " targets = targets.to(device)\n", "\n", " valid = torch.ones(targets.size(0)).float().to(device)\n", " fake = torch.zeros(targets.size(0)).float().to(device)\n", " \n", "\n", " ### FORWARD AND BACK PROP\n", " \n", " \n", " # --------------------------\n", " # Train Generator\n", " # --------------------------\n", " \n", " # Make new images\n", " z = torch.zeros((targets.size(0), LATENT_DIM)).uniform_(-1.0, 1.0).to(device)\n", " generated_features = model.generator_forward(z)\n", " \n", " # Loss for fooling the discriminator\n", " discr_pred = model.discriminator_forward(generated_features)\n", " \n", " gener_loss = F.binary_cross_entropy(discr_pred, valid)\n", " \n", " optim_gener.zero_grad()\n", " gener_loss.backward()\n", " optim_gener.step()\n", " \n", " # --------------------------\n", " # Train Discriminator\n", " # -------------------------- \n", " \n", " discr_pred_real = model.discriminator_forward(features.view(-1, IMG_SIZE))\n", " real_loss = F.binary_cross_entropy(discr_pred_real, valid)\n", " \n", " discr_pred_fake = model.discriminator_forward(generated_features.detach())\n", " fake_loss = F.binary_cross_entropy(discr_pred_fake, fake)\n", " \n", " discr_loss = 0.5*(real_loss + fake_loss)\n", "\n", " optim_discr.zero_grad()\n", " discr_loss.backward()\n", " optim_discr.step() \n", " \n", " discr_costs.append(discr_loss)\n", " gener_costs.append(gener_loss)\n", " \n", " \n", " ### LOGGING\n", " if not batch_idx % 100:\n", " print ('Epoch: %03d/%03d | Batch %03d/%03d | Gen/Dis Loss: %.4f/%.4f' \n", " %(epoch+1, NUM_EPOCHS, batch_idx, \n", " len(train_loader), gener_loss, discr_loss))\n", "\n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ax1 = plt.subplot(1, 1, 1)\n", "ax1.plot(range(len(gener_costs)), gener_costs, label='Generator loss')\n", "ax1.plot(range(len(discr_costs)), discr_costs, label='Discriminator loss')\n", "ax1.set_xlabel('Iterations')\n", "ax1.set_ylabel('Loss')\n", "ax1.legend()\n", "\n", "###################\n", "# Set scond x-axis\n", "ax2 = ax1.twiny()\n", "newlabel = list(range(NUM_EPOCHS+1))\n", "iter_per_epoch = len(train_loader)\n", "newpos = [e*iter_per_epoch for e in newlabel]\n", "\n", "ax2.set_xticklabels(newlabel[::10])\n", "ax2.set_xticks(newpos[::10])\n", "\n", "ax2.xaxis.set_ticks_position('bottom')\n", "ax2.xaxis.set_label_position('bottom')\n", "ax2.spines['bottom'].set_position(('outward', 45))\n", "ax2.set_xlabel('Epochs')\n", "ax2.set_xlim(ax1.get_xlim())\n", "###################\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "##########################\n", "### VISUALIZATION\n", "##########################\n", "\n", "\n", "model.eval()\n", "# Make new images\n", "z = torch.zeros((5, LATENT_DIM)).uniform_(-1.0, 1.0).to(device)\n", "generated_features = model.generator_forward(z)\n", "imgs = generated_features.view(-1, 28, 28)\n", "\n", "fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(20, 2.5))\n", "\n", "\n", "for i, ax in enumerate(axes):\n", " axes[i].imshow(imgs[i].to(torch.device('cpu')).detach(), cmap='binary')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }