{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# OT between GMM : plan and maps in 1D\n\nIllustration of the GMM plan for\nthe Mixture Wasserstein between two GMM in 1D,\nas well as the two maps T_mean and T_rand.\nT_mean is the barycentric projection of the GMM coupling,\nand T_rand takes a random gaussian image between two components,\naccording to the coupling and the GMMs.\nSee [69] for details.\n.. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Eloi Tanguy \n# Remi Flamary \n# Julie Delon \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport numpy as np\nfrom ot.plot import plot1D_mat, rescale_for_imshow_plot\nfrom ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map\nimport matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate GMMOT plan plot it\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ks = 2\nkt = 3\nd = 1\neps = 0.1\nm_s = np.array([[1], [2]])\nm_t = np.array([[3], [4.2], [5]])\nC_s = np.array([[[0.05]], [[0.06]]])\nC_t = np.array([[[0.03]], [[0.07]], [[0.04]]])\nw_s = np.array([0.4, 0.6])\nw_t = np.array([0.4, 0.2, 0.4])\n\nn = 500\na_x, b_x = 0, 3\nx = np.linspace(a_x, b_x, n)\na_y, b_y = 2, 6\ny = np.linspace(a_y, b_y, n)\nplan_density = gmm_ot_plan_density(\n x[:, None], y[:, None], m_s, m_t, C_s, C_t, w_s, w_t, plan=None, atol=2e-2\n)\n\na = gmm_pdf(x[:, None], m_s, C_s, w_s)\nb = gmm_pdf(y[:, None], m_t, C_t, w_t)\nplt.figure(figsize=(8, 8))\nplot1D_mat(\n a,\n b,\n plan_density,\n title=\"GMM OT plan\",\n plot_style=\"xy\",\n a_label=\"Source distribution\",\n b_label=\"Target distribution\",\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate GMMOT maps and plot them over plan\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "plt.figure(figsize=(8, 8))\nax_s, ax_t, ax_M = plot1D_mat(\n a,\n b,\n plan_density,\n plot_style=\"xy\",\n title=\"GMM OT plan with T_mean and T_rand maps\",\n a_label=\"Source distribution\",\n b_label=\"Target distribution\",\n)\nT_mean = gmm_ot_apply_map(x[:, None], m_s, m_t, C_s, C_t, w_s, w_t, method=\"bary\")[:, 0]\nx_rescaled, T_mean_rescaled = rescale_for_imshow_plot(x, T_mean, n, a_y=a_y, b_y=b_y)\n\nax_M.plot(\n x_rescaled, T_mean_rescaled, label=\"T_mean\", alpha=0.5, linewidth=5, color=\"aqua\"\n)\n\nT_rand = gmm_ot_apply_map(\n x[:, None], m_s, m_t, C_s, C_t, w_s, w_t, method=\"rand\", seed=0\n)[:, 0]\nx_rescaled, T_rand_rescaled = rescale_for_imshow_plot(x, T_rand, n, a_y=a_y, b_y=b_y)\n\nax_M.scatter(\n x_rescaled, T_rand_rescaled, label=\"T_rand\", alpha=0.5, s=20, color=\"orange\"\n)\n\nax_M.legend(loc=\"upper left\", fontsize=13)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }