{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.6.1\n", "\n", "torch 1.2.0\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Runs on CPU or GPU (if available)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Deep Convolutional Wasserstein GAN" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation of a deep convolutional Wasserstein GAN based on the paper \n", "\n", "- Arjovsky, M., Chintala, S., & Bottou, L. (2017). Wasserstein GAN. arXiv preprint arXiv:1701.07875. (https://arxiv.org/abs/1701.07875)\n", "\n", "The main differences to a conventional deep convolutional GAN are annotated in the code. In short, the main differences are \n", "\n", "1. Not using a sigmoid activation function and just using a linear output layer for the critic (i.e., discriminator).\n", "2. Using label -1 instead of 1 for the real images; using label 1 instead of 0 for fake images.\n", "3. Using Wasserstein distance (loss) for training both the critic and the generator.\n", "4. After each weight update, clip the weights to be in range [-0.1, 0.1].\n", "5. Train the critic 5 times for each generator training update.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import time\n", "import numpy as np\n", "import torch\n", "import torch.nn.functional as F\n", "from torchvision import datasets\n", "from torchvision import transforms\n", "import torch.nn as nn\n", "from torch.utils.data import DataLoader\n", "\n", "\n", "if torch.cuda.is_available():\n", " torch.backends.cudnn.deterministic = True" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Settings and Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Image batch dimensions: torch.Size([128, 1, 28, 28])\n", "Image label dimensions: torch.Size([128])\n" ] } ], "source": [ "##########################\n", "### SETTINGS\n", "##########################\n", "\n", "# Device\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Hyperparameters\n", "random_seed = 0\n", "generator_learning_rate = 0.00005\n", "discriminator_learning_rate = 0.00005\n", "NUM_EPOCHS = 100\n", "BATCH_SIZE = 128\n", "LATENT_DIM = 100\n", "IMG_SHAPE = (1, 28, 28)\n", "IMG_SIZE = 1\n", "for x in IMG_SHAPE:\n", " IMG_SIZE *= x\n", "\n", "## WGAN-specific settings\n", "num_iter_critic = 5\n", "weight_clip_value = 0.01\n", "\n", "\n", "##########################\n", "### MNIST DATASET\n", "##########################\n", "\n", "# Note transforms.ToTensor() scales input images\n", "# to 0-1 range\n", "train_dataset = datasets.MNIST(root='data', \n", " train=True, \n", " transform=transforms.ToTensor(),\n", " download=True)\n", "\n", "test_dataset = datasets.MNIST(root='data', \n", " train=False, \n", " transform=transforms.ToTensor())\n", "\n", "\n", "train_loader = DataLoader(dataset=train_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=True)\n", "\n", "test_loader = DataLoader(dataset=test_dataset, \n", " batch_size=BATCH_SIZE,\n", " num_workers=4,\n", " shuffle=False)\n", "\n", "# Checking the dataset\n", "for images, labels in train_loader: \n", " print('Image batch dimensions:', images.shape)\n", " print('Image label dimensions:', labels.shape)\n", " break" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "##########################\n", "### MODEL\n", "##########################\n", "\n", "class Flatten(nn.Module):\n", " def forward(self, input):\n", " return input.view(input.size(0), -1)\n", " \n", "class Reshape1(nn.Module):\n", " def forward(self, input):\n", " return input.view(input.size(0), 64, 7, 7)\n", "\n", "\n", "def wasserstein_loss(y_true, y_pred):\n", " return torch.mean(y_true * y_pred)\n", " \n", " \n", "class GAN(torch.nn.Module):\n", "\n", " def __init__(self):\n", " super(GAN, self).__init__()\n", " \n", " \n", " self.generator = nn.Sequential(\n", " \n", " nn.Linear(LATENT_DIM, 3136, bias=False),\n", " nn.BatchNorm1d(num_features=3136),\n", " nn.LeakyReLU(inplace=True, negative_slope=0.0001),\n", " Reshape1(),\n", " \n", " nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(3, 3), stride=(2, 2), padding=1, bias=False),\n", " nn.BatchNorm2d(num_features=32),\n", " nn.LeakyReLU(inplace=True, negative_slope=0.0001),\n", " #nn.Dropout2d(p=0.2),\n", " \n", " nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 3), stride=(2, 2), padding=1, bias=False),\n", " nn.BatchNorm2d(num_features=16),\n", " nn.LeakyReLU(inplace=True, negative_slope=0.0001),\n", " #nn.Dropout2d(p=0.2),\n", " \n", " nn.ConvTranspose2d(in_channels=16, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=0, bias=False),\n", " nn.BatchNorm2d(num_features=8),\n", " nn.LeakyReLU(inplace=True, negative_slope=0.0001),\n", " #nn.Dropout2d(p=0.2),\n", " \n", " nn.ConvTranspose2d(in_channels=8, out_channels=1, kernel_size=(2, 2), stride=(1, 1), padding=0, bias=False),\n", " nn.Tanh()\n", " )\n", " \n", " self.discriminator = nn.Sequential(\n", " nn.Conv2d(in_channels=1, out_channels=8, padding=1, kernel_size=(3, 3), stride=(2, 2), bias=False),\n", " nn.BatchNorm2d(num_features=8),\n", " nn.LeakyReLU(inplace=True, negative_slope=0.0001), \n", " #nn.Dropout2d(p=0.2),\n", " \n", " nn.Conv2d(in_channels=8, out_channels=16, padding=1, kernel_size=(3, 3), stride=(2, 2), bias=False),\n", " nn.BatchNorm2d(num_features=16),\n", " nn.LeakyReLU(inplace=True, negative_slope=0.0001), \n", " #nn.Dropout2d(p=0.2),\n", " \n", " nn.Conv2d(in_channels=16, out_channels=32, padding=1, kernel_size=(3, 3), stride=(2, 2), bias=False),\n", " nn.BatchNorm2d(num_features=32),\n", " nn.LeakyReLU(inplace=True, negative_slope=0.0001), \n", " #nn.Dropout2d(p=0.2),\n", " \n", " Flatten(),\n", "\n", " nn.Linear(512, 1),\n", " #nn.Sigmoid()\n", " )\n", "\n", " \n", " def generator_forward(self, z):\n", " img = self.generator(z)\n", " return img\n", " \n", " def discriminator_forward(self, img):\n", " pred = model.discriminator(img)\n", " return pred.view(-1)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GAN(\n", " (generator): Sequential(\n", " (0): Linear(in_features=100, out_features=3136, bias=False)\n", " (1): BatchNorm1d(3136, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): LeakyReLU(negative_slope=0.0001, inplace=True)\n", " (3): Reshape1()\n", " (4): ConvTranspose2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (6): LeakyReLU(negative_slope=0.0001, inplace=True)\n", " (7): ConvTranspose2d(32, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (8): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (9): LeakyReLU(negative_slope=0.0001, inplace=True)\n", " (10): ConvTranspose2d(16, 8, kernel_size=(3, 3), stride=(1, 1), bias=False)\n", " (11): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (12): LeakyReLU(negative_slope=0.0001, inplace=True)\n", " (13): ConvTranspose2d(8, 1, kernel_size=(2, 2), stride=(1, 1), bias=False)\n", " (14): Tanh()\n", " )\n", " (discriminator): Sequential(\n", " (0): Conv2d(1, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (2): LeakyReLU(negative_slope=0.0001, inplace=True)\n", " (3): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (4): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (5): LeakyReLU(negative_slope=0.0001, inplace=True)\n", " (6): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", " (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (8): LeakyReLU(negative_slope=0.0001, inplace=True)\n", " (9): Flatten()\n", " (10): Linear(in_features=512, out_features=1, bias=True)\n", " )\n", ")\n" ] } ], "source": [ "torch.manual_seed(random_seed)\n", "\n", "#del model\n", "model = GAN()\n", "model = model.to(device)\n", "\n", "print(model)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'\\noutputs = []\\ndef hook(module, input, output):\\n outputs.append(output)\\n\\nfor i, layer in enumerate(model.discriminator):\\n if isinstance(layer, torch.nn.modules.conv.Conv2d):\\n model.discriminator[i].register_forward_hook(hook)\\n\\n#for i, layer in enumerate(model.generator):\\n# if isinstance(layer, torch.nn.modules.ConvTranspose2d):\\n# model.generator[i].register_forward_hook(hook)\\n'" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "### ## FOR DEBUGGING\n", "\n", "\"\"\"\n", "outputs = []\n", "def hook(module, input, output):\n", " outputs.append(output)\n", "\n", "for i, layer in enumerate(model.discriminator):\n", " if isinstance(layer, torch.nn.modules.conv.Conv2d):\n", " model.discriminator[i].register_forward_hook(hook)\n", "\n", "#for i, layer in enumerate(model.generator):\n", "# if isinstance(layer, torch.nn.modules.ConvTranspose2d):\n", "# model.generator[i].register_forward_hook(hook)\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "optim_gener = torch.optim.RMSprop(model.generator.parameters(), lr=generator_learning_rate)\n", "optim_discr = torch.optim.RMSprop(model.discriminator.parameters(), lr=discriminator_learning_rate)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 001/100 | Batch 000/469 | Gen/Dis Loss: 0.3318/-0.0001\n", "Epoch: 001/100 | Batch 100/469 | Gen/Dis Loss: 0.0037/-0.0026\n", "Epoch: 001/100 | Batch 200/469 | Gen/Dis Loss: 0.0121/-0.0126\n", "Epoch: 001/100 | Batch 300/469 | Gen/Dis Loss: 0.0117/-0.0123\n", "Epoch: 001/100 | Batch 400/469 | Gen/Dis Loss: 0.0110/-0.0124\n", "Time elapsed: 0.31 min\n", "Epoch: 002/100 | Batch 000/469 | Gen/Dis Loss: 0.0123/-0.0140\n", "Epoch: 002/100 | Batch 100/469 | Gen/Dis Loss: 0.0124/-0.0136\n", "Epoch: 002/100 | Batch 200/469 | Gen/Dis Loss: 0.0108/-0.0126\n", "Epoch: 002/100 | Batch 300/469 | Gen/Dis Loss: 0.0089/-0.0104\n", "Epoch: 002/100 | Batch 400/469 | Gen/Dis Loss: 0.0093/-0.0108\n", "Time elapsed: 0.64 min\n", "Epoch: 003/100 | Batch 000/469 | Gen/Dis Loss: 0.0095/-0.0107\n", "Epoch: 003/100 | Batch 100/469 | Gen/Dis Loss: 0.0094/-0.0097\n", "Epoch: 003/100 | Batch 200/469 | Gen/Dis Loss: 0.0089/-0.0099\n", "Epoch: 003/100 | Batch 300/469 | Gen/Dis Loss: 0.0084/-0.0087\n", "Epoch: 003/100 | Batch 400/469 | Gen/Dis Loss: 0.0083/-0.0081\n", "Time elapsed: 1.12 min\n", "Epoch: 004/100 | Batch 000/469 | Gen/Dis Loss: 0.0071/-0.0080\n", "Epoch: 004/100 | Batch 100/469 | Gen/Dis Loss: 0.0077/-0.0076\n", "Epoch: 004/100 | Batch 200/469 | Gen/Dis Loss: 0.0090/-0.0070\n", "Epoch: 004/100 | Batch 300/469 | Gen/Dis Loss: 0.0079/-0.0082\n", "Epoch: 004/100 | Batch 400/469 | Gen/Dis Loss: 0.0101/-0.0072\n", "Time elapsed: 1.65 min\n", "Epoch: 005/100 | Batch 000/469 | Gen/Dis Loss: 0.0098/-0.0080\n", "Epoch: 005/100 | Batch 100/469 | Gen/Dis Loss: 0.0089/-0.0078\n", "Epoch: 005/100 | Batch 200/469 | Gen/Dis Loss: 0.0087/-0.0075\n", "Epoch: 005/100 | Batch 300/469 | Gen/Dis Loss: 0.0079/-0.0073\n", "Epoch: 005/100 | Batch 400/469 | Gen/Dis Loss: 0.0058/-0.0078\n", "Time elapsed: 2.15 min\n", "Epoch: 006/100 | Batch 000/469 | Gen/Dis Loss: 0.0048/-0.0071\n", "Epoch: 006/100 | Batch 100/469 | Gen/Dis Loss: 0.0050/-0.0070\n", "Epoch: 006/100 | Batch 200/469 | Gen/Dis Loss: 0.0046/-0.0069\n", "Epoch: 006/100 | Batch 300/469 | Gen/Dis Loss: 0.0060/-0.0069\n", "Epoch: 006/100 | Batch 400/469 | Gen/Dis Loss: 0.0067/-0.0067\n", "Time elapsed: 2.67 min\n", "Epoch: 007/100 | Batch 000/469 | Gen/Dis Loss: 0.0066/-0.0075\n", "Epoch: 007/100 | Batch 100/469 | Gen/Dis Loss: 0.0074/-0.0067\n", "Epoch: 007/100 | Batch 200/469 | Gen/Dis Loss: 0.0053/-0.0028\n", "Epoch: 007/100 | Batch 300/469 | Gen/Dis Loss: 0.0029/-0.0043\n", "Epoch: 007/100 | Batch 400/469 | Gen/Dis Loss: 0.0018/-0.0043\n", "Time elapsed: 3.20 min\n", "Epoch: 008/100 | Batch 000/469 | Gen/Dis Loss: 0.0025/-0.0040\n", "Epoch: 008/100 | Batch 100/469 | Gen/Dis Loss: 0.0015/-0.0034\n", "Epoch: 008/100 | Batch 200/469 | Gen/Dis Loss: -0.0001/-0.0023\n", "Epoch: 008/100 | Batch 300/469 | Gen/Dis Loss: 0.0014/-0.0017\n", "Epoch: 008/100 | Batch 400/469 | Gen/Dis Loss: -0.0003/-0.0022\n", "Time elapsed: 3.74 min\n", "Epoch: 009/100 | Batch 000/469 | Gen/Dis Loss: 0.0006/-0.0021\n", "Epoch: 009/100 | Batch 100/469 | Gen/Dis Loss: 0.0017/-0.0022\n", "Epoch: 009/100 | Batch 200/469 | Gen/Dis Loss: 0.0014/-0.0016\n", "Epoch: 009/100 | Batch 300/469 | Gen/Dis Loss: -0.0005/-0.0015\n", "Epoch: 009/100 | Batch 400/469 | Gen/Dis Loss: -0.0032/-0.0012\n", "Time elapsed: 4.25 min\n", "Epoch: 010/100 | Batch 000/469 | Gen/Dis Loss: -0.0036/-0.0015\n", "Epoch: 010/100 | Batch 100/469 | Gen/Dis Loss: -0.0000/-0.0015\n", "Epoch: 010/100 | Batch 200/469 | Gen/Dis Loss: -0.0024/-0.0009\n", "Epoch: 010/100 | Batch 300/469 | Gen/Dis Loss: -0.0010/-0.0012\n", "Epoch: 010/100 | Batch 400/469 | Gen/Dis Loss: 0.0012/-0.0015\n", "Time elapsed: 4.76 min\n", "Epoch: 011/100 | Batch 000/469 | Gen/Dis Loss: 0.0013/-0.0010\n", "Epoch: 011/100 | Batch 100/469 | Gen/Dis Loss: 0.0003/-0.0011\n", "Epoch: 011/100 | Batch 200/469 | Gen/Dis Loss: -0.0005/-0.0013\n", "Epoch: 011/100 | Batch 300/469 | Gen/Dis Loss: 0.0000/-0.0014\n", "Epoch: 011/100 | Batch 400/469 | Gen/Dis Loss: -0.0002/-0.0014\n", "Time elapsed: 5.26 min\n", "Epoch: 012/100 | Batch 000/469 | Gen/Dis Loss: -0.0000/-0.0012\n", "Epoch: 012/100 | Batch 100/469 | Gen/Dis Loss: 0.0009/-0.0010\n", "Epoch: 012/100 | Batch 200/469 | Gen/Dis Loss: -0.0001/-0.0011\n", "Epoch: 012/100 | Batch 300/469 | Gen/Dis Loss: -0.0016/-0.0010\n", "Epoch: 012/100 | Batch 400/469 | Gen/Dis Loss: -0.0021/-0.0010\n", "Time elapsed: 5.79 min\n", "Epoch: 013/100 | Batch 000/469 | Gen/Dis Loss: -0.0032/-0.0009\n", "Epoch: 013/100 | Batch 100/469 | Gen/Dis Loss: -0.0023/-0.0009\n", "Epoch: 013/100 | Batch 200/469 | Gen/Dis Loss: -0.0038/-0.0013\n", "Epoch: 013/100 | Batch 300/469 | Gen/Dis Loss: 0.0004/-0.0014\n", "Epoch: 013/100 | Batch 400/469 | Gen/Dis Loss: -0.0002/-0.0012\n", "Time elapsed: 6.30 min\n", "Epoch: 014/100 | Batch 000/469 | Gen/Dis Loss: -0.0007/-0.0011\n", "Epoch: 014/100 | Batch 100/469 | Gen/Dis Loss: -0.0009/-0.0012\n", "Epoch: 014/100 | Batch 200/469 | Gen/Dis Loss: -0.0007/-0.0010\n", "Epoch: 014/100 | Batch 300/469 | Gen/Dis Loss: -0.0002/-0.0009\n", "Epoch: 014/100 | Batch 400/469 | Gen/Dis Loss: -0.0009/-0.0008\n", "Time elapsed: 6.82 min\n", "Epoch: 015/100 | Batch 000/469 | Gen/Dis Loss: -0.0006/-0.0008\n", "Epoch: 015/100 | Batch 100/469 | Gen/Dis Loss: -0.0014/-0.0009\n", "Epoch: 015/100 | Batch 200/469 | Gen/Dis Loss: -0.0029/-0.0008\n", "Epoch: 015/100 | Batch 300/469 | Gen/Dis Loss: -0.0030/-0.0008\n", "Epoch: 015/100 | Batch 400/469 | Gen/Dis Loss: -0.0022/-0.0009\n", "Time elapsed: 7.32 min\n", "Epoch: 016/100 | Batch 000/469 | Gen/Dis Loss: -0.0015/-0.0010\n", "Epoch: 016/100 | Batch 100/469 | Gen/Dis Loss: -0.0013/-0.0008\n", "Epoch: 016/100 | Batch 200/469 | Gen/Dis Loss: -0.0011/-0.0008\n", "Epoch: 016/100 | Batch 300/469 | Gen/Dis Loss: -0.0008/-0.0007\n", "Epoch: 016/100 | Batch 400/469 | Gen/Dis Loss: -0.0023/-0.0008\n", "Time elapsed: 7.84 min\n", "Epoch: 017/100 | Batch 000/469 | Gen/Dis Loss: -0.0017/-0.0009\n", "Epoch: 017/100 | Batch 100/469 | Gen/Dis Loss: -0.0017/-0.0008\n", "Epoch: 017/100 | Batch 200/469 | Gen/Dis Loss: -0.0038/-0.0009\n", "Epoch: 017/100 | Batch 300/469 | Gen/Dis Loss: -0.0036/-0.0009\n", "Epoch: 017/100 | Batch 400/469 | Gen/Dis Loss: -0.0029/-0.0007\n", "Time elapsed: 8.39 min\n", "Epoch: 018/100 | Batch 000/469 | Gen/Dis Loss: -0.0024/-0.0009\n", "Epoch: 018/100 | Batch 100/469 | Gen/Dis Loss: -0.0029/-0.0008\n", "Epoch: 018/100 | Batch 200/469 | Gen/Dis Loss: -0.0029/-0.0007\n", "Epoch: 018/100 | Batch 300/469 | Gen/Dis Loss: -0.0014/-0.0007\n", "Epoch: 018/100 | Batch 400/469 | Gen/Dis Loss: -0.0017/-0.0008\n", "Time elapsed: 8.91 min\n", "Epoch: 019/100 | Batch 000/469 | Gen/Dis Loss: -0.0038/-0.0009\n", "Epoch: 019/100 | Batch 100/469 | Gen/Dis Loss: -0.0054/-0.0009\n", "Epoch: 019/100 | Batch 200/469 | Gen/Dis Loss: -0.0035/-0.0010\n", "Epoch: 019/100 | Batch 300/469 | Gen/Dis Loss: -0.0027/-0.0008\n", "Epoch: 019/100 | Batch 400/469 | Gen/Dis Loss: -0.0005/-0.0008\n", "Time elapsed: 9.44 min\n", "Epoch: 020/100 | Batch 000/469 | Gen/Dis Loss: -0.0005/-0.0006\n", "Epoch: 020/100 | Batch 100/469 | Gen/Dis Loss: -0.0010/-0.0005\n", "Epoch: 020/100 | Batch 200/469 | Gen/Dis Loss: -0.0012/-0.0006\n", "Epoch: 020/100 | Batch 300/469 | Gen/Dis Loss: -0.0038/-0.0007\n", "Epoch: 020/100 | Batch 400/469 | Gen/Dis Loss: -0.0041/-0.0008\n", "Time elapsed: 9.97 min\n", "Epoch: 021/100 | Batch 000/469 | Gen/Dis Loss: -0.0043/-0.0008\n", "Epoch: 021/100 | Batch 100/469 | Gen/Dis Loss: -0.0029/-0.0008\n", "Epoch: 021/100 | Batch 200/469 | Gen/Dis Loss: -0.0021/-0.0007\n", "Epoch: 021/100 | Batch 300/469 | Gen/Dis Loss: -0.0023/-0.0007\n", "Epoch: 021/100 | Batch 400/469 | Gen/Dis Loss: -0.0018/-0.0006\n", "Time elapsed: 10.47 min\n", "Epoch: 022/100 | Batch 000/469 | Gen/Dis Loss: -0.0014/-0.0006\n", "Epoch: 022/100 | Batch 100/469 | Gen/Dis Loss: -0.0033/-0.0007\n", "Epoch: 022/100 | Batch 200/469 | Gen/Dis Loss: -0.0007/-0.0005\n", "Epoch: 022/100 | Batch 300/469 | Gen/Dis Loss: 0.0003/-0.0007\n", "Epoch: 022/100 | Batch 400/469 | Gen/Dis Loss: -0.0019/-0.0006\n", "Time elapsed: 10.99 min\n", "Epoch: 023/100 | Batch 000/469 | Gen/Dis Loss: -0.0046/-0.0006\n", "Epoch: 023/100 | Batch 100/469 | Gen/Dis Loss: -0.0029/-0.0006\n", "Epoch: 023/100 | Batch 200/469 | Gen/Dis Loss: -0.0027/-0.0005\n", "Epoch: 023/100 | Batch 300/469 | Gen/Dis Loss: -0.0024/-0.0004\n", "Epoch: 023/100 | Batch 400/469 | Gen/Dis Loss: -0.0037/-0.0005\n", "Time elapsed: 11.48 min\n", "Epoch: 024/100 | Batch 000/469 | Gen/Dis Loss: -0.0032/-0.0005\n", "Epoch: 024/100 | Batch 100/469 | Gen/Dis Loss: -0.0027/-0.0006\n", "Epoch: 024/100 | Batch 200/469 | Gen/Dis Loss: -0.0013/-0.0006\n", "Epoch: 024/100 | Batch 300/469 | Gen/Dis Loss: -0.0010/-0.0006\n", "Epoch: 024/100 | Batch 400/469 | Gen/Dis Loss: -0.0025/-0.0007\n", "Time elapsed: 11.83 min\n", "Epoch: 025/100 | Batch 000/469 | Gen/Dis Loss: -0.0036/-0.0006\n", "Epoch: 025/100 | Batch 100/469 | Gen/Dis Loss: -0.0038/-0.0005\n", "Epoch: 025/100 | Batch 200/469 | Gen/Dis Loss: -0.0030/-0.0006\n", "Epoch: 025/100 | Batch 300/469 | Gen/Dis Loss: -0.0029/-0.0008\n", "Epoch: 025/100 | Batch 400/469 | Gen/Dis Loss: -0.0022/-0.0005\n", "Time elapsed: 12.14 min\n", "Epoch: 026/100 | Batch 000/469 | Gen/Dis Loss: -0.0010/-0.0005\n", "Epoch: 026/100 | Batch 100/469 | Gen/Dis Loss: -0.0030/-0.0005\n", "Epoch: 026/100 | Batch 200/469 | Gen/Dis Loss: -0.0002/-0.0005\n", "Epoch: 026/100 | Batch 300/469 | Gen/Dis Loss: -0.0004/-0.0004\n", "Epoch: 026/100 | Batch 400/469 | Gen/Dis Loss: 0.0006/-0.0005\n", "Time elapsed: 12.45 min\n", "Epoch: 027/100 | Batch 000/469 | Gen/Dis Loss: -0.0004/-0.0004\n", "Epoch: 027/100 | Batch 100/469 | Gen/Dis Loss: -0.0005/-0.0005\n", "Epoch: 027/100 | Batch 200/469 | Gen/Dis Loss: -0.0029/-0.0006\n", "Epoch: 027/100 | Batch 300/469 | Gen/Dis Loss: -0.0031/-0.0005\n", "Epoch: 027/100 | Batch 400/469 | Gen/Dis Loss: -0.0033/-0.0006\n", "Time elapsed: 12.76 min\n", "Epoch: 028/100 | Batch 000/469 | Gen/Dis Loss: 0.0026/-0.0005\n", "Epoch: 028/100 | Batch 100/469 | Gen/Dis Loss: 0.0000/-0.0006\n", "Epoch: 028/100 | Batch 200/469 | Gen/Dis Loss: 0.0007/-0.0002\n", "Epoch: 028/100 | Batch 300/469 | Gen/Dis Loss: -0.0001/-0.0004\n", "Epoch: 028/100 | Batch 400/469 | Gen/Dis Loss: 0.0024/-0.0005\n", "Time elapsed: 13.06 min\n", "Epoch: 029/100 | Batch 000/469 | Gen/Dis Loss: 0.0015/-0.0005\n", "Epoch: 029/100 | Batch 100/469 | Gen/Dis Loss: 0.0006/-0.0004\n", "Epoch: 029/100 | Batch 200/469 | Gen/Dis Loss: 0.0006/-0.0003\n", "Epoch: 029/100 | Batch 300/469 | Gen/Dis Loss: -0.0056/-0.0002\n", "Epoch: 029/100 | Batch 400/469 | Gen/Dis Loss: 0.0086/-0.0007\n", "Time elapsed: 13.36 min\n", "Epoch: 030/100 | Batch 000/469 | Gen/Dis Loss: 0.0015/-0.0006\n", "Epoch: 030/100 | Batch 100/469 | Gen/Dis Loss: -0.0056/-0.0008\n", "Epoch: 030/100 | Batch 200/469 | Gen/Dis Loss: 0.0057/-0.0007\n", "Epoch: 030/100 | Batch 300/469 | Gen/Dis Loss: -0.0112/-0.0001\n", "Epoch: 030/100 | Batch 400/469 | Gen/Dis Loss: 0.0086/-0.0005\n", "Time elapsed: 13.67 min\n", "Epoch: 031/100 | Batch 000/469 | Gen/Dis Loss: 0.0026/-0.0005\n", "Epoch: 031/100 | Batch 100/469 | Gen/Dis Loss: 0.0044/-0.0002\n", "Epoch: 031/100 | Batch 200/469 | Gen/Dis Loss: 0.0021/-0.0003\n", "Epoch: 031/100 | Batch 300/469 | Gen/Dis Loss: 0.0005/-0.0004\n", "Epoch: 031/100 | Batch 400/469 | Gen/Dis Loss: 0.0001/-0.0005\n", "Time elapsed: 13.98 min\n", "Epoch: 032/100 | Batch 000/469 | Gen/Dis Loss: 0.0011/-0.0005\n", "Epoch: 032/100 | Batch 100/469 | Gen/Dis Loss: 0.0046/-0.0008\n", "Epoch: 032/100 | Batch 200/469 | Gen/Dis Loss: 0.0025/-0.0007\n", "Epoch: 032/100 | Batch 300/469 | Gen/Dis Loss: 0.0029/-0.0005\n", "Epoch: 032/100 | Batch 400/469 | Gen/Dis Loss: 0.0069/-0.0007\n", "Time elapsed: 14.29 min\n", "Epoch: 033/100 | Batch 000/469 | Gen/Dis Loss: 0.0048/-0.0006\n", "Epoch: 033/100 | Batch 100/469 | Gen/Dis Loss: 0.0011/-0.0005\n", "Epoch: 033/100 | Batch 200/469 | Gen/Dis Loss: 0.0008/-0.0003\n", "Epoch: 033/100 | Batch 300/469 | Gen/Dis Loss: 0.0039/-0.0006\n", "Epoch: 033/100 | Batch 400/469 | Gen/Dis Loss: 0.0039/-0.0004\n", "Time elapsed: 14.60 min\n", "Epoch: 034/100 | Batch 000/469 | Gen/Dis Loss: 0.0010/-0.0005\n", "Epoch: 034/100 | Batch 100/469 | Gen/Dis Loss: 0.0001/-0.0004\n", "Epoch: 034/100 | Batch 200/469 | Gen/Dis Loss: 0.0026/-0.0004\n", "Epoch: 034/100 | Batch 300/469 | Gen/Dis Loss: 0.0008/-0.0003\n", "Epoch: 034/100 | Batch 400/469 | Gen/Dis Loss: 0.0043/-0.0005\n", "Time elapsed: 14.90 min\n", "Epoch: 035/100 | Batch 000/469 | Gen/Dis Loss: 0.0033/-0.0004\n", "Epoch: 035/100 | Batch 100/469 | Gen/Dis Loss: 0.0017/-0.0002\n", "Epoch: 035/100 | Batch 200/469 | Gen/Dis Loss: 0.0012/-0.0004\n", "Epoch: 035/100 | Batch 300/469 | Gen/Dis Loss: 0.0013/-0.0004\n", "Epoch: 035/100 | Batch 400/469 | Gen/Dis Loss: 0.0018/-0.0003\n", "Time elapsed: 15.20 min\n", "Epoch: 036/100 | Batch 000/469 | Gen/Dis Loss: 0.0046/-0.0004\n", "Epoch: 036/100 | Batch 100/469 | Gen/Dis Loss: 0.0046/-0.0004\n", "Epoch: 036/100 | Batch 200/469 | Gen/Dis Loss: 0.0022/-0.0004\n", "Epoch: 036/100 | Batch 300/469 | Gen/Dis Loss: 0.0007/-0.0002\n", "Epoch: 036/100 | Batch 400/469 | Gen/Dis Loss: 0.0027/-0.0003\n", "Time elapsed: 15.51 min\n", "Epoch: 037/100 | Batch 000/469 | Gen/Dis Loss: 0.0006/-0.0004\n", "Epoch: 037/100 | Batch 100/469 | Gen/Dis Loss: 0.0016/-0.0004\n", "Epoch: 037/100 | Batch 200/469 | Gen/Dis Loss: -0.0014/-0.0003\n", "Epoch: 037/100 | Batch 300/469 | Gen/Dis Loss: 0.0015/-0.0004\n", "Epoch: 037/100 | Batch 400/469 | Gen/Dis Loss: 0.0015/-0.0002\n", "Time elapsed: 15.82 min\n", "Epoch: 038/100 | Batch 000/469 | Gen/Dis Loss: 0.0013/-0.0003\n", "Epoch: 038/100 | Batch 100/469 | Gen/Dis Loss: 0.0011/-0.0002\n", "Epoch: 038/100 | Batch 200/469 | Gen/Dis Loss: 0.0023/-0.0003\n", "Epoch: 038/100 | Batch 300/469 | Gen/Dis Loss: 0.0008/-0.0003\n", "Epoch: 038/100 | Batch 400/469 | Gen/Dis Loss: 0.0031/-0.0003\n", "Time elapsed: 16.28 min\n", "Epoch: 039/100 | Batch 000/469 | Gen/Dis Loss: 0.0041/-0.0002\n", "Epoch: 039/100 | Batch 100/469 | Gen/Dis Loss: 0.0047/-0.0002\n", "Epoch: 039/100 | Batch 200/469 | Gen/Dis Loss: 0.0040/-0.0002\n", "Epoch: 039/100 | Batch 300/469 | Gen/Dis Loss: 0.0051/-0.0003\n", "Epoch: 039/100 | Batch 400/469 | Gen/Dis Loss: 0.0094/-0.0003\n", "Time elapsed: 16.76 min\n", "Epoch: 040/100 | Batch 000/469 | Gen/Dis Loss: 0.0061/-0.0002\n", "Epoch: 040/100 | Batch 100/469 | Gen/Dis Loss: 0.0054/-0.0003\n", "Epoch: 040/100 | Batch 200/469 | Gen/Dis Loss: 0.0064/-0.0002\n", "Epoch: 040/100 | Batch 300/469 | Gen/Dis Loss: 0.0075/-0.0001\n", "Epoch: 040/100 | Batch 400/469 | Gen/Dis Loss: 0.0066/-0.0002\n", "Time elapsed: 17.29 min\n", "Epoch: 041/100 | Batch 000/469 | Gen/Dis Loss: 0.0054/-0.0002\n", "Epoch: 041/100 | Batch 100/469 | Gen/Dis Loss: 0.0021/-0.0002\n", "Epoch: 041/100 | Batch 200/469 | Gen/Dis Loss: 0.0018/-0.0002\n", "Epoch: 041/100 | Batch 300/469 | Gen/Dis Loss: -0.0017/-0.0001\n", "Epoch: 041/100 | Batch 400/469 | Gen/Dis Loss: 0.0028/-0.0002\n", "Time elapsed: 17.79 min\n", "Epoch: 042/100 | Batch 000/469 | Gen/Dis Loss: 0.0041/-0.0002\n", "Epoch: 042/100 | Batch 100/469 | Gen/Dis Loss: 0.0044/-0.0003\n", "Epoch: 042/100 | Batch 200/469 | Gen/Dis Loss: -0.0007/-0.0002\n", "Epoch: 042/100 | Batch 300/469 | Gen/Dis Loss: -0.0091/0.0000\n", "Epoch: 042/100 | Batch 400/469 | Gen/Dis Loss: -0.0043/0.0001\n", "Time elapsed: 18.34 min\n", "Epoch: 043/100 | Batch 000/469 | Gen/Dis Loss: 0.0023/-0.0008\n", "Epoch: 043/100 | Batch 100/469 | Gen/Dis Loss: -0.0025/-0.0006\n", "Epoch: 043/100 | Batch 200/469 | Gen/Dis Loss: -0.0066/-0.0006\n", "Epoch: 043/100 | Batch 300/469 | Gen/Dis Loss: -0.0074/0.0004\n", "Epoch: 043/100 | Batch 400/469 | Gen/Dis Loss: -0.0021/-0.0018\n", "Time elapsed: 18.86 min\n", "Epoch: 044/100 | Batch 000/469 | Gen/Dis Loss: -0.0033/0.0001\n", "Epoch: 044/100 | Batch 100/469 | Gen/Dis Loss: 0.0019/0.0003\n", "Epoch: 044/100 | Batch 200/469 | Gen/Dis Loss: 0.0014/-0.0021\n", "Epoch: 044/100 | Batch 300/469 | Gen/Dis Loss: -0.0003/-0.0016\n", "Epoch: 044/100 | Batch 400/469 | Gen/Dis Loss: -0.0072/-0.0009\n", "Time elapsed: 19.35 min\n", "Epoch: 045/100 | Batch 000/469 | Gen/Dis Loss: -0.0013/-0.0012\n", "Epoch: 045/100 | Batch 100/469 | Gen/Dis Loss: -0.0002/-0.0002\n", "Epoch: 045/100 | Batch 200/469 | Gen/Dis Loss: 0.0008/0.0005\n", "Epoch: 045/100 | Batch 300/469 | Gen/Dis Loss: 0.0056/-0.0006\n", "Epoch: 045/100 | Batch 400/469 | Gen/Dis Loss: -0.0134/0.0001\n", "Time elapsed: 19.88 min\n", "Epoch: 046/100 | Batch 000/469 | Gen/Dis Loss: -0.0147/0.0003\n", "Epoch: 046/100 | Batch 100/469 | Gen/Dis Loss: 0.0120/0.0002\n", "Epoch: 046/100 | Batch 200/469 | Gen/Dis Loss: -0.0061/-0.0006\n", "Epoch: 046/100 | Batch 300/469 | Gen/Dis Loss: 0.0007/-0.0012\n", "Epoch: 046/100 | Batch 400/469 | Gen/Dis Loss: -0.0118/0.0007\n", "Time elapsed: 20.40 min\n", "Epoch: 047/100 | Batch 000/469 | Gen/Dis Loss: 0.0015/-0.0018\n", "Epoch: 047/100 | Batch 100/469 | Gen/Dis Loss: -0.0118/-0.0000\n", "Epoch: 047/100 | Batch 200/469 | Gen/Dis Loss: 0.0048/0.0009\n", "Epoch: 047/100 | Batch 300/469 | Gen/Dis Loss: -0.0124/-0.0005\n", "Epoch: 047/100 | Batch 400/469 | Gen/Dis Loss: -0.0039/0.0002\n", "Time elapsed: 20.91 min\n", "Epoch: 048/100 | Batch 000/469 | Gen/Dis Loss: 0.0008/-0.0021\n", "Epoch: 048/100 | Batch 100/469 | Gen/Dis Loss: -0.0005/-0.0018\n", "Epoch: 048/100 | Batch 200/469 | Gen/Dis Loss: 0.0010/-0.0005\n", "Epoch: 048/100 | Batch 300/469 | Gen/Dis Loss: 0.0115/0.0001\n", "Epoch: 048/100 | Batch 400/469 | Gen/Dis Loss: 0.0111/-0.0002\n", "Time elapsed: 21.40 min\n", "Epoch: 049/100 | Batch 000/469 | Gen/Dis Loss: -0.0005/-0.0015\n", "Epoch: 049/100 | Batch 100/469 | Gen/Dis Loss: 0.0011/0.0006\n", "Epoch: 049/100 | Batch 200/469 | Gen/Dis Loss: -0.0071/-0.0001\n", "Epoch: 049/100 | Batch 300/469 | Gen/Dis Loss: -0.0178/0.0002\n", "Epoch: 049/100 | Batch 400/469 | Gen/Dis Loss: 0.0072/-0.0016\n", "Time elapsed: 21.93 min\n", "Epoch: 050/100 | Batch 000/469 | Gen/Dis Loss: -0.0129/0.0002\n", "Epoch: 050/100 | Batch 100/469 | Gen/Dis Loss: 0.0003/-0.0013\n", "Epoch: 050/100 | Batch 200/469 | Gen/Dis Loss: -0.0002/-0.0005\n", "Epoch: 050/100 | Batch 300/469 | Gen/Dis Loss: -0.0052/-0.0002\n", "Epoch: 050/100 | Batch 400/469 | Gen/Dis Loss: -0.0026/0.0008\n", "Time elapsed: 22.46 min\n", "Epoch: 051/100 | Batch 000/469 | Gen/Dis Loss: -0.0113/0.0000\n", "Epoch: 051/100 | Batch 100/469 | Gen/Dis Loss: -0.0013/-0.0012\n", "Epoch: 051/100 | Batch 200/469 | Gen/Dis Loss: 0.0119/0.0004\n", "Epoch: 051/100 | Batch 300/469 | Gen/Dis Loss: -0.0066/0.0003\n", "Epoch: 051/100 | Batch 400/469 | Gen/Dis Loss: -0.0078/-0.0013\n", "Time elapsed: 22.98 min\n", "Epoch: 052/100 | Batch 000/469 | Gen/Dis Loss: 0.0033/-0.0019\n", "Epoch: 052/100 | Batch 100/469 | Gen/Dis Loss: -0.0053/-0.0001\n", "Epoch: 052/100 | Batch 200/469 | Gen/Dis Loss: 0.0040/0.0007\n", "Epoch: 052/100 | Batch 300/469 | Gen/Dis Loss: 0.0033/-0.0021\n", "Time elapsed: 23.50 min\n", "Epoch: 053/100 | Batch 000/469 | Gen/Dis Loss: -0.0034/-0.0008\n", "Epoch: 053/100 | Batch 100/469 | Gen/Dis Loss: 0.0042/-0.0001\n", "Epoch: 053/100 | Batch 200/469 | Gen/Dis Loss: 0.0124/-0.0001\n", "Epoch: 053/100 | Batch 300/469 | Gen/Dis Loss: -0.0017/0.0004\n", "Epoch: 053/100 | Batch 400/469 | Gen/Dis Loss: -0.0066/0.0002\n", "Time elapsed: 24.02 min\n", "Epoch: 054/100 | Batch 000/469 | Gen/Dis Loss: 0.0108/-0.0000\n", "Epoch: 054/100 | Batch 100/469 | Gen/Dis Loss: -0.0019/-0.0019\n", "Epoch: 054/100 | Batch 200/469 | Gen/Dis Loss: 0.0058/-0.0010\n", "Epoch: 054/100 | Batch 300/469 | Gen/Dis Loss: 0.0080/-0.0002\n", "Epoch: 054/100 | Batch 400/469 | Gen/Dis Loss: -0.0115/-0.0003\n", "Time elapsed: 24.55 min\n", "Epoch: 055/100 | Batch 000/469 | Gen/Dis Loss: 0.0126/-0.0001\n", "Epoch: 055/100 | Batch 100/469 | Gen/Dis Loss: 0.0151/-0.0007\n", "Epoch: 055/100 | Batch 200/469 | Gen/Dis Loss: -0.0005/0.0007\n", "Epoch: 055/100 | Batch 300/469 | Gen/Dis Loss: 0.0079/-0.0014\n", "Epoch: 055/100 | Batch 400/469 | Gen/Dis Loss: -0.0089/-0.0005\n", "Time elapsed: 25.07 min\n", "Epoch: 056/100 | Batch 000/469 | Gen/Dis Loss: -0.0097/0.0002\n", "Epoch: 056/100 | Batch 100/469 | Gen/Dis Loss: -0.0038/0.0010\n", "Epoch: 056/100 | Batch 200/469 | Gen/Dis Loss: -0.0095/0.0006\n", "Epoch: 056/100 | Batch 300/469 | Gen/Dis Loss: -0.0044/-0.0008\n", "Epoch: 056/100 | Batch 400/469 | Gen/Dis Loss: -0.0044/-0.0016\n", "Time elapsed: 25.58 min\n", "Epoch: 057/100 | Batch 000/469 | Gen/Dis Loss: -0.0152/-0.0004\n", "Epoch: 057/100 | Batch 100/469 | Gen/Dis Loss: 0.0012/-0.0002\n", "Epoch: 057/100 | Batch 200/469 | Gen/Dis Loss: -0.0033/-0.0004\n", "Epoch: 057/100 | Batch 300/469 | Gen/Dis Loss: 0.0100/-0.0000\n", "Epoch: 057/100 | Batch 400/469 | Gen/Dis Loss: -0.0003/-0.0003\n", "Time elapsed: 26.10 min\n", "Epoch: 058/100 | Batch 000/469 | Gen/Dis Loss: -0.0068/-0.0003\n", "Epoch: 058/100 | Batch 100/469 | Gen/Dis Loss: -0.0049/0.0001\n", "Epoch: 058/100 | Batch 200/469 | Gen/Dis Loss: 0.0008/0.0003\n", "Epoch: 058/100 | Batch 300/469 | Gen/Dis Loss: -0.0002/-0.0016\n", "Epoch: 058/100 | Batch 400/469 | Gen/Dis Loss: -0.0016/-0.0004\n", "Time elapsed: 26.57 min\n", "Epoch: 059/100 | Batch 000/469 | Gen/Dis Loss: -0.0093/-0.0001\n", "Epoch: 059/100 | Batch 100/469 | Gen/Dis Loss: 0.0033/-0.0002\n", "Epoch: 059/100 | Batch 200/469 | Gen/Dis Loss: 0.0009/-0.0004\n", "Epoch: 059/100 | Batch 300/469 | Gen/Dis Loss: -0.0142/-0.0001\n", "Epoch: 059/100 | Batch 400/469 | Gen/Dis Loss: -0.0129/0.0001\n", "Time elapsed: 26.97 min\n", "Epoch: 060/100 | Batch 000/469 | Gen/Dis Loss: -0.0021/-0.0009\n", "Epoch: 060/100 | Batch 100/469 | Gen/Dis Loss: 0.0020/-0.0002\n", "Epoch: 060/100 | Batch 200/469 | Gen/Dis Loss: -0.0099/-0.0010\n", "Epoch: 060/100 | Batch 300/469 | Gen/Dis Loss: -0.0112/0.0001\n", "Epoch: 060/100 | Batch 400/469 | Gen/Dis Loss: -0.0024/-0.0004\n", "Time elapsed: 27.28 min\n", "Epoch: 061/100 | Batch 000/469 | Gen/Dis Loss: -0.0044/-0.0012\n", "Epoch: 061/100 | Batch 100/469 | Gen/Dis Loss: -0.0034/-0.0005\n", "Epoch: 061/100 | Batch 200/469 | Gen/Dis Loss: -0.0031/-0.0009\n", "Epoch: 061/100 | Batch 300/469 | Gen/Dis Loss: -0.0058/-0.0000\n", "Epoch: 061/100 | Batch 400/469 | Gen/Dis Loss: -0.0034/0.0003\n", "Time elapsed: 27.59 min\n", "Epoch: 062/100 | Batch 000/469 | Gen/Dis Loss: -0.0017/-0.0014\n", "Epoch: 062/100 | Batch 100/469 | Gen/Dis Loss: -0.0043/0.0000\n", "Epoch: 062/100 | Batch 200/469 | Gen/Dis Loss: -0.0055/0.0001\n", "Epoch: 062/100 | Batch 300/469 | Gen/Dis Loss: 0.0066/-0.0009\n", "Epoch: 062/100 | Batch 400/469 | Gen/Dis Loss: -0.0013/0.0007\n", "Time elapsed: 27.89 min\n", "Epoch: 063/100 | Batch 000/469 | Gen/Dis Loss: 0.0031/-0.0013\n", "Epoch: 063/100 | Batch 100/469 | Gen/Dis Loss: -0.0093/0.0000\n", "Epoch: 063/100 | Batch 200/469 | Gen/Dis Loss: -0.0041/-0.0003\n", "Epoch: 063/100 | Batch 300/469 | Gen/Dis Loss: 0.0093/-0.0003\n", "Epoch: 063/100 | Batch 400/469 | Gen/Dis Loss: 0.0140/0.0003\n", "Time elapsed: 28.20 min\n", "Epoch: 064/100 | Batch 000/469 | Gen/Dis Loss: 0.0119/0.0003\n", "Epoch: 064/100 | Batch 100/469 | Gen/Dis Loss: 0.0124/-0.0009\n", "Epoch: 064/100 | Batch 200/469 | Gen/Dis Loss: -0.0076/0.0001\n", "Epoch: 064/100 | Batch 300/469 | Gen/Dis Loss: -0.0027/-0.0002\n", "Epoch: 064/100 | Batch 400/469 | Gen/Dis Loss: 0.0005/0.0002\n", "Time elapsed: 28.51 min\n", "Epoch: 065/100 | Batch 000/469 | Gen/Dis Loss: -0.0002/0.0005\n", "Epoch: 065/100 | Batch 100/469 | Gen/Dis Loss: 0.0005/-0.0006\n", "Epoch: 065/100 | Batch 200/469 | Gen/Dis Loss: 0.0089/0.0001\n", "Epoch: 065/100 | Batch 300/469 | Gen/Dis Loss: -0.0060/-0.0007\n", "Epoch: 065/100 | Batch 400/469 | Gen/Dis Loss: 0.0010/-0.0006\n", "Time elapsed: 28.82 min\n", "Epoch: 066/100 | Batch 000/469 | Gen/Dis Loss: 0.0078/0.0003\n", "Epoch: 066/100 | Batch 100/469 | Gen/Dis Loss: 0.0001/0.0003\n", "Epoch: 066/100 | Batch 200/469 | Gen/Dis Loss: -0.0047/-0.0001\n", "Epoch: 066/100 | Batch 300/469 | Gen/Dis Loss: 0.0067/-0.0005\n", "Epoch: 066/100 | Batch 400/469 | Gen/Dis Loss: 0.0030/-0.0004\n", "Time elapsed: 29.13 min\n", "Epoch: 067/100 | Batch 000/469 | Gen/Dis Loss: 0.0048/-0.0002\n", "Epoch: 067/100 | Batch 100/469 | Gen/Dis Loss: 0.0046/-0.0004\n", "Epoch: 067/100 | Batch 200/469 | Gen/Dis Loss: -0.0027/-0.0002\n", "Epoch: 067/100 | Batch 300/469 | Gen/Dis Loss: 0.0040/-0.0005\n", "Epoch: 067/100 | Batch 400/469 | Gen/Dis Loss: 0.0062/-0.0004\n", "Time elapsed: 29.44 min\n", "Epoch: 068/100 | Batch 000/469 | Gen/Dis Loss: 0.0033/0.0000\n", "Epoch: 068/100 | Batch 100/469 | Gen/Dis Loss: 0.0000/-0.0007\n", "Epoch: 068/100 | Batch 200/469 | Gen/Dis Loss: -0.0038/0.0001\n", "Epoch: 068/100 | Batch 300/469 | Gen/Dis Loss: -0.0047/-0.0000\n", "Epoch: 068/100 | Batch 400/469 | Gen/Dis Loss: 0.0131/-0.0001\n", "Time elapsed: 29.74 min\n", "Epoch: 069/100 | Batch 000/469 | Gen/Dis Loss: 0.0051/-0.0004\n", "Epoch: 069/100 | Batch 100/469 | Gen/Dis Loss: 0.0027/-0.0002\n", "Epoch: 069/100 | Batch 200/469 | Gen/Dis Loss: -0.0019/-0.0005\n", "Epoch: 069/100 | Batch 300/469 | Gen/Dis Loss: 0.0072/-0.0000\n", "Epoch: 069/100 | Batch 400/469 | Gen/Dis Loss: -0.0068/0.0002\n", "Time elapsed: 30.05 min\n", "Epoch: 070/100 | Batch 000/469 | Gen/Dis Loss: 0.0098/-0.0007\n", "Epoch: 070/100 | Batch 100/469 | Gen/Dis Loss: -0.0070/-0.0001\n", "Epoch: 070/100 | Batch 200/469 | Gen/Dis Loss: 0.0102/-0.0005\n", "Epoch: 070/100 | Batch 300/469 | Gen/Dis Loss: -0.0098/0.0000\n", "Epoch: 070/100 | Batch 400/469 | Gen/Dis Loss: -0.0121/-0.0001\n", "Time elapsed: 30.37 min\n", "Epoch: 071/100 | Batch 000/469 | Gen/Dis Loss: 0.0081/0.0000\n", "Epoch: 071/100 | Batch 100/469 | Gen/Dis Loss: 0.0028/-0.0002\n", "Epoch: 071/100 | Batch 200/469 | Gen/Dis Loss: -0.0082/-0.0002\n", "Epoch: 071/100 | Batch 300/469 | Gen/Dis Loss: 0.0113/-0.0003\n", "Epoch: 071/100 | Batch 400/469 | Gen/Dis Loss: 0.0028/-0.0009\n", "Time elapsed: 30.68 min\n", "Epoch: 072/100 | Batch 000/469 | Gen/Dis Loss: 0.0087/-0.0003\n", "Epoch: 072/100 | Batch 100/469 | Gen/Dis Loss: -0.0092/0.0004\n", "Epoch: 072/100 | Batch 200/469 | Gen/Dis Loss: 0.0016/-0.0002\n", "Epoch: 072/100 | Batch 300/469 | Gen/Dis Loss: 0.0059/-0.0011\n", "Epoch: 072/100 | Batch 400/469 | Gen/Dis Loss: -0.0058/-0.0004\n", "Time elapsed: 30.99 min\n", "Epoch: 073/100 | Batch 000/469 | Gen/Dis Loss: -0.0079/-0.0006\n", "Epoch: 073/100 | Batch 100/469 | Gen/Dis Loss: 0.0076/-0.0001\n", "Epoch: 073/100 | Batch 200/469 | Gen/Dis Loss: -0.0003/-0.0004\n", "Epoch: 073/100 | Batch 300/469 | Gen/Dis Loss: 0.0090/-0.0000\n", "Epoch: 073/100 | Batch 400/469 | Gen/Dis Loss: 0.0064/-0.0003\n", "Time elapsed: 31.29 min\n", "Epoch: 074/100 | Batch 000/469 | Gen/Dis Loss: -0.0062/0.0002\n", "Epoch: 074/100 | Batch 100/469 | Gen/Dis Loss: 0.0074/0.0004\n", "Epoch: 074/100 | Batch 200/469 | Gen/Dis Loss: 0.0034/-0.0004\n", "Epoch: 074/100 | Batch 300/469 | Gen/Dis Loss: -0.0032/0.0000\n", "Epoch: 074/100 | Batch 400/469 | Gen/Dis Loss: 0.0045/-0.0016\n", "Time elapsed: 31.61 min\n", "Epoch: 075/100 | Batch 000/469 | Gen/Dis Loss: 0.0067/-0.0018\n", "Epoch: 075/100 | Batch 100/469 | Gen/Dis Loss: -0.0029/-0.0007\n", "Epoch: 075/100 | Batch 200/469 | Gen/Dis Loss: -0.0014/-0.0001\n", "Epoch: 075/100 | Batch 300/469 | Gen/Dis Loss: -0.0001/-0.0013\n", "Epoch: 075/100 | Batch 400/469 | Gen/Dis Loss: -0.0023/-0.0006\n", "Time elapsed: 31.91 min\n", "Epoch: 076/100 | Batch 000/469 | Gen/Dis Loss: 0.0036/-0.0008\n", "Epoch: 076/100 | Batch 100/469 | Gen/Dis Loss: -0.0003/-0.0001\n", "Epoch: 076/100 | Batch 200/469 | Gen/Dis Loss: 0.0024/-0.0001\n", "Epoch: 076/100 | Batch 300/469 | Gen/Dis Loss: 0.0006/-0.0003\n", "Epoch: 076/100 | Batch 400/469 | Gen/Dis Loss: -0.0000/0.0000\n", "Time elapsed: 32.23 min\n", "Epoch: 077/100 | Batch 000/469 | Gen/Dis Loss: -0.0022/0.0005\n", "Epoch: 077/100 | Batch 100/469 | Gen/Dis Loss: 0.0091/-0.0000\n", "Epoch: 077/100 | Batch 200/469 | Gen/Dis Loss: 0.0090/-0.0004\n", "Epoch: 077/100 | Batch 300/469 | Gen/Dis Loss: -0.0045/-0.0001\n", "Epoch: 077/100 | Batch 400/469 | Gen/Dis Loss: 0.0035/0.0006\n", "Time elapsed: 32.53 min\n", "Epoch: 078/100 | Batch 000/469 | Gen/Dis Loss: 0.0089/0.0001\n", "Epoch: 078/100 | Batch 100/469 | Gen/Dis Loss: 0.0075/-0.0003\n", "Epoch: 078/100 | Batch 200/469 | Gen/Dis Loss: -0.0023/-0.0014\n", "Epoch: 078/100 | Batch 300/469 | Gen/Dis Loss: 0.0030/-0.0012\n", "Epoch: 078/100 | Batch 400/469 | Gen/Dis Loss: -0.0115/0.0000\n", "Time elapsed: 32.84 min\n", "Epoch: 079/100 | Batch 000/469 | Gen/Dis Loss: -0.0055/0.0006\n", "Epoch: 079/100 | Batch 100/469 | Gen/Dis Loss: -0.0082/-0.0001\n", "Epoch: 079/100 | Batch 200/469 | Gen/Dis Loss: -0.0013/-0.0006\n", "Epoch: 079/100 | Batch 300/469 | Gen/Dis Loss: -0.0147/0.0006\n", "Epoch: 079/100 | Batch 400/469 | Gen/Dis Loss: 0.0019/-0.0005\n", "Time elapsed: 33.15 min\n", "Epoch: 080/100 | Batch 000/469 | Gen/Dis Loss: -0.0017/-0.0001\n", "Epoch: 080/100 | Batch 100/469 | Gen/Dis Loss: -0.0035/-0.0014\n", "Epoch: 080/100 | Batch 200/469 | Gen/Dis Loss: -0.0055/0.0005\n", "Epoch: 080/100 | Batch 300/469 | Gen/Dis Loss: 0.0093/0.0001\n", "Epoch: 080/100 | Batch 400/469 | Gen/Dis Loss: 0.0036/-0.0003\n", "Time elapsed: 33.45 min\n", "Epoch: 081/100 | Batch 000/469 | Gen/Dis Loss: -0.0003/-0.0008\n", "Epoch: 081/100 | Batch 100/469 | Gen/Dis Loss: -0.0013/-0.0002\n", "Epoch: 081/100 | Batch 200/469 | Gen/Dis Loss: -0.0011/0.0001\n", "Epoch: 081/100 | Batch 300/469 | Gen/Dis Loss: 0.0014/-0.0009\n", "Epoch: 081/100 | Batch 400/469 | Gen/Dis Loss: -0.0065/0.0005\n", "Time elapsed: 33.76 min\n", "Epoch: 082/100 | Batch 000/469 | Gen/Dis Loss: 0.0072/-0.0007\n", "Epoch: 082/100 | Batch 100/469 | Gen/Dis Loss: 0.0079/-0.0005\n", "Epoch: 082/100 | Batch 200/469 | Gen/Dis Loss: -0.0043/-0.0005\n", "Epoch: 082/100 | Batch 300/469 | Gen/Dis Loss: -0.0119/0.0002\n", "Epoch: 082/100 | Batch 400/469 | Gen/Dis Loss: -0.0008/-0.0007\n", "Time elapsed: 34.06 min\n", "Epoch: 083/100 | Batch 000/469 | Gen/Dis Loss: -0.0010/-0.0015\n", "Epoch: 083/100 | Batch 100/469 | Gen/Dis Loss: 0.0126/-0.0000\n", "Epoch: 083/100 | Batch 200/469 | Gen/Dis Loss: -0.0006/-0.0008\n", "Epoch: 083/100 | Batch 300/469 | Gen/Dis Loss: 0.0055/-0.0005\n", "Epoch: 083/100 | Batch 400/469 | Gen/Dis Loss: 0.0085/-0.0000\n", "Time elapsed: 34.37 min\n", "Epoch: 084/100 | Batch 000/469 | Gen/Dis Loss: -0.0085/-0.0003\n", "Epoch: 084/100 | Batch 100/469 | Gen/Dis Loss: -0.0008/-0.0001\n", "Epoch: 084/100 | Batch 200/469 | Gen/Dis Loss: 0.0046/-0.0001\n", "Epoch: 084/100 | Batch 300/469 | Gen/Dis Loss: -0.0052/-0.0002\n", "Epoch: 084/100 | Batch 400/469 | Gen/Dis Loss: -0.0037/-0.0002\n", "Time elapsed: 34.67 min\n", "Epoch: 085/100 | Batch 000/469 | Gen/Dis Loss: -0.0008/-0.0006\n", "Epoch: 085/100 | Batch 100/469 | Gen/Dis Loss: -0.0061/-0.0001\n", "Epoch: 085/100 | Batch 200/469 | Gen/Dis Loss: -0.0102/0.0001\n", "Epoch: 085/100 | Batch 300/469 | Gen/Dis Loss: 0.0008/-0.0008\n", "Epoch: 085/100 | Batch 400/469 | Gen/Dis Loss: -0.0019/-0.0004\n", "Time elapsed: 34.99 min\n", "Epoch: 086/100 | Batch 000/469 | Gen/Dis Loss: -0.0029/0.0001\n", "Epoch: 086/100 | Batch 100/469 | Gen/Dis Loss: 0.0046/-0.0001\n", "Epoch: 086/100 | Batch 200/469 | Gen/Dis Loss: -0.0042/-0.0005\n", "Epoch: 086/100 | Batch 300/469 | Gen/Dis Loss: -0.0082/-0.0002\n", "Epoch: 086/100 | Batch 400/469 | Gen/Dis Loss: -0.0093/-0.0004\n", "Time elapsed: 35.29 min\n", "Epoch: 087/100 | Batch 000/469 | Gen/Dis Loss: -0.0035/0.0002\n", "Epoch: 087/100 | Batch 100/469 | Gen/Dis Loss: -0.0071/0.0000\n", "Epoch: 087/100 | Batch 200/469 | Gen/Dis Loss: 0.0018/0.0002\n", "Epoch: 087/100 | Batch 300/469 | Gen/Dis Loss: -0.0019/-0.0004\n", "Epoch: 087/100 | Batch 400/469 | Gen/Dis Loss: 0.0075/-0.0002\n", "Time elapsed: 35.60 min\n", "Epoch: 088/100 | Batch 000/469 | Gen/Dis Loss: 0.0017/-0.0003\n", "Epoch: 088/100 | Batch 100/469 | Gen/Dis Loss: 0.0024/-0.0005\n", "Epoch: 088/100 | Batch 200/469 | Gen/Dis Loss: -0.0023/-0.0003\n", "Epoch: 088/100 | Batch 300/469 | Gen/Dis Loss: 0.0001/-0.0005\n", "Epoch: 088/100 | Batch 400/469 | Gen/Dis Loss: -0.0027/-0.0003\n", "Time elapsed: 35.91 min\n", "Epoch: 089/100 | Batch 000/469 | Gen/Dis Loss: -0.0039/-0.0007\n", "Epoch: 089/100 | Batch 100/469 | Gen/Dis Loss: -0.0031/-0.0003\n", "Epoch: 089/100 | Batch 200/469 | Gen/Dis Loss: 0.0024/-0.0003\n", "Epoch: 089/100 | Batch 300/469 | Gen/Dis Loss: -0.0041/-0.0001\n", "Epoch: 089/100 | Batch 400/469 | Gen/Dis Loss: 0.0014/-0.0005\n", "Time elapsed: 36.22 min\n", "Epoch: 090/100 | Batch 000/469 | Gen/Dis Loss: -0.0012/-0.0004\n", "Epoch: 090/100 | Batch 100/469 | Gen/Dis Loss: -0.0022/-0.0004\n", "Epoch: 090/100 | Batch 200/469 | Gen/Dis Loss: -0.0083/-0.0005\n", "Epoch: 090/100 | Batch 300/469 | Gen/Dis Loss: -0.0047/-0.0004\n", "Epoch: 090/100 | Batch 400/469 | Gen/Dis Loss: 0.0001/-0.0003\n", "Time elapsed: 36.52 min\n", "Epoch: 091/100 | Batch 000/469 | Gen/Dis Loss: -0.0013/-0.0003\n", "Epoch: 091/100 | Batch 100/469 | Gen/Dis Loss: -0.0040/-0.0005\n", "Epoch: 091/100 | Batch 200/469 | Gen/Dis Loss: -0.0029/-0.0003\n", "Epoch: 091/100 | Batch 300/469 | Gen/Dis Loss: -0.0026/-0.0003\n", "Epoch: 091/100 | Batch 400/469 | Gen/Dis Loss: -0.0001/-0.0002\n", "Time elapsed: 36.84 min\n", "Epoch: 092/100 | Batch 000/469 | Gen/Dis Loss: 0.0051/-0.0004\n", "Epoch: 092/100 | Batch 100/469 | Gen/Dis Loss: -0.0005/-0.0003\n", "Epoch: 092/100 | Batch 200/469 | Gen/Dis Loss: 0.0041/-0.0004\n", "Epoch: 092/100 | Batch 300/469 | Gen/Dis Loss: 0.0020/-0.0004\n", "Epoch: 092/100 | Batch 400/469 | Gen/Dis Loss: 0.0004/-0.0003\n", "Time elapsed: 37.15 min\n", "Epoch: 093/100 | Batch 000/469 | Gen/Dis Loss: -0.0005/-0.0003\n", "Epoch: 093/100 | Batch 100/469 | Gen/Dis Loss: 0.0008/-0.0004\n", "Epoch: 093/100 | Batch 200/469 | Gen/Dis Loss: -0.0013/-0.0004\n", "Epoch: 093/100 | Batch 300/469 | Gen/Dis Loss: -0.0007/-0.0004\n", "Epoch: 093/100 | Batch 400/469 | Gen/Dis Loss: -0.0013/-0.0002\n", "Time elapsed: 37.45 min\n", "Epoch: 094/100 | Batch 000/469 | Gen/Dis Loss: -0.0017/-0.0003\n", "Epoch: 094/100 | Batch 100/469 | Gen/Dis Loss: -0.0018/-0.0003\n", "Epoch: 094/100 | Batch 200/469 | Gen/Dis Loss: -0.0018/-0.0003\n", "Epoch: 094/100 | Batch 300/469 | Gen/Dis Loss: -0.0017/-0.0003\n", "Epoch: 094/100 | Batch 400/469 | Gen/Dis Loss: -0.0019/-0.0003\n", "Time elapsed: 37.75 min\n", "Epoch: 095/100 | Batch 000/469 | Gen/Dis Loss: -0.0026/-0.0003\n", "Epoch: 095/100 | Batch 100/469 | Gen/Dis Loss: -0.0022/-0.0003\n", "Epoch: 095/100 | Batch 200/469 | Gen/Dis Loss: -0.0014/-0.0003\n", "Epoch: 095/100 | Batch 300/469 | Gen/Dis Loss: -0.0005/-0.0002\n", "Epoch: 095/100 | Batch 400/469 | Gen/Dis Loss: -0.0008/-0.0002\n", "Time elapsed: 38.06 min\n", "Epoch: 096/100 | Batch 000/469 | Gen/Dis Loss: -0.0002/-0.0003\n", "Epoch: 096/100 | Batch 100/469 | Gen/Dis Loss: 0.0011/-0.0003\n", "Epoch: 096/100 | Batch 200/469 | Gen/Dis Loss: 0.0006/-0.0003\n", "Epoch: 096/100 | Batch 300/469 | Gen/Dis Loss: 0.0019/-0.0004\n", "Epoch: 096/100 | Batch 400/469 | Gen/Dis Loss: 0.0012/-0.0003\n", "Time elapsed: 38.37 min\n", "Epoch: 097/100 | Batch 000/469 | Gen/Dis Loss: -0.0001/-0.0003\n", "Epoch: 097/100 | Batch 100/469 | Gen/Dis Loss: -0.0007/-0.0004\n", "Epoch: 097/100 | Batch 200/469 | Gen/Dis Loss: -0.0015/-0.0003\n", "Epoch: 097/100 | Batch 300/469 | Gen/Dis Loss: -0.0034/-0.0003\n", "Epoch: 097/100 | Batch 400/469 | Gen/Dis Loss: -0.0007/-0.0004\n", "Time elapsed: 38.67 min\n", "Epoch: 098/100 | Batch 000/469 | Gen/Dis Loss: -0.0004/-0.0003\n", "Epoch: 098/100 | Batch 100/469 | Gen/Dis Loss: -0.0015/-0.0004\n", "Epoch: 098/100 | Batch 200/469 | Gen/Dis Loss: -0.0017/-0.0002\n", "Epoch: 098/100 | Batch 300/469 | Gen/Dis Loss: -0.0014/-0.0004\n", "Epoch: 098/100 | Batch 400/469 | Gen/Dis Loss: -0.0037/-0.0004\n", "Time elapsed: 38.97 min\n", "Epoch: 099/100 | Batch 000/469 | Gen/Dis Loss: -0.0055/-0.0003\n", "Epoch: 099/100 | Batch 100/469 | Gen/Dis Loss: -0.0039/-0.0003\n", "Epoch: 099/100 | Batch 200/469 | Gen/Dis Loss: -0.0037/-0.0003\n", "Epoch: 099/100 | Batch 300/469 | Gen/Dis Loss: -0.0041/-0.0003\n", "Epoch: 099/100 | Batch 400/469 | Gen/Dis Loss: -0.0044/-0.0003\n", "Time elapsed: 39.27 min\n", "Epoch: 100/100 | Batch 000/469 | Gen/Dis Loss: -0.0041/-0.0002\n", "Epoch: 100/100 | Batch 100/469 | Gen/Dis Loss: -0.0024/-0.0003\n", "Epoch: 100/100 | Batch 200/469 | Gen/Dis Loss: -0.0021/-0.0002\n", "Epoch: 100/100 | Batch 300/469 | Gen/Dis Loss: -0.0033/-0.0002\n", "Epoch: 100/100 | Batch 400/469 | Gen/Dis Loss: -0.0035/-0.0002\n", "Time elapsed: 39.57 min\n", "Total Training Time: 39.57 min\n" ] } ], "source": [ "start_time = time.time() \n", "\n", "discr_costs = []\n", "gener_costs = []\n", "for epoch in range(NUM_EPOCHS):\n", " model = model.train()\n", " for batch_idx, (features, targets) in enumerate(train_loader):\n", "\n", " \n", " # Normalize images to [-1, 1] range\n", " features = (features - 0.5)*2.\n", " features = features.view(-1, IMG_SIZE).to(device) \n", "\n", " targets = targets.to(device)\n", "\n", " # Regular GAN:\n", " # valid = torch.ones(targets.size(0)).float().to(device)\n", " # fake = torch.zeros(targets.size(0)).float().to(device)\n", " \n", " # WGAN:\n", " valid = -(torch.ones(targets.size(0)).float()).to(device)\n", " fake = torch.ones(targets.size(0)).float().to(device)\n", " \n", "\n", " ### FORWARD AND BACK PROP\n", " \n", " \n", " # --------------------------\n", " # Train Generator\n", " # --------------------------\n", " \n", " # Make new images\n", " z = torch.zeros((targets.size(0), LATENT_DIM)).uniform_(0.0, 1.0).to(device)\n", " generated_features = model.generator_forward(z)\n", " \n", " # Loss for fooling the discriminator\n", " discr_pred = model.discriminator_forward(generated_features.view(targets.size(0), 1, 28, 28))\n", " \n", " # Regular GAN:\n", " # gener_loss = F.binary_cross_entropy_with_logits(discr_pred, valid)\n", " \n", " # WGAN:\n", " gener_loss = wasserstein_loss(valid, discr_pred)\n", " \n", " optim_gener.zero_grad()\n", " gener_loss.backward()\n", " optim_gener.step()\n", " \n", " \n", " # --------------------------\n", " # Train Discriminator\n", " # -------------------------- \n", "\n", " # WGAN: Multiple loops for the discriminator\n", " for _ in range(num_iter_critic):\n", " \n", " discr_pred_real = model.discriminator_forward(features.view(targets.size(0), 1, 28, 28))\n", " # Regular GAN:\n", " # real_loss = F.binary_cross_entropy_with_logits(discr_pred_real, valid)\n", " # WGAN:\n", " real_loss = wasserstein_loss(valid, discr_pred_real)\n", "\n", " discr_pred_fake = model.discriminator_forward(generated_features.view(targets.size(0), 1, 28, 28).detach())\n", "\n", " # Regular GAN:\n", " # fake_loss = F.binary_cross_entropy_with_logits(discr_pred_fake, fake)\n", " # WGAN:\n", " fake_loss = wasserstein_loss(fake, discr_pred_fake)\n", "\n", " discr_loss = 0.5*(real_loss + fake_loss)\n", " \n", " optim_discr.zero_grad()\n", " discr_loss.backward()\n", " optim_discr.step() \n", "\n", " # WGAN:\n", " for p in model.discriminator.parameters():\n", " p.data.clamp_(-weight_clip_value, weight_clip_value)\n", "\n", " \n", " discr_costs.append(discr_loss.item())\n", " gener_costs.append(gener_loss.item())\n", " \n", " \n", " ### LOGGING\n", " if not batch_idx % 100:\n", " print ('Epoch: %03d/%03d | Batch %03d/%03d | Gen/Dis Loss: %.4f/%.4f' \n", " %(epoch+1, NUM_EPOCHS, batch_idx, \n", " len(train_loader), gener_loss, discr_loss))\n", "\n", " print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n", " \n", "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'\\nfor i in outputs:\\n print(i.size())\\n'" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "### For Debugging\n", "\n", "\"\"\"\n", "for i in outputs:\n", " print(i.size())\n", "\"\"\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "ax1 = plt.subplot(1, 1, 1)\n", "ax1.plot(range(len(gener_costs)), gener_costs, label='Generator loss')\n", "ax1.plot(range(len(discr_costs)), discr_costs, label='Discriminator loss')\n", "ax1.set_xlabel('Iterations')\n", "ax1.set_ylabel('Loss')\n", "ax1.legend()\n", "\n", "###################\n", "# Set scond x-axis\n", "ax2 = ax1.twiny()\n", "newlabel = list(range(NUM_EPOCHS+1))\n", "iter_per_epoch = len(train_loader)\n", "newpos = [e*iter_per_epoch for e in newlabel]\n", "\n", "ax2.set_xticklabels(newlabel[::10])\n", "ax2.set_xticks(newpos[::10])\n", "\n", "ax2.xaxis.set_ticks_position('bottom')\n", "ax2.xaxis.set_label_position('bottom')\n", "ax2.spines['bottom'].set_position(('outward', 45))\n", "ax2.set_xlabel('Epochs')\n", "ax2.set_xlim(ax1.get_xlim())\n", "###################\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "##########################\n", "### VISUALIZATION\n", "##########################\n", "\n", "\n", "model.eval()\n", "# Make new images\n", "z = torch.zeros((10, LATENT_DIM)).uniform_(0.0, 1.0).to(device)\n", "generated_features = model.generator_forward(z)\n", "imgs = generated_features.view(-1, 28, 28)\n", "\n", "fig, axes = plt.subplots(nrows=1, ncols=10, figsize=(20, 2.5))\n", "\n", "\n", "for i, ax in enumerate(axes):\n", " axes[i].imshow(imgs[i].to(torch.device('cpu')).detach(), cmap='binary')" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Linear-1 [-1, 3136] 313,600\n", " BatchNorm1d-2 [-1, 3136] 6,272\n", " LeakyReLU-3 [-1, 3136] 0\n", " Reshape1-4 [-1, 64, 7, 7] 0\n", " ConvTranspose2d-5 [-1, 32, 13, 13] 18,432\n", " BatchNorm2d-6 [-1, 32, 13, 13] 64\n", " LeakyReLU-7 [-1, 32, 13, 13] 0\n", " ConvTranspose2d-8 [-1, 16, 25, 25] 4,608\n", " BatchNorm2d-9 [-1, 16, 25, 25] 32\n", " LeakyReLU-10 [-1, 16, 25, 25] 0\n", " ConvTranspose2d-11 [-1, 8, 27, 27] 1,152\n", " BatchNorm2d-12 [-1, 8, 27, 27] 16\n", " LeakyReLU-13 [-1, 8, 27, 27] 0\n", " ConvTranspose2d-14 [-1, 1, 28, 28] 32\n", " Tanh-15 [-1, 1, 28, 28] 0\n", "================================================================\n", "Total params: 344,208\n", "Trainable params: 344,208\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.00\n", "Forward/backward pass size (MB): 0.59\n", "Params size (MB): 1.31\n", "Estimated Total Size (MB): 1.91\n", "----------------------------------------------------------------\n", "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Conv2d-1 [-1, 8, 14, 14] 72\n", " BatchNorm2d-2 [-1, 8, 14, 14] 16\n", " LeakyReLU-3 [-1, 8, 14, 14] 0\n", " Conv2d-4 [-1, 16, 7, 7] 1,152\n", " BatchNorm2d-5 [-1, 16, 7, 7] 32\n", " LeakyReLU-6 [-1, 16, 7, 7] 0\n", " Conv2d-7 [-1, 32, 4, 4] 4,608\n", " BatchNorm2d-8 [-1, 32, 4, 4] 64\n", " LeakyReLU-9 [-1, 32, 4, 4] 0\n", " Flatten-10 [-1, 512] 0\n", " Linear-11 [-1, 1] 513\n", "================================================================\n", "Total params: 6,457\n", "Trainable params: 6,457\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.00\n", "Forward/backward pass size (MB): 0.07\n", "Params size (MB): 0.02\n", "Estimated Total Size (MB): 0.10\n", "----------------------------------------------------------------\n" ] } ], "source": [ "from torchsummary import summary\n", "model = model.to('cuda:0')\n", "summary(model.generator, input_size=(100,))\n", "summary(model.discriminator, input_size=(1, 28, 28))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" }, "toc": { "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }