{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Wasserstein 1D (flow and barycenter) with PyTorch\n\nIn this small example, we consider the following minimization problem:\n\n\\begin{align}\\mu^* = \\min_\\mu W(\\mu,\\nu)\\end{align}\n\nwhere $\\nu$ is a reference 1D measure. The problem is handled\nby a projected gradient descent method, where the gradient is computed\nby pyTorch automatic differentiation. The projection on the simplex\nensures that the iterate will remain on the probability simplex.\n\nThis example illustrates both `wasserstein_1d` function and backend use within\nthe POT framework.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Nicolas Courty \n# R\u00e9mi Flamary \n#\n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport matplotlib as mpl\nimport torch\n\nfrom ot.lp import wasserstein_1d\nfrom ot.datasets import make_1D_gauss as gauss\nfrom ot.utils import proj_simplex\n\nred = np.array(mpl.colors.to_rgb(\"red\"))\nblue = np.array(mpl.colors.to_rgb(\"blue\"))\n\n\nn = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\nb = gauss(n, m=60, s=10)\n\n# enforce sum to one on the support\na = a / a.sum()\nb = b / b.sum()\n\n\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n# use pyTorch for our data\nx_torch = torch.tensor(x).to(device=device)\na_torch = torch.tensor(a).to(device=device).requires_grad_(True)\nb_torch = torch.tensor(b).to(device=device)\n\nlr = 1e-6\nnb_iter_max = 800\n\nloss_iter = []\n\npl.figure(1, figsize=(8, 4))\npl.plot(x, a, \"b\", label=\"Source distribution\")\npl.plot(x, b, \"r\", label=\"Target distribution\")\n\nfor i in range(nb_iter_max):\n # Compute the Wasserstein 1D with torch backend\n loss = wasserstein_1d(x_torch, x_torch, a_torch, b_torch, p=2)\n # record the corresponding loss value\n loss_iter.append(loss.clone().detach().cpu().numpy())\n loss.backward()\n\n # performs a step of projected gradient descent\n with torch.no_grad():\n grad = a_torch.grad\n a_torch -= a_torch.grad * lr # step\n a_torch.grad.zero_()\n a_torch.data = proj_simplex(a_torch) # projection onto the simplex\n\n # plot one curve every 10 iterations\n if i % 10 == 0:\n mix = float(i) / nb_iter_max\n pl.plot(\n x, a_torch.clone().detach().cpu().numpy(), c=(1 - mix) * blue + mix * red\n )\n\npl.legend()\npl.title(\"Distribution along the iterations of the projected gradient descent\")\npl.show()\n\npl.figure(2)\npl.plot(range(nb_iter_max), loss_iter, lw=3)\npl.title(\"Evolution of the loss along iterations\", fontsize=16)\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Wasserstein barycenter\nIn this example, we consider the following Wasserstein barycenter problem\n$$ \\\\eta^* = \\\\min_\\\\eta\\;\\;\\; (1-t)W(\\\\mu,\\\\eta) + tW(\\\\eta,\\\\nu)$$\nwhere $\\\\mu$ and $\\\\nu$ are reference 1D measures, and $t$\nis a parameter $\\in [0,1]$. The problem is handled by a project gradient\ndescent method, where the gradient is computed by pyTorch automatic differentiation.\nThe projection on the simplex ensures that the iterate will remain on the\nprobability simplex.\n\nThis example illustrates both `wasserstein_1d` function and backend use within the\nPOT framework.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n# use pyTorch for our data\nx_torch = torch.tensor(x).to(device=device)\na_torch = torch.tensor(a).to(device=device)\nb_torch = torch.tensor(b).to(device=device)\nbary_torch = torch.tensor((a + b).copy() / 2).to(device=device).requires_grad_(True)\n\n\nlr = 1e-6\nnb_iter_max = 2000\n\nloss_iter = []\n\n# instant of the interpolation\nt = 0.5\n\nfor i in range(nb_iter_max):\n # Compute the Wasserstein 1D with torch backend\n loss = (1 - t) * wasserstein_1d(\n x_torch, x_torch, a_torch.detach(), bary_torch, p=2\n ) + t * wasserstein_1d(x_torch, x_torch, b_torch, bary_torch, p=2)\n # record the corresponding loss value\n loss_iter.append(loss.clone().detach().cpu().numpy())\n loss.backward()\n\n # performs a step of projected gradient descent\n with torch.no_grad():\n grad = bary_torch.grad\n bary_torch -= bary_torch.grad * lr # step\n bary_torch.grad.zero_()\n bary_torch.data = proj_simplex(bary_torch) # projection onto the simplex\n\npl.figure(3, figsize=(8, 4))\npl.plot(x, a, \"b\", label=\"Source distribution\")\npl.plot(x, b, \"r\", label=\"Target distribution\")\npl.plot(x, bary_torch.clone().detach().cpu().numpy(), c=\"green\", label=\"W barycenter\")\npl.legend()\npl.title(\"Wasserstein barycenter computed by gradient descent\")\npl.show()\n\npl.figure(4)\npl.plot(range(nb_iter_max), loss_iter, lw=3)\npl.title(\"Evolution of the loss along iterations\", fontsize=16)\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }