{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## GANs" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "hide_input": true }, "outputs": [], "source": [ "%matplotlib inline\n", "from fastai.gen_doc.nbdoc import *\n", "from fastai.vision import * \n", "from fastai.vision.gan import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "GAN stands for [Generative Adversarial Nets](https://arxiv.org/pdf/1406.2661.pdf) and were invented by Ian Goodfellow. The concept is that we will train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in our dataset, and the critic's job will try to classify real images from the fake ones the generator does. The generator returns images, the discriminator a feature map (it can be a single number depending on the input size). Usually the discriminator will be trained to return 0. everywhere for fake images and 1. everywhere for real ones.\n", "\n", "This module contains all the necessary function to create a GAN." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We train them against each other in the sense that at each step (more or less), we:\n", "1. Freeze the generator and train the discriminator for one step by:\n", " - getting one batch of true images (let's call that `real`)\n", " - generating one batch of fake images (let's call that `fake`)\n", " - have the discriminator evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones\n", " - update the weights of the discriminator with the gradients of this loss\n", " \n", " \n", "2. Freeze the discriminator and train the generator for one step by:\n", " - generating one batch of fake images\n", " - evaluate the discriminator on it\n", " - return a loss that rewards posisitivly the discriminator thinking those are real images; the important part is that it rewards positively the detection of real images and penalizes the fake ones\n", " - update the weights of the generator with the gradients of this loss" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "hide_input": true }, "outputs": [ { "data": { "text/markdown": [ "
class
GANLearner
[source][test]GANLearner
(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`generator`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`critic`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`gen_loss_func`**:`LossFunction`, **`crit_loss_func`**:`LossFunction`, **`switcher`**:[`Callback`](/callback.html#Callback)=***`None`***, **`gen_first`**:`bool`=***`False`***, **`switch_eval`**:`bool`=***`True`***, **`show_img`**:`bool`=***`True`***, **`clip`**:`float`=***`None`***, **\\*\\*`learn_kwargs`**) :: [`Learner`](/basic_train.html#Learner)\n",
"\n",
"No tests found for GANLearner
. To contribute a test please refer to this guide and this discussion.
from_learners
[source][test]from_learners
(**`learn_gen`**:[`Learner`](/basic_train.html#Learner), **`learn_crit`**:[`Learner`](/basic_train.html#Learner), **`switcher`**:[`Callback`](/callback.html#Callback)=***`None`***, **`weights_gen`**:`Point`=***`None`***, **\\*\\*`learn_kwargs`**)\n",
"\n",
"No tests found for from_learners
. To contribute a test please refer to this guide and this discussion.
wgan
[source][test]wgan
(**`data`**:[`DataBunch`](/basic_data.html#DataBunch), **`generator`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`critic`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module), **`switcher`**:[`Callback`](/callback.html#Callback)=***`None`***, **`clip`**:`float`=***`0.01`***, **\\*\\*`learn_kwargs`**)\n",
"\n",
"No tests found for wgan
. To contribute a test please refer to this guide and this discussion.
class
FixedGANSwitcher
[source][test]FixedGANSwitcher
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`n_crit`**:`Union`\\[`int`, `Callable`\\]=***`1`***, **`n_gen`**:`Union`\\[`int`, `Callable`\\]=***`1`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for FixedGANSwitcher
. To contribute a test please refer to this guide and this discussion.
on_train_begin
[source][test]on_train_begin
(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_train_begin
. To contribute a test please refer to this guide and this discussion.
on_batch_end
[source][test]on_batch_end
(**`iteration`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_batch_end
. To contribute a test please refer to this guide and this discussion.
class
AdaptiveGANSwitcher
[source][test]AdaptiveGANSwitcher
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`gen_thresh`**:`float`=***`None`***, **`critic_thresh`**:`float`=***`None`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for AdaptiveGANSwitcher
. To contribute a test please refer to this guide and this discussion.
on_batch_end
[source][test]on_batch_end
(**`last_loss`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_batch_end
. To contribute a test please refer to this guide and this discussion.
class
GANDiscriminativeLR
[source][test]GANDiscriminativeLR
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`mult_lr`**:`float`=***`5.0`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"No tests found for GANDiscriminativeLR
. To contribute a test please refer to this guide and this discussion.
on_batch_begin
[source][test]on_batch_begin
(**`train`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_batch_begin
. To contribute a test please refer to this guide and this discussion.
on_step_end
[source][test]on_step_end
(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_step_end
. To contribute a test please refer to this guide and this discussion.
basic_critic
[source][test]basic_critic
(**`in_size`**:`int`, **`n_channels`**:`int`, **`n_features`**:`int`=***`64`***, **`n_extra_layers`**:`int`=***`0`***, **\\*\\*`conv_kwargs`**)\n",
"\n",
"basic_generator
[source][test]basic_generator
(**`in_size`**:`int`, **`n_channels`**:`int`, **`noise_sz`**:`int`=***`100`***, **`n_features`**:`int`=***`64`***, **`n_extra_layers`**=***`0`***, **\\*\\*`conv_kwargs`**)\n",
"\n",
"gan_critic
[source][test]gan_critic
(**`n_channels`**:`int`=***`3`***, **`nf`**:`int`=***`128`***, **`n_blocks`**:`int`=***`3`***, **`p`**:`int`=***`0.15`***)\n",
"\n",
"No tests found for gan_critic
. To contribute a test please refer to this guide and this discussion.
class
GANTrainer
[source][test]GANTrainer
(**`learn`**:[`Learner`](/basic_train.html#Learner), **`switch_eval`**:`bool`=***`False`***, **`clip`**:`float`=***`None`***, **`beta`**:`float`=***`0.98`***, **`gen_first`**:`bool`=***`False`***, **`show_img`**:`bool`=***`True`***) :: [`LearnerCallback`](/basic_train.html#LearnerCallback)\n",
"\n",
"switch
[source][test]switch
(**`gen_mode`**:`bool`=***`None`***)\n",
"\n",
"on_train_begin
[source][test]on_train_begin
(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_train_begin
. To contribute a test please refer to this guide and this discussion.
on_epoch_begin
[source][test]on_epoch_begin
(**`epoch`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_begin
. To contribute a test please refer to this guide and this discussion.
on_batch_begin
[source][test]on_batch_begin
(**`last_input`**, **`last_target`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_batch_begin
. To contribute a test please refer to this guide and this discussion.
on_backward_begin
[source][test]on_backward_begin
(**`last_loss`**, **`last_output`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_backward_begin
. To contribute a test please refer to this guide and this discussion.
on_epoch_end
[source][test]on_epoch_end
(**`pbar`**, **`epoch`**, **`last_metrics`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_epoch_end
. To contribute a test please refer to this guide and this discussion.
on_train_end
[source][test]on_train_end
(**\\*\\*`kwargs`**)\n",
"\n",
"No tests found for on_train_end
. To contribute a test please refer to this guide and this discussion.
class
GANModule
[source][test]GANModule
(**`generator`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=***`None`***, **`critic`**:[`Module`](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)=***`None`***, **`gen_mode`**:`bool`=***`False`***) :: [`PrePostInitMeta`](/core.html#PrePostInitMeta) :: [`Module`](/torch_core.html#Module)\n",
"\n",
"switch
[source][test]switch
(**`gen_mode`**:`bool`=***`None`***)\n",
"\n",
"class
GANLoss
[source][test]GANLoss
(**`loss_funcG`**:`Callable`, **`loss_funcC`**:`Callable`, **`gan_model`**:[`GANModule`](/vision.gan.html#GANModule)) :: [`PrePostInitMeta`](/core.html#PrePostInitMeta) :: [`GANModule`](/vision.gan.html#GANModule)\n",
"\n",
"No tests found for GANLoss
. To contribute a test please refer to this guide and this discussion.
class
AdaptiveLoss
[source][test]AdaptiveLoss
(**`crit`**) :: [`PrePostInitMeta`](/core.html#PrePostInitMeta) :: [`Module`](/torch_core.html#Module)\n",
"\n",
"No tests found for AdaptiveLoss
. To contribute a test please refer to this guide and this discussion.
accuracy_thresh_expand
[source][test]accuracy_thresh_expand
(**`y_pred`**:`Tensor`, **`y_true`**:`Tensor`, **`thresh`**:`float`=***`0.5`***, **`sigmoid`**:`bool`=***`True`***) → `Rank0Tensor`\n",
"\n",
"No tests found for accuracy_thresh_expand
. To contribute a test please refer to this guide and this discussion.
class
NoisyItem
[source][test]NoisyItem
(**`noise_sz`**) :: [`ItemBase`](/core.html#ItemBase)\n",
"\n",
"class
GANItemList
[source][test]GANItemList
(**`items`**, **`noise_sz`**:`int`=***`100`***, **\\*\\*`kwargs`**) :: [`ImageList`](/vision.data.html#ImageList)\n",
"\n",
"show_xys
[source][test]show_xys
(**`xs`**, **`ys`**, **`imgsize`**:`int`=***`4`***, **`figsize`**:`Optional`\\[`Tuple`\\[`int`, `int`\\]\\]=***`None`***, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for show_xys
. To contribute a test please refer to this guide and this discussion.
show_xyzs
[source][test]show_xyzs
(**`xs`**, **`ys`**, **`zs`**, **`imgsize`**:`int`=***`4`***, **`figsize`**:`Optional`\\[`Tuple`\\[`int`, `int`\\]\\]=***`None`***, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for show_xyzs
. To contribute a test please refer to this guide and this discussion.
critic
[source][test]critic
(**`real_pred`**, **`input`**)\n",
"\n",
"forward
[source][test]forward
(**\\*`args`**)\n",
"\n",
"No tests found for forward
. To contribute a test please refer to this guide and this discussion.
generator
[source][test]generator
(**`output`**, **`target`**)\n",
"\n",
"apply_tfms
[source][test]apply_tfms
(**`tfms`**, **\\*\\*`kwargs`**)\n",
"\n",
"No tests found for apply_tfms
. To contribute a test please refer to this guide and this discussion.
forward
[source][test]forward
(**`output`**, **`target`**)\n",
"\n",
"No tests found for forward
. To contribute a test please refer to this guide and this discussion.
get
[source][test]get
(**`i`**)\n",
"\n",
"No tests found for get
. To contribute a test please refer to this guide and this discussion.
reconstruct
[source][test]reconstruct
(**`t`**)\n",
"\n",
"No tests found for reconstruct
. To contribute a test please refer to this guide and this discussion.
forward
[source][test]forward
(**`output`**, **`target`**)\n",
"\n",
"No tests found for forward
. To contribute a test please refer to this guide and this discussion.