{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Regularization path of l2-penalized unbalanced optimal transport\n\n

Note

Example added in release: 0.8.0.

\n\nThis example illustrate the regularization path for 2D unbalanced\noptimal transport. We present here both the fully relaxed case\nand the semi-relaxed case.\n\n[Chapel et al., 2021] Chapel, L., Flamary, R., Wu, H., F\u00e9votte, C.,\nand Gasso, G. (2021). Unbalanced optimal transport through non-negative\npenalized linear regression.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Haoran Wu \n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport matplotlib.animation as animation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 20 # nb samples\n\nmu_s = np.array([-1, -1])\ncov_s = np.array([[1, 0], [0, 1]])\n\nmu_t = np.array([4, 4])\ncov_t = np.array([[1, -0.8], [-0.8, 1]])\n\nnp.random.seed(0)\nxs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)\nxt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)\n\na, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples\n\n# loss matrix\nM = ot.dist(xs, xt)\nM /= M.max()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(1)\npl.scatter(xs[:, 0], xs[:, 1], c=\"C0\", label=\"Source\")\npl.scatter(xt[:, 0], xt[:, 1], c=\"C1\", label=\"Target\")\npl.legend(loc=2)\npl.title(\"Source and target distributions\")\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute semi-relaxed and fully relaxed regularization paths\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "final_gamma = 1e-6\nt, t_list, g_list = ot.regpath.regularization_path(\n a, b, M, reg=final_gamma, semi_relaxed=False\n)\nt2, t_list2, g_list2 = ot.regpath.regularization_path(\n a, b, M, reg=final_gamma, semi_relaxed=True\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the regularization path\n\nThe OT plan is plotted as a function of $\\gamma$ that is the inverse of the\nweight on the marginal relaxations.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(2)\nselected_gamma = [2e-1, 1e-1, 5e-2, 1e-3]\nfor p in range(4):\n tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list, t_list)\n P = tp.reshape((n, n))\n pl.subplot(2, 2, p + 1)\n if P.sum() > 0:\n P = P / P.max()\n for i in range(n):\n for j in range(n):\n if P[i, j] > 0:\n pl.plot(\n [xs[i, 0], xt[j, 0]],\n [xs[i, 1], xt[j, 1]],\n color=\"C2\",\n alpha=P[i, j] * 0.3,\n )\n pl.scatter(xs[:, 0], xs[:, 1], c=\"C0\", alpha=0.2)\n pl.scatter(xt[:, 0], xt[:, 1], c=\"C1\", alpha=0.2)\n pl.scatter(\n xs[:, 0],\n xs[:, 1],\n c=\"C0\",\n s=P.sum(1).ravel() * (1 + p) * 2,\n label=\"Re-weighted source\",\n alpha=1,\n )\n pl.scatter(\n xt[:, 0],\n xt[:, 1],\n c=\"C1\",\n s=P.sum(0).ravel() * (1 + p) * 2,\n label=\"Re-weighted target\",\n alpha=1,\n )\n pl.plot([], [], color=\"C2\", alpha=0.8, label=\"OT plan\")\n pl.title(r\"$\\ell_2$ UOT $\\gamma$={}\".format(selected_gamma[p]), fontsize=11)\n if p < 2:\n pl.xticks(())\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Animation of the regpath for UOT l2\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "nv = 50\ng_list_v = np.logspace(-0.5, -2.5, nv)\n\npl.figure(3)\n\n\ndef _update_plot(iv):\n pl.clf()\n tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list, t_list)\n P = tp.reshape((n, n))\n if P.sum() > 0:\n P = P / P.max()\n for i in range(n):\n for j in range(n):\n if P[i, j] > 0:\n pl.plot(\n [xs[i, 0], xt[j, 0]],\n [xs[i, 1], xt[j, 1]],\n color=\"C2\",\n alpha=P[i, j] * 0.5,\n )\n pl.scatter(xs[:, 0], xs[:, 1], c=\"C0\", alpha=0.2)\n pl.scatter(xt[:, 0], xt[:, 1], c=\"C1\", alpha=0.2)\n pl.scatter(\n xs[:, 0],\n xs[:, 1],\n c=\"C0\",\n s=P.sum(1).ravel() * (1 + p) * 4,\n label=\"Re-weighted source\",\n alpha=1,\n )\n pl.scatter(\n xt[:, 0],\n xt[:, 1],\n c=\"C1\",\n s=P.sum(0).ravel() * (1 + p) * 4,\n label=\"Re-weighted target\",\n alpha=1,\n )\n pl.plot([], [], color=\"C2\", alpha=0.8, label=\"OT plan\")\n pl.title(r\"$\\ell_2$ UOT $\\gamma$={:1.3f}\".format(g_list_v[iv]), fontsize=11)\n return 1\n\n\ni = 0\n_update_plot(i)\n\nani = animation.FuncAnimation(\n pl.gcf(), _update_plot, nv, interval=100, repeat_delay=2000\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the semi-relaxed regularization path\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(4)\nselected_gamma = [10, 1, 1e-1, 1e-2]\nfor p in range(4):\n tp = ot.regpath.compute_transport_plan(selected_gamma[p], g_list2, t_list2)\n P = tp.reshape((n, n))\n pl.subplot(2, 2, p + 1)\n if P.sum() > 0:\n P = P / P.max()\n for i in range(n):\n for j in range(n):\n if P[i, j] > 0:\n pl.plot(\n [xs[i, 0], xt[j, 0]],\n [xs[i, 1], xt[j, 1]],\n color=\"C2\",\n alpha=P[i, j] * 0.3,\n )\n pl.scatter(xs[:, 0], xs[:, 1], c=\"C0\", alpha=0.2)\n pl.scatter(xt[:, 0], xt[:, 1], c=\"C1\", alpha=1, label=\"Target marginal\")\n pl.scatter(\n xs[:, 0],\n xs[:, 1],\n c=\"C0\",\n s=P.sum(1).ravel() * 2 * (1 + p),\n label=\"Source marginal\",\n alpha=1,\n )\n pl.plot([], [], color=\"C2\", alpha=0.8, label=\"OT plan\")\n pl.title(\n r\"Semi-relaxed $l_2$ UOT $\\gamma$={}\".format(selected_gamma[p]), fontsize=11\n )\n if p < 2:\n pl.xticks(())\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Animation of the regpath for semi-relaxed UOT l2\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "nv = 50\ng_list_v = np.logspace(2, -2, nv)\n\npl.figure(5)\n\n\ndef _update_plot(iv):\n pl.clf()\n tp = ot.regpath.compute_transport_plan(g_list_v[iv], g_list2, t_list2)\n P = tp.reshape((n, n))\n if P.sum() > 0:\n P = P / P.max()\n for i in range(n):\n for j in range(n):\n if P[i, j] > 0:\n pl.plot(\n [xs[i, 0], xt[j, 0]],\n [xs[i, 1], xt[j, 1]],\n color=\"C2\",\n alpha=P[i, j] * 0.5,\n )\n pl.scatter(xs[:, 0], xs[:, 1], c=\"C0\", alpha=0.2)\n pl.scatter(xt[:, 0], xt[:, 1], c=\"C1\", alpha=0.2)\n pl.scatter(\n xs[:, 0],\n xs[:, 1],\n c=\"C0\",\n s=P.sum(1).ravel() * (1 + p) * 4,\n label=\"Re-weighted source\",\n alpha=1,\n )\n pl.scatter(\n xt[:, 0],\n xt[:, 1],\n c=\"C1\",\n s=P.sum(0).ravel() * (1 + p) * 4,\n label=\"Re-weighted target\",\n alpha=1,\n )\n pl.plot([], [], color=\"C2\", alpha=0.8, label=\"OT plan\")\n pl.title(\n r\"Semi-relaxed $\\ell_2$ UOT $\\gamma$={:1.3f}\".format(g_list_v[iv]), fontsize=11\n )\n return 1\n\n\ni = 0\n_update_plot(i)\n\nani = animation.FuncAnimation(\n pl.gcf(), _update_plot, nv, interval=100, repeat_delay=2000\n)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }