{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "25P38JgWSYbZ" }, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timsainb/tensorflow2-generative-models/blob/master/3.0-WGAN-GP-fashion-mnist.ipynb)\n", "\n", "## Wasserstein GAN with Gradient Penalty (WGAN-GP) ([article](https://arxiv.org/abs/1701.07875)) \n", "\n", "WGAN-GP is a GAN that improves over the original loss function to improve training stability. \n", "\n", "![wgan gp](https://github.com/timsainb/tensorflow2-generative-models/blob/f3360a819b5773692e943dfe181972a76b9d91bb/imgs/gan.png?raw=1)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "DoEPSlfmSYbc" }, "source": [ "### Install packages if in colab" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:31:29.973887Z", "start_time": "2019-05-14T06:31:29.969185Z" }, "colab": {}, "colab_type": "code", "id": "WbqrTgB_SYbf" }, "outputs": [], "source": [ "### install necessary packages if in colab\n", "def run_subprocess_command(cmd):\n", " process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE)\n", " for line in process.stdout:\n", " print(line.decode().strip())\n", " \n", "import sys, subprocess\n", "IN_COLAB = 'google.colab' in sys.modules\n", "colab_requirements = ['pip install tf-nightly-gpu-2.0-preview==2.0.0.dev20190513']\n", "if IN_COLAB:\n", " for i in colab_requirements:\n", " run_subprocess_command(i)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "3eKFKF5HSYbi" }, "source": [ "### load packages" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "UjL-sOZzSYbj" }, "outputs": [], "source": [ "" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:31:30.061880Z", "start_time": "2019-05-14T06:31:29.975587Z" }, "colab": {}, "colab_type": "code", "id": "at1xYevFSYbl", "outputId": "d70e29a0-b0d0-416b-b163-28974fab61fa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "env: CUDA_VISIBLE_DEVICES=3\n" ] } ], "source": [ "# make visible the only one GPU\n", "%env CUDA_VISIBLE_DEVICES=3" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:31:33.702580Z", "start_time": "2019-05-14T06:31:30.063437Z" }, "colab": {}, "colab_type": "code", "id": "759gzUFlSYbq", "outputId": "d2ec559a-bbb8-4785-8e42-f355c270fbce" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/mnt/cube/tsainbur/conda_envs/tpy3/lib/python3.6/site-packages/tqdm/autonotebook/__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", " \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n" ] } ], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from tqdm.autonotebook import tqdm\n", "%matplotlib inline\n", "from IPython import display\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:31:33.711214Z", "start_time": "2019-05-14T06:31:33.706313Z" }, "colab": {}, "colab_type": "code", "id": "AxY3I4SfSYbt", "outputId": "64769acf-bcf3-4ab4-d753-a69ec24b2ee5" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.0.0-dev20190513\n" ] } ], "source": [ "print(tf.__version__)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "LdCkp6ybSYbw" }, "source": [ "### Create a fashion-MNIST dataset" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:31:33.803523Z", "start_time": "2019-05-14T06:31:33.714599Z" }, "colab": {}, "colab_type": "code", "id": "Ypym6ZAESYbx" }, "outputs": [], "source": [ "TRAIN_BUF=60000\n", "BATCH_SIZE=512\n", "TEST_BUF=10000\n", "DIMS = (28,28,1)\n", "N_TRAIN_BATCHES =int(TRAIN_BUF/BATCH_SIZE)\n", "N_TEST_BATCHES = int(TEST_BUF/BATCH_SIZE)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:31:38.044471Z", "start_time": "2019-05-14T06:31:33.805821Z" }, "colab": {}, "colab_type": "code", "id": "xhqU6sqiSYbz" }, "outputs": [], "source": [ "# load dataset\n", "(train_images, _), (test_images, _) = tf.keras.datasets.fashion_mnist.load_data()\n", "\n", "# split dataset\n", "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype(\n", " \"float32\"\n", ") / 255.0\n", "test_images = test_images.reshape(test_images.shape[0], 28, 28, 1).astype(\"float32\") / 255.0\n", "\n", "# batch datasets\n", "train_dataset = (\n", " tf.data.Dataset.from_tensor_slices(train_images)\n", " .shuffle(TRAIN_BUF)\n", " .batch(BATCH_SIZE)\n", ")\n", "test_dataset = (\n", " tf.data.Dataset.from_tensor_slices(test_images)\n", " .shuffle(TEST_BUF)\n", " .batch(BATCH_SIZE)\n", ")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "HLxPlL7QSYb1" }, "source": [ "### Define the network as tf.keras.model object" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:31:38.068468Z", "start_time": "2019-05-14T06:31:38.046751Z" }, "colab": {}, "colab_type": "code", "id": "Wyipg-4oSYb1" }, "outputs": [], "source": [ "class WGAN(tf.keras.Model):\n", " \"\"\"[summary]\n", " I used github/LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Tensorflow-2/ as a reference on this.\n", " \n", " Extends:\n", " tf.keras.Model\n", " \"\"\"\n", "\n", " def __init__(self, **kwargs):\n", " super(WGAN, self).__init__()\n", " self.__dict__.update(kwargs)\n", "\n", " self.gen = tf.keras.Sequential(self.gen)\n", " self.disc = tf.keras.Sequential(self.disc)\n", "\n", " def generate(self, z):\n", " return self.gen(z)\n", "\n", " def discriminate(self, x):\n", " return self.disc(x)\n", "\n", " def compute_loss(self, x):\n", " \"\"\" passes through the network and computes loss\n", " \"\"\"\n", " ### pass through network\n", " # generating noise from a uniform distribution\n", "\n", " z_samp = tf.random.normal([x.shape[0], 1, 1, self.n_Z])\n", "\n", " # run noise through generator\n", " x_gen = self.generate(z_samp)\n", " # discriminate x and x_gen\n", " logits_x = self.discriminate(x)\n", " logits_x_gen = self.discriminate(x_gen)\n", "\n", " # gradient penalty\n", " d_regularizer = self.gradient_penalty(x, x_gen)\n", " ### losses\n", " disc_loss = (\n", " tf.reduce_mean(logits_x)\n", " - tf.reduce_mean(logits_x_gen)\n", " + d_regularizer * self.gradient_penalty_weight\n", " )\n", "\n", " # losses of fake with label \"1\"\n", " gen_loss = tf.reduce_mean(logits_x_gen)\n", "\n", " return disc_loss, gen_loss\n", "\n", " def compute_gradients(self, x):\n", " \"\"\" passes through the network and computes loss\n", " \"\"\"\n", " ### pass through network\n", " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", " disc_loss, gen_loss = self.compute_loss(x)\n", "\n", " # compute gradients\n", " gen_gradients = gen_tape.gradient(gen_loss, self.gen.trainable_variables)\n", "\n", " disc_gradients = disc_tape.gradient(disc_loss, self.disc.trainable_variables)\n", "\n", " return gen_gradients, disc_gradients\n", "\n", " def apply_gradients(self, gen_gradients, disc_gradients):\n", "\n", " self.gen_optimizer.apply_gradients(\n", " zip(gen_gradients, self.gen.trainable_variables)\n", " )\n", " self.disc_optimizer.apply_gradients(\n", " zip(disc_gradients, self.disc.trainable_variables)\n", " )\n", "\n", " def gradient_penalty(self, x, x_gen):\n", " epsilon = tf.random.uniform([x.shape[0], 1, 1, 1], 0.0, 1.0)\n", " x_hat = epsilon * x + (1 - epsilon) * x_gen\n", " with tf.GradientTape() as t:\n", " t.watch(x_hat)\n", " d_hat = self.discriminate(x_hat)\n", " gradients = t.gradient(d_hat, x_hat)\n", " ddx = tf.sqrt(tf.reduce_sum(gradients ** 2, axis=[1, 2]))\n", " d_regularizer = tf.reduce_mean((ddx - 1.0) ** 2)\n", " return d_regularizer\n", "\n", " @tf.function\n", " def train(self, train_x):\n", " gen_gradients, disc_gradients = self.compute_gradients(train_x)\n", " self.apply_gradients(gen_gradients, disc_gradients)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qEVl58nDSYb4" }, "source": [ "### Define the network architecture" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:31:38.219862Z", "start_time": "2019-05-14T06:31:38.070570Z" }, "colab": {}, "colab_type": "code", "id": "dyU21SGbSYb4" }, "outputs": [], "source": [ "N_Z = 64\n", "\n", "generator = [\n", " tf.keras.layers.Dense(units=7 * 7 * 64, activation=\"relu\"),\n", " tf.keras.layers.Reshape(target_shape=(7, 7, 64)),\n", " tf.keras.layers.Conv2DTranspose(\n", " filters=64, kernel_size=3, strides=(2, 2), padding=\"SAME\", activation=\"relu\"\n", " ),\n", " tf.keras.layers.Conv2DTranspose(\n", " filters=32, kernel_size=3, strides=(2, 2), padding=\"SAME\", activation=\"relu\"\n", " ),\n", " tf.keras.layers.Conv2DTranspose(\n", " filters=1, kernel_size=3, strides=(1, 1), padding=\"SAME\", activation=\"sigmoid\"\n", " ),\n", "]\n", "\n", "discriminator = [\n", " tf.keras.layers.InputLayer(input_shape=DIMS),\n", " tf.keras.layers.Conv2D(\n", " filters=32, kernel_size=3, strides=(2, 2), activation=\"relu\"\n", " ),\n", " tf.keras.layers.Conv2D(\n", " filters=64, kernel_size=3, strides=(2, 2), activation=\"relu\"\n", " ),\n", " tf.keras.layers.Flatten(),\n", " tf.keras.layers.Dense(units=1, activation=\"sigmoid\"),\n", "]" ] }, { "cell_type": "markdown", "metadata": { "ExecuteTime": { "end_time": "2019-05-10T18:40:40.306731Z", "start_time": "2019-05-10T18:40:40.292930Z" }, "colab_type": "text", "id": "wi_ZuWBdSYb6" }, "source": [ "### Create Model" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:31:39.047233Z", "start_time": "2019-05-14T06:31:38.222179Z" }, "colab": {}, "colab_type": "code", "id": "dSYjNRAwSYb7" }, "outputs": [], "source": [ "# optimizers\n", "gen_optimizer = tf.keras.optimizers.Adam(0.0001, beta_1=0.5)\n", "disc_optimizer = tf.keras.optimizers.RMSprop(0.0005)# train the model\n", "# model\n", "model = WGAN(\n", " gen = generator,\n", " disc = discriminator,\n", " gen_optimizer = gen_optimizer,\n", " disc_optimizer = disc_optimizer,\n", " n_Z = N_Z,\n", " gradient_penalty_weight = 10.0\n", ")" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "qwBg8NwrSYb9" }, "source": [ "### Train the model" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:31:39.056490Z", "start_time": "2019-05-14T06:31:39.049635Z" }, "colab": {}, "colab_type": "code", "id": "47sz8RMeSYb-" }, "outputs": [], "source": [ "# exampled data for plotting results\n", "def plot_reconstruction(model, nex=8, zm=2):\n", " samples = model.generate(tf.random.normal(shape=(BATCH_SIZE, N_Z)))\n", " fig, axs = plt.subplots(ncols=nex, nrows=1, figsize=(zm * nex, zm))\n", " for axi in range(nex):\n", " axs[axi].matshow(\n", " samples.numpy()[axi].squeeze(), cmap=plt.cm.Greys, vmin=0, vmax=1\n", " )\n", " axs[axi].axis('off')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:31:39.152670Z", "start_time": "2019-05-14T06:31:39.058505Z" }, "colab": {}, "colab_type": "code", "id": "pKkEX9yBSYcB" }, "outputs": [], "source": [ "# a pandas dataframe to save the loss information to\n", "losses = pd.DataFrame(columns = ['disc_loss', 'gen_loss'])" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T07:04:26.791634Z", "start_time": "2019-05-14T07:04:17.126436Z" }, "colab": {}, "colab_type": "code", "id": "00dI2M4iSYcE", "outputId": "8312d004-9e5d-43a1-f28f-6d182bd9add2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0 | disc_loss: -0.050283897668123245 | gen_loss: 0.5204998254776001\n" ] }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "n_epochs = 200\n", "for epoch in range(n_epochs):\n", " # train\n", " for batch, train_x in tqdm(\n", " zip(range(N_TRAIN_BATCHES), train_dataset), total=N_TRAIN_BATCHES\n", " ):\n", " model.train(train_x)\n", " # test on holdout\n", " loss = []\n", " for batch, test_x in tqdm(\n", " zip(range(N_TEST_BATCHES), test_dataset), total=N_TEST_BATCHES\n", " ):\n", " loss.append(model.compute_loss(train_x))\n", " losses.loc[len(losses)] = np.mean(loss, axis=0)\n", " # plot results\n", " display.clear_output()\n", " print(\n", " \"Epoch: {} | disc_loss: {} | gen_loss: {}\".format(\n", " epoch, losses.disc_loss.values[-1], losses.gen_loss.values[-1]\n", " )\n", " )\n", " plot_reconstruction(model)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:46:24.425722Z", "start_time": "2019-05-14T06:46:24.188266Z" }, "colab": {}, "colab_type": "code", "id": "XZbwB70ESYcH", "outputId": "40913f5e-a991-4d25-c56a-086acb4c0cd6" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 15, "metadata": { "tags": [] }, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.plot(losses.gen_loss.values)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "ExecuteTime": { "end_time": "2019-05-14T06:46:38.649136Z", "start_time": "2019-05-14T06:46:38.440378Z" }, "colab": {}, "colab_type": "code", "id": "FydTnKDxSYcL", "outputId": "503e0a49-2ed1-4c2d-9072-adc57dae5ae8" }, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 17, "metadata": { "tags": [] }, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light", "tags": [] }, "output_type": "display_data" } ], "source": [ "plt.plot(losses.disc_loss.values)" ] }, { "cell_type": "code", "execution_count": 0, "metadata": { "colab": {}, "colab_type": "code", "id": "R4EjPZD1SYcO" }, "outputs": [], "source": [ "" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "3.0-WGAN-GP-fashion-mnist.ipynb", "provenance": [], "toc_visible": true, "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.8" } }, "nbformat": 4, "nbformat_minor": 0 }