{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Optimal Transport solvers comparison\n\nThis example illustrates the solutions returns for different variants of exact,\nregularized and unbalanced OT solvers.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary \n#\n# License: MIT License\n# sphinx_gallery_thumbnail_number = 3" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\nimport matplotlib.pylab as pl\nimport ot\nimport ot.plot\nfrom ot.datasets import make_1D_gauss as gauss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 50 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na = 0.6 * gauss(n, m=15, s=5) + 0.4 * gauss(n, m=35, s=5) # m= mean, s= std\nb = gauss(n, m=25, s=5)\n\n# loss matrix\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))\nM /= M.max()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot distributions and loss matrix\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(1, figsize=(6.4, 3))\npl.plot(x, a, \"b\", label=\"Source distribution\")\npl.plot(x, b, \"r\", label=\"Target distribution\")\npl.legend()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(2, figsize=(5, 5))\not.plot.plot1D_mat(a, b, M, \"Cost matrix M\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define Group lasso regularization and gradient\nThe groups are the first and second half of the columns of G\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def reg_gl(G): # group lasso + small l2 reg\n G1 = G[: n // 2, :] ** 2\n G2 = G[n // 2 :, :] ** 2\n gl1 = np.sum(np.sqrt(np.sum(G1, 0)))\n gl2 = np.sum(np.sqrt(np.sum(G2, 0)))\n return gl1 + gl2 + 0.1 * np.sum(G**2)\n\n\ndef grad_gl(G): # gradient of group lasso + small l2 reg\n G1 = G[: n // 2, :]\n G2 = G[n // 2 :, :]\n gl1 = G1 / np.sqrt(np.sum(G1**2, 0, keepdims=True) + 1e-8)\n gl2 = G2 / np.sqrt(np.sum(G2**2, 0, keepdims=True) + 1e-8)\n return np.concatenate((gl1, gl2), axis=0) + 0.2 * G\n\n\nreg_type_gl = (reg_gl, grad_gl)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set up parameters for solvers and solve\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "lst_regs = [\"No Reg.\", \"Entropic\", \"L2\", \"Group Lasso + L2\"]\nlst_unbalanced = [\n \"Balanced\",\n \"Unbalanced KL\",\n \"Unbalanced L2\",\n \"Unb. TV (Partial)\",\n] # [\"Balanced\", \"Unb. KL\", \"Unb. L2\", \"Unb L1 (partial)\"]\n\nlst_solvers = [ # name, param for ot.solve function\n # balanced OT\n (\"Exact OT\", dict()),\n (\"Entropic Reg. OT\", dict(reg=0.005)),\n (\"L2 Reg OT\", dict(reg=1, reg_type=\"l2\")),\n (\"Group Lasso Reg. OT\", dict(reg=0.1, reg_type=reg_type_gl)),\n # unbalanced OT KL\n (\"Unbalanced KL No Reg.\", dict(unbalanced=0.005)),\n (\n \"Unbalanced KL with KL Reg.\",\n dict(reg=0.0005, unbalanced=0.005, unbalanced_type=\"kl\", reg_type=\"kl\"),\n ),\n (\n \"Unbalanced KL with L2 Reg.\",\n dict(reg=0.5, reg_type=\"l2\", unbalanced=0.005, unbalanced_type=\"kl\"),\n ),\n (\n \"Unbalanced KL with Group Lasso Reg.\",\n dict(reg=0.1, reg_type=reg_type_gl, unbalanced=0.05, unbalanced_type=\"kl\"),\n ),\n # unbalanced OT L2\n (\"Unbalanced L2 No Reg.\", dict(unbalanced=0.5, unbalanced_type=\"l2\")),\n (\n \"Unbalanced L2 with KL Reg.\",\n dict(reg=0.001, unbalanced=0.2, unbalanced_type=\"l2\"),\n ),\n (\n \"Unbalanced L2 with L2 Reg.\",\n dict(reg=0.1, reg_type=\"l2\", unbalanced=0.2, unbalanced_type=\"l2\"),\n ),\n (\n \"Unbalanced L2 with Group Lasso Reg.\",\n dict(reg=0.05, reg_type=reg_type_gl, unbalanced=0.7, unbalanced_type=\"l2\"),\n ),\n # unbalanced OT TV\n (\"Unbalanced TV No Reg.\", dict(unbalanced=0.1, unbalanced_type=\"tv\")),\n (\n \"Unbalanced TV with KL Reg.\",\n dict(reg=0.001, unbalanced=0.01, unbalanced_type=\"tv\"),\n ),\n (\n \"Unbalanced TV with L2 Reg.\",\n dict(reg=0.1, reg_type=\"l2\", unbalanced=0.01, unbalanced_type=\"tv\"),\n ),\n (\n \"Unbalanced TV with Group Lasso Reg.\",\n dict(reg=0.02, reg_type=reg_type_gl, unbalanced=0.01, unbalanced_type=\"tv\"),\n ),\n]\n\nlst_plans = []\nfor name, param in lst_solvers:\n G = ot.solve(M, a, b, **param).plan\n lst_plans.append(G)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot plans\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(3, figsize=(9, 9))\n\nfor i, bname in enumerate(lst_unbalanced):\n for j, rname in enumerate(lst_regs):\n pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1)\n\n plan = lst_plans[i * len(lst_regs) + j]\n m2 = plan.sum(0)\n m1 = plan.sum(1)\n m1, m2 = m1 / a.max(), m2 / b.max()\n pl.imshow(plan, cmap=\"Greys\")\n pl.plot(x, m2 * 10, \"r\")\n pl.plot(m1 * 10, x, \"b\")\n pl.plot(x, b / b.max() * 10, \"r\", alpha=0.3)\n pl.plot(a / a.max() * 10, x, \"b\", alpha=0.3)\n # pl.axis('off')\n pl.tick_params(\n left=False, right=False, labelleft=False, labelbottom=False, bottom=False\n )\n if i == 0:\n pl.title(rname)\n if j == 0:\n pl.ylabel(bname, fontsize=14)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }