{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai import *\n", "from fastai.vision import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## LSun bedroom data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this lesson, we'll be using the bedrooms from the [LSUN dataset](http://lsun.cs.princeton.edu/2017/). The full dataset is a bit too large so we'll use a sample from [kaggle](https://www.kaggle.com/jhoward/lsun_bedroom)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "path = untar_data(URLs.LSUN_BEDROOMS)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We then grab all the images in the folder with the data block API. We don't create a validation set here for reasons we'll explain later." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class NoisyItem(ItemBase):\n", " def __init__(self, noise_sz): self.obj,self.data = noise_sz,torch.randn(noise_sz, 1, 1)\n", " def __str__(self): return ''\n", " def apply_tfms(self, tfms, **kwargs): return self" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class GANItemList(ImageItemList):\n", " _label_cls = ImageItemList\n", " \n", " def __init__(self, items, noise_sz:int=100, **kwargs):\n", " super().__init__(items, **kwargs)\n", " self.noise_sz = noise_sz\n", " self.copy_new.append('noise_sz')\n", " \n", " def get(self, i): return NoisyItem(self.noise_sz)\n", " def reconstruct(self, t): return NoisyItem(t.size(0))\n", " \n", " def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):\n", " super().show_xys(ys, xs, imgsize=imgsize, figsize=figsize, **kwargs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_data(bs, size):\n", " train_ds = (GANItemList.from_folder(path).label_from_func(noop)\n", " .transform(tfms=[crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))], size=size, tfm_y=True))\n", " return (ImageDataBunch.create(train_ds, valid_ds=None, path=path, bs=bs)\n", " .normalize(do_x=False, stats = [torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])], do_y=True))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll begin with a small side and use gradual resizing." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data = get_data(128, 64)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "data.show_batch(rows=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "GAN stands for [Generative Adversarial Nets](https://arxiv.org/pdf/1406.2661.pdf) and were invented by Ian Goodfellow. The concept is that we will train two models at the same time: a generator and a discriminator. The generator will try to make new images similar to the ones in our dataset, and the discriminator job will try to classify real images from the ones the generator does. The generator returns images, the discriminator a single number (usually 0. for fake images and 1. for real ones).\n", "\n", "We train them against each other in the sense that at each step (more or less), we:\n", "1. Freeze the generator and train the discriminator for one step by:\n", " - getting one batch of true images (let's call that `real`)\n", " - generating one batch of fake images (let's call that `fake`)\n", " - have the discriminator evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones\n", " - update the weights of the discriminator with the gradients of this loss\n", " \n", " \n", "2. Freeze the discriminator and train the generator for one step by:\n", " - generating one batch of fake images\n", " - evaluate the discriminator on it\n", " - return a loss that rewards posisitivly the discriminator thinking those are real images; the important part is that it rewards positively the detection of real images and penalizes the fake ones\n", " - update the weights of the generator with the gradients of this loss\n", " \n", "Here, we'll use the [Wassertein GAN](https://arxiv.org/pdf/1701.07875.pdf)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We create a generator and a discriminator that we pass to `gan_learner`. The noise_size is the size of the random vector from which our generator creates images." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.vision.gan import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "generator = basic_generator(in_size=64, n_channels=3, n_extra_layers=1)\n", "critic = basic_critic(in_size=64, n_channels=3, n_extra_layers=1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = GANLearner.wgan(data, generator, critic, opt_func=optim.RMSprop, wd=0.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit(1,1e-4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 2 }