{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Comparison of Fused Gromov-Wasserstein solvers\n\n

Note

Example added in release: 0.9.1.

\n\nThis example illustrates the computation of FGW for attributed graphs\nusing 4 different solvers to estimate the distance based on Conditional\nGradient [24], Sinkhorn projections [12, 51] and alternated Bregman\nprojections [63, 64].\n\nWe generate two graphs following Stochastic Block Models further endowed with\nnode features and compute their FGW matchings.\n\n[12] Gabriel Peyr\u00e9, Marco Cuturi, and Justin Solomon (2016),\n\"Gromov-Wasserstein averaging of kernel and distance matrices\".\nInternational Conference on Machine Learning (ICML).\n\n[24] Vayer Titouan, Chapel Laetitia, Flamary R\u00e9mi, Tavenard Romain\nand Courty Nicolas\n\"Optimal Transport for structured data with application on graphs\"\nInternational Conference on Machine Learning (ICML). 2019.\n\n[51] Xu, H., Luo, D., Zha, H., & Carin, L. (2019).\n\"Gromov-wasserstein learning for graph matching and node embedding\".\nIn International Conference on Machine Learning (ICML), 2019.\n\n[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J.\n\"A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in\nGraph Data\". International Conference on Learning Representations (ICLR), 2023.\n\n[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W.\n\"Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications\".\nIn Thirty-seventh Conference on Neural Information Processing Systems\n(NeurIPS), 2023.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: C\u00e9dric Vincent-Cuaz \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport numpy as np\nimport matplotlib.pylab as pl\nfrom ot.gromov import (\n fused_gromov_wasserstein,\n entropic_fused_gromov_wasserstein,\n BAPG_fused_gromov_wasserstein,\n)\nimport networkx\nfrom networkx.generators.community import stochastic_block_model as sbm\nfrom time import time" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate two graphs following Stochastic Block models of 2 and 3 clusters.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "np.random.seed(0)\n\nN2 = 20 # 2 communities\nN3 = 30 # 3 communities\np2 = [[1.0, 0.1], [0.1, 0.9]]\np3 = [[1.0, 0.1, 0.0], [0.1, 0.95, 0.1], [0.0, 0.1, 0.9]]\nG2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)\nG3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)\npart_G2 = [G2.nodes[i][\"block\"] for i in range(N2)]\npart_G3 = [G3.nodes[i][\"block\"] for i in range(N3)]\n\nC2 = networkx.to_numpy_array(G2)\nC3 = networkx.to_numpy_array(G3)\n\n\n# We add node features with given mean - by clusters\n# and inversely proportional to clusters' intra-connectivity\n\nF2 = np.zeros((N2, 1))\nfor i, c in enumerate(part_G2):\n F2[i, 0] = np.random.normal(loc=c, scale=0.01)\n\nF3 = np.zeros((N3, 1))\nfor i, c in enumerate(part_G3):\n F3[i, 0] = np.random.normal(loc=2.0 - c, scale=0.01)\n\n# Compute pairwise euclidean distance between node features\nM = (F2**2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3**2).T) - 2 * F2.dot(F3.T)\n\nh2 = np.ones(C2.shape[0]) / C2.shape[0]\nh3 = np.ones(C3.shape[0]) / C3.shape[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute their Fused Gromov-Wasserstein distances\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "alpha = 0.5\n\n\n# Conditional Gradient algorithm\nprint(\"Conditional Gradient \\n\")\nstart_cg = time()\nT_cg, log_cg = fused_gromov_wasserstein(\n M, C2, C3, h2, h3, \"square_loss\", alpha=alpha, tol_rel=1e-9, verbose=True, log=True\n)\nend_cg = time()\ntime_cg = 1000 * (end_cg - start_cg)\n\n# Proximal Point algorithm with Kullback-Leibler as proximal operator\nprint(\"Proximal Point Algorithm \\n\")\nstart_ppa = time()\nT_ppa, log_ppa = entropic_fused_gromov_wasserstein(\n M,\n C2,\n C3,\n h2,\n h3,\n \"square_loss\",\n alpha=alpha,\n epsilon=1.0,\n solver=\"PPA\",\n tol=1e-9,\n log=True,\n verbose=True,\n warmstart=False,\n numItermax=10,\n)\nend_ppa = time()\ntime_ppa = 1000 * (end_ppa - start_ppa)\n\n# Projected Gradient algorithm with entropic regularization\nprint(\"Projected Gradient Descent \\n\")\nstart_pgd = time()\nT_pgd, log_pgd = entropic_fused_gromov_wasserstein(\n M,\n C2,\n C3,\n h2,\n h3,\n \"square_loss\",\n alpha=alpha,\n epsilon=0.01,\n solver=\"PGD\",\n tol=1e-9,\n log=True,\n verbose=True,\n warmstart=False,\n numItermax=10,\n)\nend_pgd = time()\ntime_pgd = 1000 * (end_pgd - start_pgd)\n\n# Alternated Bregman Projected Gradient algorithm with Kullback-Leibler as proximal operator\nprint(\"Bregman Alternated Projected Gradient \\n\")\nstart_bapg = time()\nT_bapg, log_bapg = BAPG_fused_gromov_wasserstein(\n M,\n C2,\n C3,\n h2,\n h3,\n \"square_loss\",\n alpha=alpha,\n epsilon=1.0,\n tol=1e-9,\n marginal_loss=True,\n verbose=True,\n log=True,\n)\nend_bapg = time()\ntime_bapg = 1000 * (end_bapg - start_bapg)\n\nprint(\n \"Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: \"\n + str(log_cg[\"fgw_dist\"])\n)\nprint(\n \"Fused Gromov-Wasserstein distance estimated with Proximal Point solver: \"\n + str(log_ppa[\"fgw_dist\"])\n)\nprint(\n \"Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: \"\n + str(log_pgd[\"fgw_dist\"])\n)\nprint(\n \"Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: \"\n + str(log_bapg[\"fgw_dist\"])\n)\n\n# compute OT sparsity level\nT_cg_sparsity = 100 * (T_cg == 0.0).astype(np.float64).sum() / (N2 * N3)\nT_ppa_sparsity = 100 * (T_ppa == 0.0).astype(np.float64).sum() / (N2 * N3)\nT_pgd_sparsity = 100 * (T_pgd == 0.0).astype(np.float64).sum() / (N2 * N3)\nT_bapg_sparsity = 100 * (T_bapg == 0.0).astype(np.float64).sum() / (N2 * N3)\n\n# Methods using Sinkhorn/Bregman projections tend to produce feasibility errors on the\n# marginal constraints\n\nerr_cg = np.linalg.norm(T_cg.sum(1) - h2) + np.linalg.norm(T_cg.sum(0) - h3)\nerr_ppa = np.linalg.norm(T_ppa.sum(1) - h2) + np.linalg.norm(T_ppa.sum(0) - h3)\nerr_pgd = np.linalg.norm(T_pgd.sum(1) - h2) + np.linalg.norm(T_pgd.sum(0) - h3)\nerr_bapg = np.linalg.norm(T_bapg.sum(1) - h2) + np.linalg.norm(T_bapg.sum(0) - h3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization of the Fused Gromov-Wasserstein matchings\n\nWe color nodes of the graph on the right - then project its node colors\nbased on the optimal transport plan from the FGW matchings\nWe adjust the intensity of links across domains proportionaly to the mass\nsent, adding a minimal intensity of 0.1 if mass sent is not zero.\nFor each matching, all node sizes are proportionnal to their mass computed\nfrom marginals of the OT plan to illustrate potential feasibility errors.\nNB: colors refer to clusters - not to node features\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Add weights on the edges for visualization later on\nweight_intra_G2 = 5\nweight_inter_G2 = 0.5\nweight_intra_G3 = 1.0\nweight_inter_G3 = 1.5\n\nweightedG2 = networkx.Graph()\npart_G2 = [G2.nodes[i][\"block\"] for i in range(N2)]\n\nfor node in G2.nodes():\n weightedG2.add_node(node)\nfor i, j in G2.edges():\n if part_G2[i] == part_G2[j]:\n weightedG2.add_edge(i, j, weight=weight_intra_G2)\n else:\n weightedG2.add_edge(i, j, weight=weight_inter_G2)\n\nweightedG3 = networkx.Graph()\npart_G3 = [G3.nodes[i][\"block\"] for i in range(N3)]\n\nfor node in G3.nodes():\n weightedG3.add_node(node)\nfor i, j in G3.edges():\n if part_G3[i] == part_G3[j]:\n weightedG3.add_edge(i, j, weight=weight_intra_G3)\n else:\n weightedG3.add_edge(i, j, weight=weight_inter_G3)\n\n\ndef draw_graph(\n G,\n C,\n nodes_color_part,\n Gweights=None,\n pos=None,\n edge_color=\"black\",\n node_size=None,\n shiftx=0,\n seed=0,\n):\n if pos is None:\n pos = networkx.spring_layout(G, scale=1.0, seed=seed)\n\n if shiftx != 0:\n for k, v in pos.items():\n v[0] = v[0] + shiftx\n\n alpha_edge = 0.7\n width_edge = 1.8\n if Gweights is None:\n networkx.draw_networkx_edges(\n G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color\n )\n else:\n # We make more visible connections between activated nodes\n n = len(Gweights)\n edgelist_activated = []\n edgelist_deactivated = []\n for i in range(n):\n for j in range(n):\n if Gweights[i] * Gweights[j] * C[i, j] > 0:\n edgelist_activated.append((i, j))\n elif C[i, j] > 0:\n edgelist_deactivated.append((i, j))\n\n networkx.draw_networkx_edges(\n G,\n pos,\n edgelist=edgelist_activated,\n width=width_edge,\n alpha=alpha_edge,\n edge_color=edge_color,\n )\n networkx.draw_networkx_edges(\n G,\n pos,\n edgelist=edgelist_deactivated,\n width=width_edge,\n alpha=0.1,\n edge_color=edge_color,\n )\n\n if Gweights is None:\n for node, node_color in enumerate(nodes_color_part):\n networkx.draw_networkx_nodes(\n G,\n pos,\n nodelist=[node],\n node_size=node_size,\n alpha=1,\n node_color=node_color,\n )\n else:\n scaled_Gweights = Gweights / (0.5 * Gweights.max())\n nodes_size = node_size * scaled_Gweights\n for node, node_color in enumerate(nodes_color_part):\n networkx.draw_networkx_nodes(\n G,\n pos,\n nodelist=[node],\n node_size=nodes_size[node],\n alpha=1,\n node_color=node_color,\n )\n return pos\n\n\ndef draw_transp_colored_GW(\n G1,\n C1,\n G2,\n C2,\n part_G1,\n p1,\n p2,\n T,\n pos1=None,\n pos2=None,\n shiftx=4,\n switchx=False,\n node_size=70,\n seed_G1=0,\n seed_G2=0,\n):\n starting_color = 0\n # get graphs partition and their coloring\n part1 = part_G1.copy()\n unique_colors = [\"C%s\" % (starting_color + i) for i in np.unique(part1)]\n nodes_color_part1 = []\n for cluster in part1:\n nodes_color_part1.append(unique_colors[cluster])\n\n nodes_color_part2 = []\n # T: getting colors assignment from argmin of columns\n for i in range(len(G2.nodes())):\n j = np.argmax(T[:, i])\n nodes_color_part2.append(nodes_color_part1[j])\n pos1 = draw_graph(\n G1,\n C1,\n nodes_color_part1,\n Gweights=p1,\n pos=pos1,\n node_size=node_size,\n shiftx=0,\n seed=seed_G1,\n )\n pos2 = draw_graph(\n G2,\n C2,\n nodes_color_part2,\n Gweights=p2,\n pos=pos2,\n node_size=node_size,\n shiftx=shiftx,\n seed=seed_G2,\n )\n\n for k1, v1 in pos1.items():\n max_Tk1 = np.max(T[k1, :])\n for k2, v2 in pos2.items():\n if T[k1, k2] > 0:\n pl.plot(\n [pos1[k1][0], pos2[k2][0]],\n [pos1[k1][1], pos2[k2][1]],\n \"-\",\n lw=0.7,\n alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.0),\n color=nodes_color_part1[k1],\n )\n return pos1, pos2\n\n\nnode_size = 40\nfontsize = 13\nseed_G2 = 0\nseed_G3 = 4\n\npl.figure(2, figsize=(15, 3.5))\npl.clf()\npl.subplot(141)\npl.axis(\"off\")\n\npl.title(\n \"(CG) FGW=%s\\n \\n OT sparsity = %s \\n marg. error = %s \\n runtime = %s\"\n % (\n np.round(log_cg[\"fgw_dist\"], 3),\n str(np.round(T_cg_sparsity, 2)) + \" %\",\n np.round(err_cg, 4),\n str(np.round(time_cg, 2)) + \" ms\",\n ),\n fontsize=fontsize,\n)\n\npos1, pos2 = draw_transp_colored_GW(\n weightedG2,\n C2,\n weightedG3,\n C3,\n part_G2,\n p1=T_cg.sum(1),\n p2=T_cg.sum(0),\n T=T_cg,\n shiftx=1.5,\n node_size=node_size,\n seed_G1=seed_G2,\n seed_G2=seed_G3,\n)\n\npl.subplot(142)\npl.axis(\"off\")\n\npl.title(\n \"(PPA) FGW=%s\\n \\n OT sparsity = %s \\n marg. error = %s \\n runtime = %s\"\n % (\n np.round(log_ppa[\"fgw_dist\"], 3),\n str(np.round(T_ppa_sparsity, 2)) + \" %\",\n np.round(err_ppa, 4),\n str(np.round(time_ppa, 2)) + \" ms\",\n ),\n fontsize=fontsize,\n)\n\npos1, pos2 = draw_transp_colored_GW(\n weightedG2,\n C2,\n weightedG3,\n C3,\n part_G2,\n p1=T_ppa.sum(1),\n p2=T_ppa.sum(0),\n T=T_ppa,\n pos1=pos1,\n pos2=pos2,\n shiftx=0.0,\n node_size=node_size,\n seed_G1=0,\n seed_G2=0,\n)\n\npl.subplot(143)\npl.axis(\"off\")\n\npl.title(\n \"(PGD) Entropic FGW=%s\\n \\n OT sparsity = %s \\n marg. error = %s \\n runtime = %s\"\n % (\n np.round(log_pgd[\"fgw_dist\"], 3),\n str(np.round(T_pgd_sparsity, 2)) + \" %\",\n np.round(err_pgd, 4),\n str(np.round(time_pgd, 2)) + \" ms\",\n ),\n fontsize=fontsize,\n)\n\npos1, pos2 = draw_transp_colored_GW(\n weightedG2,\n C2,\n weightedG3,\n C3,\n part_G2,\n p1=T_pgd.sum(1),\n p2=T_pgd.sum(0),\n T=T_pgd,\n pos1=pos1,\n pos2=pos2,\n shiftx=0.0,\n node_size=node_size,\n seed_G1=0,\n seed_G2=0,\n)\n\n\npl.subplot(144)\npl.axis(\"off\")\n\npl.title(\n \"(BAPG) FGW=%s\\n \\n OT sparsity = %s \\n marg. error = %s \\n runtime = %s\"\n % (\n np.round(log_bapg[\"fgw_dist\"], 3),\n str(np.round(T_bapg_sparsity, 2)) + \" %\",\n np.round(err_bapg, 4),\n str(np.round(time_bapg, 2)) + \" ms\",\n ),\n fontsize=fontsize,\n)\n\npos1, pos2 = draw_transp_colored_GW(\n weightedG2,\n C2,\n weightedG3,\n C3,\n part_G2,\n p1=T_bapg.sum(1),\n p2=T_bapg.sum(0),\n T=T_bapg,\n pos1=pos1,\n pos2=pos2,\n shiftx=0.0,\n node_size=node_size,\n seed_G1=0,\n seed_G2=0,\n)\n\npl.tight_layout()\n\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }