{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Translation Invariant Sinkhorn for Unbalanced Optimal Transport\n\n

Note

Example added in release: 0.9.5.

\n\nThis examples illustrates the better convergence of the translation\ninvariance Sinkhorn algorithm proposed in [73] compared to the classical\nSinkhorn algorithm.\n\n[73] S\u00e9journ\u00e9, T., Vialard, F. X., & Peyr\u00e9, G. (2022).\nFaster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe.\nIn International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Cl\u00e9ment Bonet \n# License: MIT License\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting parameters\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_iter = 50 # nb iters\nn = 40 # nb samples\n\nnum_iter_max = 100\nn_noise = 10\n\nreg = 0.005\nreg_m_kl = 0.05\n\nmu_s = np.array([-1, -1])\ncov_s = np.array([[1, 0], [0, 1]])\n\nmu_t = np.array([4, 4])\ncov_t = np.array([[1, -0.8], [-0.8, 1]])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute entropic kl-regularized UOT with Sinkhorn and Translation Invariant Sinkhorn\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "err_sinkhorn_uot = np.empty((n_iter, num_iter_max))\nerr_sinkhorn_uot_ti = np.empty((n_iter, num_iter_max))\n\n\nfor seed in range(n_iter):\n np.random.seed(seed)\n xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)\n xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)\n\n xs = np.concatenate((xs, (np.random.rand(n_noise, 2) - 4)), axis=0)\n xt = np.concatenate((xt, (np.random.rand(n_noise, 2) + 6)), axis=0)\n\n n = n + n_noise\n\n a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples\n\n # loss matrix\n M = ot.dist(xs, xt)\n M /= M.max()\n\n entropic_kl_uot, log_uot = ot.unbalanced.sinkhorn_unbalanced(\n a,\n b,\n M,\n reg,\n reg_m_kl,\n reg_type=\"kl\",\n log=True,\n numItermax=num_iter_max,\n stopThr=0,\n )\n entropic_kl_uot_ti, log_uot_ti = ot.unbalanced.sinkhorn_unbalanced(\n a,\n b,\n M,\n reg,\n reg_m_kl,\n reg_type=\"kl\",\n method=\"sinkhorn_translation_invariant\",\n log=True,\n numItermax=num_iter_max,\n stopThr=0,\n )\n\n err_sinkhorn_uot[seed] = log_uot[\"err\"]\n err_sinkhorn_uot_ti[seed] = log_uot_ti[\"err\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the results\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "mean_sinkh = np.mean(err_sinkhorn_uot, axis=0)\nstd_sinkh = np.std(err_sinkhorn_uot, axis=0)\n\nmean_sinkh_ti = np.mean(err_sinkhorn_uot_ti, axis=0)\nstd_sinkh_ti = np.std(err_sinkhorn_uot_ti, axis=0)\n\nabsc = list(range(num_iter_max))\n\npl.plot(absc, mean_sinkh, label=\"Sinkhorn\")\npl.fill_between(absc, mean_sinkh - 2 * std_sinkh, mean_sinkh + 2 * std_sinkh, alpha=0.5)\n\npl.plot(absc, mean_sinkh_ti, label=\"Translation Invariant Sinkhorn\")\npl.fill_between(\n absc, mean_sinkh_ti - 2 * std_sinkh_ti, mean_sinkh_ti + 2 * std_sinkh_ti, alpha=0.5\n)\n\npl.yscale(\"log\")\npl.legend()\npl.xlabel(\"Number of Iterations\")\npl.ylabel(r\"$\\|u-v\\|_\\infty$\")\npl.grid(True)\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }