{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# 2D examples of exact and entropic unbalanced optimal transport\n\n

Note

Example added in release: 0.8.2.

\n\nThis example is designed to show how to compute unbalanced and\npartial OT in POT.\n\nUOT aims at solving the following optimization problem:\n\n .. math::\n W = \\min_{\\gamma} <\\gamma, \\mathbf{M}>_F +\n \\mathrm{reg}\\cdot\\Omega(\\gamma) +\n \\mathrm{reg_m} \\cdot \\mathrm{div}(\\gamma \\mathbf{1}, \\mathbf{a}) +\n \\mathrm{reg_m} \\cdot \\mathrm{div}(\\gamma^T \\mathbf{1}, \\mathbf{b})\n\n s.t.\n \\gamma \\geq 0\n\nwhere $\\mathrm{div}$ is a divergence.\nWhen using the entropic UOT, $\\mathrm{reg}>0$ and $\\mathrm{div}$\nshould be the Kullback-Leibler divergence.\nWhen solving exact UOT, $\\mathrm{reg}=0$ and $\\mathrm{div}$\ncan be either the Kullback-Leibler or the quadratic divergence.\nUsing $\\ell_1$ norm gives the so-called partial OT.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Laetitia Chapel \n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 40 # nb samples\n\nmu_s = np.array([-1, -1])\ncov_s = np.array([[1, 0], [0, 1]])\n\nmu_t = np.array([4, 4])\ncov_t = np.array([[1, -0.8], [-0.8, 1]])\n\nnp.random.seed(0)\nxs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)\nxt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)\n\nn_noise = 10\n\nxs = np.concatenate((xs, (np.random.rand(n_noise, 2) - 4)), axis=0)\nxt = np.concatenate((xt, (np.random.rand(n_noise, 2) + 6)), axis=0)\n\nn = n + n_noise\n\na, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples\n\n# loss matrix\nM = ot.dist(xs, xt)\nM /= M.max()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute entropic kl-regularized UOT, kl- and l2-regularized UOT\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "reg = 0.005\nreg_m_kl = 0.05\nreg_m_l2 = 5\nmass = 0.7\n\nentropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl)\nkl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div=\"kl\")\nl2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div=\"l2\")\npartial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the results\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(2)\ntransp = [partial_ot, l2_uot, kl_uot, entropic_kl_uot]\ntitle = [\n \"partial OT \\n m=\" + str(mass),\n \"$\\ell_2$-UOT \\n $\\mathrm{reg_m}$=\" + str(reg_m_l2),\n \"kl-UOT \\n $\\mathrm{reg_m}$=\" + str(reg_m_kl),\n \"entropic kl-UOT \\n $\\mathrm{reg_m}$=\" + str(reg_m_kl),\n]\n\nfor p in range(4):\n pl.subplot(2, 4, p + 1)\n P = transp[p]\n if P.sum() > 0:\n P = P / P.max()\n for i in range(n):\n for j in range(n):\n if P[i, j] > 0:\n pl.plot(\n [xs[i, 0], xt[j, 0]],\n [xs[i, 1], xt[j, 1]],\n color=\"C2\",\n alpha=P[i, j] * 0.3,\n )\n pl.scatter(xs[:, 0], xs[:, 1], c=\"C0\", alpha=0.2)\n pl.scatter(xt[:, 0], xt[:, 1], c=\"C1\", alpha=0.2)\n pl.scatter(xs[:, 0], xs[:, 1], c=\"C0\", s=P.sum(1).ravel() * (1 + p) * 2)\n pl.scatter(xt[:, 0], xt[:, 1], c=\"C1\", s=P.sum(0).ravel() * (1 + p) * 2)\n pl.title(title[p])\n pl.yticks(())\n pl.xticks(())\n if p < 1:\n pl.ylabel(\"mappings\")\n pl.subplot(2, 4, p + 5)\n pl.imshow(P, cmap=\"jet\")\n pl.yticks(())\n pl.xticks(())\n if p < 1:\n pl.ylabel(\"transport plans\")\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }