{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Wasserstein 2 Minibatch GAN with PyTorch\n\n

Note

Example added in release: 0.8.0.

\n\nIn this example we train a Wasserstein GAN using Wasserstein 2 on minibatches\nas a distribution fitting term.\n\nWe want to train a generator $G_\\theta$ that generates realistic\ndata from random noise drawn form a Gaussian $\\mu_n$ distribution so\nthat the data is indistinguishable from true data in the data distribution\n$\\mu_d$. To this end Wasserstein GAN [Arjovsky2017] aim at optimizing\nthe parameters $\\theta$ of the generator with the following\noptimization problem:\n\n\\begin{align}\\min_{\\theta} W(\\mu_d,G_\\theta\\#\\mu_n)\\end{align}\n\n\nIn practice we do not have access to the full distribution $\\mu_d$ but\nsamples and we cannot compute the Wasserstein distance for large dataset.\n[Arjovsky2017] proposed to approximate the dual potential of Wasserstein 1\nwith a neural network recovering an optimization problem similar to GAN.\nIn this example\nwe will optimize the expectation of the Wasserstein distance over minibatches\nat each iterations as proposed in [Genevay2018]. Optimizing the Minibatches\nof the Wasserstein distance has been studied in [Fatras2019].\n\n[Arjovsky2017] Arjovsky, M., Chintala, S., & Bottou, L. (2017, July).\nWasserstein generative adversarial networks. In International conference\non machine learning (pp. 214-223). PMLR.\n\n[Genevay2018] Genevay, Aude, Gabriel Peyr\u00e9, and Marco Cuturi. \"Learning generative models\nwith sinkhorn divergences.\" International Conference on Artificial Intelligence\nand Statistics. PMLR, 2018.\n\n[Fatras2019] Fatras, K., Zine, Y., Flamary, R., Gribonval, R., & Courty, N.\n(2020, June). Learning with minibatch Wasserstein: asymptotic and gradient\nproperties. In the 23nd International Conference on Artificial Intelligence\nand Statistics (Vol. 108).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 3\n\nimport numpy as np\nimport matplotlib.pyplot as pl\nimport matplotlib.animation as animation\nimport torch\nfrom torch import nn\nimport ot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data generation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.manual_seed(1)\nsigma = 0.1\nn_dims = 2\nn_features = 2\n\n\ndef get_data(n_samples):\n c = torch.rand(size=(n_samples, 1))\n angle = c * 2 * np.pi\n x = torch.cat((torch.cos(angle), torch.sin(angle)), 1)\n x += torch.randn(n_samples, 2) * sigma\n return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# plot the distributions\nx = get_data(500)\npl.figure(1)\npl.scatter(x[:, 0], x[:, 1], label=\"Data samples from $\\mu_d$\", alpha=0.5)\npl.title(\"Data distribution\")\npl.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generator Model\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# define the MLP model\nclass Generator(torch.nn.Module):\n def __init__(self):\n super(Generator, self).__init__()\n self.fc1 = nn.Linear(n_features, 200)\n self.fc2 = nn.Linear(200, 500)\n self.fc3 = nn.Linear(500, n_dims)\n self.relu = torch.nn.ReLU() # instead of Heaviside step fn\n\n def forward(self, x):\n output = self.fc1(x)\n output = self.relu(output) # instead of Heaviside step fn\n output = self.fc2(output)\n output = self.relu(output)\n output = self.fc3(output)\n return output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training the model\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "G = Generator()\noptimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5)\n\n# number of iteration and size of the batches\nn_iter = 200 # set to 200 for doc build but 1000 is better ;)\nsize_batch = 500\n\n# generate statis samples to see their trajectory along training\nn_visu = 100\nxnvisu = torch.randn(n_visu, n_features)\nxvisu = torch.zeros(n_iter, n_visu, n_dims)\n\nab = torch.ones(size_batch) / size_batch\nlosses = []\n\n\nfor i in range(n_iter):\n # generate noise samples\n xn = torch.randn(size_batch, n_features)\n\n # generate data samples\n xd = get_data(size_batch)\n\n # generate sample along iterations\n xvisu[i, :, :] = G(xnvisu).detach()\n\n # generate samples and compte distance matrix\n xg = G(xn)\n M = ot.dist(xg, xd)\n\n loss = ot.emd2(ab, ab, M)\n losses.append(float(loss.detach()))\n\n if i % 10 == 0:\n print(\"Iter: {:3d}, loss={}\".format(i, losses[-1]))\n\n loss.backward()\n optimizer.step()\n optimizer.zero_grad()\n\n del M\n\npl.figure(2)\npl.semilogy(losses)\npl.grid()\npl.title(\"Wasserstein distance\")\npl.xlabel(\"Iterations\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot trajectories of generated samples along iterations\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(3, (10, 10))\n\nivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199]\n\nfor i in range(9):\n pl.subplot(3, 3, i + 1)\n pl.scatter(xd[:, 0], xd[:, 1], label=\"Data samples from $\\mu_d$\", alpha=0.1)\n pl.scatter(\n xvisu[ivisu[i], :, 0],\n xvisu[ivisu[i], :, 1],\n label=\"Data samples from $G\\#\\mu_n$\",\n alpha=0.5,\n )\n pl.xticks(())\n pl.yticks(())\n pl.title(\"Iter. {}\".format(ivisu[i]))\n if i == 0:\n pl.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Animate trajectories of generated samples along iteration\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4, (8, 8))\n\n\ndef _update_plot(i):\n pl.clf()\n pl.scatter(xd[:, 0], xd[:, 1], label=\"Data samples from $\\mu_d$\", alpha=0.1)\n pl.scatter(\n xvisu[i, :, 0], xvisu[i, :, 1], label=\"Data samples from $G\\#\\mu_n$\", alpha=0.5\n )\n pl.xticks(())\n pl.yticks(())\n pl.xlim((-1.5, 1.5))\n pl.ylim((-1.5, 1.5))\n pl.title(\"Iter. {}\".format(i))\n return 1\n\n\ni = 0\npl.scatter(xd[:, 0], xd[:, 1], label=\"Data samples from $\\mu_d$\", alpha=0.1)\npl.scatter(\n xvisu[i, :, 0], xvisu[i, :, 1], label=\"Data samples from $G\\#\\mu_n$\", alpha=0.5\n)\npl.xticks(())\npl.yticks(())\npl.xlim((-1.5, 1.5))\npl.ylim((-1.5, 1.5))\npl.title(\"Iter. {}\".format(ivisu[i]))\n\n\nani = animation.FuncAnimation(\n pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate and visualize data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "size_batch = 500\nxd = get_data(size_batch)\nxn = torch.randn(size_batch, 2)\nx = G(xn).detach().numpy()\n\npl.figure(5)\npl.scatter(xd[:, 0], xd[:, 1], label=\"Data samples from $\\mu_d$\", alpha=0.5)\npl.scatter(x[:, 0], x[:, 1], label=\"Data samples from $G\\#\\mu_n$\", alpha=0.5)\npl.title(\"Sources and Target distributions\")\npl.legend()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }