{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Gaussian Bures-Wasserstein barycenters\n\n

Note

Example added in release: 0.9.2.

\n\nIllustration of Gaussian Bures-Wasserstein barycenters.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Authors: R\u00e9mi Flamary \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from matplotlib import colors\nfrom matplotlib.patches import Ellipse\nimport numpy as np\nimport matplotlib.pylab as pl\nimport ot" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define Gaussian Covariances and distributions\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "C1 = np.array([[0.5, -0.4], [-0.4, 0.5]])\nC2 = np.array([[1, 0.3], [0.3, 1]])\nC3 = np.array([[1.5, 0], [0, 0.5]])\nC4 = np.array([[0.5, 0], [0, 1.5]])\n\nC = np.stack((C1, C2, C3, C4))\n\nm1 = np.array([0, 0])\nm2 = np.array([0, 4])\nm3 = np.array([4, 0])\nm4 = np.array([4, 4])\n\nm = np.stack((m1, m2, m3, m4))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the distributions\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def draw_cov(mu, C, color=None, label=None, nstd=1):\n def eigsorted(cov):\n vals, vecs = np.linalg.eigh(cov)\n order = vals.argsort()[::-1]\n return vals[order], vecs[:, order]\n\n vals, vecs = eigsorted(C)\n theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))\n w, h = 2 * nstd * np.sqrt(vals)\n ell = Ellipse(\n xy=(mu[0], mu[1]),\n width=w,\n height=h,\n alpha=0.5,\n angle=theta,\n facecolor=color,\n edgecolor=color,\n label=label,\n fill=True,\n )\n pl.gca().add_artist(ell)\n # pl.scatter(mu[0],mu[1],color=color, marker='x')\n\n\naxis = [-1.5, 5.5, -1.5, 5.5]\n\npl.figure(1, (8, 2))\npl.clf()\n\npl.subplot(1, 4, 1)\ndraw_cov(m1, C1, color=\"C0\")\npl.axis(axis)\npl.title(\"$\\mathcal{N}(m_1,\\Sigma_1)$\")\n\npl.subplot(1, 4, 2)\ndraw_cov(m2, C2, color=\"C1\")\npl.axis(axis)\npl.title(\"$\\mathcal{N}(m_2,\\Sigma_2)$\")\n\npl.subplot(1, 4, 3)\ndraw_cov(m3, C3, color=\"C2\")\npl.axis(axis)\npl.title(\"$\\mathcal{N}(m_3,\\Sigma_3)$\")\n\npl.subplot(1, 4, 4)\ndraw_cov(m4, C4, color=\"C3\")\npl.axis(axis)\npl.title(\"$\\mathcal{N}(m_4,\\Sigma_4)$\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute Bures-Wasserstein barycenters and plot them\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# basis for bilinear interpolation\nv1 = np.array((1, 0, 0, 0))\nv2 = np.array((0, 1, 0, 0))\nv3 = np.array((0, 0, 1, 0))\nv4 = np.array((0, 0, 0, 1))\n\n\ncolors = np.stack(\n (colors.to_rgb(\"C0\"), colors.to_rgb(\"C1\"), colors.to_rgb(\"C2\"), colors.to_rgb(\"C3\"))\n)\n\npl.figure(2, (8, 8))\n\nnb_interp = 6\n\nfor i in range(nb_interp):\n for j in range(nb_interp):\n tx = float(i) / (nb_interp - 1)\n ty = float(j) / (nb_interp - 1)\n\n # weights are constructed by bilinear interpolation\n tmp1 = (1 - tx) * v1 + tx * v2\n tmp2 = (1 - tx) * v3 + tx * v4\n weights = (1 - ty) * tmp1 + ty * tmp2\n\n color = np.dot(colors.T, weights)\n\n mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, weights)\n\n draw_cov(mb, Cb, color=color, label=None, nstd=0.3)\n\npl.axis(axis)\npl.axis(\"off\")\npl.tight_layout()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }