{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "wAMRwrH0Qg22"
},
"source": [
"# Generating Images of Clothes Using Deep Convolutional Generative Adversarial Network (DCGAN)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> - 🤖 See [full list of Machine Learning Experiments](https://github.com/trekhleb/machine-learning-experiments) on **GitHub**
\n",
"> - ▶️ **Interactive Demo**: [try this model and other machine learning experiments in action](https://trekhleb.github.io/machine-learning-experiments/)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Experiment overview"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this experiment we will generate images of clothing using a [Deep Convolutional Generative Adversarial Network](https://arxiv.org/pdf/1511.06434.pdf) (DCGAN). The code is written using the [Keras Sequential API](https://www.tensorflow.org/guide/keras) with a [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) training loop. For training we will be using Fashion [MNIST dataset](https://www.tensorflow.org/datasets/catalog/fashion_mnist).\n",
"\n",
"A **generative adversarial network** (GAN) is a class of machine learning frameworks. Two neural networks contest with each other in a game. Two models are trained simultaneously by an adversarial process. A generator (\"the artist\") learns to create images that look real, while a discriminator (\"the art critic\") learns to tell real images apart from fakes.\n",
"\n",
"![clothes_generation_dcgan.jpg](../../demos/src/images/clothes_generation_dcgan.jpg)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Ul5K4DQaQg23"
},
"source": [
"Inspired by: [Deep Convolutional Generative Adversarial Network](https://www.tensorflow.org/tutorials/generative/dcgan) tutorial."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "dhA2OtD0Qg24"
},
"source": [
"## Importing dependencies"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 68
},
"colab_type": "code",
"executionInfo": {
"elapsed": 1863,
"status": "ok",
"timestamp": 1590383290949,
"user": {
"displayName": "Oleksii Trekhleb",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiiA4aUKCbFho88Jd0WWMoAqQUt3jbuCtfNYpHVOA=s64",
"userId": "03172675069638383074"
},
"user_tz": -120
},
"id": "da0YhaOJQg25",
"outputId": "2802bf95-8172-44e3-df5c-3cf51d525799"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Python version: 3.7.6\n",
"Tensorflow version: 2.1.0\n",
"Keras version: 2.2.4-tf\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import math\n",
"import datetime\n",
"import platform\n",
"import imageio\n",
"import PIL\n",
"import time\n",
"import os\n",
"import glob\n",
"import zipfile\n",
"\n",
"from IPython import display\n",
"\n",
"print('Python version:', platform.python_version())\n",
"print('Tensorflow version:', tf.__version__)\n",
"print('Keras version:', tf.keras.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"executionInfo": {
"elapsed": 3449,
"status": "ok",
"timestamp": 1590383292584,
"user": {
"displayName": "Oleksii Trekhleb",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiiA4aUKCbFho88Jd0WWMoAqQUt3jbuCtfNYpHVOA=s64",
"userId": "03172675069638383074"
},
"user_tz": -120
},
"id": "0GRH9c5coMoK",
"outputId": "929093b9-484e-4f2f-d483-56ede033e46d"
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Checking the eager execution availability.\n",
"tf.executing_eagerly()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "bhS6rs6eQg29"
},
"source": [
"## Loading the data"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 173
},
"colab_type": "code",
"executionInfo": {
"elapsed": 4260,
"status": "ok",
"timestamp": 1590383293429,
"user": {
"displayName": "Oleksii Trekhleb",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiiA4aUKCbFho88Jd0WWMoAqQUt3jbuCtfNYpHVOA=s64",
"userId": "03172675069638383074"
},
"user_tz": -120
},
"id": "FpudNNDnQg2-",
"outputId": "7a2e4e23-8584-4b7e-a22c-268caaf99859"
},
"outputs": [],
"source": [
"(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 102
},
"colab_type": "code",
"executionInfo": {
"elapsed": 4218,
"status": "ok",
"timestamp": 1590383293432,
"user": {
"displayName": "Oleksii Trekhleb",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiiA4aUKCbFho88Jd0WWMoAqQUt3jbuCtfNYpHVOA=s64",
"userId": "03172675069638383074"
},
"user_tz": -120
},
"id": "IzIOnN83Qg3C",
"outputId": "67b8f46a-0206-4857-bf9e-e87d972b55fe"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_train.shape: (60000, 28, 28)\n",
"y_train.shape: (60000,)\n",
"\n",
"x_test.shape: (10000, 28, 28)\n",
"y_test.shape: (10000,)\n"
]
}
],
"source": [
"print('x_train.shape: ', x_train.shape)\n",
"print('y_train.shape: ', y_train.shape)\n",
"print()\n",
"print('x_test.shape: ', x_test.shape)\n",
"print('y_test.shape: ', y_test.shape)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"executionInfo": {
"elapsed": 4142,
"status": "ok",
"timestamp": 1590383293433,
"user": {
"displayName": "Oleksii Trekhleb",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiiA4aUKCbFho88Jd0WWMoAqQUt3jbuCtfNYpHVOA=s64",
"userId": "03172675069638383074"
},
"user_tz": -120
},
"id": "gdfkpOpMVDkh",
"outputId": "45fe19eb-bbd7-4a50-e5d6-aad6822d70a6"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_train.shape: (70000, 28, 28)\n"
]
}
],
"source": [
"# Since we don't need test examples we may concatenate both sets\n",
"x_train = np.concatenate((x_train, x_test), axis=0)\n",
"\n",
"print('x_train.shape: ', x_train.shape)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"executionInfo": {
"elapsed": 4051,
"status": "ok",
"timestamp": 1590383293433,
"user": {
"displayName": "Oleksii Trekhleb",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiiA4aUKCbFho88Jd0WWMoAqQUt3jbuCtfNYpHVOA=s64",
"userId": "03172675069638383074"
},
"user_tz": -120
},
"id": "vq-kHJWm1uAn",
"outputId": "e35139f6-bdb8-4c8b-ef45-fe639918eca4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"TOTAL_EXAMPLES_NUM: 70000\n"
]
}
],
"source": [
"TOTAL_EXAMPLES_NUM = x_train.shape[0]\n",
"\n",
"print('TOTAL_EXAMPLES_NUM: ', TOTAL_EXAMPLES_NUM)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"colab_type": "code",
"executionInfo": {
"elapsed": 4017,
"status": "ok",
"timestamp": 1590383293435,
"user": {
"displayName": "Oleksii Trekhleb",
"photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GiiA4aUKCbFho88Jd0WWMoAqQUt3jbuCtfNYpHVOA=s64",
"userId": "03172675069638383074"
},
"user_tz": -120
},
"id": "nMQ3k3-bQg3E",
"outputId": "1ba14f75-1a81-4a49-a8a8-abaac9a68192"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"y_train[0] = 9\n"
]
}
],
"source": [
"print('y_train[0] =', y_train[0])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "dU36mGsZUdah"
},
"source": [
"Here are the map of classes for the dataset according to [documentation](https://github.com/zalandoresearch/fashion-mnist):\n",
"\n",
"
Label | \n", "Class | \n", "
---|---|
0 | \n", "T-shirt/top | \n", "
1 | \n", "Trouser | \n", "
2 | \n", "Pullover | \n", "
3 | \n", "Dress | \n", "
4 | \n", "Coat | \n", "
5 | \n", "Sandal | \n", "
6 | \n", "Shirt | \n", "
7 | \n", "Sneaker | \n", "
8 | \n", "Bag | \n", "
9 | \n", "Ankle boot | \n", "