{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Quickstart Guide\n\nQuickstart guide to the POT toolbox.\n\nFor better readability, only the use of POT is provided and the plotting code\nwith matplotlib is hidden (but is available in the source file of the example).\n\n

Note

We use here the unified API of POT which is more flexible and allows to solve a wider range of problems with just a few functions. The classical API is still available (the unified API\n one is a convenient wrapper around the classical one) and we provide pointers to the\n classical API when needed.

\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary\n#\n# License: MIT License\n# sphinx_gallery_thumbnail_number = 18\n\n# Import necessary libraries\n\nimport numpy as np\nimport pylab as pl\n\nimport ot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2D data example\n\nWe first generate two sets of samples in 2D that 25 and 50\nsamples respectively located on circles. The weights of the samples are\nuniform.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Problem size\nn1 = 25\nn2 = 50\n\n# Generate random data\nnp.random.seed(0)\na = ot.utils.unif(n1) # weights of points in the source domain\nb = ot.utils.unif(n2) # weights of points in the target domain\n\nx1 = np.random.randn(n1, 2)\nx1 /= np.sqrt(np.sum(x1**2, 1, keepdims=True)) / 2\n\nx2 = np.random.randn(n2, 2)\nx2 /= np.sqrt(np.sum(x2**2, 1, keepdims=True)) / 4\n\n# Compute the cost matrix\nC = ot.dist(x1, x2) # Squared Euclidean cost matrix by default" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solving exact Optimal Transport\nSolve the Optimal Transport problem between the samples\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe :func:`ot.solve_sample` function can be used to solve the Optimal Transport problem\nbetween two sets of samples. The function takes as its two first arguments the\npositions of the source and target samples, and returns an :class:`ot.utils.OTResult` object.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Solve the OT problem\nsol = ot.solve_sample(x1, x2, a, b)\n\n# get the OT plan\nP = sol.plan\n\n# get the OT loss\nloss = sol.value\n\n# get the dual potentials\nalpha, beta = sol.potentials\n\nprint(f\"OT loss = {loss:1.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The figure above shows the Optimal Transport plan between the source and target\nsamples. The color intensity represents the amount of mass transported\nbetween the samples. The dual potentials of the OT problem are also shown.\n\nThe weights of the samples in the source and target domains :code:`a` and\n:code:`b` are given to the function. If not provided, the weights are assumed\nto be uniform See :func:`ot.solve_sample` for more details.\n\nThe :class:`ot.utils.OTResult` object contains the following attributes:\n\n- :code:`value`: the value of the OT problem\n- :code:`plan`: the OT matrix\n- :code:`potentials`: Dual potentials of the OT problem\n- :code:`log`: log dictionary of the solver\n\nThe OT matrix $P$ is a matrix of size :code:`(n1, n2)` where\n:code:`P[i,j]` is the amount of mass\ntransported from :code:`x1[i]` to :code:`x2[j]`.\n\nThe OT loss is the sum of the element-wise product of the OT matrix and the\ncost matrix taken by default as the Squared Euclidean distance.\n\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Optimal Transport problem with a custom cost matrix\n\nThe cost matrix can be customized by passing it to the more general\n:func:`ot.solve` function. The cost matrix should be a matrix of size\n:code:`(n1, n2)` where :code:`C[i,j]` is the cost of transporting mass from\n:code:`x1[i]` to :code:`x2[j]`.\n\nIn this example, we use the Citybloc distance as the cost matrix.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Compute the cost matrix\nC_city = ot.dist(x1, x2, metric=\"cityblock\")\n\n# Solve the OT problem with the custom cost matrix\nsol = ot.solve(C_city)\n# the parameters a and b are not provided so uniform weights are assumed\nP_city = sol.plan\n# on empirical data the same can be done with ot.solve_sample :\n# sol = ot.solve_sample(x1, x2, metric='cityblock')\n\n# Compute the OT loss (equivalent to ot.solve(C).value)\nloss_city = sol.value # same as np.sum(P_city * C)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that we show here how to solve the OT problem with a custom cost matrix\nwith the more general :func:`ot.solve` function.\nBut the same can be done with the :func:`ot.solve_sample` function by passing\n:code:`metric='cityblock'` as argument.\n\nThe cost matrix can be computed with the :func:`ot.dist` function which\ncomputes the pairwise distance between two sets of samples or can be provided\ndirectly as a matrix by the user when no samples are available.\n\n

Note

The examples above use the unified API of POT. The classic API is still available\n and and OT plan and loss can be computed with the :func:`ot.emd` and\n the :func:`ot.emd2` functions as below:\n\n```python\nP = ot.emd(a, b, C)\nloss = ot.emd2(a, b, C) # same as np.sum(P*C) but differentiable wrt a/b

\n```\n## Sinkhorn and Regularized OT\n\n### Entropic OT with Sinkhorn algorithm\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Solve the Sinkhorn problem (just add reg parameter value)\nsol = ot.solve_sample(x1, x2, a, b, reg=1e-1)\n\n# get the OT plan and loss\nP_sink = sol.plan\nloss_sink = sol.value # objective value of the Sinkhorn problem (incl. entropy)\nloss_sink_linear = sol.value_linear # np.sum(P_sink * C) linear part of loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Sinkhorn algorithm solves the Entropic Regularized OT problem. The\nregularization strength can be controlled with the :code:`reg` parameter.\nThe Sinkhorn algorithm can be faster than the exact OT solver for large\nregularization strength but the solution is only an approximation of the\nexact OT problem and the OT plan is not sparse.\n\n### Quadratic Regularized OT\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Use quadratic regularization\nP_quad = ot.solve_sample(x1, x2, a, b, reg=3, reg_type=\"L2\").plan\n\nloss_quad = ot.solve_sample(x1, x2, a, b, reg=3, reg_type=\"L2\").value" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We plot above the OT plans obtained with different regularizations. The\nquadratic regularization is another common choice for regularized OT and\npreserves the sparsity of the OT plan.\n\n### Solve the Regularized OT problem with user-defined regularization\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Define a custom regularization function\ndef f(G):\n return 0.5 * np.sum(G**2)\n\n\ndef df(G):\n return G\n\n\nP_reg = ot.solve_sample(x1, x2, a, b, reg=3, reg_type=(f, df)).plan" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

Note

The examples above use the unified API of POT. The classic API is still available\n and and the entropic OT plan and loss can be computed with the\n :func:`ot.sinkhorn` # and :func:`ot.sinkhorn2` functions as below:\n\n```python\nGs = ot.sinkhorn(a, b, C, reg=1e-1)\nloss_sink = ot.sinkhorn2(a, b, C, reg=1e-1)\n```\n For quadratic regularization, the :func:`ot.smooth.smooth_ot_dual` function\n can be used to compute the solution of the regularized OT problem. For\n user-defined regularization, the :func:`ot.optim.cg` function can be used\n to solve the regularized OT problem with Conditional Gradient algorithm.

\n\n## Unbalanced and Partial OT\n\n### Unbalanced Optimal Transport\n\nUnbalanced OT relaxes the marginal constraints and allows for the source and\ntarget total weights to be different. The :func:`ot.solve_sample` function can be\nused to solve the unbalanced OT problem by setting the marginal penalization\n:code:`unbalanced` parameter to a positive value.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Solve the unbalanced OT problem with KL penalization\nP_unb_kl = ot.solve_sample(x1, x2, a, b, unbalanced=5e-2).plan\n\n# Unbalanced with KL penalization ad KL regularization\nP_unb_kl_reg = ot.solve_sample(\n x1, x2, a, b, unbalanced=5e-2, reg=1e-1\n).plan # also regularized\n\n# Unbalanced with L2 penalization\nP_unb_l2 = ot.solve_sample(x1, x2, a, b, unbalanced=7e1, unbalanced_type=\"L2\").plan" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

Note

Solving the unbalanced OT problem with the classic API can be done with the\n :func:`ot.unbalanced.sinkhorn_unbalanced` function as below:\n\n```python\nG_unb_kl = ot.unbalanced.sinkhorn_unbalanced(a, b, C, eps=reg, alpha=unbalanced)

\n```\n### Partial Optimal Transport\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Solve the Unbalanced OT problem with TV penalization (equivalent)\nP_part_pen = ot.solve_sample(x1, x2, a, b, unbalanced=3, unbalanced_type=\"TV\").plan\n\n# Solve the Partial OT problem with mass constraints (only classic API)\nP_part_const = ot.partial.partial_wasserstein(a, b, C, m=0.5) # 50% mass transported" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gromov-Wasserstein and Fused Gromov-Wasserstein\n\n### Gromov-Wasserstein and Entropic GW\n\nThe Gromov-Wasserstein distance is a similarity measure between metric\nmeasure spaces. So it does not require the samples to be in the same space.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Define the metric cost matrices in each spaces\n\nC1 = ot.dist(x1, x1, metric=\"sqeuclidean\")\nC2 = ot.dist(x2, x2, metric=\"sqeuclidean\")\n\nC1 /= C1.max()\nC2 /= C2.max()\n\n# Solve the Gromov-Wasserstein problem\nsol_gw = ot.solve_gromov(C1, C2, a=a, b=b)\nP_gw = sol_gw.plan\nloss_gw = sol_gw.value # quadratic + reg if reg>0\nloss_gw_quad = sol_gw.value_quad # quadratic part of loss\n\n# Solve the Entropic Gromov-Wasserstein problem\nP_egw = ot.solve_gromov(C1, C2, a=a, b=b, reg=1e-2).plan" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

Note

The Gromov-Wasserstein problem can be solved with the classic API using the\n :func:`ot.gromov.gromov_wasserstein` function and the Entropic\n Gromov-Wasserstein problem can be solved with the\n :func:`ot.gromov.entropic_gromov_wasserstein` function.\n\n```python\nP_gw = ot.gromov.gromov_wasserstein(C1, C2, a, b)\nP_egw = ot.gromov.entropic_gromov_wasserstein(C1, C2, a, b, epsilon=reg)\n\nloss_gw = ot.gromov.gromov_wasserstein2(C1, C2, a, b)\nloss_egw = ot.gromov.entropic_gromov_wasserstein2(C1, C2, a, b, epsilon=reg)

\n```\n### Fused Gromov-Wasserstein\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Cost matrix\nM = C / np.max(C)\n\n# Solve FGW problem with alpha=0.1\nsol = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1)\nP_fgw = sol.plan\nloss_fgw = sol.value\nloss_fgw_linear = sol.value_linear # linear part of loss (wrt M)\nloss_fgw_quad = sol.value_quad # quadratic part of loss (wrt C1 and C2)\n\n# Solve entropic FGW problem with alpha=0.1\nP_efgw = ot.solve_gromov(C1, C2, M, a=a, b=b, alpha=0.1, reg=1e-3).plan" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

Note

The Fused Gromov-Wasserstein problem can be solved with the classic API using\n the :func:`ot.gromov.fused_gromov_wasserstein` function and the Entropic\n Fused Gromov-Wasserstein problem can be solved with the\n :func:`ot.gromov.entropic_fused_gromov_wasserstein` function.\n\n```python\nP_fgw = ot.gromov.fused_gromov_wasserstein(C1, C2, M, a, b, alpha=0.1)\nP_efgw = ot.gromov.entropic_fused_gromov_wasserstein(C1, C2, M, a, b, alpha=0.1, epsilon=reg)\n\nloss_fgw = ot.gromov.fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1)\nloss_efgw = ot.gromov.entropic_fused_gromov_wasserstein2(C1, C2, M, a, b, alpha=0.1, epsilon=reg)

\n```\n## Large scale OT\n\nWe discuss here strategies to solve large scale OT problems using approximations\nof the exact OT problem.\n\n### Large scale Sinkhorn\n\nWhen having samples with a large number of points, the Sinkhorn algorithm can\nbe implemented in a Lazy version which is more memory efficient and avoids\nthe computation of the $n \\times m$ cost matrix.\n\nPOT provides two implementation of the lazy Sinkhorn algorithm that return their\nresults in a lazy form of type :class:`ot.utils.LazyTensor`. This object can be\nused to compute the loss or the OT plan in a lazy way or to recover its values\nin a dense form.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Solve the Sinkhorn problem in a lazy way\nsol = ot.solve_sample(x1, x2, a, b, reg=1e-1, lazy=True)\n\n# Solve the sinkhoorn in a lazy way with geomloss\nsol_geo = ot.solve_sample(x1, x2, a, b, reg=1e-1, method=\"geomloss\", lazy=True)\n\n# get the OT lazy plan and loss\nP_sink_lazy = sol.lazy_plan\n\n# recover values for Lazy plan\nP12 = P_sink_lazy[1, 2]\nP1dots = P_sink_lazy[1, :]\n# convert to dense matrix !!warning this can be memory consuming\nP_sink_lazy_dense = P_sink_lazy[:]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

Note

The lazy Sinkhorn algorithm can be found in the classic API with the\n :func:`ot.bregman.empirical_sinkhorn` function with parameter\n :code:`lazy=True`. Similarly the geoloss implementation is available\n with the :func:`ot.bregman.empirical_sinkhorn2_geomloss`.

\n\n\nthe first example shows how to solve the Sinkhorn problem in a lazy way with\nthe default POT implementation. The second example shows how to solve the\nSinkhorn problem in a lazy way with the PyKeops/Geomloss implementation that provides\na very efficient way to solve large scale problems on low dimensionality\nsamples.\n\n### Factored and Low rank OT\n\nThe Sinkhorn algorithm can be implemented in a low rank version that\napproximates the OT plan with a low rank matrix. This can be useful to\naccelerate the computation of the OT plan for large scale problems.\nA similar non-regularized version of low rank factorization is also available.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Solve the Factored OT problem (use lazy=True for large scale)\nP_fact = ot.solve_sample(x1, x2, a, b, method=\"factored\", rank=15).plan\n\nP_lowrank = ot.solve_sample(x1, x2, a, b, reg=0.1, method=\"lowrank\", rank=10).plan" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

Note

The factored OT problem can be solved with the classic API using the\n :func:`ot.factored.factored_optimal_transport` function and the low rank\n OT problem can be solved with the :func:`ot.lowrank.lowrank_sinkhorn` function.

\n\n### Gaussian OT with Bures-Wasserstein\n\nThe Gaussian Wasserstein or Bures-Wasserstein distance is the Wasserstein distance\nbetween Gaussian distributions. It can be used as an approximation of the\nWasserstein distance between empirical distributions by estimating the\ncovariance matrices of the samples.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Compute the Bures-Wasserstein distance\nbw_value = ot.solve_sample(x1, x2, a, b, method=\"gaussian\").value\n\nprint(f\"Exact OT loss = {loss:1.3f}\")\nprint(f\"Bures-Wasserstein distance = {bw_value:1.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "

Note

The Gaussian Wasserstein problem can be solved with the classic API using the\n :func:`ot.gaussian.empirical_bures_wasserstein_distance` function.

\n\n## Comparing all OT plans\n\nThe figure below shows all the OT plans computed in this example.\nThe color intensity represents the amount of mass transported\nbetween the samples.\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# plot all plans" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }