{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# 1D Wasserstein barycenter demo\n\nThis example illustrates the computation of regularized Wasserstein Barycenter\nas proposed in [3].\n\n\n[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyr\u00e9, G. (2015).\nIterative Bregman projections for regularized transportation problems\nSIAM Journal on Scientific Computing, 37(2), A1111-A1138.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Remi Flamary \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport ot\n\n# necessary for 3d plot even if not used\nfrom mpl_toolkits.mplot3d import Axes3D # noqa\nfrom matplotlib.collections import PolyCollection" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std\na2 = ot.datasets.make_1D_gauss(n, m=60, s=8)\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Barycenter computation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "alpha = 0.2 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\n# l2bary\nbary_l2 = A.dot(weights)\n\n# wasserstein\nreg = 1e-3\nbary_wass = ot.bregman.barycenter(A, M, reg, weights)\n\nf, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1)\nax1.plot(x, A, color=\"black\")\nax1.set_title(\"Distributions\")\n\nax2.plot(x, bary_l2, \"r\", label=\"l2\")\nax2.plot(x, bary_wass, \"g\", label=\"Wasserstein\")\nax2.set_title(\"Barycenters\")\n\nplt.legend()\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Barycentric interpolation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_alpha = 11\nalpha_list = np.linspace(0, 1, n_alpha)\n\n\nB_l2 = np.zeros((n, n_alpha))\n\nB_wass = np.copy(B_l2)\n\nfor i in range(n_alpha):\n alpha = alpha_list[i]\n weights = np.array([1 - alpha, alpha])\n B_l2[:, i] = A.dot(weights)\n B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.figure(2)\n\ncmap = plt.get_cmap(\"viridis\")\nverts = []\nzs = alpha_list\nfor i, z in enumerate(zs):\n ys = B_l2[:, i]\n verts.append(list(zip(x, ys)))\n\nax = plt.gcf().add_subplot(projection=\"3d\")\n\npoly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir=\"y\")\nax.set_xlabel(\"x\")\nax.set_xlim3d(0, n)\nax.set_ylabel(\"$\\\\alpha$\")\nax.set_ylim3d(0, 1)\nax.set_zlabel(\"\")\nax.set_zlim3d(0, B_l2.max() * 1.01)\nplt.title(\"Barycenter interpolation with l2\")\nplt.tight_layout()\n\nplt.figure(3)\ncmap = plt.get_cmap(\"viridis\")\nverts = []\nzs = alpha_list\nfor i, z in enumerate(zs):\n ys = B_wass[:, i]\n verts.append(list(zip(x, ys)))\n\nax = plt.gcf().add_subplot(projection=\"3d\")\n\npoly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])\npoly.set_alpha(0.7)\nax.add_collection3d(poly, zs=zs, zdir=\"y\")\nax.set_xlabel(\"x\")\nax.set_xlim3d(0, n)\nax.set_ylabel(\"$\\\\alpha$\")\nax.set_ylim3d(0, 1)\nax.set_zlabel(\"\")\nax.set_zlim3d(0, B_l2.max() * 1.01)\nplt.title(\"Barycenter interpolation with Wasserstein\")\nplt.tight_layout()\n\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }