{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Low rank Gromov-Wasterstein between samples\n\n

Note

Example added in release: 0.9.4.

\n\nComparison between entropic Gromov-Wasserstein and Low Rank Gromov Wasserstein [67]\non two curves in 2D and 3D, both sampled with 200 points.\n\nThe squared Euclidean distance is considered as the ground cost for both samples.\n\n[67] Scetbon, M., Peyr\u00e9, G. & Cuturi, M. (2022).\n\"Linear-Time GromovWasserstein Distances using Low Rank Couplings and Costs\".\nIn International Conference on Machine Learning (ICML), 2022.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Laur\u00e8ne David \n#\n# License: MIT License\n#\n# sphinx_gallery_thumbnail_number = 3" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\nimport matplotlib.pylab as pl\nimport ot.plot\nimport time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_samples = 200\n\n# Generate 2D and 3D curves\ntheta = np.linspace(-4 * np.pi, 4 * np.pi, n_samples)\nz = np.linspace(1, 2, n_samples)\nr = z**2 + 1\nx = r * np.sin(theta)\ny = r * np.cos(theta)\n\n# Source and target distribution\nX = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1)\nY = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot the source and target samples\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig = pl.figure(1, figsize=(10, 4))\n\nax = fig.add_subplot(121)\nax.plot(X[:, 0], X[:, 1], color=\"blue\", linewidth=6)\nax.tick_params(\n left=False, right=False, labelleft=False, labelbottom=False, bottom=False\n)\nax.set_title(\"2D curve (source)\")\n\nax2 = fig.add_subplot(122, projection=\"3d\")\nax2.plot(Y[:, 0], Y[:, 1], Y[:, 2], c=\"red\", linewidth=6)\nax2.tick_params(\n left=False, right=False, labelleft=False, labelbottom=False, bottom=False\n)\nax2.view_init(15, -50)\nax2.set_title(\"3D curve (target)\")\n\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Entropic Gromov-Wasserstein\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Compute cost matrices\nC1 = ot.dist(X, X, metric=\"sqeuclidean\")\nC2 = ot.dist(Y, Y, metric=\"sqeuclidean\")\n\n# Scale cost matrices\nr1 = C1.max()\nr2 = C2.max()\n\nC1 = C1 / r1\nC2 = C2 / r2\n\n\n# Solve entropic gw\nreg = 5 * 1e-3\n\nstart = time.time()\ngw, log = ot.gromov.entropic_gromov_wasserstein(\n C1, C2, tol=1e-3, epsilon=reg, log=True, verbose=False\n)\n\nend = time.time()\ntime_entropic = end - start\n\nentropic_gw_loss = np.round(log[\"gw_dist\"], 3)\n\n# Plot entropic gw\npl.figure(2)\npl.imshow(gw, interpolation=\"nearest\", aspect=\"auto\", cmap=\"gray_r\")\npl.title(\"Entropic Gromov-Wasserstein (loss={})\".format(entropic_gw_loss))\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Low rank squared euclidean cost matrices\n%%\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Compute the low rank sqeuclidean cost decompositions\nA1, A2 = ot.lowrank.compute_lr_sqeuclidean_matrix(X, X, rescale_cost=False)\nB1, B2 = ot.lowrank.compute_lr_sqeuclidean_matrix(Y, Y, rescale_cost=False)\n\n# Scale the low rank cost matrices\nA1, A2 = A1 / np.sqrt(r1), A2 / np.sqrt(r1)\nB1, B2 = B1 / np.sqrt(r2), B2 / np.sqrt(r2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Low rank Gromov-Wasserstein\n%%\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Solve low rank gromov-wasserstein with different ranks\nlist_rank = [10, 50]\nlist_P_GW = []\nlist_loss_GW = []\nlist_time_GW = []\n\nfor rank in list_rank:\n start = time.time()\n\n Q, R, g, log = ot.lowrank_gromov_wasserstein_samples(\n X,\n Y,\n reg=0,\n rank=rank,\n rescale_cost=False,\n cost_factorized_Xs=(A1, A2),\n cost_factorized_Xt=(B1, B2),\n seed_init=49,\n numItermax=1000,\n log=True,\n stopThr=1e-6,\n )\n end = time.time()\n\n P = log[\"lazy_plan\"][:]\n loss = log[\"value\"]\n\n list_P_GW.append(P)\n list_loss_GW.append(np.round(loss, 3))\n list_time_GW.append(end - start)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot low rank GW with different ranks\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(3, figsize=(10, 4))\n\npl.subplot(1, 2, 1)\npl.imshow(list_P_GW[0], interpolation=\"nearest\", aspect=\"auto\", cmap=\"gray_r\")\npl.title(\"Low rank GW (rank=10, loss={})\".format(list_loss_GW[0]))\n\npl.subplot(1, 2, 2)\npl.imshow(list_P_GW[1], interpolation=\"nearest\", aspect=\"auto\", cmap=\"gray_r\")\npl.title(\"Low rank GW (rank=50, loss={})\".format(list_loss_GW[1]))\n\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compare computation time between entropic GW and low rank GW\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(\"Entropic GW: {:.2f}s\".format(time_entropic))\nprint(\"Low rank GW (rank=10): {:.2f}s\".format(list_time_GW[0]))\nprint(\"Low rank GW (rank=50): {:.2f}s\".format(list_time_GW[1]))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }