{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Geometry of OT distances\n\nShows how to compute multiple Wasserstein and Sinkhorn with two different\nground metrics and plot their values for different distributions.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot\nfrom ot.datasets import make_1D_gauss as gauss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 100 # nb bins\nn_target = 20 # nb target distributions\n\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\nlst_m = np.linspace(20, 90, n_target)\n\n# Gaussian distributions\na = gauss(n, m=20, s=5) # m= mean, s= std\n\nB = np.zeros((n, n_target))\n\nfor i, m in enumerate(lst_m):\n B[:, i] = gauss(n, m=m, s=5)\n\n# loss matrix and normalization\nM = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), \"euclidean\")\nM /= M.max() * 0.1\nM2 = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)), \"sqeuclidean\")\nM2 /= M2.max() * 0.1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(1)\npl.subplot(2, 1, 1)\npl.plot(x, a, \"r\", label=\"Source distribution\")\npl.title(\"Source distribution\")\npl.subplot(2, 1, 2)\nfor i in range(n_target):\n pl.plot(x, B[:, i], \"b\", alpha=i / n_target)\npl.plot(x, B[:, -1], \"b\", label=\"Target distributions\")\npl.title(\"Target distributions\")\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute EMD for the different losses\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "d_emd = ot.emd2(a, B, M) # direct computation of OT loss\nd_emd2 = ot.emd2(a, B, M2) # direct computation of OT loss with metric M2\nd_tv = [np.sum(abs(a - B[:, i])) for i in range(n_target)]\n\npl.figure(2)\npl.subplot(2, 1, 1)\npl.plot(x, a, \"r\", label=\"Source distribution\")\npl.title(\"Distributions\")\nfor i in range(n_target):\n pl.plot(x, B[:, i], \"b\", alpha=i / n_target)\npl.plot(x, B[:, -1], \"b\", label=\"Target distributions\")\npl.ylim((-0.01, 0.13))\npl.xticks(())\npl.legend()\npl.subplot(2, 1, 2)\npl.plot(d_emd, label=\"Euclidean OT\")\npl.plot(d_emd2, label=\"Squared Euclidean OT\")\npl.plot(d_tv, label=\"Total Variation (TV)\")\n# pl.xlim((-7,23))\npl.xlabel(\"Displacement\")\npl.title(\"Divergences\")\npl.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute Sinkhorn for the different losses\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "reg = 1e-1\nd_sinkhorn = ot.sinkhorn2(a, B, M, reg)\nd_sinkhorn2 = ot.sinkhorn2(a, B, M2, reg)\n\npl.figure(3)\npl.clf()\n\npl.subplot(2, 1, 1)\npl.plot(x, a, \"r\", label=\"Source distribution\")\npl.title(\"Distributions\")\nfor i in range(n_target):\n pl.plot(x, B[:, i], \"b\", alpha=i / n_target)\npl.plot(x, B[:, -1], \"b\", label=\"Target distributions\")\npl.ylim((-0.01, 0.13))\npl.xticks(())\npl.legend()\npl.subplot(2, 1, 2)\npl.plot(d_emd, label=\"Euclidean OT\")\npl.plot(d_emd2, label=\"Squared Euclidean OT\")\npl.plot(d_sinkhorn, \"+\", label=\"Euclidean Sinkhorn\")\npl.plot(d_sinkhorn2, \"+\", label=\"Squared Euclidean Sinkhorn\")\npl.plot(d_tv, label=\"Total Variation (TV)\")\n# pl.xlim((-7,23))\npl.xlabel(\"Displacement\")\npl.title(\"Divergences\")\npl.legend()\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }