{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Continuous OT plan estimation with Pytorch\n\n

Note

Example added in release: 0.8.2.

\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 3\n\nimport numpy as np\nimport matplotlib.pyplot as pl\nimport torch\nfrom torch import nn\nimport ot\nimport ot.plot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data generation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.manual_seed(42)\nnp.random.seed(42)\n\nn_source_samples = 1000\nn_target_samples = 1000\ntheta = 2 * np.pi / 20\nnoise_level = 0.1\n\nXs = np.random.randn(n_source_samples, 2) * 0.5\nXt = np.random.randn(n_target_samples, 2) * 2\n\n# one of the target mode changes its variance (no linear mapping)\nXt = Xt + 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "nvisu = 300\npl.figure(1, (5, 5))\npl.clf()\npl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker=\"+\", label=\"Source samples\", alpha=0.5)\npl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker=\"o\", label=\"Target samples\", alpha=0.5)\npl.legend(loc=0)\nax_bounds = pl.axis()\npl.title(\"Source and target distributions\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert data to torch tensors\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "xs = torch.tensor(Xs)\nxt = torch.tensor(Xt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimating deep dual variables for entropic OT\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.manual_seed(42)\n\n# define the MLP model\n\n\nclass Potential(torch.nn.Module):\n def __init__(self):\n super(Potential, self).__init__()\n self.fc1 = nn.Linear(2, 200)\n self.fc2 = nn.Linear(200, 1)\n self.relu = torch.nn.ReLU() # instead of Heaviside step fn\n\n def forward(self, x):\n output = self.fc1(x)\n output = self.relu(output) # instead of Heaviside step fn\n output = self.fc2(output)\n return output.ravel()\n\n\nu = Potential().double()\nv = Potential().double()\n\nreg = 1\n\noptimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=0.005)\n\n# number of iteration\nn_iter = 500\nn_batch = 500\n\n\nlosses = []\n\nfor i in range(n_iter):\n # generate noise samples\n\n iperms = torch.randint(0, n_source_samples, (n_batch,))\n ipermt = torch.randint(0, n_target_samples, (n_batch,))\n\n xsi = xs[iperms]\n xti = xt[ipermt]\n\n # minus because we maximize the dual loss\n loss = -ot.stochastic.loss_dual_entropic(u(xsi), v(xti), xsi, xti, reg=reg)\n losses.append(float(loss.detach()))\n\n if i % 10 == 0:\n print(\"Iter: {:3d}, loss={}\".format(i, losses[-1]))\n\n loss.backward()\n optimizer.step()\n optimizer.zero_grad()\n\n\npl.figure(2)\npl.plot(losses)\npl.grid()\npl.title(\"Dual objective (negative)\")\npl.xlabel(\"Iterations\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the density on target for a given source sample\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "nv = 100\nxl = np.linspace(ax_bounds[0], ax_bounds[1], nv)\nyl = np.linspace(ax_bounds[2], ax_bounds[3], nv)\n\nXX, YY = np.meshgrid(xl, yl)\n\nxg = np.concatenate((XX.ravel()[:, None], YY.ravel()[:, None]), axis=1)\n\nwxg = np.exp(-((xg[:, 0] - 4) ** 2 + (xg[:, 1] - 4) ** 2) / (2 * 2))\nwxg = wxg / np.sum(wxg)\n\nxg = torch.tensor(xg)\nwxg = torch.tensor(wxg)\n\n\npl.figure(4, (12, 4))\npl.clf()\npl.subplot(1, 3, 1)\n\niv = 2\nGg = ot.stochastic.plan_dual_entropic(\n u(xs[iv : iv + 1, :]), v(xg), xs[iv : iv + 1, :], xg, reg=reg, wt=wxg\n)\nGg = Gg.reshape((nv, nv)).detach().numpy()\n\npl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker=\"+\", zorder=2, alpha=0.05)\npl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker=\"o\", zorder=2, alpha=0.05)\npl.scatter(\n Xs[iv : iv + 1, 0],\n Xs[iv : iv + 1, 1],\n s=100,\n marker=\"+\",\n label=\"Source sample\",\n zorder=2,\n alpha=1,\n color=\"C0\",\n)\npl.pcolormesh(XX, YY, Gg, cmap=\"Greens\", label=\"Density of transported source sample\")\npl.legend(loc=0)\nax_bounds = pl.axis()\npl.title(\"Density of transported source sample\")\n\npl.subplot(1, 3, 2)\n\niv = 3\nGg = ot.stochastic.plan_dual_entropic(\n u(xs[iv : iv + 1, :]), v(xg), xs[iv : iv + 1, :], xg, reg=reg, wt=wxg\n)\nGg = Gg.reshape((nv, nv)).detach().numpy()\n\npl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker=\"+\", zorder=2, alpha=0.05)\npl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker=\"o\", zorder=2, alpha=0.05)\npl.scatter(\n Xs[iv : iv + 1, 0],\n Xs[iv : iv + 1, 1],\n s=100,\n marker=\"+\",\n label=\"Source sample\",\n zorder=2,\n alpha=1,\n color=\"C0\",\n)\npl.pcolormesh(XX, YY, Gg, cmap=\"Greens\", label=\"Density of transported source sample\")\npl.legend(loc=0)\nax_bounds = pl.axis()\npl.title(\"Density of transported source sample\")\n\npl.subplot(1, 3, 3)\n\niv = 6\nGg = ot.stochastic.plan_dual_entropic(\n u(xs[iv : iv + 1, :]), v(xg), xs[iv : iv + 1, :], xg, reg=reg, wt=wxg\n)\nGg = Gg.reshape((nv, nv)).detach().numpy()\n\npl.scatter(Xs[:nvisu, 0], Xs[:nvisu, 1], marker=\"+\", zorder=2, alpha=0.05)\npl.scatter(Xt[:nvisu, 0], Xt[:nvisu, 1], marker=\"o\", zorder=2, alpha=0.05)\npl.scatter(\n Xs[iv : iv + 1, 0],\n Xs[iv : iv + 1, 1],\n s=100,\n marker=\"+\",\n label=\"Source sample\",\n zorder=2,\n alpha=1,\n color=\"C0\",\n)\npl.pcolormesh(XX, YY, Gg, cmap=\"Greens\", label=\"Density of transported source sample\")\npl.legend(loc=0)\nax_bounds = pl.axis()\npl.title(\"Density of transported source sample\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }