{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Entropic-regularized semi-relaxed (Fused) Gromov-Wasserstein example\n\nThis example is designed to show how to use the entropic semi-relaxed Gromov-Wasserstein\nand the entropic semi-relaxed Fused Gromov-Wasserstein divergences.\n\nEntropic-regularized sr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of\nG2 at a minimal entropic-regularized (F)GW distance from G1.\n\nFirst, we generate two graphs following Stochastic Block Models, then show\nhow to compute their srGW matchings and illustrate them. These graphs are then\nendowed with node features and we follow the same process with srFGW.\n\n[48] C\u00e9dric Vincent-Cuaz, R\u00e9mi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.\n\"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs\"\nInternational Conference on Learning Representations (ICLR), 2021.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: C\u00e9dric Vincent-Cuaz \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport numpy as np\nimport matplotlib.pylab as pl\nfrom ot.gromov import (\n entropic_semirelaxed_gromov_wasserstein,\n entropic_semirelaxed_fused_gromov_wasserstein,\n gromov_wasserstein,\n fused_gromov_wasserstein,\n)\nimport networkx\nfrom networkx.generators.community import stochastic_block_model as sbm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate two graphs following Stochastic Block models of 2 and 3 clusters.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "N2 = 20 # 2 communities\nN3 = 30 # 3 communities\np2 = [[1.0, 0.1], [0.1, 0.9]]\np3 = [[1.0, 0.1, 0.0], [0.1, 0.95, 0.1], [0.0, 0.1, 0.9]]\nG2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)\nG3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)\n\n\nC2 = networkx.to_numpy_array(G2)\nC3 = networkx.to_numpy_array(G3)\n\nh2 = np.ones(C2.shape[0]) / C2.shape[0]\nh3 = np.ones(C3.shape[0]) / C3.shape[0]\n\n# Add weights on the edges for visualization later on\nweight_intra_G2 = 5\nweight_inter_G2 = 0.5\nweight_intra_G3 = 1.0\nweight_inter_G3 = 1.5\n\nweightedG2 = networkx.Graph()\npart_G2 = [G2.nodes[i][\"block\"] for i in range(N2)]\n\nfor node in G2.nodes():\n weightedG2.add_node(node)\nfor i, j in G2.edges():\n if part_G2[i] == part_G2[j]:\n weightedG2.add_edge(i, j, weight=weight_intra_G2)\n else:\n weightedG2.add_edge(i, j, weight=weight_inter_G2)\n\nweightedG3 = networkx.Graph()\npart_G3 = [G3.nodes[i][\"block\"] for i in range(N3)]\n\nfor node in G3.nodes():\n weightedG3.add_node(node)\nfor i, j in G3.edges():\n if part_G3[i] == part_G3[j]:\n weightedG3.add_edge(i, j, weight=weight_intra_G3)\n else:\n weightedG3.add_edge(i, j, weight=weight_inter_G3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute their entropic-regularized semi-relaxed Gromov-Wasserstein divergences\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# 0) GW(C2, h2, C3, h3) for reference\nOT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)\ngw = log[\"gw_dist\"]\n\n# 1) srGW_e(C2, h2, C3)\nOT_23, log_23 = entropic_semirelaxed_gromov_wasserstein(\n C2, C3, h2, symmetric=True, epsilon=1.0, G0=None, log=True\n)\nsrgw_23 = log_23[\"srgw_dist\"]\n\n# 2) srGW_e(C3, h3, C2)\n\nOT_32, log_32 = entropic_semirelaxed_gromov_wasserstein(\n C3, C2, h3, symmetric=None, epsilon=1.0, G0=None, log=True\n)\nsrgw_32 = log_32[\"srgw_dist\"]\n\nprint(\"GW(C2, C3) = \", gw)\nprint(\"srGW_e(C2, h2, C3) = \", srgw_23)\nprint(\"srGW_e(C3, h3, C2) = \", srgw_32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization of the entropic-regularized semi-relaxed Gromov-Wasserstein matchings\n\nWe color nodes of the graph on the right - then project its node colors\nbased on the optimal transport plan from the entropic srGW matching.\nWe adjust the intensity of links across domains proportionaly to the mass\nsent, adding a minimal intensity of 0.1 if mass sent is not zero.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def draw_graph(\n G,\n C,\n nodes_color_part,\n Gweights=None,\n pos=None,\n edge_color=\"black\",\n node_size=None,\n shiftx=0,\n seed=0,\n):\n if pos is None:\n pos = networkx.spring_layout(G, scale=1.0, seed=seed)\n\n if shiftx != 0:\n for k, v in pos.items():\n v[0] = v[0] + shiftx\n\n alpha_edge = 0.7\n width_edge = 1.8\n if Gweights is None:\n networkx.draw_networkx_edges(\n G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color\n )\n else:\n # We make more visible connections between activated nodes\n n = len(Gweights)\n edgelist_activated = []\n edgelist_deactivated = []\n for i in range(n):\n for j in range(n):\n if Gweights[i] * Gweights[j] * C[i, j] > 0:\n edgelist_activated.append((i, j))\n elif C[i, j] > 0:\n edgelist_deactivated.append((i, j))\n\n networkx.draw_networkx_edges(\n G,\n pos,\n edgelist=edgelist_activated,\n width=width_edge,\n alpha=alpha_edge,\n edge_color=edge_color,\n )\n networkx.draw_networkx_edges(\n G,\n pos,\n edgelist=edgelist_deactivated,\n width=width_edge,\n alpha=0.1,\n edge_color=edge_color,\n )\n\n if Gweights is None:\n for node, node_color in enumerate(nodes_color_part):\n networkx.draw_networkx_nodes(\n G,\n pos,\n nodelist=[node],\n node_size=node_size,\n alpha=1,\n node_color=node_color,\n )\n else:\n scaled_Gweights = Gweights / (0.5 * Gweights.max())\n nodes_size = node_size * scaled_Gweights\n for node, node_color in enumerate(nodes_color_part):\n networkx.draw_networkx_nodes(\n G,\n pos,\n nodelist=[node],\n node_size=nodes_size[node],\n alpha=1,\n node_color=node_color,\n )\n return pos\n\n\ndef draw_transp_colored_srGW(\n G1,\n C1,\n G2,\n C2,\n part_G1,\n p1,\n p2,\n T,\n pos1=None,\n pos2=None,\n shiftx=4,\n switchx=False,\n node_size=70,\n seed_G1=0,\n seed_G2=0,\n):\n starting_color = 0\n # get graphs partition and their coloring\n part1 = part_G1.copy()\n unique_colors = [\"C%s\" % (starting_color + i) for i in np.unique(part1)]\n nodes_color_part1 = []\n for cluster in part1:\n nodes_color_part1.append(unique_colors[cluster])\n\n nodes_color_part2 = []\n # T: getting colors assignment from argmin of columns\n for i in range(len(G2.nodes())):\n j = np.argmax(T[:, i])\n nodes_color_part2.append(nodes_color_part1[j])\n pos1 = draw_graph(\n G1,\n C1,\n nodes_color_part1,\n Gweights=p1,\n pos=pos1,\n node_size=node_size,\n shiftx=0,\n seed=seed_G1,\n )\n pos2 = draw_graph(\n G2,\n C2,\n nodes_color_part2,\n Gweights=p2,\n pos=pos2,\n node_size=node_size,\n shiftx=shiftx,\n seed=seed_G2,\n )\n for k1, v1 in pos1.items():\n max_Tk1 = np.max(T[k1, :])\n for k2, v2 in pos2.items():\n if T[k1, k2] > 0:\n pl.plot(\n [pos1[k1][0], pos2[k2][0]],\n [pos1[k1][1], pos2[k2][1]],\n \"-\",\n lw=0.6,\n alpha=min(T[k1, k2] / max_Tk1 + 0.1, 1.0),\n color=nodes_color_part1[k1],\n )\n return pos1, pos2\n\n\nnode_size = 40\nfontsize = 10\nseed_G2 = 0\nseed_G3 = 4\n\npl.figure(1, figsize=(8, 2.5))\npl.clf()\npl.subplot(121)\npl.axis(\"off\")\npl.axis\npl.title(\n r\"$srGW_e(\\mathbf{C_2},\\mathbf{h_2},\\mathbf{C_3}) =%s$\" % (np.round(srgw_23, 3)),\n fontsize=fontsize,\n)\n\nhbar2 = OT_23.sum(axis=0)\npos1, pos2 = draw_transp_colored_srGW(\n weightedG2,\n C2,\n weightedG3,\n C3,\n part_G2,\n p1=None,\n p2=hbar2,\n T=OT_23,\n shiftx=1.5,\n node_size=node_size,\n seed_G1=seed_G2,\n seed_G2=seed_G3,\n)\npl.subplot(122)\npl.axis(\"off\")\nhbar3 = OT_32.sum(axis=0)\npl.title(\n r\"$srGW_e(\\mathbf{C_3}, \\mathbf{h_3},\\mathbf{C_2}) =%s$\" % (np.round(srgw_32, 3)),\n fontsize=fontsize,\n)\npos1, pos2 = draw_transp_colored_srGW(\n weightedG3,\n C3,\n weightedG2,\n C2,\n part_G3,\n p1=None,\n p2=hbar3,\n T=OT_32,\n pos1=pos2,\n pos2=pos1,\n shiftx=3.0,\n node_size=node_size,\n seed_G1=0,\n seed_G2=0,\n)\npl.tight_layout()\n\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Add node features\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# We add node features with given mean - by clusters\n# and inversely proportional to clusters' intra-connectivity\n\nF2 = np.zeros((N2, 1))\nfor i, c in enumerate(part_G2):\n F2[i, 0] = np.random.normal(loc=c, scale=0.01)\n\nF3 = np.zeros((N3, 1))\nfor i, c in enumerate(part_G3):\n F3[i, 0] = np.random.normal(loc=2.0 - c, scale=0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute their semi-relaxed Fused Gromov-Wasserstein divergences\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "alpha = 0.5\n# Compute pairwise euclidean distance between node features\nM = (F2**2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3**2).T) - 2 * F2.dot(F3.T)\n\n# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference\n\nOT, log = fused_gromov_wasserstein(\n M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True\n)\nfgw = log[\"fgw_dist\"]\n\n# 1) srFGW_e(C2, F2, h2, C3, F3)\nOT_23, log_23 = entropic_semirelaxed_fused_gromov_wasserstein(\n M, C2, C3, h2, symmetric=True, epsilon=1.0, alpha=0.5, log=True, G0=None\n)\nsrfgw_23 = log_23[\"srfgw_dist\"]\n\n# 2) srFGW(C3, F3, h3, C2, F2)\n\nOT_32, log_32 = entropic_semirelaxed_fused_gromov_wasserstein(\n M.T, C3, C2, h3, symmetric=None, epsilon=1.0, alpha=alpha, log=True, G0=None\n)\nsrfgw_32 = log_32[\"srfgw_dist\"]\n\nprint(\"FGW(C2, F2, C3, F3) = \", fgw)\nprint(r\"$srGW_e$(C2, F2, h2, C3, F3) = \", srfgw_23)\nprint(r\"$srGW_e$(C3, F3, h3, C2, F2) = \", srfgw_32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization of the entropic semi-relaxed Fused Gromov-Wasserstein matchings\n\nWe color nodes of the graph on the right - then project its node colors\nbased on the optimal transport plan from the srFGW matching\nNB: colors refer to clusters - not to node features\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(2, figsize=(8, 2.5))\npl.clf()\npl.subplot(121)\npl.axis(\"off\")\npl.axis\npl.title(\n r\"$srFGW_e(\\mathbf{C_2},\\mathbf{F_2},\\mathbf{h_2},\\mathbf{C_3},\\mathbf{F_3}) =%s$\"\n % (np.round(srfgw_23, 3)),\n fontsize=fontsize,\n)\n\nhbar2 = OT_23.sum(axis=0)\npos1, pos2 = draw_transp_colored_srGW(\n weightedG2,\n C2,\n weightedG3,\n C3,\n part_G2,\n p1=None,\n p2=hbar2,\n T=OT_23,\n shiftx=1.5,\n node_size=node_size,\n seed_G1=seed_G2,\n seed_G2=seed_G3,\n)\npl.subplot(122)\npl.axis(\"off\")\nhbar3 = OT_32.sum(axis=0)\npl.title(\n r\"$srFGW_e(\\mathbf{C_3}, \\mathbf{F_3}, \\mathbf{h_3}, \\mathbf{C_2}, \\mathbf{F_2}) =%s$\"\n % (np.round(srfgw_32, 3)),\n fontsize=fontsize,\n)\npos1, pos2 = draw_transp_colored_srGW(\n weightedG3,\n C3,\n weightedG2,\n C2,\n part_G3,\n p1=None,\n p2=hbar3,\n T=OT_32,\n pos1=pos2,\n pos2=pos1,\n shiftx=3.0,\n node_size=node_size,\n seed_G1=0,\n seed_G2=0,\n)\npl.tight_layout()\n\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }