{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Dual OT solvers for entropic and quadratic regularized OT with Pytorch\n\n

Note

Example added in release: 0.8.2.

\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 3\n\nimport numpy as np\nimport matplotlib.pyplot as pl\nimport torch\nimport ot\nimport ot.plot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data generation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.manual_seed(1)\n\nn_source_samples = 100\nn_target_samples = 100\ntheta = 2 * np.pi / 20\nnoise_level = 0.1\n\nXs, ys = ot.datasets.make_data_classif(\"gaussrot\", n_source_samples, nz=noise_level)\nXt, yt = ot.datasets.make_data_classif(\n \"gaussrot\", n_target_samples, theta=theta, nz=noise_level\n)\n\n# one of the target mode changes its variance (no linear mapping)\nXt[yt == 2] *= 3\nXt = Xt + 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(1, (10, 5))\npl.clf()\npl.scatter(Xs[:, 0], Xs[:, 1], marker=\"+\", label=\"Source samples\")\npl.scatter(Xt[:, 0], Xt[:, 1], marker=\"o\", label=\"Target samples\")\npl.legend(loc=0)\npl.title(\"Source and target distributions\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convert data to torch tensors\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "xs = torch.tensor(Xs)\nxt = torch.tensor(Xt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimating dual variables for entropic OT\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "u = torch.randn(n_source_samples, requires_grad=True)\nv = torch.randn(n_source_samples, requires_grad=True)\n\nreg = 0.5\n\noptimizer = torch.optim.Adam([u, v], lr=1)\n\n# number of iteration\nn_iter = 200\n\n\nlosses = []\n\nfor i in range(n_iter):\n # generate noise samples\n\n # minus because we maximize the dual loss\n loss = -ot.stochastic.loss_dual_entropic(u, v, xs, xt, reg=reg)\n losses.append(float(loss.detach()))\n\n if i % 10 == 0:\n print(\"Iter: {:3d}, loss={}\".format(i, losses[-1]))\n\n loss.backward()\n optimizer.step()\n optimizer.zero_grad()\n\n\npl.figure(2)\npl.plot(losses)\npl.grid()\npl.title(\"Dual objective (negative)\")\npl.xlabel(\"Iterations\")\n\nGe = ot.stochastic.plan_dual_entropic(u, v, xs, xt, reg=reg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the estimated entropic OT plan\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(3, (10, 5))\npl.clf()\not.plot.plot2D_samples_mat(Xs, Xt, Ge.detach().numpy(), alpha=0.1)\npl.scatter(Xs[:, 0], Xs[:, 1], marker=\"+\", label=\"Source samples\", zorder=2)\npl.scatter(Xt[:, 0], Xt[:, 1], marker=\"o\", label=\"Target samples\", zorder=2)\npl.legend(loc=0)\npl.title(\"Source and target distributions\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimating dual variables for quadratic OT\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "u = torch.randn(n_source_samples, requires_grad=True)\nv = torch.randn(n_source_samples, requires_grad=True)\n\nreg = 0.01\n\noptimizer = torch.optim.Adam([u, v], lr=1)\n\n# number of iteration\nn_iter = 200\n\n\nlosses = []\n\n\nfor i in range(n_iter):\n # generate noise samples\n\n # minus because we maximize the dual loss\n loss = -ot.stochastic.loss_dual_quadratic(u, v, xs, xt, reg=reg)\n losses.append(float(loss.detach()))\n\n if i % 10 == 0:\n print(\"Iter: {:3d}, loss={}\".format(i, losses[-1]))\n\n loss.backward()\n optimizer.step()\n optimizer.zero_grad()\n\n\npl.figure(4)\npl.plot(losses)\npl.grid()\npl.title(\"Dual objective (negative)\")\npl.xlabel(\"Iterations\")\n\nGq = ot.stochastic.plan_dual_quadratic(u, v, xs, xt, reg=reg)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the estimated quadratic OT plan\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(5, (10, 5))\npl.clf()\not.plot.plot2D_samples_mat(Xs, Xt, Gq.detach().numpy(), alpha=0.1)\npl.scatter(Xs[:, 0], Xs[:, 1], marker=\"+\", label=\"Source samples\", zorder=2)\npl.scatter(Xt[:, 0], Xt[:, 1], marker=\"o\", label=\"Target samples\", zorder=2)\npl.legend(loc=0)\npl.title(\"OT plan with quadratic regularization\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }