{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Gaussian Mixture Model OT Barycenters\n\nThis example illustrates the computation of a barycenter between Gaussian\nMixtures in the sense of GMM-OT [69]. This computation is done using the\nfixed-point method for OT barycenters with generic costs [77], for which POT\nprovides a general solver, and a specific GMM solver. Note that this is a\n'free-support' method, implying that the number of components of the barycenter\nGMM and their weights are fixed.\n\nThe idea behind GMM-OT barycenters is to see the GMMs as discrete measures over\nthe space of Gaussian distributions $\\mathcal{N}$ (or equivalently the\nBures-Wasserstein manifold), and to compute barycenters with respect to the\n2-Wasserstein distance between measures in $\\mathcal{P}(\\mathcal{N})$: a\ngaussian mixture is a finite combination of Diracs on specific gaussians, and\ntwo mixtures are compared with the 2-Wasserstein distance on this space, where\nground cost the squared Bures distance between gaussians.\n\n[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space\nof Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.\n\n[77] Tanguy, Eloi and Delon, Julie and Gozlan, Natha\u00ebl (2024). Computing\nBarycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016\n(2024)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Eloi Tanguy \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\nimport matplotlib.pyplot as plt\nfrom matplotlib.patches import Ellipse\nimport ot\nfrom ot.gmm import gmm_barycenter_fixed_point\n\n\nK = 3 # number of GMMs\nd = 2 # dimension\nn = 6 # number of components of the desired barycenter\n\n\ndef get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2):\n rng = np.random.RandomState(seed=seed)\n means = rng.randn(K, d)\n P = rng.randn(K, d, d) * cov_scale\n # C[k] = P[k] @ P[k]^T + min_cov_eig * I\n covariances = np.einsum(\"kab,kcb->kac\", P, P)\n covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)])\n weights = rng.random(K)\n weights /= np.sum(weights)\n return means, covariances, weights\n\n\nm_list = [5, 6, 7] # number of components in each GMM\noffsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])]\nmeans_list = [] # list of means for each GMM\ncovs_list = [] # list of covariances for each GMM\nw_list = [] # list of weights for each GMM\n\n# generate GMMs\nfor k in range(K):\n means, covs, b = get_random_gmm(\n m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5\n )\n means = means / 2 + offsets[k][None, :]\n means_list.append(means)\n covs_list.append(covs)\n w_list.append(b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compute the barycenter using the fixed-point method\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "init_means, init_covs, _ = get_random_gmm(n, d, seed=0)\nweights = ot.unif(K) # barycenter coefficients\nmeans_bar, covs_bar, log = gmm_barycenter_fixed_point(\n means_list,\n covs_list,\n w_list,\n init_means,\n init_covs,\n weights,\n iterations=3,\n log=True,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define plotting functions\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# draw a covariance ellipse\ndef draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None):\n def eigsorted(cov):\n vals, vecs = np.linalg.eigh(cov)\n order = vals.argsort()[::-1].copy()\n return vals[order], vecs[:, order]\n\n vals, vecs = eigsorted(C)\n theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))\n w, h = 2 * nstd * np.sqrt(vals)\n ell = Ellipse(\n xy=(mu[0], mu[1]),\n width=w,\n height=h,\n alpha=alpha,\n angle=theta,\n facecolor=color,\n edgecolor=color,\n label=label,\n fill=True,\n )\n if ax is None:\n ax = plt.gca()\n ax.add_artist(ell)\n\n\n# draw a gmm as a set of ellipses with weights shown in alpha value\ndef draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None):\n for k in range(ms.shape[0]):\n draw_cov(\n ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot the results\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "c_list = [\"#7ED321\", \"#4A90E2\", \"#9013FE\", \"#F5A623\"]\nc_bar = \"#D0021B\"\nfig, ax = plt.subplots(figsize=(6, 6))\naxis = [-4, 4, -2, 6]\nax.set_title(\"Fixed Point Barycenter (3 Iterations)\", fontsize=16)\nfor k in range(K):\n draw_gmm(means_list[k], covs_list[k], w_list[k], color=c_list[k], ax=ax)\ndraw_gmm(means_bar, covs_bar, ot.unif(n), color=c_bar, ax=ax)\nax.axis(axis)\nax.axis(\"off\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }