{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# 2D free support Wasserstein barycenters of distributions\n\nIllustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted\nsum of Diracs.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: Vivien Seguy \n# R\u00e9mi Flamary \n# Eduardo Fernandes Montesuma \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "N = 2\nd = 2\n\nI1 = pl.imread(\"../../data/redcross.png\").astype(np.float64)[::4, ::4, 2]\nI2 = pl.imread(\"../../data/duck.png\").astype(np.float64)[::4, ::4, 2]\n\nsz = I2.shape[0]\nXX, YY = np.meshgrid(np.arange(sz), np.arange(sz))\n\nx1 = np.stack((XX[I1 == 0], YY[I1 == 0]), 1) * 1.0\nx2 = np.stack((XX[I2 == 0] + 80, -YY[I2 == 0] + 32), 1) * 1.0\nx3 = np.stack((XX[I2 == 0], -YY[I2 == 0] + 32), 1) * 1.0\n\nmeasures_locations = [x1, x2]\nmeasures_weights = [ot.unif(x1.shape[0]), ot.unif(x2.shape[0])]\n\npl.figure(1, (12, 4))\npl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)\npl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)\npl.title(\"Distributions\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute free support Wasserstein barycenter\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "k = 200 # number of Diracs of the barycenter\nX_init = np.random.normal(0.0, 1.0, (k, d)) # initial Dirac locations\nb = (\n np.ones((k,)) / k\n) # weights of the barycenter (it will not be optimized, only the locations are optimized)\n\nX = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the Wasserstein barycenter\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(2, (8, 3))\npl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)\npl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)\npl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker=\"s\", label=\"2-Wasserstein barycenter\")\npl.title(\"Data measures and their barycenter\")\npl.legend(loc=\"lower right\")\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compute free support Sinkhorn barycenter\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "k = 200 # number of Diracs of the barycenter\nX_init = np.random.normal(0.0, 1.0, (k, d)) # initial Dirac locations\nb = (\n np.ones((k,)) / k\n) # weights of the barycenter (it will not be optimized, only the locations are optimized)\n\nX = ot.bregman.free_support_sinkhorn_barycenter(\n measures_locations, measures_weights, X_init, 20, b, numItermax=15\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the Wasserstein barycenter\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(2, (8, 3))\npl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)\npl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)\npl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker=\"s\", label=\"2-Wasserstein barycenter\")\npl.title(\"Data measures and their barycenter\")\npl.legend(loc=\"lower right\")\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }