{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Plot partial FGW for subgraph matching\n\n

Note

Example added in release: 0.9.6.

\n\nThis example illustrates the computation of partial (Fused) Gromov-Wasserstein\ndivergences for subgraph matching tasks, using the exact formulation $p(F)GW$ and\nthe entropically regularized one $p(F)GW_e$ [18, 29].\n\nWe first create a clean circular graph of 15 nodes with node features correlated with\nnode positions on the unit circle, and a noisy version where 5 nodes out of the\ncircle are added. Then knowing the proportion of clean samples in the target graph\n$m=3/4$, we show how to identify them using :\n - The partial GW matching and its entropic counterpart, omitting node features.\n - The partial Fused GW matching and its entropic counterpart.\n\n[18] Vayer Titouan, Chapel Laetitia, Flamary R\u00e9mi, Tavenard Romain\nand Courty Nicolas\n\"Optimal Transport for structured data with application on graphs\"\nInternational Conference on Machine Learning (ICML). 2019.\n\n[29] Chapel, L., Alaya, M., Gasso, G. (2020). \"Partial Optimal\nTransport with Applications on Positive-Unlabeled Learning\". NeurIPS.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: C\u00e9dric Vincent-Cuaz \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 3" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\nimport pylab as pl\nimport networkx as nx\nimport math\nfrom scipy.sparse.csgraph import shortest_path\nimport matplotlib.colors as mcol\nfrom matplotlib import cm\nfrom ot.gromov import (\n partial_gromov_wasserstein,\n entropic_partial_gromov_wasserstein,\n partial_fused_gromov_wasserstein,\n entropic_partial_fused_gromov_wasserstein,\n)\nfrom ot import unif, dist" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Utils for generation and visualization\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def build_noisy_circular_graph(n_clean=15, n_noise=5, random_seed=0):\n \"\"\"Create a noisy circular graph\"\"\"\n # create clean circle\n np.random.seed(random_seed)\n g = nx.Graph()\n g.add_nodes_from(np.arange(n_clean + n_noise))\n for i in range(n_clean):\n g.add_node(i, weight=math.sin(2 * i * math.pi / n_clean))\n if i == (n_clean - 1):\n g.add_edge(i, 0)\n else:\n g.add_edge(i, i + 1)\n # add nodes out of the circle as structure noise\n if n_noise > 0:\n noisy_nodes = np.random.choice(np.arange(n_clean), n_noise)\n for i, j in enumerate(noisy_nodes):\n g.add_node(i + n_clean, weight=math.sin(2 * j * math.pi / n_clean))\n g.add_edge(i + n_clean, j)\n g.add_edge(i + n_clean, (j + 1) % n_clean)\n return g\n\n\ndef graph_colors(nx_graph, vmin=0, vmax=7):\n cnorm = mcol.Normalize(vmin=vmin, vmax=vmax)\n cpick = cm.ScalarMappable(norm=cnorm, cmap=\"viridis\")\n cpick.set_array([])\n val_map = {}\n for k, v in nx.get_node_attributes(nx_graph, \"weight\").items():\n val_map[k] = cpick.to_rgba(v)\n colors = []\n for node in nx_graph.nodes():\n colors.append(val_map[node])\n return colors\n\n\ndef draw_graph(\n G,\n C,\n nodes_color_part,\n Gweights=None,\n pos=None,\n edge_color=\"black\",\n node_size=None,\n shiftx=0,\n):\n if pos is None:\n pos = nx.kamada_kawai_layout(G)\n\n if shiftx != 0:\n for k, v in pos.items():\n v[0] = v[0] + shiftx\n\n alpha_edge = 0.7\n width_edge = 1.8\n if Gweights is None:\n nx.draw_networkx_edges(\n G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color\n )\n else:\n # We make more visible connections between activated nodes\n n = len(Gweights)\n edgelist_activated = []\n edgelist_deactivated = []\n for i in range(n):\n for j in range(n):\n if Gweights[i] * Gweights[j] * C[i, j] > 0:\n edgelist_activated.append((i, j))\n elif C[i, j] > 0:\n edgelist_deactivated.append((i, j))\n\n nx.draw_networkx_edges(\n G,\n pos,\n edgelist=edgelist_activated,\n width=width_edge,\n alpha=alpha_edge,\n edge_color=edge_color,\n )\n nx.draw_networkx_edges(\n G,\n pos,\n edgelist=edgelist_deactivated,\n width=width_edge,\n alpha=0.1,\n edge_color=edge_color,\n )\n\n if Gweights is None:\n for node, node_color in enumerate(nodes_color_part):\n nx.draw_networkx_nodes(\n G,\n pos,\n nodelist=[node],\n node_size=node_size,\n alpha=1,\n node_color=node_color,\n )\n else:\n scaled_Gweights = Gweights / (0.5 * Gweights.max())\n nodes_size = node_size * scaled_Gweights\n for node, node_color in enumerate(nodes_color_part):\n if nodes_size[node] == 0:\n local_node_size = 0\n else:\n local_node_size = max(0.1 * node_size, nodes_size[node])\n nx.draw_networkx_nodes(\n G,\n pos,\n nodelist=[node],\n node_size=local_node_size,\n alpha=1,\n node_color=node_color,\n )\n return pos\n\n\ndef draw_transp_colored(\n G1,\n C1,\n G2,\n C2,\n p1,\n p2,\n T,\n pos1=None,\n pos2=None,\n shiftx=4,\n switchx=False,\n node_size=70,\n color_features=False,\n):\n if color_features:\n nodes_color_part1 = graph_colors(G1, vmin=-1, vmax=1)\n nodes_color_part2 = graph_colors(G2, vmin=-1, vmax=1)\n else:\n nodes_color_part1 = C1.shape[0] * [\"C0\"]\n nodes_color_part2 = C2.shape[0] * [\"C0\"]\n\n pos1 = draw_graph(\n G1,\n C1,\n nodes_color_part1,\n Gweights=p1,\n pos=pos1,\n node_size=node_size,\n shiftx=0,\n )\n pos2 = draw_graph(\n G2,\n C2,\n nodes_color_part2,\n Gweights=p2,\n pos=pos2,\n node_size=node_size,\n shiftx=shiftx,\n )\n T_max = T.max()\n for k1, v1 in pos1.items():\n for k2, v2 in pos2.items():\n if T[k1, k2] > 0:\n pl.plot(\n [pos1[k1][0], pos2[k2][0]],\n [pos1[k1][1], pos2[k2][1]],\n \"-\",\n lw=0.8,\n alpha=max(0.05, 0.8 * T[k1, k2] / T_max),\n color=nodes_color_part1[k1],\n )\n return pos1, pos2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate and visualize data\nWe build a clean circular graph that will be matched to a noisy circular graph.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "clean_graph = build_noisy_circular_graph(n_clean=15, n_noise=0)\n\nnoisy_graph = build_noisy_circular_graph(n_clean=15, n_noise=5)\n\ngraphs = [clean_graph, noisy_graph]\nlist_pos = []\npl.figure(figsize=(6, 3))\nfor i in range(2):\n pl.subplot(1, 2, i + 1)\n g = graphs[i]\n if i == 0:\n pl.title(\"clean graph\", fontsize=16)\n else:\n pl.title(\"noisy graph\", fontsize=16)\n pos = nx.kamada_kawai_layout(g)\n list_pos.append(pos)\n nx.draw_networkx(\n g,\n pos=pos,\n node_color=graph_colors(g, vmin=-1, vmax=1),\n with_labels=False,\n node_size=100,\n )\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Partial (Entropic) Gromov-Wasserstein computation and visualization\nAdjacency matrices are compared using both exact and entropic partial GW\ndiscarding for now node features.\nThen for illustration, the node sizes are proportional to their optimized masses\nand the intensity of the link between two nodes across graphs is set proportionally\nto the corresponding transported mass.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "Cs = [nx.adjacency_matrix(G).toarray().astype(np.float64) for G in graphs]\nps = [unif(C.shape[0]) for C in Cs]\n\n# provide an informative initialization for better visualization\nm = 3.0 / 4.0\npartial_id = np.zeros((15, 20))\npartial_id[:15, :15] = np.eye(15) / 15.0\nG0 = (np.outer(ps[0], ps[1]) + partial_id) * m / 2\n\n# compute exact partial GW\nT, log = partial_gromov_wasserstein(\n Cs[0], Cs[1], ps[0], ps[1], m=m, G0=G0, symmetric=True, log=True\n)\n\n# compute entropic partial GW leading to dense transport plans\nTent, logent = entropic_partial_gromov_wasserstein(\n Cs[0], Cs[1], ps[0], ps[1], reg=0.01, m=m, G0=G0, symmetric=True, log=True\n)\n\n# Plot matchings\nlist_T = [T, Tent]\nlist_dist = [\n np.round(log[\"partial_gw_dist\"], 3),\n np.round(logent[\"partial_gw_dist\"], 3),\n]\nlist_dist_str = [\"pGW\", \"pGW_e\"]\n\npl.figure(2, figsize=(10, 3))\npl.clf()\nfor i in range(2):\n pl.subplot(1, 2, i + 1)\n pl.axis(\"off\")\n pl.title(\n r\"$%s(\\mathbf{C_1},\\mathbf{p_1}^\\star,\\mathbf{C_2},\\mathbf{p_2}^\\star) =%s$\"\n % (list_dist_str[i], list_dist[i]),\n fontsize=14,\n )\n\n p2 = list_T[i].sum(0)\n\n pos1, pos2 = draw_transp_colored(\n clean_graph,\n Cs[0],\n noisy_graph,\n Cs[1],\n p1=None,\n p2=p2,\n T=list_T[i],\n shiftx=3,\n node_size=50,\n )\n\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Partial (Entropic) Fused Gromov-Wasserstein computation and visualization\nWe add now node features compared using pairwise euclidean distance\nto illustrate partial FGW computation with trade-off parameter alpha=0.5\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "Ys = [\n np.array([v for (k, v) in nx.get_node_attributes(G, \"weight\").items()]).reshape(\n -1, 1\n )\n for G in graphs\n]\nM = dist(Ys[0], Ys[1])\n# provide an informative initialization for better visualization\nm = 3.0 / 4.0\npartial_id = np.zeros((15, 20))\npartial_id[:15, :15] = np.eye(15) / 15.0\nG0 = (np.outer(ps[0], ps[1]) + partial_id) * m / 2\n\n# compute exact partial GW\nT, log = partial_fused_gromov_wasserstein(\n M,\n Cs[0],\n Cs[1],\n ps[0],\n ps[1],\n alpha=0.5,\n m=m,\n G0=G0,\n symmetric=True,\n log=True,\n)\n\n# compute entropic partial GW leading to dense transport plans\nTent, logent = entropic_partial_fused_gromov_wasserstein(\n M,\n Cs[0],\n Cs[1],\n ps[0],\n ps[1],\n reg=0.01,\n alpha=0.5,\n m=m,\n G0=G0,\n symmetric=True,\n log=True,\n)\n\n# Plot matchings\nlist_T = [T, Tent]\nlist_dist = [\n np.round(log[\"partial_fgw_dist\"], 3),\n np.round(logent[\"partial_fgw_dist\"], 3),\n]\nlist_dist_str = [\"pFGW\", \"pFGW_e\"]\n\npl.figure(3, figsize=(10, 3))\npl.clf()\nfor i in range(2):\n pl.subplot(1, 2, i + 1)\n pl.axis(\"off\")\n pl.title(\n r\"$%s(\\mathbf{C_1},\\mathbf{p_1}^\\star,\\mathbf{C_2}, \\mathbf{p_2}^\\star) =%s$\"\n % (list_dist_str[i], list_dist[i]),\n fontsize=14,\n )\n\n p2 = list_T[i].sum(0)\n pos1, pos2 = draw_transp_colored(\n clean_graph,\n Cs[0],\n noisy_graph,\n Cs[1],\n p1=None,\n p2=p2,\n T=list_T[i],\n shiftx=3,\n node_size=50,\n color_features=True,\n )\n\npl.tight_layout()\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }