{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Gradient Flow for GMM-OT distance\n\nIllustration of the flow of a Gaussian Mixture with\nrespect to its GMM-OT distance with respect to a\nfixed GMM.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Eloi Tanguy \n# Remi Flamary \n# Julie Delon \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 4\n\nimport numpy as np\nimport matplotlib.pylab as pl\nfrom matplotlib import colormaps as cm\nimport ot\nimport ot.plot\nfrom ot.utils import proj_SDP, proj_simplex\nfrom ot.gmm import gmm_ot_loss\nimport torch\nfrom torch.optim import Adam\nfrom matplotlib.patches import Ellipse" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data and plot it\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.manual_seed(3)\nks = 3\nkt = 2\nd = 2\neps = 0.1\nm_s = torch.randn(ks, d)\nm_s.requires_grad_()\nm_t = torch.randn(kt, d)\nC_s = torch.randn(ks, d, d)\nC_s = torch.matmul(C_s, torch.transpose(C_s, 2, 1))\nC_s += eps * torch.eye(d)[None, :, :] * torch.ones(ks, 1, 1)\nC_s.requires_grad_()\nC_t = torch.randn(kt, d, d)\nC_t = torch.matmul(C_t, torch.transpose(C_t, 2, 1))\nC_t += eps * torch.eye(d)[None, :, :] * torch.ones(kt, 1, 1)\nw_s = torch.randn(ks)\nw_s = proj_simplex(w_s)\nw_s.requires_grad_()\nw_t = torch.tensor(ot.unif(kt))\n\n\ndef draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5):\n def eigsorted(cov):\n if torch.is_tensor(cov):\n cov = cov.detach().numpy()\n vals, vecs = np.linalg.eigh(cov)\n order = vals.argsort()[::-1].copy()\n return vals[order], vecs[:, order]\n\n vals, vecs = eigsorted(C)\n theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))\n w, h = 2 * nstd * np.sqrt(vals)\n ell = Ellipse(\n xy=(mu[0], mu[1]),\n width=w,\n height=h,\n alpha=alpha,\n angle=theta,\n facecolor=color,\n edgecolor=color,\n label=label,\n fill=True,\n )\n pl.gca().add_artist(ell)\n\n\ndef draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1):\n for k in range(ms.shape[0]):\n draw_cov(ms[k], Cs[k], color, None, nstd, alpha * ws[k])\n\n\naxis = [-3, 3, -3, 3]\npl.figure(1, (20, 10))\npl.clf()\n\npl.subplot(1, 2, 1)\npl.scatter(m_s[:, 0].detach(), m_s[:, 1].detach(), color=\"C0\")\ndraw_gmm(m_s.detach(), C_s.detach(), torch.softmax(w_s, 0).detach().numpy(), color=\"C0\")\npl.axis(axis)\npl.title(\"Source GMM\")\n\npl.subplot(1, 2, 2)\npl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color=\"C1\")\ndraw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color=\"C1\")\npl.axis(axis)\npl.title(\"Target GMM\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Gradient descent loop\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_gd_its = 100\nlr = 3e-2\nopt = Adam(\n [\n {\"params\": m_s, \"lr\": 2 * lr},\n {\"params\": C_s, \"lr\": lr},\n {\"params\": w_s, \"lr\": lr},\n ]\n)\nm_list = [m_s.data.numpy().copy()]\nC_list = [C_s.data.numpy().copy()]\nw_list = [torch.softmax(w_s, 0).data.numpy().copy()]\nloss_list = []\n\nfor _ in range(n_gd_its):\n opt.zero_grad()\n loss = gmm_ot_loss(m_s, m_t, C_s, C_t, torch.softmax(w_s, 0), w_t)\n loss.backward()\n opt.step()\n with torch.no_grad():\n C_s.data = proj_SDP(C_s.data, vmin=1e-6)\n m_list.append(m_s.data.numpy().copy())\n C_list.append(C_s.data.numpy().copy())\n w_list.append(torch.softmax(w_s, 0).data.numpy().copy())\n loss_list.append(loss.item())\n\npl.figure(2)\npl.clf()\npl.plot(loss_list)\npl.title(\"Loss\")\npl.xlabel(\"its\")\npl.ylabel(\"loss\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Last step visualisation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "axis = [-3, 3, -3, 3]\npl.figure(3, (10, 10))\npl.clf()\npl.title(\"GMM flow, last step\")\npl.scatter(m_list[0][:, 0], m_list[0][:, 1], color=\"C0\", label=\"Source\")\ndraw_gmm(m_list[0], C_list[0], w_list[0], color=\"C0\")\npl.axis(axis)\n\npl.scatter(m_t[:, 0].detach(), m_t[:, 1].detach(), color=\"C1\", label=\"Target\")\ndraw_gmm(m_t.detach(), C_t.detach(), w_t.numpy(), color=\"C1\")\npl.axis(axis)\n\nk = -1\npl.scatter(m_list[k][:, 0], m_list[k][:, 1], color=\"C2\", alpha=1, label=\"Last step\")\ndraw_gmm(m_list[k], C_list[k], w_list[0], color=\"C2\", alpha=1)\n\npl.axis(axis)\npl.legend(fontsize=15)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Steps visualisation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def index_to_color(i):\n return int(i**0.5)\n\n\nn_steps_visu = 100\npl.figure(3, (10, 10))\npl.clf()\npl.title(\"GMM flow, all steps\")\n\nits_to_show = [int(x) for x in np.linspace(1, n_gd_its - 1, n_steps_visu)]\ncmp = cm[\"plasma\"].resampled(index_to_color(n_steps_visu))\n\npl.scatter(\n m_list[0][:, 0], m_list[0][:, 1], color=cmp(index_to_color(0)), label=\"Source\"\n)\ndraw_gmm(m_list[0], C_list[0], w_list[0], color=cmp(index_to_color(0)))\n\npl.scatter(\n m_t[:, 0].detach(),\n m_t[:, 1].detach(),\n color=cmp(index_to_color(n_steps_visu - 1)),\n label=\"Target\",\n)\ndraw_gmm(\n m_t.detach(), C_t.detach(), w_t.numpy(), color=cmp(index_to_color(n_steps_visu - 1))\n)\n\n\nfor k in its_to_show:\n pl.scatter(\n m_list[k][:, 0], m_list[k][:, 1], color=cmp(index_to_color(k)), alpha=0.8\n )\n draw_gmm(m_list[k], C_list[k], w_list[0], color=cmp(index_to_color(k)), alpha=0.04)\n\npl.axis(axis)\npl.legend(fontsize=15)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }