{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.6.1\n", "\n", "tensorflow 1.13.1\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p tensorflow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Convolutional General Adversarial Networks with Label Smoothing" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Same as [./gan-conv.ipynb](./gan-conv.ipynb) but with **label smoothing**.\n", "\n", "Here, the label smoothing approach is to replace real image labels (1's) by 0.9, based on the idea in\n", "\n", "- Salimans, Tim, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, and Xi Chen. \"Improved techniques for training GANs.\" In Advances in Neural Information Processing Systems, pp. 2234-2242. 2016.\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/device:GPU:0'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "from tensorflow.examples.tutorials.mnist import input_data\n", "import pickle as pkl\n", "\n", "tf.test.gpu_device_name()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From :17: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please write your own downloading logic.\n", "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use tf.data to implement this functionality.\n", "Extracting MNIST_data/train-images-idx3-ubyte.gz\n", "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use tf.data to implement this functionality.\n", "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n", "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n", "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n", "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", "WARNING:tensorflow:From :64: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use keras.layers.dense instead.\n", "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Colocations handled automatically by placer.\n", "WARNING:tensorflow:From :65: batch_normalization (from tensorflow.python.layers.normalization) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use keras.layers.batch_normalization instead.\n", "WARNING:tensorflow:From :74: conv2d_transpose (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use keras.layers.conv2d_transpose instead.\n", "WARNING:tensorflow:From :77: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use keras.layers.dropout instead.\n", "WARNING:tensorflow:From :121: conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use keras.layers.conv2d instead.\n", "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.cast instead.\n" ] } ], "source": [ "### Abbreviatiuons\n", "# dis_*: discriminator network\n", "# gen_*: generator network\n", "\n", "########################\n", "### Helper functions\n", "########################\n", "\n", "def leaky_relu(x, alpha=0.0001):\n", " return tf.maximum(alpha * x, x)\n", "\n", "\n", "########################\n", "### DATASET\n", "########################\n", "\n", "mnist = input_data.read_data_sets('MNIST_data')\n", "\n", "\n", "#########################\n", "### SETTINGS\n", "#########################\n", "\n", "# Hyperparameters\n", "learning_rate = 0.001\n", "training_epochs = 50\n", "batch_size = 64\n", "dropout_rate = 0.5\n", "\n", "# Architecture\n", "dis_input_size = 784\n", "gen_input_size = 100\n", "\n", "# Other settings\n", "print_interval = 200\n", "\n", "#########################\n", "### GRAPH DEFINITION\n", "#########################\n", "\n", "g = tf.Graph()\n", "with g.as_default():\n", " \n", " # Placeholders for settings\n", " dropout = tf.placeholder(tf.float32, shape=None, name='dropout')\n", " is_training = tf.placeholder(tf.bool, shape=None, name='is_training')\n", " \n", " # Input data\n", " dis_x = tf.placeholder(tf.float32, shape=[None, dis_input_size],\n", " name='discriminator_inputs') \n", " gen_x = tf.placeholder(tf.float32, [None, gen_input_size],\n", " name='generator_inputs')\n", "\n", "\n", " ##################\n", " # Generator Model\n", " ##################\n", "\n", " with tf.variable_scope('generator'):\n", " \n", " # 100 => 784 => 7x7x64\n", " gen_fc = tf.layers.dense(inputs=gen_x, units=3136,\n", " bias_initializer=None, # no bias required when using batch_norm\n", " activation=None)\n", " gen_fc = tf.layers.batch_normalization(gen_fc, training=is_training)\n", " gen_fc = leaky_relu(gen_fc)\n", " gen_fc = tf.reshape(gen_fc, (-1, 7, 7, 64))\n", " \n", " # 7x7x64 => 14x14x32\n", " deconv1 = tf.layers.conv2d_transpose(gen_fc, filters=32, \n", " kernel_size=(3, 3), strides=(2, 2), \n", " padding='same',\n", " bias_initializer=None,\n", " activation=None)\n", " deconv1 = tf.layers.batch_normalization(deconv1, training=is_training)\n", " deconv1 = leaky_relu(deconv1) \n", " deconv1 = tf.layers.dropout(deconv1, rate=dropout_rate)\n", " \n", " # 14x14x32 => 28x28x16\n", " deconv2 = tf.layers.conv2d_transpose(deconv1, filters=16, \n", " kernel_size=(3, 3), strides=(2, 2), \n", " padding='same',\n", " bias_initializer=None,\n", " activation=None)\n", " deconv2 = tf.layers.batch_normalization(deconv2, training=is_training)\n", " deconv2 = leaky_relu(deconv2) \n", " deconv2 = tf.layers.dropout(deconv2, rate=dropout_rate)\n", " \n", " # 28x28x16 => 28x28x8\n", " deconv3 = tf.layers.conv2d_transpose(deconv2, filters=8, \n", " kernel_size=(3, 3), strides=(1, 1), \n", " padding='same',\n", " bias_initializer=None,\n", " activation=None)\n", " deconv3 = tf.layers.batch_normalization(deconv3, training=is_training)\n", " deconv3 = leaky_relu(deconv3) \n", " deconv3 = tf.layers.dropout(deconv3, rate=dropout_rate)\n", " \n", " # 28x28x8 => 28x28x1\n", " gen_logits = tf.layers.conv2d_transpose(deconv3, filters=1, \n", " kernel_size=(3, 3), strides=(1, 1), \n", " padding='same',\n", " bias_initializer=None,\n", " activation=None)\n", " gen_out = tf.tanh(gen_logits, 'generator_outputs')\n", "\n", "\n", " ######################\n", " # Discriminator Model\n", " ######################\n", " \n", " def build_discriminator_graph(input_x, reuse=None):\n", "\n", " with tf.variable_scope('discriminator', reuse=reuse):\n", " \n", " # 28x28x1 => 14x14x8\n", " conv_input = tf.reshape(input_x, (-1, 28, 28, 1))\n", " conv1 = tf.layers.conv2d(conv_input, filters=8, kernel_size=(3, 3),\n", " strides=(2, 2), padding='same',\n", " bias_initializer=None,\n", " activation=None)\n", " conv1 = tf.layers.batch_normalization(conv1, training=is_training)\n", " conv1 = leaky_relu(conv1)\n", " conv1 = tf.layers.dropout(conv1, rate=dropout_rate)\n", " \n", " # 14x14x8 => 7x7x32\n", " conv2 = tf.layers.conv2d(conv1, filters=32, kernel_size=(3, 3),\n", " strides=(2, 2), padding='same',\n", " bias_initializer=None,\n", " activation=None)\n", " conv2 = tf.layers.batch_normalization(conv2, training=is_training)\n", " conv2 = leaky_relu(conv2)\n", " conv2 = tf.layers.dropout(conv2, rate=dropout_rate)\n", "\n", " # fully connected layer\n", " fc_input = tf.reshape(conv2, (-1, 7*7*32))\n", " logits = tf.layers.dense(inputs=fc_input, units=1, activation=None)\n", " out = tf.sigmoid(logits)\n", " \n", " return logits, out \n", "\n", " # Create a discriminator for real data and a discriminator for fake data\n", " dis_real_logits, dis_real_out = build_discriminator_graph(dis_x, reuse=False)\n", " dis_fake_logits, dis_fake_out = build_discriminator_graph(gen_out, reuse=True)\n", "\n", "\n", " #####################################\n", " # Generator and Discriminator Losses\n", " #####################################\n", " \n", " # Two discriminator cost components: loss on real data + loss on fake data\n", " # Real data has class label 1, fake data has class label 0\n", " dis_real_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_real_logits, \n", " labels=tf.ones_like(dis_real_logits) * 0.9)\n", " dis_fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake_logits, \n", " labels=tf.zeros_like(dis_fake_logits))\n", " dis_cost = tf.add(tf.reduce_mean(dis_fake_loss), \n", " tf.reduce_mean(dis_real_loss), \n", " name='discriminator_cost')\n", " \n", " # Generator cost: difference between dis. prediction and label \"1\" for real images\n", " gen_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake_logits,\n", " labels=tf.ones_like(dis_fake_logits) * 0.9)\n", " gen_cost = tf.reduce_mean(gen_loss, name='generator_cost')\n", " \n", " \n", " #########################################\n", " # Generator and Discriminator Optimizers\n", " #########################################\n", " \n", " dis_optimizer = tf.train.AdamOptimizer(learning_rate)\n", " dis_train_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator')\n", " dis_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator')\n", " \n", " with tf.control_dependencies(dis_update_ops): # required to upd. batch_norm params\n", " dis_train = dis_optimizer.minimize(dis_cost, var_list=dis_train_vars,\n", " name='train_discriminator')\n", " \n", " gen_optimizer = tf.train.AdamOptimizer(learning_rate)\n", " gen_train_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')\n", " gen_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator')\n", " \n", " with tf.control_dependencies(gen_update_ops): # required to upd. batch_norm params\n", " gen_train = gen_optimizer.minimize(gen_cost, var_list=gen_train_vars,\n", " name='train_generator')\n", " \n", " # Saver to save session for reuse\n", " saver = tf.train.Saver()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Minibatch: 0001 | Dis/Gen Cost: 1.290/1.057\n", "Minibatch: 0201 | Dis/Gen Cost: 1.054/1.374\n", "Minibatch: 0401 | Dis/Gen Cost: 0.993/1.822\n", "Minibatch: 0601 | Dis/Gen Cost: 0.580/2.214\n", "Minibatch: 0801 | Dis/Gen Cost: 0.708/1.710\n", "Epoch: 0001 | Dis/Gen AvgCost: 0.821/1.891\n", "Minibatch: 0001 | Dis/Gen Cost: 1.013/2.102\n", "Minibatch: 0201 | Dis/Gen Cost: 1.113/1.212\n", "Minibatch: 0401 | Dis/Gen Cost: 1.247/1.045\n", "Minibatch: 0601 | Dis/Gen Cost: 0.935/1.376\n", "Minibatch: 0801 | Dis/Gen Cost: 0.685/2.113\n", "Epoch: 0002 | Dis/Gen AvgCost: 0.884/1.699\n", "Minibatch: 0001 | Dis/Gen Cost: 1.089/1.191\n", "Minibatch: 0201 | Dis/Gen Cost: 0.971/1.760\n", "Minibatch: 0401 | Dis/Gen Cost: 0.769/2.239\n", "Minibatch: 0601 | Dis/Gen Cost: 1.275/1.442\n", "Minibatch: 0801 | Dis/Gen Cost: 1.023/1.846\n", "Epoch: 0003 | Dis/Gen AvgCost: 0.996/1.580\n", "Minibatch: 0001 | Dis/Gen Cost: 1.088/1.628\n", "Minibatch: 0201 | Dis/Gen Cost: 1.095/1.451\n", "Minibatch: 0401 | Dis/Gen Cost: 0.891/1.560\n", "Minibatch: 0601 | Dis/Gen Cost: 0.974/1.353\n", "Minibatch: 0801 | Dis/Gen Cost: 1.140/1.344\n", "Epoch: 0004 | Dis/Gen AvgCost: 1.119/1.341\n", "Minibatch: 0001 | Dis/Gen Cost: 1.197/1.256\n", "Minibatch: 0201 | Dis/Gen Cost: 1.281/1.192\n", "Minibatch: 0401 | Dis/Gen Cost: 1.159/1.402\n", "Minibatch: 0601 | Dis/Gen Cost: 1.397/0.997\n", "Minibatch: 0801 | Dis/Gen Cost: 1.230/1.087\n", "Epoch: 0005 | Dis/Gen AvgCost: 1.177/1.229\n", "Minibatch: 0001 | Dis/Gen Cost: 1.104/1.103\n", "Minibatch: 0201 | Dis/Gen Cost: 1.385/1.217\n", "Minibatch: 0401 | Dis/Gen Cost: 1.069/1.247\n", "Minibatch: 0601 | Dis/Gen Cost: 1.126/1.309\n", "Minibatch: 0801 | Dis/Gen Cost: 1.202/1.529\n", "Epoch: 0006 | Dis/Gen AvgCost: 1.257/1.143\n", "Minibatch: 0001 | Dis/Gen Cost: 1.274/1.314\n", "Minibatch: 0201 | Dis/Gen Cost: 1.362/0.915\n", "Minibatch: 0401 | Dis/Gen Cost: 1.395/1.082\n", "Minibatch: 0601 | Dis/Gen Cost: 1.270/0.947\n", "Minibatch: 0801 | Dis/Gen Cost: 1.327/1.151\n", "Epoch: 0007 | Dis/Gen AvgCost: 1.324/1.042\n", "Minibatch: 0001 | Dis/Gen Cost: 1.417/0.794\n", "Minibatch: 0201 | Dis/Gen Cost: 1.210/0.995\n", "Minibatch: 0401 | Dis/Gen Cost: 1.558/0.925\n", "Minibatch: 0601 | Dis/Gen Cost: 1.191/1.106\n", "Minibatch: 0801 | Dis/Gen Cost: 1.150/1.047\n", "Epoch: 0008 | Dis/Gen AvgCost: 1.306/1.026\n", "Minibatch: 0001 | Dis/Gen Cost: 1.186/0.991\n", "Minibatch: 0201 | Dis/Gen Cost: 1.332/1.005\n", "Minibatch: 0401 | Dis/Gen Cost: 1.185/1.090\n", "Minibatch: 0601 | Dis/Gen Cost: 1.314/1.000\n", "Minibatch: 0801 | Dis/Gen Cost: 1.115/1.158\n", "Epoch: 0009 | Dis/Gen AvgCost: 1.305/1.006\n", "Minibatch: 0001 | Dis/Gen Cost: 1.348/0.868\n", "Minibatch: 0201 | Dis/Gen Cost: 1.367/0.863\n", "Minibatch: 0401 | Dis/Gen Cost: 1.328/1.020\n", "Minibatch: 0601 | Dis/Gen Cost: 1.395/0.962\n", "Minibatch: 0801 | Dis/Gen Cost: 1.390/0.979\n", "Epoch: 0010 | Dis/Gen AvgCost: 1.300/1.025\n", "Minibatch: 0001 | Dis/Gen Cost: 1.403/1.199\n", "Minibatch: 0201 | Dis/Gen Cost: 1.222/0.985\n", "Minibatch: 0401 | Dis/Gen Cost: 1.212/1.235\n", "Minibatch: 0601 | Dis/Gen Cost: 1.052/1.168\n", "Minibatch: 0801 | Dis/Gen Cost: 1.268/0.917\n", "Epoch: 0011 | Dis/Gen AvgCost: 1.305/1.002\n", "Minibatch: 0001 | Dis/Gen Cost: 1.304/0.949\n", "Minibatch: 0201 | Dis/Gen Cost: 1.198/1.137\n", "Minibatch: 0401 | Dis/Gen Cost: 1.237/1.077\n", "Minibatch: 0601 | Dis/Gen Cost: 1.337/0.930\n", "Minibatch: 0801 | Dis/Gen Cost: 1.341/0.909\n", "Epoch: 0012 | Dis/Gen AvgCost: 1.315/0.986\n", "Minibatch: 0001 | Dis/Gen Cost: 1.411/0.964\n", "Minibatch: 0201 | Dis/Gen Cost: 1.335/0.955\n", "Minibatch: 0401 | Dis/Gen Cost: 1.319/0.927\n", "Minibatch: 0601 | Dis/Gen Cost: 1.257/0.952\n", "Minibatch: 0801 | Dis/Gen Cost: 1.283/0.973\n", "Epoch: 0013 | Dis/Gen AvgCost: 1.329/0.974\n", "Minibatch: 0001 | Dis/Gen Cost: 1.266/1.170\n", "Minibatch: 0201 | Dis/Gen Cost: 1.478/0.830\n", "Minibatch: 0401 | Dis/Gen Cost: 1.300/0.954\n", "Minibatch: 0601 | Dis/Gen Cost: 1.305/0.980\n", "Minibatch: 0801 | Dis/Gen Cost: 1.435/0.809\n", "Epoch: 0014 | Dis/Gen AvgCost: 1.325/0.946\n", "Minibatch: 0001 | Dis/Gen Cost: 1.305/0.940\n", "Minibatch: 0201 | Dis/Gen Cost: 1.473/0.910\n", "Minibatch: 0401 | Dis/Gen Cost: 1.408/0.976\n", "Minibatch: 0601 | Dis/Gen Cost: 1.312/0.944\n", "Minibatch: 0801 | Dis/Gen Cost: 1.412/0.905\n", "Epoch: 0015 | Dis/Gen AvgCost: 1.338/0.949\n", "Minibatch: 0001 | Dis/Gen Cost: 1.297/0.971\n", "Minibatch: 0201 | Dis/Gen Cost: 1.196/1.051\n", "Minibatch: 0401 | Dis/Gen Cost: 1.262/0.956\n", "Minibatch: 0601 | Dis/Gen Cost: 1.248/0.974\n", "Minibatch: 0801 | Dis/Gen Cost: 1.278/0.954\n", "Epoch: 0016 | Dis/Gen AvgCost: 1.331/0.947\n", "Minibatch: 0001 | Dis/Gen Cost: 1.227/0.928\n", "Minibatch: 0201 | Dis/Gen Cost: 1.304/0.998\n", "Minibatch: 0401 | Dis/Gen Cost: 1.195/0.963\n", "Minibatch: 0601 | Dis/Gen Cost: 1.230/0.910\n", "Minibatch: 0801 | Dis/Gen Cost: 1.281/1.064\n", "Epoch: 0017 | Dis/Gen AvgCost: 1.335/0.914\n", "Minibatch: 0001 | Dis/Gen Cost: 1.423/0.921\n", "Minibatch: 0201 | Dis/Gen Cost: 1.309/0.892\n", "Minibatch: 0401 | Dis/Gen Cost: 1.311/0.895\n", "Minibatch: 0601 | Dis/Gen Cost: 1.378/0.842\n", "Minibatch: 0801 | Dis/Gen Cost: 1.388/0.833\n", "Epoch: 0018 | Dis/Gen AvgCost: 1.344/0.902\n", "Minibatch: 0001 | Dis/Gen Cost: 1.177/1.030\n", "Minibatch: 0201 | Dis/Gen Cost: 1.255/1.045\n", "Minibatch: 0401 | Dis/Gen Cost: 1.359/0.986\n", "Minibatch: 0601 | Dis/Gen Cost: 1.273/0.944\n", "Minibatch: 0801 | Dis/Gen Cost: 1.297/0.914\n", "Epoch: 0019 | Dis/Gen AvgCost: 1.333/0.928\n", "Minibatch: 0001 | Dis/Gen Cost: 1.403/0.921\n", "Minibatch: 0201 | Dis/Gen Cost: 1.272/0.932\n", "Minibatch: 0401 | Dis/Gen Cost: 1.250/0.931\n", "Minibatch: 0601 | Dis/Gen Cost: 1.298/0.904\n", "Minibatch: 0801 | Dis/Gen Cost: 1.290/0.852\n", "Epoch: 0020 | Dis/Gen AvgCost: 1.332/0.916\n", "Minibatch: 0001 | Dis/Gen Cost: 1.384/0.898\n", "Minibatch: 0201 | Dis/Gen Cost: 1.386/0.886\n", "Minibatch: 0401 | Dis/Gen Cost: 1.314/1.025\n", "Minibatch: 0601 | Dis/Gen Cost: 1.546/0.881\n", "Minibatch: 0801 | Dis/Gen Cost: 1.202/1.017\n", "Epoch: 0021 | Dis/Gen AvgCost: 1.330/0.930\n", "Minibatch: 0001 | Dis/Gen Cost: 1.232/1.135\n", "Minibatch: 0201 | Dis/Gen Cost: 1.317/0.930\n", "Minibatch: 0401 | Dis/Gen Cost: 1.194/1.068\n", "Minibatch: 0601 | Dis/Gen Cost: 1.378/0.859\n", "Minibatch: 0801 | Dis/Gen Cost: 1.267/0.955\n", "Epoch: 0022 | Dis/Gen AvgCost: 1.339/0.907\n", "Minibatch: 0001 | Dis/Gen Cost: 1.294/0.937\n", "Minibatch: 0201 | Dis/Gen Cost: 1.347/0.860\n", "Minibatch: 0401 | Dis/Gen Cost: 1.362/0.878\n", "Minibatch: 0601 | Dis/Gen Cost: 1.228/0.866\n", "Minibatch: 0801 | Dis/Gen Cost: 1.344/0.900\n", "Epoch: 0023 | Dis/Gen AvgCost: 1.339/0.895\n", "Minibatch: 0001 | Dis/Gen Cost: 1.454/0.811\n", "Minibatch: 0201 | Dis/Gen Cost: 1.448/0.924\n", "Minibatch: 0401 | Dis/Gen Cost: 1.300/0.950\n", "Minibatch: 0601 | Dis/Gen Cost: 1.326/0.881\n", "Minibatch: 0801 | Dis/Gen Cost: 1.283/1.006\n", "Epoch: 0024 | Dis/Gen AvgCost: 1.340/0.889\n", "Minibatch: 0001 | Dis/Gen Cost: 1.348/0.922\n", "Minibatch: 0201 | Dis/Gen Cost: 1.430/0.758\n", "Minibatch: 0401 | Dis/Gen Cost: 1.369/0.870\n", "Minibatch: 0601 | Dis/Gen Cost: 1.343/0.838\n", "Minibatch: 0801 | Dis/Gen Cost: 1.189/0.967\n", "Epoch: 0025 | Dis/Gen AvgCost: 1.347/0.891\n", "Minibatch: 0001 | Dis/Gen Cost: 1.395/0.865\n", "Minibatch: 0201 | Dis/Gen Cost: 1.495/0.803\n", "Minibatch: 0401 | Dis/Gen Cost: 1.450/0.861\n", "Minibatch: 0601 | Dis/Gen Cost: 1.299/0.953\n", "Minibatch: 0801 | Dis/Gen Cost: 1.426/0.793\n", "Epoch: 0026 | Dis/Gen AvgCost: 1.339/0.891\n", "Minibatch: 0001 | Dis/Gen Cost: 1.348/0.856\n", "Minibatch: 0201 | Dis/Gen Cost: 1.303/0.942\n", "Minibatch: 0401 | Dis/Gen Cost: 1.344/0.846\n", "Minibatch: 0601 | Dis/Gen Cost: 1.276/0.888\n", "Minibatch: 0801 | Dis/Gen Cost: 1.393/0.855\n", "Epoch: 0027 | Dis/Gen AvgCost: 1.347/0.881\n", "Minibatch: 0001 | Dis/Gen Cost: 1.305/0.963\n", "Minibatch: 0201 | Dis/Gen Cost: 1.391/0.850\n", "Minibatch: 0401 | Dis/Gen Cost: 1.380/0.795\n", "Minibatch: 0601 | Dis/Gen Cost: 1.295/0.840\n", "Minibatch: 0801 | Dis/Gen Cost: 1.194/0.927\n", "Epoch: 0028 | Dis/Gen AvgCost: 1.350/0.867\n", "Minibatch: 0001 | Dis/Gen Cost: 1.394/0.805\n", "Minibatch: 0201 | Dis/Gen Cost: 1.288/0.889\n", "Minibatch: 0401 | Dis/Gen Cost: 1.331/0.922\n", "Minibatch: 0601 | Dis/Gen Cost: 1.466/0.795\n", "Minibatch: 0801 | Dis/Gen Cost: 1.430/0.779\n", "Epoch: 0029 | Dis/Gen AvgCost: 1.341/0.873\n", "Minibatch: 0001 | Dis/Gen Cost: 1.297/0.879\n", "Minibatch: 0201 | Dis/Gen Cost: 1.268/0.932\n", "Minibatch: 0401 | Dis/Gen Cost: 1.432/0.831\n", "Minibatch: 0601 | Dis/Gen Cost: 1.335/0.845\n", "Minibatch: 0801 | Dis/Gen Cost: 1.401/0.962\n", "Epoch: 0030 | Dis/Gen AvgCost: 1.337/0.872\n", "Minibatch: 0001 | Dis/Gen Cost: 1.300/0.910\n", "Minibatch: 0201 | Dis/Gen Cost: 1.369/0.872\n", "Minibatch: 0401 | Dis/Gen Cost: 1.421/0.826\n", "Minibatch: 0601 | Dis/Gen Cost: 1.351/0.946\n", "Minibatch: 0801 | Dis/Gen Cost: 1.401/0.864\n", "Epoch: 0031 | Dis/Gen AvgCost: 1.344/0.863\n", "Minibatch: 0001 | Dis/Gen Cost: 1.273/0.875\n", "Minibatch: 0201 | Dis/Gen Cost: 1.353/0.836\n", "Minibatch: 0401 | Dis/Gen Cost: 1.372/0.867\n", "Minibatch: 0601 | Dis/Gen Cost: 1.368/0.853\n", "Minibatch: 0801 | Dis/Gen Cost: 1.186/0.904\n", "Epoch: 0032 | Dis/Gen AvgCost: 1.342/0.868\n", "Minibatch: 0001 | Dis/Gen Cost: 1.405/0.823\n", "Minibatch: 0201 | Dis/Gen Cost: 1.321/0.931\n", "Minibatch: 0401 | Dis/Gen Cost: 1.361/0.858\n", "Minibatch: 0601 | Dis/Gen Cost: 1.274/0.891\n", "Minibatch: 0801 | Dis/Gen Cost: 1.397/0.848\n", "Epoch: 0033 | Dis/Gen AvgCost: 1.345/0.858\n", "Minibatch: 0001 | Dis/Gen Cost: 1.174/0.992\n", "Minibatch: 0201 | Dis/Gen Cost: 1.278/0.902\n", "Minibatch: 0401 | Dis/Gen Cost: 1.341/0.900\n", "Minibatch: 0601 | Dis/Gen Cost: 1.267/0.906\n", "Minibatch: 0801 | Dis/Gen Cost: 1.369/0.820\n", "Epoch: 0034 | Dis/Gen AvgCost: 1.346/0.862\n", "Minibatch: 0001 | Dis/Gen Cost: 1.305/0.838\n", "Minibatch: 0201 | Dis/Gen Cost: 1.403/0.846\n", "Minibatch: 0401 | Dis/Gen Cost: 1.338/0.850\n", "Minibatch: 0601 | Dis/Gen Cost: 1.343/0.833\n", "Minibatch: 0801 | Dis/Gen Cost: 1.334/0.797\n", "Epoch: 0035 | Dis/Gen AvgCost: 1.353/0.850\n", "Minibatch: 0001 | Dis/Gen Cost: 1.394/0.846\n", "Minibatch: 0201 | Dis/Gen Cost: 1.407/0.841\n", "Minibatch: 0401 | Dis/Gen Cost: 1.481/0.732\n", "Minibatch: 0601 | Dis/Gen Cost: 1.328/0.884\n", "Minibatch: 0801 | Dis/Gen Cost: 1.414/0.789\n", "Epoch: 0036 | Dis/Gen AvgCost: 1.352/0.850\n", "Minibatch: 0001 | Dis/Gen Cost: 1.310/0.838\n", "Minibatch: 0201 | Dis/Gen Cost: 1.376/0.805\n", "Minibatch: 0401 | Dis/Gen Cost: 1.341/0.864\n", "Minibatch: 0601 | Dis/Gen Cost: 1.328/0.896\n", "Minibatch: 0801 | Dis/Gen Cost: 1.383/0.791\n", "Epoch: 0037 | Dis/Gen AvgCost: 1.352/0.840\n", "Minibatch: 0001 | Dis/Gen Cost: 1.295/0.861\n", "Minibatch: 0201 | Dis/Gen Cost: 1.455/0.826\n", "Minibatch: 0401 | Dis/Gen Cost: 1.420/0.796\n", "Minibatch: 0601 | Dis/Gen Cost: 1.337/0.871\n", "Minibatch: 0801 | Dis/Gen Cost: 1.328/0.863\n", "Epoch: 0038 | Dis/Gen AvgCost: 1.348/0.852\n", "Minibatch: 0001 | Dis/Gen Cost: 1.382/0.824\n", "Minibatch: 0201 | Dis/Gen Cost: 1.302/0.897\n", "Minibatch: 0401 | Dis/Gen Cost: 1.385/0.792\n", "Minibatch: 0601 | Dis/Gen Cost: 1.314/0.847\n", "Minibatch: 0801 | Dis/Gen Cost: 1.423/0.779\n", "Epoch: 0039 | Dis/Gen AvgCost: 1.350/0.848\n", "Minibatch: 0001 | Dis/Gen Cost: 1.419/0.852\n", "Minibatch: 0201 | Dis/Gen Cost: 1.390/0.885\n", "Minibatch: 0401 | Dis/Gen Cost: 1.348/0.802\n", "Minibatch: 0601 | Dis/Gen Cost: 1.349/0.833\n", "Minibatch: 0801 | Dis/Gen Cost: 1.382/0.775\n", "Epoch: 0040 | Dis/Gen AvgCost: 1.349/0.842\n", "Minibatch: 0001 | Dis/Gen Cost: 1.289/0.918\n", "Minibatch: 0201 | Dis/Gen Cost: 1.410/0.772\n", "Minibatch: 0401 | Dis/Gen Cost: 1.393/0.790\n", "Minibatch: 0601 | Dis/Gen Cost: 1.317/0.829\n", "Minibatch: 0801 | Dis/Gen Cost: 1.267/0.878\n", "Epoch: 0041 | Dis/Gen AvgCost: 1.358/0.837\n", "Minibatch: 0001 | Dis/Gen Cost: 1.342/0.859\n", "Minibatch: 0201 | Dis/Gen Cost: 1.340/0.870\n", "Minibatch: 0401 | Dis/Gen Cost: 1.394/0.803\n", "Minibatch: 0601 | Dis/Gen Cost: 1.355/0.820\n", "Minibatch: 0801 | Dis/Gen Cost: 1.359/0.836\n", "Epoch: 0042 | Dis/Gen AvgCost: 1.348/0.847\n", "Minibatch: 0001 | Dis/Gen Cost: 1.330/0.807\n", "Minibatch: 0201 | Dis/Gen Cost: 1.386/0.836\n", "Minibatch: 0401 | Dis/Gen Cost: 1.400/0.816\n", "Minibatch: 0601 | Dis/Gen Cost: 1.355/0.855\n", "Minibatch: 0801 | Dis/Gen Cost: 1.315/0.919\n", "Epoch: 0043 | Dis/Gen AvgCost: 1.354/0.845\n", "Minibatch: 0001 | Dis/Gen Cost: 1.338/0.838\n", "Minibatch: 0201 | Dis/Gen Cost: 1.317/0.866\n", "Minibatch: 0401 | Dis/Gen Cost: 1.341/0.819\n", "Minibatch: 0601 | Dis/Gen Cost: 1.260/0.863\n", "Minibatch: 0801 | Dis/Gen Cost: 1.285/0.917\n", "Epoch: 0044 | Dis/Gen AvgCost: 1.351/0.850\n", "Minibatch: 0001 | Dis/Gen Cost: 1.378/0.826\n", "Minibatch: 0201 | Dis/Gen Cost: 1.332/0.881\n", "Minibatch: 0401 | Dis/Gen Cost: 1.247/0.920\n", "Minibatch: 0601 | Dis/Gen Cost: 1.339/0.807\n", "Minibatch: 0801 | Dis/Gen Cost: 1.350/0.850\n", "Epoch: 0045 | Dis/Gen AvgCost: 1.356/0.836\n", "Minibatch: 0001 | Dis/Gen Cost: 1.341/0.872\n", "Minibatch: 0201 | Dis/Gen Cost: 1.406/0.818\n", "Minibatch: 0401 | Dis/Gen Cost: 1.478/0.765\n", "Minibatch: 0601 | Dis/Gen Cost: 1.426/0.837\n", "Minibatch: 0801 | Dis/Gen Cost: 1.271/0.824\n", "Epoch: 0046 | Dis/Gen AvgCost: 1.356/0.832\n", "Minibatch: 0001 | Dis/Gen Cost: 1.388/0.812\n", "Minibatch: 0201 | Dis/Gen Cost: 1.279/0.916\n", "Minibatch: 0401 | Dis/Gen Cost: 1.331/0.805\n", "Minibatch: 0601 | Dis/Gen Cost: 1.321/0.861\n", "Minibatch: 0801 | Dis/Gen Cost: 1.344/0.860\n", "Epoch: 0047 | Dis/Gen AvgCost: 1.351/0.843\n", "Minibatch: 0001 | Dis/Gen Cost: 1.342/0.807\n", "Minibatch: 0201 | Dis/Gen Cost: 1.356/0.813\n", "Minibatch: 0401 | Dis/Gen Cost: 1.361/0.806\n", "Minibatch: 0601 | Dis/Gen Cost: 1.393/0.811\n", "Minibatch: 0801 | Dis/Gen Cost: 1.379/0.783\n", "Epoch: 0048 | Dis/Gen AvgCost: 1.357/0.824\n", "Minibatch: 0001 | Dis/Gen Cost: 1.368/0.793\n", "Minibatch: 0201 | Dis/Gen Cost: 1.364/0.812\n", "Minibatch: 0401 | Dis/Gen Cost: 1.339/0.843\n", "Minibatch: 0601 | Dis/Gen Cost: 1.331/0.798\n", "Minibatch: 0801 | Dis/Gen Cost: 1.358/0.815\n", "Epoch: 0049 | Dis/Gen AvgCost: 1.359/0.823\n", "Minibatch: 0001 | Dis/Gen Cost: 1.367/0.819\n", "Minibatch: 0201 | Dis/Gen Cost: 1.300/0.845\n", "Minibatch: 0401 | Dis/Gen Cost: 1.364/0.808\n", "Minibatch: 0601 | Dis/Gen Cost: 1.284/0.912\n", "Minibatch: 0801 | Dis/Gen Cost: 1.334/0.837\n", "Epoch: 0050 | Dis/Gen AvgCost: 1.355/0.833\n" ] } ], "source": [ "##########################\n", "### TRAINING & EVALUATION\n", "##########################\n", "\n", "with tf.Session(graph=g) as sess:\n", " sess.run(tf.global_variables_initializer())\n", " \n", " avg_costs = {'discriminator': [], 'generator': []}\n", "\n", " for epoch in range(training_epochs):\n", " dis_avg_cost, gen_avg_cost = 0., 0.\n", " total_batch = mnist.train.num_examples // batch_size\n", "\n", " for i in range(total_batch):\n", " \n", " batch_x, batch_y = mnist.train.next_batch(batch_size)\n", " batch_x = batch_x*2 - 1 # normalize\n", " batch_randsample = np.random.uniform(-1, 1, size=(batch_size, gen_input_size))\n", " \n", " # Train\n", " \n", " _, dc = sess.run(['train_discriminator', 'discriminator_cost:0'],\n", " feed_dict={'discriminator_inputs:0': batch_x, \n", " 'generator_inputs:0': batch_randsample,\n", " 'dropout:0': dropout_rate,\n", " 'is_training:0': True})\n", " \n", " _, gc = sess.run(['train_generator', 'generator_cost:0'],\n", " feed_dict={'generator_inputs:0': batch_randsample,\n", " 'dropout:0': dropout_rate,\n", " 'is_training:0': True})\n", " \n", " dis_avg_cost += dc\n", " gen_avg_cost += gc\n", "\n", " if not i % print_interval:\n", " print(\"Minibatch: %04d | Dis/Gen Cost: %.3f/%.3f\" % (i + 1, dc, gc))\n", " \n", "\n", " print(\"Epoch: %04d | Dis/Gen AvgCost: %.3f/%.3f\" % \n", " (epoch + 1, dis_avg_cost / total_batch, gen_avg_cost / total_batch))\n", " \n", " avg_costs['discriminator'].append(dis_avg_cost / total_batch)\n", " avg_costs['generator'].append(gen_avg_cost / total_batch)\n", " \n", " \n", " saver.save(sess, save_path='./gan-conv.ckpt')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "plt.plot(range(len(avg_costs['discriminator'])), \n", " avg_costs['discriminator'], label='discriminator')\n", "plt.plot(range(len(avg_costs['generator'])),\n", " avg_costs['generator'], label='generator')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use standard file APIs to check for files with this prefix.\n", "INFO:tensorflow:Restoring parameters from ./gan-conv.ckpt\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "####################################\n", "### RELOAD & GENERATE SAMPLE IMAGES\n", "####################################\n", "\n", "\n", "n_examples = 25\n", "\n", "with tf.Session(graph=g) as sess:\n", " saver.restore(sess, save_path='./gan-conv.ckpt')\n", "\n", " batch_randsample = np.random.uniform(-1, 1, size=(n_examples, gen_input_size))\n", " new_examples = sess.run('generator/generator_outputs:0',\n", " feed_dict={'generator_inputs:0': batch_randsample,\n", " 'dropout:0': 0.0,\n", " 'is_training:0': False})\n", "\n", "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(8, 8),\n", " sharey=True, sharex=True)\n", "\n", "for image, ax in zip(new_examples, axes.flatten()):\n", " ax.imshow(image.reshape((dis_input_size // 28, dis_input_size // 28)), cmap='binary')\n", "\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }