{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Debiased Sinkhorn barycenter demo\n\n

Note

Example added in release: 0.8.0.

\n\nThis example illustrates the computation of the debiased Sinkhorn barycenter\nas proposed in [37]_.\n\n\n.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th\n International Conference on Machine Learning, PMLR 119:4692-4701, 2020\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Hicham Janati \n#\n# License: MIT License\n# sphinx_gallery_thumbnail_number = 3\n\nimport os\nfrom pathlib import Path\n\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nimport ot\nfrom ot.bregman import (\n barycenter,\n barycenter_debiased,\n convolutional_barycenter2d,\n convolutional_barycenter2d_debiased,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Debiased barycenter of 1D Gaussians\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n = 100 # nb bins\n\n# bin positions\nx = np.arange(n, dtype=np.float64)\n\n# Gaussian distributions\na1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m= mean, s= std\na2 = ot.datasets.make_1D_gauss(n, m=60, s=8)\n\n# creating matrix A containing all distributions\nA = np.vstack((a1, a2)).T\nn_distributions = A.shape[1]\n\n# loss matrix + normalization\nM = ot.utils.dist0(n)\nM /= M.max()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "alpha = 0.2 # 0<=alpha<=1\nweights = np.array([1 - alpha, alpha])\n\nepsilons = [5e-3, 1e-2, 5e-2]\n\n\nbars = [barycenter(A, M, reg, weights) for reg in epsilons]\nbars_debiased = [barycenter_debiased(A, M, reg, weights) for reg in epsilons]\nlabels = [\"Sinkhorn barycenter\", \"Debiased barycenter\"]\ncolors = [\"indianred\", \"gold\"]\n\nf, axes = plt.subplots(\n 1, len(epsilons), tight_layout=True, sharey=True, figsize=(12, 4), num=1\n)\nfor ax, eps, bar, bar_debiased in zip(axes, epsilons, bars, bars_debiased):\n ax.plot(A[:, 0], color=\"k\", ls=\"--\", label=\"Input data\", alpha=0.3)\n ax.plot(A[:, 1], color=\"k\", ls=\"--\", alpha=0.3)\n for data, label, color in zip([bar, bar_debiased], labels, colors):\n ax.plot(data, color=color, label=label, lw=2)\n ax.set_title(r\"$\\varepsilon = %.3f$\" % eps)\nplt.legend()\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Debiased barycenter of 2D images\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "this_file = os.path.realpath(\"__file__\")\ndata_path = os.path.join(Path(this_file).parent.parent.parent, \"data\")\nf1 = 1 - plt.imread(os.path.join(data_path, \"heart.png\"))[:, :, 2]\nf2 = 1 - plt.imread(os.path.join(data_path, \"duck.png\"))[:, :, 2]\n\nA = np.asarray([f1, f2]) + 1e-2\nA /= A.sum(axis=(1, 2))[:, None, None]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Display the input images\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 2, figsize=(7, 4), num=2)\nfor ax, img in zip(axes, A):\n ax.imshow(img, cmap=\"Greys\")\n ax.axis(\"off\")\nfig.tight_layout()\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Barycenter computation and visualization\n\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "bars_sinkhorn, bars_debiased = [], []\nepsilons = [5e-3, 7e-3, 1e-2]\nfor eps in epsilons:\n bar = convolutional_barycenter2d(A, eps)\n bar_debiased, log = convolutional_barycenter2d_debiased(A, eps, log=True)\n bars_sinkhorn.append(bar)\n bars_debiased.append(bar_debiased)\n\ntitles = [\"Sinkhorn\", \"Debiased\"]\nall_bars = [bars_sinkhorn, bars_debiased]\nfig, axes = plt.subplots(2, 3, figsize=(8, 6), num=3)\nfor jj, (method, ax_row, bars) in enumerate(zip(titles, axes, all_bars)):\n for ii, (ax, img, eps) in enumerate(zip(ax_row, bars, epsilons)):\n ax.imshow(img, cmap=\"Greys\")\n if jj == 0:\n ax.set_title(r\"$\\varepsilon = %.3f$\" % eps, fontsize=13)\n ax.set_xticks([])\n ax.set_yticks([])\n ax.spines[\"top\"].set_visible(False)\n ax.spines[\"right\"].set_visible(False)\n ax.spines[\"bottom\"].set_visible(False)\n ax.spines[\"left\"].set_visible(False)\n if ii == 0:\n ax.set_ylabel(method, fontsize=15)\nfig.tight_layout()\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }