{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Semi-relaxed (Fused) Gromov-Wasserstein example\n\nThis example is designed to show how to use the semi-relaxed Gromov-Wasserstein\nand the semi-relaxed Fused Gromov-Wasserstein divergences.\n\nsr(F)GW between two graphs G1 and G2 searches for a reweighing of the nodes of\nG2 at a minimal (F)GW distance from G1.\n\nFirst, we generate two graphs following Stochastic Block Models, then show\nhow to compute their srGW matchings and illustrate them. These graphs are then\nendowed with node features and we follow the same process with srFGW.\n\n[48] C\u00e9dric Vincent-Cuaz, R\u00e9mi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.\n\"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs\"\nInternational Conference on Learning Representations (ICLR), 2021.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: C\u00e9dric Vincent-Cuaz \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport numpy as np\nimport matplotlib.pylab as pl\nfrom ot.gromov import (\n semirelaxed_gromov_wasserstein,\n semirelaxed_fused_gromov_wasserstein,\n gromov_wasserstein,\n fused_gromov_wasserstein,\n)\nimport networkx\nfrom networkx.generators.community import stochastic_block_model as sbm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate two graphs following Stochastic Block models of 2 and 3 clusters.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "N2 = 20 # 2 communities\nN3 = 30 # 3 communities\np2 = [[1.0, 0.1], [0.1, 0.9]]\np3 = [[1.0, 0.1, 0.0], [0.1, 0.95, 0.1], [0.0, 0.1, 0.9]]\nG2 = sbm(seed=0, sizes=[N2 // 2, N2 // 2], p=p2)\nG3 = sbm(seed=0, sizes=[N3 // 3, N3 // 3, N3 // 3], p=p3)\n\n\nC2 = networkx.to_numpy_array(G2)\nC3 = networkx.to_numpy_array(G3)\n\nh2 = np.ones(C2.shape[0]) / C2.shape[0]\nh3 = np.ones(C3.shape[0]) / C3.shape[0]\n\n# Add weights on the edges for visualization later on\nweight_intra_G2 = 5\nweight_inter_G2 = 0.5\nweight_intra_G3 = 1.0\nweight_inter_G3 = 1.5\n\nweightedG2 = networkx.Graph()\npart_G2 = [G2.nodes[i][\"block\"] for i in range(N2)]\n\nfor node in G2.nodes():\n weightedG2.add_node(node)\nfor i, j in G2.edges():\n if part_G2[i] == part_G2[j]:\n weightedG2.add_edge(i, j, weight=weight_intra_G2)\n else:\n weightedG2.add_edge(i, j, weight=weight_inter_G2)\n\nweightedG3 = networkx.Graph()\npart_G3 = [G3.nodes[i][\"block\"] for i in range(N3)]\n\nfor node in G3.nodes():\n weightedG3.add_node(node)\nfor i, j in G3.edges():\n if part_G3[i] == part_G3[j]:\n weightedG3.add_edge(i, j, weight=weight_intra_G3)\n else:\n weightedG3.add_edge(i, j, weight=weight_inter_G3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute their semi-relaxed Gromov-Wasserstein divergences\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# 0) GW(C2, h2, C3, h3) for reference\nOT, log = gromov_wasserstein(C2, C3, h2, h3, symmetric=True, log=True)\ngw = log[\"gw_dist\"]\n\n# 1) srGW(C2, h2, C3)\nOT_23, log_23 = semirelaxed_gromov_wasserstein(\n C2, C3, h2, symmetric=True, log=True, G0=None\n)\nsrgw_23 = log_23[\"srgw_dist\"]\n\n# 2) srGW(C3, h3, C2)\n\nOT_32, log_32 = semirelaxed_gromov_wasserstein(\n C3, C2, h3, symmetric=None, log=True, G0=OT.T\n)\nsrgw_32 = log_32[\"srgw_dist\"]\n\nprint(\"GW(C2, C3) = \", gw)\nprint(\"srGW(C2, h2, C3) = \", srgw_23)\nprint(\"srGW(C3, h3, C2) = \", srgw_32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization of the semi-relaxed Gromov-Wasserstein matchings\n\nWe color nodes of the graph on the right - then project its node colors\nbased on the optimal transport plan from the srGW matching\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def draw_graph(\n G,\n C,\n nodes_color_part,\n Gweights=None,\n pos=None,\n edge_color=\"black\",\n node_size=None,\n shiftx=0,\n seed=0,\n):\n if pos is None:\n pos = networkx.spring_layout(G, scale=1.0, seed=seed)\n\n if shiftx != 0:\n for k, v in pos.items():\n v[0] = v[0] + shiftx\n\n alpha_edge = 0.7\n width_edge = 1.8\n if Gweights is None:\n networkx.draw_networkx_edges(\n G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color\n )\n else:\n # We make more visible connections between activated nodes\n n = len(Gweights)\n edgelist_activated = []\n edgelist_deactivated = []\n for i in range(n):\n for j in range(n):\n if Gweights[i] * Gweights[j] * C[i, j] > 0:\n edgelist_activated.append((i, j))\n elif C[i, j] > 0:\n edgelist_deactivated.append((i, j))\n\n networkx.draw_networkx_edges(\n G,\n pos,\n edgelist=edgelist_activated,\n width=width_edge,\n alpha=alpha_edge,\n edge_color=edge_color,\n )\n networkx.draw_networkx_edges(\n G,\n pos,\n edgelist=edgelist_deactivated,\n width=width_edge,\n alpha=0.1,\n edge_color=edge_color,\n )\n\n if Gweights is None:\n for node, node_color in enumerate(nodes_color_part):\n networkx.draw_networkx_nodes(\n G,\n pos,\n nodelist=[node],\n node_size=node_size,\n alpha=1,\n node_color=node_color,\n )\n else:\n scaled_Gweights = Gweights / (0.5 * Gweights.max())\n nodes_size = node_size * scaled_Gweights\n for node, node_color in enumerate(nodes_color_part):\n networkx.draw_networkx_nodes(\n G,\n pos,\n nodelist=[node],\n node_size=nodes_size[node],\n alpha=1,\n node_color=node_color,\n )\n return pos\n\n\ndef draw_transp_colored_srGW(\n G1,\n C1,\n G2,\n C2,\n part_G1,\n p1,\n p2,\n T,\n pos1=None,\n pos2=None,\n shiftx=4,\n switchx=False,\n node_size=70,\n seed_G1=0,\n seed_G2=0,\n):\n starting_color = 0\n # get graphs partition and their coloring\n part1 = part_G1.copy()\n unique_colors = [\"C%s\" % (starting_color + i) for i in np.unique(part1)]\n nodes_color_part1 = []\n for cluster in part1:\n nodes_color_part1.append(unique_colors[cluster])\n\n nodes_color_part2 = []\n # T: getting colors assignment from argmin of columns\n for i in range(len(G2.nodes())):\n j = np.argmax(T[:, i])\n nodes_color_part2.append(nodes_color_part1[j])\n pos1 = draw_graph(\n G1,\n C1,\n nodes_color_part1,\n Gweights=p1,\n pos=pos1,\n node_size=node_size,\n shiftx=0,\n seed=seed_G1,\n )\n pos2 = draw_graph(\n G2,\n C2,\n nodes_color_part2,\n Gweights=p2,\n pos=pos2,\n node_size=node_size,\n shiftx=shiftx,\n seed=seed_G2,\n )\n for k1, v1 in pos1.items():\n for k2, v2 in pos2.items():\n if T[k1, k2] > 0:\n pl.plot(\n [pos1[k1][0], pos2[k2][0]],\n [pos1[k1][1], pos2[k2][1]],\n \"-\",\n lw=0.8,\n alpha=0.5,\n color=nodes_color_part1[k1],\n )\n return pos1, pos2\n\n\nnode_size = 40\nfontsize = 10\nseed_G2 = 0\nseed_G3 = 4\n\npl.figure(1, figsize=(8, 2.5))\npl.clf()\npl.subplot(121)\npl.axis(\"off\")\npl.axis\npl.title(\n r\"srGW$(\\mathbf{C_2},\\mathbf{h_2},\\mathbf{C_3}) =%s$\" % (np.round(srgw_23, 3)),\n fontsize=fontsize,\n)\n\nhbar2 = OT_23.sum(axis=0)\npos1, pos2 = draw_transp_colored_srGW(\n weightedG2,\n C2,\n weightedG3,\n C3,\n part_G2,\n p1=None,\n p2=hbar2,\n T=OT_23,\n shiftx=1.5,\n node_size=node_size,\n seed_G1=seed_G2,\n seed_G2=seed_G3,\n)\npl.subplot(122)\npl.axis(\"off\")\nhbar3 = OT_32.sum(axis=0)\npl.title(\n r\"srGW$(\\mathbf{C_3}, \\mathbf{h_3},\\mathbf{C_2}) =%s$\" % (np.round(srgw_32, 3)),\n fontsize=fontsize,\n)\npos1, pos2 = draw_transp_colored_srGW(\n weightedG3,\n C3,\n weightedG2,\n C2,\n part_G3,\n p1=None,\n p2=hbar3,\n T=OT_32,\n pos1=pos2,\n pos2=pos1,\n shiftx=3.0,\n node_size=node_size,\n seed_G1=0,\n seed_G2=0,\n)\npl.tight_layout()\n\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Add node features\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# We add node features with given mean - by clusters\n# and inversely proportional to clusters' intra-connectivity\n\nF2 = np.zeros((N2, 1))\nfor i, c in enumerate(part_G2):\n F2[i, 0] = np.random.normal(loc=c, scale=0.01)\n\nF3 = np.zeros((N3, 1))\nfor i, c in enumerate(part_G3):\n F3[i, 0] = np.random.normal(loc=2.0 - c, scale=0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute their semi-relaxed Fused Gromov-Wasserstein divergences\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "alpha = 0.5\n# Compute pairwise euclidean distance between node features\nM = (F2**2).dot(np.ones((1, N3))) + np.ones((N2, 1)).dot((F3**2).T) - 2 * F2.dot(F3.T)\n\n# 0) FGW_alpha(C2, F2, h2, C3, F3, h3) for reference\n\nOT, log = fused_gromov_wasserstein(\n M, C2, C3, h2, h3, symmetric=True, alpha=alpha, log=True\n)\nfgw = log[\"fgw_dist\"]\n\n# 1) srFGW(C2, F2, h2, C3, F3)\nOT_23, log_23 = semirelaxed_fused_gromov_wasserstein(\n M, C2, C3, h2, symmetric=True, alpha=0.5, log=True, G0=None\n)\nsrfgw_23 = log_23[\"srfgw_dist\"]\n\n# 2) srFGW(C3, F3, h3, C2, F2)\n\nOT_32, log_32 = semirelaxed_fused_gromov_wasserstein(\n M.T, C3, C2, h3, symmetric=None, alpha=alpha, log=True, G0=None\n)\nsrfgw_32 = log_32[\"srfgw_dist\"]\n\nprint(\"FGW(C2, F2, C3, F3) = \", fgw)\nprint(\"srGW(C2, F2, h2, C3, F3) = \", srfgw_23)\nprint(\"srGW(C3, F3, h3, C2, F2) = \", srfgw_32)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization of the semi-relaxed Fused Gromov-Wasserstein matchings\n\nWe color nodes of the graph on the right - then project its node colors\nbased on the optimal transport plan from the srFGW matching\nNB: colors refer to clusters - not to node features\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(2, figsize=(8, 2.5))\npl.clf()\npl.subplot(121)\npl.axis(\"off\")\npl.axis\npl.title(\n r\"srFGW$(\\mathbf{C_2},\\mathbf{F_2},\\mathbf{h_2},\\mathbf{C_3},\\mathbf{F_3}) =%s$\"\n % (np.round(srfgw_23, 3)),\n fontsize=fontsize,\n)\n\nhbar2 = OT_23.sum(axis=0)\npos1, pos2 = draw_transp_colored_srGW(\n weightedG2,\n C2,\n weightedG3,\n C3,\n part_G2,\n p1=None,\n p2=hbar2,\n T=OT_23,\n shiftx=1.5,\n node_size=node_size,\n seed_G1=seed_G2,\n seed_G2=seed_G3,\n)\npl.subplot(122)\npl.axis(\"off\")\nhbar3 = OT_32.sum(axis=0)\npl.title(\n r\"srFGW$(\\mathbf{C_3}, \\mathbf{F_3}, \\mathbf{h_3}, \\mathbf{C_2}, \\mathbf{F_2}) =%s$\"\n % (np.round(srfgw_32, 3)),\n fontsize=fontsize,\n)\npos1, pos2 = draw_transp_colored_srGW(\n weightedG3,\n C3,\n weightedG2,\n C2,\n part_G3,\n p1=None,\n p2=hbar3,\n T=OT_32,\n pos1=pos2,\n pos2=pos1,\n shiftx=3.0,\n node_size=node_size,\n seed_G1=0,\n seed_G2=0,\n)\npl.tight_layout()\n\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }