{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", "- Author: Sebastian Raschka\n", "- GitHub Repository: https://github.com/rasbt/deeplearning-models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sebastian Raschka \n", "\n", "CPython 3.7.3\n", "IPython 7.6.1\n", "\n", "tensorflow 1.13.1\n" ] } ], "source": [ "%load_ext watermark\n", "%watermark -a 'Sebastian Raschka' -v -p tensorflow" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Convolutional General Adversarial Networks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Implementation of General Adversarial Nets (GAN) where both the discriminator and generator have convolutional and deconvolutional layers, respectively. In this example, the GAN generator was trained to generate MNIST images.\n", "\n", "Uses\n", "\n", "- samples from a random normal distribution (range [-1, 1])\n", "- dropout\n", "- leaky relus\n", "- batch normalization\n", "- separate batches for \"fake\" and \"real\" images (where the labels are 1 = real images, 0 = fake images)\n", "- MNIST images normalized to [-1, 1] range\n", "- generator with tanh output\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/device:GPU:0'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "import tensorflow as tf\n", "from tensorflow.examples.tutorials.mnist import input_data\n", "import pickle as pkl\n", "\n", "tf.test.gpu_device_name()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From :17: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please write your own downloading logic.\n", "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use tf.data to implement this functionality.\n", "Extracting MNIST_data/train-images-idx3-ubyte.gz\n", "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use tf.data to implement this functionality.\n", "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n", "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n", "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n", "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n", "WARNING:tensorflow:From :64: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use keras.layers.dense instead.\n", "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Colocations handled automatically by placer.\n", "WARNING:tensorflow:From :65: batch_normalization (from tensorflow.python.layers.normalization) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use keras.layers.batch_normalization instead.\n", "WARNING:tensorflow:From :74: conv2d_transpose (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use keras.layers.conv2d_transpose instead.\n", "WARNING:tensorflow:From :77: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use keras.layers.dropout instead.\n", "WARNING:tensorflow:From :121: conv2d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use keras.layers.conv2d instead.\n", "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use tf.cast instead.\n" ] } ], "source": [ "### Abbreviatiuons\n", "# dis_*: discriminator network\n", "# gen_*: generator network\n", "\n", "########################\n", "### Helper functions\n", "########################\n", "\n", "def leaky_relu(x, alpha=0.0001):\n", " return tf.maximum(alpha * x, x)\n", "\n", "\n", "########################\n", "### DATASET\n", "########################\n", "\n", "mnist = input_data.read_data_sets('MNIST_data')\n", "\n", "\n", "#########################\n", "### SETTINGS\n", "#########################\n", "\n", "# Hyperparameters\n", "learning_rate = 0.001\n", "training_epochs = 50\n", "batch_size = 64\n", "dropout_rate = 0.5\n", "\n", "# Architecture\n", "dis_input_size = 784\n", "gen_input_size = 100\n", "\n", "# Other settings\n", "print_interval = 200\n", "\n", "#########################\n", "### GRAPH DEFINITION\n", "#########################\n", "\n", "g = tf.Graph()\n", "with g.as_default():\n", " \n", " # Placeholders for settings\n", " dropout = tf.placeholder(tf.float32, shape=None, name='dropout')\n", " is_training = tf.placeholder(tf.bool, shape=None, name='is_training')\n", " \n", " # Input data\n", " dis_x = tf.placeholder(tf.float32, shape=[None, dis_input_size],\n", " name='discriminator_inputs') \n", " gen_x = tf.placeholder(tf.float32, [None, gen_input_size],\n", " name='generator_inputs')\n", "\n", "\n", " ##################\n", " # Generator Model\n", " ##################\n", "\n", " with tf.variable_scope('generator'):\n", " \n", " # 100 => 784 => 7x7x64\n", " gen_fc = tf.layers.dense(inputs=gen_x, units=3136,\n", " bias_initializer=None, # no bias required when using batch_norm\n", " activation=None)\n", " gen_fc = tf.layers.batch_normalization(gen_fc, training=is_training)\n", " gen_fc = leaky_relu(gen_fc)\n", " gen_fc = tf.reshape(gen_fc, (-1, 7, 7, 64))\n", " \n", " # 7x7x64 => 14x14x32\n", " deconv1 = tf.layers.conv2d_transpose(gen_fc, filters=32, \n", " kernel_size=(3, 3), strides=(2, 2), \n", " padding='same',\n", " bias_initializer=None,\n", " activation=None)\n", " deconv1 = tf.layers.batch_normalization(deconv1, training=is_training)\n", " deconv1 = leaky_relu(deconv1) \n", " deconv1 = tf.layers.dropout(deconv1, rate=dropout_rate)\n", " \n", " # 14x14x32 => 28x28x16\n", " deconv2 = tf.layers.conv2d_transpose(deconv1, filters=16, \n", " kernel_size=(3, 3), strides=(2, 2), \n", " padding='same',\n", " bias_initializer=None,\n", " activation=None)\n", " deconv2 = tf.layers.batch_normalization(deconv2, training=is_training)\n", " deconv2 = leaky_relu(deconv2) \n", " deconv2 = tf.layers.dropout(deconv2, rate=dropout_rate)\n", " \n", " # 28x28x16 => 28x28x8\n", " deconv3 = tf.layers.conv2d_transpose(deconv2, filters=8, \n", " kernel_size=(3, 3), strides=(1, 1), \n", " padding='same',\n", " bias_initializer=None,\n", " activation=None)\n", " deconv3 = tf.layers.batch_normalization(deconv3, training=is_training)\n", " deconv3 = leaky_relu(deconv3) \n", " deconv3 = tf.layers.dropout(deconv3, rate=dropout_rate)\n", " \n", " # 28x28x8 => 28x28x1\n", " gen_logits = tf.layers.conv2d_transpose(deconv3, filters=1, \n", " kernel_size=(3, 3), strides=(1, 1), \n", " padding='same',\n", " bias_initializer=None,\n", " activation=None)\n", " gen_out = tf.tanh(gen_logits, 'generator_outputs')\n", "\n", "\n", " ######################\n", " # Discriminator Model\n", " ######################\n", " \n", " def build_discriminator_graph(input_x, reuse=None):\n", "\n", " with tf.variable_scope('discriminator', reuse=reuse):\n", " \n", " # 28x28x1 => 14x14x8\n", " conv_input = tf.reshape(input_x, (-1, 28, 28, 1))\n", " conv1 = tf.layers.conv2d(conv_input, filters=8, kernel_size=(3, 3),\n", " strides=(2, 2), padding='same',\n", " bias_initializer=None,\n", " activation=None)\n", " conv1 = tf.layers.batch_normalization(conv1, training=is_training)\n", " conv1 = leaky_relu(conv1)\n", " conv1 = tf.layers.dropout(conv1, rate=dropout_rate)\n", " \n", " # 14x14x8 => 7x7x32\n", " conv2 = tf.layers.conv2d(conv1, filters=32, kernel_size=(3, 3),\n", " strides=(2, 2), padding='same',\n", " bias_initializer=None,\n", " activation=None)\n", " conv2 = tf.layers.batch_normalization(conv2, training=is_training)\n", " conv2 = leaky_relu(conv2)\n", " conv2 = tf.layers.dropout(conv2, rate=dropout_rate)\n", "\n", " # fully connected layer\n", " fc_input = tf.reshape(conv2, (-1, 7*7*32))\n", " logits = tf.layers.dense(inputs=fc_input, units=1, activation=None)\n", " out = tf.sigmoid(logits)\n", " \n", " return logits, out \n", "\n", " # Create a discriminator for real data and a discriminator for fake data\n", " dis_real_logits, dis_real_out = build_discriminator_graph(dis_x, reuse=False)\n", " dis_fake_logits, dis_fake_out = build_discriminator_graph(gen_out, reuse=True)\n", "\n", "\n", " #####################################\n", " # Generator and Discriminator Losses\n", " #####################################\n", " \n", " # Two discriminator cost components: loss on real data + loss on fake data\n", " # Real data has class label 1, fake data has class label 0\n", " dis_real_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_real_logits, \n", " labels=tf.ones_like(dis_real_logits))\n", " dis_fake_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake_logits, \n", " labels=tf.zeros_like(dis_fake_logits))\n", " dis_cost = tf.add(tf.reduce_mean(dis_fake_loss), \n", " tf.reduce_mean(dis_real_loss), \n", " name='discriminator_cost')\n", " \n", " # Generator cost: difference between dis. prediction and label \"1\" for real images\n", " gen_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=dis_fake_logits,\n", " labels=tf.ones_like(dis_fake_logits))\n", " gen_cost = tf.reduce_mean(gen_loss, name='generator_cost')\n", " \n", " \n", " #########################################\n", " # Generator and Discriminator Optimizers\n", " #########################################\n", " \n", " dis_optimizer = tf.train.AdamOptimizer(learning_rate)\n", " dis_train_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='discriminator')\n", " dis_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='discriminator')\n", " \n", " with tf.control_dependencies(dis_update_ops): # required to upd. batch_norm params\n", " dis_train = dis_optimizer.minimize(dis_cost, var_list=dis_train_vars,\n", " name='train_discriminator')\n", " \n", " gen_optimizer = tf.train.AdamOptimizer(learning_rate)\n", " gen_train_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='generator')\n", " gen_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='generator')\n", " \n", " with tf.control_dependencies(gen_update_ops): # required to upd. batch_norm params\n", " gen_train = gen_optimizer.minimize(gen_cost, var_list=gen_train_vars,\n", " name='train_generator')\n", " \n", " # Saver to save session for reuse\n", " saver = tf.train.Saver()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Minibatch: 0001 | Dis/Gen Cost: 1.630/0.866\n", "Minibatch: 0201 | Dis/Gen Cost: 0.850/1.879\n", "Minibatch: 0401 | Dis/Gen Cost: 0.606/2.467\n", "Minibatch: 0601 | Dis/Gen Cost: 0.695/1.661\n", "Minibatch: 0801 | Dis/Gen Cost: 1.149/1.297\n", "Epoch: 0001 | Dis/Gen AvgCost: 0.820/1.887\n", "Minibatch: 0001 | Dis/Gen Cost: 0.707/1.486\n", "Minibatch: 0201 | Dis/Gen Cost: 0.924/1.438\n", "Minibatch: 0401 | Dis/Gen Cost: 0.751/1.508\n", "Minibatch: 0601 | Dis/Gen Cost: 0.899/1.611\n", "Minibatch: 0801 | Dis/Gen Cost: 0.914/1.535\n", "Epoch: 0002 | Dis/Gen AvgCost: 0.954/1.510\n", "Minibatch: 0001 | Dis/Gen Cost: 0.498/1.955\n", "Minibatch: 0201 | Dis/Gen Cost: 0.757/1.670\n", "Minibatch: 0401 | Dis/Gen Cost: 1.100/1.204\n", "Minibatch: 0601 | Dis/Gen Cost: 0.656/2.054\n", "Minibatch: 0801 | Dis/Gen Cost: 1.036/1.174\n", "Epoch: 0003 | Dis/Gen AvgCost: 0.784/1.720\n", "Minibatch: 0001 | Dis/Gen Cost: 1.576/0.992\n", "Minibatch: 0201 | Dis/Gen Cost: 0.663/2.002\n", "Minibatch: 0401 | Dis/Gen Cost: 0.869/1.773\n", "Minibatch: 0601 | Dis/Gen Cost: 0.675/1.772\n", "Minibatch: 0801 | Dis/Gen Cost: 0.881/1.489\n", "Epoch: 0004 | Dis/Gen AvgCost: 0.898/1.575\n", "Minibatch: 0001 | Dis/Gen Cost: 1.201/1.386\n", "Minibatch: 0201 | Dis/Gen Cost: 1.245/1.606\n", "Minibatch: 0401 | Dis/Gen Cost: 1.281/1.015\n", "Minibatch: 0601 | Dis/Gen Cost: 0.925/1.124\n", "Minibatch: 0801 | Dis/Gen Cost: 1.126/1.634\n", "Epoch: 0005 | Dis/Gen AvgCost: 1.037/1.435\n", "Minibatch: 0001 | Dis/Gen Cost: 0.853/1.626\n", "Minibatch: 0201 | Dis/Gen Cost: 1.204/0.929\n", "Minibatch: 0401 | Dis/Gen Cost: 1.070/1.365\n", "Minibatch: 0601 | Dis/Gen Cost: 1.366/0.927\n", "Minibatch: 0801 | Dis/Gen Cost: 1.253/1.500\n", "Epoch: 0006 | Dis/Gen AvgCost: 1.168/1.186\n", "Minibatch: 0001 | Dis/Gen Cost: 1.590/0.945\n", "Minibatch: 0201 | Dis/Gen Cost: 0.822/1.563\n", "Minibatch: 0401 | Dis/Gen Cost: 0.894/1.410\n", "Minibatch: 0601 | Dis/Gen Cost: 1.292/1.131\n", "Minibatch: 0801 | Dis/Gen Cost: 1.361/1.005\n", "Epoch: 0007 | Dis/Gen AvgCost: 1.248/1.103\n", "Minibatch: 0001 | Dis/Gen Cost: 1.860/0.697\n", "Minibatch: 0201 | Dis/Gen Cost: 1.291/0.986\n", "Minibatch: 0401 | Dis/Gen Cost: 1.097/0.934\n", "Minibatch: 0601 | Dis/Gen Cost: 1.316/0.788\n", "Minibatch: 0801 | Dis/Gen Cost: 1.437/0.885\n", "Epoch: 0008 | Dis/Gen AvgCost: 1.298/0.995\n", "Minibatch: 0001 | Dis/Gen Cost: 1.150/1.072\n", "Minibatch: 0201 | Dis/Gen Cost: 1.177/1.148\n", "Minibatch: 0401 | Dis/Gen Cost: 1.351/0.884\n", "Minibatch: 0601 | Dis/Gen Cost: 1.434/0.797\n", "Minibatch: 0801 | Dis/Gen Cost: 1.291/0.929\n", "Epoch: 0009 | Dis/Gen AvgCost: 1.333/0.968\n", "Minibatch: 0001 | Dis/Gen Cost: 1.324/0.764\n", "Minibatch: 0201 | Dis/Gen Cost: 1.255/0.942\n", "Minibatch: 0401 | Dis/Gen Cost: 1.181/1.007\n", "Minibatch: 0601 | Dis/Gen Cost: 1.132/1.134\n", "Minibatch: 0801 | Dis/Gen Cost: 1.170/1.249\n", "Epoch: 0010 | Dis/Gen AvgCost: 1.328/0.922\n", "Minibatch: 0001 | Dis/Gen Cost: 1.539/0.739\n", "Minibatch: 0201 | Dis/Gen Cost: 1.181/1.186\n", "Minibatch: 0401 | Dis/Gen Cost: 1.014/1.331\n", "Minibatch: 0601 | Dis/Gen Cost: 1.380/0.884\n", "Minibatch: 0801 | Dis/Gen Cost: 1.441/0.893\n", "Epoch: 0011 | Dis/Gen AvgCost: 1.306/0.949\n", "Minibatch: 0001 | Dis/Gen Cost: 1.248/0.953\n", "Minibatch: 0201 | Dis/Gen Cost: 1.421/0.751\n", "Minibatch: 0401 | Dis/Gen Cost: 1.323/0.891\n", "Minibatch: 0601 | Dis/Gen Cost: 1.363/0.912\n", "Minibatch: 0801 | Dis/Gen Cost: 1.174/1.112\n", "Epoch: 0012 | Dis/Gen AvgCost: 1.334/0.931\n", "Minibatch: 0001 | Dis/Gen Cost: 1.463/0.792\n", "Minibatch: 0201 | Dis/Gen Cost: 1.296/0.992\n", "Minibatch: 0401 | Dis/Gen Cost: 1.213/1.037\n", "Minibatch: 0601 | Dis/Gen Cost: 1.273/0.899\n", "Minibatch: 0801 | Dis/Gen Cost: 1.282/0.893\n", "Epoch: 0013 | Dis/Gen AvgCost: 1.323/0.910\n", "Minibatch: 0001 | Dis/Gen Cost: 1.192/0.921\n", "Minibatch: 0201 | Dis/Gen Cost: 1.287/0.933\n", "Minibatch: 0401 | Dis/Gen Cost: 1.292/0.898\n", "Minibatch: 0601 | Dis/Gen Cost: 1.164/0.945\n", "Minibatch: 0801 | Dis/Gen Cost: 1.469/0.776\n", "Epoch: 0014 | Dis/Gen AvgCost: 1.312/0.890\n", "Minibatch: 0001 | Dis/Gen Cost: 1.363/0.876\n", "Minibatch: 0201 | Dis/Gen Cost: 1.398/0.759\n", "Minibatch: 0401 | Dis/Gen Cost: 1.099/1.088\n", "Minibatch: 0601 | Dis/Gen Cost: 1.415/0.831\n", "Minibatch: 0801 | Dis/Gen Cost: 1.287/0.813\n", "Epoch: 0015 | Dis/Gen AvgCost: 1.310/0.896\n", "Minibatch: 0001 | Dis/Gen Cost: 1.309/0.910\n", "Minibatch: 0201 | Dis/Gen Cost: 1.397/0.829\n", "Minibatch: 0401 | Dis/Gen Cost: 1.221/0.949\n", "Minibatch: 0601 | Dis/Gen Cost: 1.284/0.918\n", "Minibatch: 0801 | Dis/Gen Cost: 1.315/0.737\n", "Epoch: 0016 | Dis/Gen AvgCost: 1.306/0.860\n", "Minibatch: 0001 | Dis/Gen Cost: 1.193/0.901\n", "Minibatch: 0201 | Dis/Gen Cost: 1.339/0.908\n", "Minibatch: 0401 | Dis/Gen Cost: 1.119/0.969\n", "Minibatch: 0601 | Dis/Gen Cost: 1.293/0.907\n", "Minibatch: 0801 | Dis/Gen Cost: 1.368/0.882\n", "Epoch: 0017 | Dis/Gen AvgCost: 1.320/0.892\n", "Minibatch: 0001 | Dis/Gen Cost: 1.308/1.014\n", "Minibatch: 0201 | Dis/Gen Cost: 1.194/0.936\n", "Minibatch: 0401 | Dis/Gen Cost: 1.536/0.755\n", "Minibatch: 0601 | Dis/Gen Cost: 1.443/0.810\n", "Minibatch: 0801 | Dis/Gen Cost: 1.288/0.730\n", "Epoch: 0018 | Dis/Gen AvgCost: 1.315/0.867\n", "Minibatch: 0001 | Dis/Gen Cost: 1.259/0.979\n", "Minibatch: 0201 | Dis/Gen Cost: 1.307/0.822\n", "Minibatch: 0401 | Dis/Gen Cost: 1.242/0.845\n", "Minibatch: 0601 | Dis/Gen Cost: 1.422/0.891\n", "Minibatch: 0801 | Dis/Gen Cost: 1.263/0.904\n", "Epoch: 0019 | Dis/Gen AvgCost: 1.306/0.866\n", "Minibatch: 0001 | Dis/Gen Cost: 1.204/0.811\n", "Minibatch: 0201 | Dis/Gen Cost: 1.340/0.810\n", "Minibatch: 0401 | Dis/Gen Cost: 1.278/0.963\n", "Minibatch: 0601 | Dis/Gen Cost: 1.249/0.936\n", "Minibatch: 0801 | Dis/Gen Cost: 1.285/0.945\n", "Epoch: 0020 | Dis/Gen AvgCost: 1.316/0.853\n", "Minibatch: 0001 | Dis/Gen Cost: 1.370/0.772\n", "Minibatch: 0201 | Dis/Gen Cost: 1.478/0.762\n", "Minibatch: 0401 | Dis/Gen Cost: 1.440/0.822\n", "Minibatch: 0601 | Dis/Gen Cost: 1.269/0.809\n", "Minibatch: 0801 | Dis/Gen Cost: 1.260/0.923\n", "Epoch: 0021 | Dis/Gen AvgCost: 1.324/0.837\n", "Minibatch: 0001 | Dis/Gen Cost: 1.401/0.892\n", "Minibatch: 0201 | Dis/Gen Cost: 1.361/0.762\n", "Minibatch: 0401 | Dis/Gen Cost: 1.121/1.012\n", "Minibatch: 0601 | Dis/Gen Cost: 1.366/0.822\n", "Minibatch: 0801 | Dis/Gen Cost: 1.484/0.744\n", "Epoch: 0022 | Dis/Gen AvgCost: 1.314/0.851\n", "Minibatch: 0001 | Dis/Gen Cost: 1.207/0.829\n", "Minibatch: 0201 | Dis/Gen Cost: 1.320/0.786\n", "Minibatch: 0401 | Dis/Gen Cost: 1.327/0.807\n", "Minibatch: 0601 | Dis/Gen Cost: 1.250/0.909\n", "Minibatch: 0801 | Dis/Gen Cost: 1.339/0.769\n", "Epoch: 0023 | Dis/Gen AvgCost: 1.323/0.833\n", "Minibatch: 0001 | Dis/Gen Cost: 1.363/0.825\n", "Minibatch: 0201 | Dis/Gen Cost: 1.416/0.738\n", "Minibatch: 0401 | Dis/Gen Cost: 1.290/0.876\n", "Minibatch: 0601 | Dis/Gen Cost: 1.257/0.825\n", "Minibatch: 0801 | Dis/Gen Cost: 1.510/0.633\n", "Epoch: 0024 | Dis/Gen AvgCost: 1.323/0.841\n", "Minibatch: 0001 | Dis/Gen Cost: 1.291/0.694\n", "Minibatch: 0201 | Dis/Gen Cost: 1.400/0.720\n", "Minibatch: 0401 | Dis/Gen Cost: 1.340/0.802\n", "Minibatch: 0601 | Dis/Gen Cost: 1.339/0.784\n", "Minibatch: 0801 | Dis/Gen Cost: 1.211/0.886\n", "Epoch: 0025 | Dis/Gen AvgCost: 1.339/0.811\n", "Minibatch: 0001 | Dis/Gen Cost: 1.395/0.865\n", "Minibatch: 0201 | Dis/Gen Cost: 1.400/0.823\n", "Minibatch: 0401 | Dis/Gen Cost: 1.357/0.811\n", "Minibatch: 0601 | Dis/Gen Cost: 1.404/0.741\n", "Minibatch: 0801 | Dis/Gen Cost: 1.298/0.930\n", "Epoch: 0026 | Dis/Gen AvgCost: 1.340/0.819\n", "Minibatch: 0001 | Dis/Gen Cost: 1.257/0.833\n", "Minibatch: 0201 | Dis/Gen Cost: 1.359/0.772\n", "Minibatch: 0401 | Dis/Gen Cost: 1.453/0.798\n", "Minibatch: 0601 | Dis/Gen Cost: 1.389/0.853\n", "Minibatch: 0801 | Dis/Gen Cost: 1.447/0.754\n", "Epoch: 0027 | Dis/Gen AvgCost: 1.340/0.808\n", "Minibatch: 0001 | Dis/Gen Cost: 1.353/0.764\n", "Minibatch: 0201 | Dis/Gen Cost: 1.353/0.811\n", "Minibatch: 0401 | Dis/Gen Cost: 1.458/0.748\n", "Minibatch: 0601 | Dis/Gen Cost: 1.448/0.753\n", "Minibatch: 0801 | Dis/Gen Cost: 1.475/0.696\n", "Epoch: 0028 | Dis/Gen AvgCost: 1.349/0.792\n", "Minibatch: 0001 | Dis/Gen Cost: 1.271/0.932\n", "Minibatch: 0201 | Dis/Gen Cost: 1.294/0.894\n", "Minibatch: 0401 | Dis/Gen Cost: 1.156/0.866\n", "Minibatch: 0601 | Dis/Gen Cost: 1.292/0.778\n", "Minibatch: 0801 | Dis/Gen Cost: 1.309/0.817\n", "Epoch: 0029 | Dis/Gen AvgCost: 1.347/0.799\n", "Minibatch: 0001 | Dis/Gen Cost: 1.459/0.727\n", "Minibatch: 0201 | Dis/Gen Cost: 1.396/0.753\n", "Minibatch: 0401 | Dis/Gen Cost: 1.367/0.754\n", "Minibatch: 0601 | Dis/Gen Cost: 1.336/0.785\n", "Minibatch: 0801 | Dis/Gen Cost: 1.304/0.756\n", "Epoch: 0030 | Dis/Gen AvgCost: 1.347/0.780\n", "Minibatch: 0001 | Dis/Gen Cost: 1.431/0.726\n", "Minibatch: 0201 | Dis/Gen Cost: 1.348/0.793\n", "Minibatch: 0401 | Dis/Gen Cost: 1.102/0.823\n", "Minibatch: 0601 | Dis/Gen Cost: 1.276/0.772\n", "Minibatch: 0801 | Dis/Gen Cost: 1.390/0.776\n", "Epoch: 0031 | Dis/Gen AvgCost: 1.337/0.801\n", "Minibatch: 0001 | Dis/Gen Cost: 1.507/0.704\n", "Minibatch: 0201 | Dis/Gen Cost: 1.295/0.873\n", "Minibatch: 0401 | Dis/Gen Cost: 1.312/0.835\n", "Minibatch: 0601 | Dis/Gen Cost: 1.346/0.842\n", "Minibatch: 0801 | Dis/Gen Cost: 1.328/0.721\n", "Epoch: 0032 | Dis/Gen AvgCost: 1.342/0.792\n", "Minibatch: 0001 | Dis/Gen Cost: 1.401/0.717\n", "Minibatch: 0201 | Dis/Gen Cost: 1.436/0.737\n", "Minibatch: 0401 | Dis/Gen Cost: 1.332/0.774\n", "Minibatch: 0601 | Dis/Gen Cost: 1.311/0.804\n", "Minibatch: 0801 | Dis/Gen Cost: 1.391/0.650\n", "Epoch: 0033 | Dis/Gen AvgCost: 1.352/0.783\n", "Minibatch: 0001 | Dis/Gen Cost: 1.317/0.740\n", "Minibatch: 0201 | Dis/Gen Cost: 1.343/0.810\n", "Minibatch: 0401 | Dis/Gen Cost: 1.394/0.717\n", "Minibatch: 0601 | Dis/Gen Cost: 1.455/0.779\n", "Minibatch: 0801 | Dis/Gen Cost: 1.445/0.704\n", "Epoch: 0034 | Dis/Gen AvgCost: 1.348/0.785\n", "Minibatch: 0001 | Dis/Gen Cost: 1.294/0.791\n", "Minibatch: 0201 | Dis/Gen Cost: 1.277/0.886\n", "Minibatch: 0401 | Dis/Gen Cost: 1.349/0.721\n", "Minibatch: 0601 | Dis/Gen Cost: 1.297/0.717\n", "Minibatch: 0801 | Dis/Gen Cost: 1.320/0.777\n", "Epoch: 0035 | Dis/Gen AvgCost: 1.353/0.780\n", "Minibatch: 0001 | Dis/Gen Cost: 1.338/0.756\n", "Minibatch: 0201 | Dis/Gen Cost: 1.273/0.778\n", "Minibatch: 0401 | Dis/Gen Cost: 1.325/0.865\n", "Minibatch: 0601 | Dis/Gen Cost: 1.438/0.717\n", "Minibatch: 0801 | Dis/Gen Cost: 1.328/0.785\n", "Epoch: 0036 | Dis/Gen AvgCost: 1.352/0.770\n", "Minibatch: 0001 | Dis/Gen Cost: 1.375/0.764\n", "Minibatch: 0201 | Dis/Gen Cost: 1.453/0.723\n", "Minibatch: 0401 | Dis/Gen Cost: 1.270/0.807\n", "Minibatch: 0601 | Dis/Gen Cost: 1.392/0.775\n", "Minibatch: 0801 | Dis/Gen Cost: 1.318/0.824\n", "Epoch: 0037 | Dis/Gen AvgCost: 1.353/0.773\n", "Minibatch: 0001 | Dis/Gen Cost: 1.270/0.874\n", "Minibatch: 0201 | Dis/Gen Cost: 1.214/0.833\n", "Minibatch: 0401 | Dis/Gen Cost: 1.456/0.666\n", "Minibatch: 0601 | Dis/Gen Cost: 1.400/0.824\n", "Minibatch: 0801 | Dis/Gen Cost: 1.328/0.736\n", "Epoch: 0038 | Dis/Gen AvgCost: 1.354/0.776\n", "Minibatch: 0001 | Dis/Gen Cost: 1.332/0.743\n", "Minibatch: 0201 | Dis/Gen Cost: 1.389/0.710\n", "Minibatch: 0401 | Dis/Gen Cost: 1.375/0.708\n", "Minibatch: 0601 | Dis/Gen Cost: 1.296/0.758\n", "Minibatch: 0801 | Dis/Gen Cost: 1.337/0.783\n", "Epoch: 0039 | Dis/Gen AvgCost: 1.356/0.765\n", "Minibatch: 0001 | Dis/Gen Cost: 1.388/0.706\n", "Minibatch: 0201 | Dis/Gen Cost: 1.371/0.712\n", "Minibatch: 0401 | Dis/Gen Cost: 1.349/0.698\n", "Minibatch: 0601 | Dis/Gen Cost: 1.380/0.723\n", "Minibatch: 0801 | Dis/Gen Cost: 1.371/0.746\n", "Epoch: 0040 | Dis/Gen AvgCost: 1.358/0.759\n", "Minibatch: 0001 | Dis/Gen Cost: 1.349/0.702\n", "Minibatch: 0201 | Dis/Gen Cost: 1.315/0.742\n", "Minibatch: 0401 | Dis/Gen Cost: 1.353/0.760\n", "Minibatch: 0601 | Dis/Gen Cost: 1.335/0.799\n", "Minibatch: 0801 | Dis/Gen Cost: 1.403/0.726\n", "Epoch: 0041 | Dis/Gen AvgCost: 1.362/0.755\n", "Minibatch: 0001 | Dis/Gen Cost: 1.363/0.782\n", "Minibatch: 0201 | Dis/Gen Cost: 1.335/0.742\n", "Minibatch: 0401 | Dis/Gen Cost: 1.344/0.751\n", "Minibatch: 0601 | Dis/Gen Cost: 1.338/0.740\n", "Minibatch: 0801 | Dis/Gen Cost: 1.460/0.735\n", "Epoch: 0042 | Dis/Gen AvgCost: 1.361/0.764\n", "Minibatch: 0001 | Dis/Gen Cost: 1.308/0.767\n", "Minibatch: 0201 | Dis/Gen Cost: 1.367/0.764\n", "Minibatch: 0401 | Dis/Gen Cost: 1.382/0.764\n", "Minibatch: 0601 | Dis/Gen Cost: 1.419/0.625\n", "Minibatch: 0801 | Dis/Gen Cost: 1.393/0.777\n", "Epoch: 0043 | Dis/Gen AvgCost: 1.361/0.753\n", "Minibatch: 0001 | Dis/Gen Cost: 1.413/0.749\n", "Minibatch: 0201 | Dis/Gen Cost: 1.370/0.724\n", "Minibatch: 0401 | Dis/Gen Cost: 1.314/0.756\n", "Minibatch: 0601 | Dis/Gen Cost: 1.321/0.763\n", "Minibatch: 0801 | Dis/Gen Cost: 1.354/0.771\n", "Epoch: 0044 | Dis/Gen AvgCost: 1.364/0.752\n", "Minibatch: 0001 | Dis/Gen Cost: 1.363/0.748\n", "Minibatch: 0201 | Dis/Gen Cost: 1.365/0.727\n", "Minibatch: 0401 | Dis/Gen Cost: 1.439/0.714\n", "Minibatch: 0601 | Dis/Gen Cost: 1.429/0.696\n", "Minibatch: 0801 | Dis/Gen Cost: 1.427/0.699\n", "Epoch: 0045 | Dis/Gen AvgCost: 1.363/0.745\n", "Minibatch: 0001 | Dis/Gen Cost: 1.398/0.713\n", "Minibatch: 0201 | Dis/Gen Cost: 1.408/0.717\n", "Minibatch: 0401 | Dis/Gen Cost: 1.298/0.734\n", "Minibatch: 0601 | Dis/Gen Cost: 1.345/0.805\n", "Minibatch: 0801 | Dis/Gen Cost: 1.331/0.828\n", "Epoch: 0046 | Dis/Gen AvgCost: 1.366/0.752\n", "Minibatch: 0001 | Dis/Gen Cost: 1.319/0.751\n", "Minibatch: 0201 | Dis/Gen Cost: 1.482/0.713\n", "Minibatch: 0401 | Dis/Gen Cost: 1.341/0.803\n", "Minibatch: 0601 | Dis/Gen Cost: 1.386/0.651\n", "Minibatch: 0801 | Dis/Gen Cost: 1.428/0.701\n", "Epoch: 0047 | Dis/Gen AvgCost: 1.369/0.758\n", "Minibatch: 0001 | Dis/Gen Cost: 1.378/0.747\n", "Minibatch: 0201 | Dis/Gen Cost: 1.355/0.716\n", "Minibatch: 0401 | Dis/Gen Cost: 1.357/0.686\n", "Minibatch: 0601 | Dis/Gen Cost: 1.333/0.767\n", "Minibatch: 0801 | Dis/Gen Cost: 1.380/0.712\n", "Epoch: 0048 | Dis/Gen AvgCost: 1.370/0.735\n", "Minibatch: 0001 | Dis/Gen Cost: 1.409/0.706\n", "Minibatch: 0201 | Dis/Gen Cost: 1.307/0.789\n", "Minibatch: 0401 | Dis/Gen Cost: 1.396/0.731\n", "Minibatch: 0601 | Dis/Gen Cost: 1.375/0.711\n", "Minibatch: 0801 | Dis/Gen Cost: 1.365/0.782\n", "Epoch: 0049 | Dis/Gen AvgCost: 1.371/0.733\n", "Minibatch: 0001 | Dis/Gen Cost: 1.409/0.701\n", "Minibatch: 0201 | Dis/Gen Cost: 1.369/0.728\n", "Minibatch: 0401 | Dis/Gen Cost: 1.315/0.730\n", "Minibatch: 0601 | Dis/Gen Cost: 1.321/0.774\n", "Minibatch: 0801 | Dis/Gen Cost: 1.336/0.735\n", "Epoch: 0050 | Dis/Gen AvgCost: 1.372/0.735\n" ] } ], "source": [ "##########################\n", "### TRAINING & EVALUATION\n", "##########################\n", "\n", "with tf.Session(graph=g) as sess:\n", " sess.run(tf.global_variables_initializer())\n", " \n", " avg_costs = {'discriminator': [], 'generator': []}\n", "\n", " for epoch in range(training_epochs):\n", " dis_avg_cost, gen_avg_cost = 0., 0.\n", " total_batch = mnist.train.num_examples // batch_size\n", "\n", " for i in range(total_batch):\n", " \n", " batch_x, batch_y = mnist.train.next_batch(batch_size)\n", " batch_x = batch_x*2 - 1 # normalize\n", " batch_randsample = np.random.uniform(-1, 1, size=(batch_size, gen_input_size))\n", " \n", " # Train\n", " \n", " _, dc = sess.run(['train_discriminator', 'discriminator_cost:0'],\n", " feed_dict={'discriminator_inputs:0': batch_x, \n", " 'generator_inputs:0': batch_randsample,\n", " 'dropout:0': dropout_rate,\n", " 'is_training:0': True})\n", " \n", " _, gc = sess.run(['train_generator', 'generator_cost:0'],\n", " feed_dict={'generator_inputs:0': batch_randsample,\n", " 'dropout:0': dropout_rate,\n", " 'is_training:0': True})\n", " \n", " dis_avg_cost += dc\n", " gen_avg_cost += gc\n", "\n", " if not i % print_interval:\n", " print(\"Minibatch: %04d | Dis/Gen Cost: %.3f/%.3f\" % (i + 1, dc, gc))\n", " \n", "\n", " print(\"Epoch: %04d | Dis/Gen AvgCost: %.3f/%.3f\" % \n", " (epoch + 1, dis_avg_cost / total_batch, gen_avg_cost / total_batch))\n", " \n", " avg_costs['discriminator'].append(dis_avg_cost / total_batch)\n", " avg_costs['generator'].append(gen_avg_cost / total_batch)\n", " \n", " \n", " saver.save(sess, save_path='./gan-conv.ckpt')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "plt.plot(range(len(avg_costs['discriminator'])), \n", " avg_costs['discriminator'], label='discriminator')\n", "plt.plot(range(len(avg_costs['generator'])),\n", " avg_costs['generator'], label='generator')\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /home/raschka/miniconda3/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use standard file APIs to check for files with this prefix.\n", "INFO:tensorflow:Restoring parameters from ./gan-conv.ckpt\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "####################################\n", "### RELOAD & GENERATE SAMPLE IMAGES\n", "####################################\n", "\n", "\n", "n_examples = 25\n", "\n", "with tf.Session(graph=g) as sess:\n", " saver.restore(sess, save_path='./gan-conv.ckpt')\n", "\n", " batch_randsample = np.random.uniform(-1, 1, size=(n_examples, gen_input_size))\n", " new_examples = sess.run('generator/generator_outputs:0',\n", " feed_dict={'generator_inputs:0': batch_randsample,\n", " 'dropout:0': 0.0,\n", " 'is_training:0': False})\n", "\n", "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(8, 8),\n", " sharey=True, sharex=True)\n", "\n", "for image, ax in zip(new_examples, axes.flatten()):\n", " ax.imshow(image.reshape((dis_input_size // 28, dis_input_size // 28)), cmap='binary')\n", "\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }