{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "skip" }, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "%%html\n", "\n", "\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# Install the necessary dependencies\n", "\n", "import os\n", "import sys\n", "!{sys.executable} -m pip install --quiet pandas scikit-learn numpy matplotlib jupyterlab_myst ipython tensorflow" ] }, { "cell_type": "markdown", "metadata": { "tags": [ "remove-cell" ] }, "source": [ "---\n", "license:\n", " code: MIT\n", " content: CC-BY-4.0\n", "github: https://github.com/ocademy-ai/machine-learning\n", "venue: By Ocademy\n", "open_access: true\n", "bibliography:\n", " - https://raw.githubusercontent.com/ocademy-ai/machine-learning/main/open-machine-learning-jupyter-book/references.bib\n", "---" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Generative Adversarial Network" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## What is a gerative model? (1/3)\n", "\n", "- Many machine learning systems look at some kind of complicated input (say, an image) and produce a simple output (a label like, \"cat\")\n", "- By contrast, the goal of a generative model is something like the opposite: take a small piece of input—perhaps a few random numbers—and produce a complex output, like an image of a realistic-looking face" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## What is a gerative model? (2/3)\n", "\n", "- \"Generative\" describes a class of statistical models that contrasts with discriminative models.\n", "- Informally:\n", " - Generative models can generate new data instances.\n", " - Discriminative models discriminate between different kinds of data instances.\n", "- A generative model could generate new photos of animals that look like real animals, while a discriminative model could tell a dog from a cat." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## What is a gerative model? (3/3)\n", "\n", "- A generative model can model a distribution by producing convincing \"fake\" data that looks like it's drawn from that distribution.\n", "- A generative model for images might capture correlations like \"things that look like boats are probably going to appear near things that look like water\" and \"eyes are unlikely to appear on foreheads.\" These are very complicated distributions." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## What is a generative adversarial network (GAN)?\n", "\n", "- An especially effective type of generative model, introduced in 2014 by Ian Goodfellow, which has been a subject of intense interest in the machine learning community\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## How does a GAN work?\n", "\n", "- A generative adversarial network (GAN) has two parts:\n", " - The generator learns to generate plausible data. The generated instances become negative training examples for the discriminator.\n", " - The discriminator learns to distinguish the generator's fake data from real data. The discriminator penalizes the generator for producing implausible results." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## How does a GAN work?\n", "\n", "- When training begins, the generator produces obviously fake data, and the discriminator quickly learns to tell that it's fake:\n", "\n", "![](../images/gan/gan-g1.svg)\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## How does a GAN work?\n", "\n", "- As training progresses, the generator gets closer to producing output that can fool the discriminator:\n", "\n", "\n", "![](../images/gan/gan-g2.svg)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## How does a GAN work?\n", "\n", "- Finally, if generator training goes well, the discriminator gets worse at telling the difference between real and fake. It starts to classify fake data as real, and its accuracy decreases.\n", "\n", "![](../images/gan/gan-g3.svg)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## How does a GAN work?\n", "\n", "- Here's a picture of the whole system:\n", "\n", "![](../images/gan/gan-g4.svg)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## How does a GAN work?\n", "\n", "- Both the generator and the discriminator are neural networks. The generator output is connected directly to the discriminator input. Through backpropagation, the discriminator's classification provides a signal that the generator uses to update its weights." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Time to play!\n", "\n", "- https://poloclub.github.io/ganlab/" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Let's code a GAN!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorflow" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "from tensorflow import keras" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "## Packages & Libraries\n", "from keras.layers import Dense, Dropout, Input, ReLU\n", "from keras.models import Model, Sequential\n", "from keras.optimizers import Adam\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from tensorflow.keras.datasets import mnist" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "## Data Import\n", "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", "print(\"X train shape: \", x_train.shape)\n", "print(\"Y train shape: \", y_train.shape)\n", "print(\"X test shape: \", x_test.shape)\n", "print(\"Y test shape: \", y_test.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "plt.imshow(x_train[108])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "# x_train to (-1, 1)\n", "x_train = (x_train.astype(np.float32)-127.5)/127.5" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "#reshape data\n", "x_train = x_train.reshape(x_train.shape[0], x_train.shape[1] * x_train.shape[2])\n", "x_test = x_test.reshape(x_test.shape[0], x_test.shape[1] * x_test.shape[2])\n", "print(\"X train shape: \", x_train.shape)\n", "print(\"X test shape: \", x_test.shape)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "subslide" } }, "source": [ "## Create Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "# create generator\n", "def create_generator():\n", " generator = Sequential()\n", " generator.add(Dense(units = 512, input_dim = 100))\n", " generator.add(ReLU())\n", " generator.add(Dense(units = 512))\n", " generator.add(ReLU())\n", " generator.add(Dense(units = 1024))\n", " generator.add(ReLU())\n", " generator.add(Dense(units = 784, activation = \"tanh\"))\n", " generator.compile(loss = \"binary_crossentropy\", optimizer = Adam(lr = 0.0001, beta_1 = 0.5))\n", " return generator" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "g = create_generator()\n", "g.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# create discriminator\n", "def create_discriminator():\n", " discriminator = Sequential()\n", " discriminator.add(Dense(units = 1024, input_dim = 784))\n", " discriminator.add(ReLU())\n", " discriminator.add(Dropout(0.4))\n", " discriminator.add(Dense(units = 512))\n", " discriminator.add(ReLU())\n", " discriminator.add(Dropout(0.4))\n", " discriminator.add(Dense(units = 256))\n", " discriminator.add(ReLU())\n", " discriminator.add(Dense(units = 1, activation = \"sigmoid\"))\n", " discriminator.compile(loss = \"binary_crossentropy\", optimizer = Adam(lr = 0.0001, beta_1 = 0.5))\n", " return discriminator" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "d = create_discriminator()\n", "d.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "# GANs\n", "def create_gan(discriminator, generator):\n", " discriminator.trainable = False\n", " gan_input = Input(shape = (100,))\n", " x = generator(gan_input)\n", " gan_output = discriminator(x)\n", " gan = Model(inputs = gan_input, outputs = gan_output)\n", " gan.compile(loss = \"binary_crossentropy\", optimizer = \"adam\")\n", " return gan" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "gan = create_gan(d, g)\n", "gan.summary()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "epochs = 2 # should be large, e.g. 50\n", "batch_size = 256\n", "\n", "for e in range(epochs):\n", " for _ in range(batch_size):\n", " noise = np.random.normal(0, 1, [batch_size, 100])\n", " # generated image batch\n", " generated_images = g.predict(noise)\n", " # real image batch\n", " image_batch = x_train[np.random.randint(low = 0, high = x_train.shape[0], size = batch_size)]\n", " x = np.concatenate([image_batch, generated_images])\n", " # allocation discriminator predictions\n", " y_dis = np.zeros(batch_size * 2)\n", " y_dis[:batch_size] = 1\n", " d.trainable = True\n", " d.train_on_batch(x, y_dis)\n", " noise = np.random.normal(0, 1, [batch_size, 100])\n", " y_gen = np.ones(batch_size)\n", " d.trainable = False\n", " gan.train_on_batch(noise, y_gen)\n", " print(\"Epoch: \", e)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "g.save(\"g.hdf5\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [], "source": [ "noise = np.random.normal(loc = 0, scale = 1, size = [100, 100])\n", "generated_images = g.predict(noise)\n", "generated_images = generated_images.reshape(100, 28, 28)\n", "plt.imshow(generated_images[66], interpolation=\"nearest\")\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "gp = keras.models.load_model('g-pretrained.hdf5')" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "slideshow": { "slide_type": "subslide" } }, "outputs": [], "source": [ "noise = np.random.normal(loc = 0, scale = 1, size = [100, 100])\n", "generated_images = gp.predict(noise)\n", "generated_images = generated_images.reshape(100, 28, 28)\n", "plt.imshow(generated_images[66], interpolation=\"nearest\")\n", "plt.axis(\"off\")\n", "plt.show()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Ackowledgement\n", "\n", "- https://developers.google.com/machine-learning/gan/gan_structure\n", "- https://www.kaggle.com/code/emreustundag/generative-adversarial-networks-gans-tutorial\n" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 2 }