{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Generalized Wasserstein Barycenter Demo\n\n

Note

Example added in release: 0.9.1.

\n\nThis example illustrates the computation of Generalized Wasserstein Barycenter\nas proposed in [42].\n\n\n[42] Delon, J., Gozlan, N., and Saint-Dizier, A..\nGeneralized Wasserstein barycenters between probability measures living on different subspaces.\narXiv preprint arXiv:2105.09755, 2021.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Eloi Tanguy \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport matplotlib.pylab as pl\nimport ot\nimport matplotlib.animation as animation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate and plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Input measures\nsub_sample_factor = 8\nI1 = pl.imread(\"../../data/redcross.png\").astype(np.float64)[\n ::sub_sample_factor, ::sub_sample_factor, 2\n]\nI2 = pl.imread(\"../../data/tooth.png\").astype(np.float64)[\n ::-sub_sample_factor, ::sub_sample_factor, 2\n]\nI3 = pl.imread(\"../../data/heart.png\").astype(np.float64)[\n ::-sub_sample_factor, ::sub_sample_factor, 2\n]\n\nsz = I1.shape[0]\nUU, VV = np.meshgrid(np.arange(sz), np.arange(sz))\n\n# Input measure locations in their respective 2D spaces\nX_list = [np.stack((UU[im == 0], VV[im == 0]), 1) * 1.0 for im in [I1, I2, I3]]\n\n# Input measure weights\na_list = [ot.unif(x.shape[0]) for x in X_list]\n\n# Projections 3D -> 2D\nP1 = np.array([[1, 0, 0], [0, 1, 0]])\nP2 = np.array([[0, 1, 0], [0, 0, 1]])\nP3 = np.array([[1, 0, 0], [0, 0, 1]])\nP_list = [P1, P2, P3]\n\n# Barycenter weights\nweights = np.array([1 / 3, 1 / 3, 1 / 3])\n\n# Number of barycenter points to compute\nn_samples_bary = 150\n\n# Send the input measures into 3D space for visualization\nX_visu = [Xi @ Pi for (Xi, Pi) in zip(X_list, P_list)]\n\n# Plot the input data\nfig = plt.figure(figsize=(3, 3))\naxis = fig.add_subplot(1, 1, 1, projection=\"3d\")\nfor Xi in X_visu:\n axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker=\"o\", alpha=0.6)\naxis.view_init(azim=45)\naxis.set_xticks([])\naxis.set_yticks([])\naxis.set_zticks([])\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Barycenter computation and plot\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "Y = ot.lp.generalized_free_support_barycenter(X_list, a_list, P_list, n_samples_bary)\nfig = plt.figure(figsize=(3, 3))\n\naxis = fig.add_subplot(1, 1, 1, projection=\"3d\")\nfor Xi in X_visu:\n axis.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker=\"o\", alpha=0.6)\naxis.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker=\"o\", alpha=0.6)\naxis.view_init(azim=45)\naxis.set_xticks([])\naxis.set_yticks([])\naxis.set_zticks([])\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting projection matches\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = plt.figure(figsize=(9, 3))\n\nax = fig.add_subplot(1, 3, 1, projection=\"3d\")\nfor Xi in X_visu:\n ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker=\"o\", alpha=0.6)\nax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker=\"o\", alpha=0.6)\nax.view_init(elev=0, azim=0)\nax.set_xticks([])\nax.set_yticks([])\nax.set_zticks([])\n\nax = fig.add_subplot(1, 3, 2, projection=\"3d\")\nfor Xi in X_visu:\n ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker=\"o\", alpha=0.6)\nax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker=\"o\", alpha=0.6)\nax.view_init(elev=0, azim=90)\nax.set_xticks([])\nax.set_yticks([])\nax.set_zticks([])\n\nax = fig.add_subplot(1, 3, 3, projection=\"3d\")\nfor Xi in X_visu:\n ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker=\"o\", alpha=0.6)\nax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker=\"o\", alpha=0.6)\nax.view_init(elev=90, azim=0)\nax.set_xticks([])\nax.set_yticks([])\nax.set_zticks([])\n\nplt.tight_layout()\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Rotation animation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = plt.figure(figsize=(7, 7))\nax = fig.add_subplot(1, 1, 1, projection=\"3d\")\n\n\ndef _init():\n for Xi in X_visu:\n ax.scatter(Xi[:, 0], Xi[:, 1], Xi[:, 2], marker=\"o\", alpha=0.6)\n ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], marker=\"o\", alpha=0.6)\n ax.view_init(elev=0, azim=0)\n ax.set_xticks([])\n ax.set_yticks([])\n ax.set_zticks([])\n return (fig,)\n\n\ndef _update_plot(i):\n if i < 45:\n ax.view_init(elev=0, azim=4 * i)\n else:\n ax.view_init(elev=i - 45, azim=4 * i)\n return (fig,)\n\n\nani = animation.FuncAnimation(\n fig,\n _update_plot,\n init_func=_init,\n frames=136,\n interval=50,\n blit=True,\n repeat_delay=2000,\n)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }