{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Quantized Fused Gromov-Wasserstein examples\n\n

Note

Examples added in release: 0.9.4.

\n\nThese examples show how to use the quantized (Fused) Gromov-Wasserstein\nsolvers (qFGW) [68]. POT provides a generic solver `quantized_fused_gromov_wasserstein_partitioned`\nthat takes as inputs partitioned graphs potentially endowed with node features,\nwhich have to be built by the user. On top of that, POT provides two wrappers:\n i) `quantized_fused_gromov_wasserstein` operating over generic graphs, whose\n partitioning is performed via `get_graph_partition` using e.g the Louvain algorithm,\n and representant for each partition can be selected via `get_graph_representants`\n using e.g the PageRank algorithm.\n\n ii) `quantized_fused_gromov_wasserstein_samples` operating over point clouds,\n e.g $X_1 \\in R^{n_1 * d_1}$ and $X_2 \\in R^{n_2 * d_2}$\n endowed with their respective euclidean geometry, whose partitioning and\n representant selection is performed jointly using e.g the K-means algorithm\n via the function `get_partition_and_representants_samples`.\n\n\nWe illustrate next how to compute the qGW distance on both types of data by:\n\n i) Generating two graphs following Stochastic Block Models encoded as shortest\n path matrices as qGW solvers tends to require dense structure to achieve a good\n approximation of the GW distance (as qGW is an upper-bound of GW). In the meantime,\n we illustrate an optional feature of our solvers, namely the use of auxiliary\n structures e.g adjacency matrices to perform the graph partitioning.\n\n ii) Generating two point clouds representing curves in 2D and 3D respectively.\n We augment these point clouds by considering additional features of the same\n dimensionaly $F_1 \\in R^{n_1 * d}$ and $F_2 \\in R^{n_2 * d}$,\n representing the color intensity associated to each sample of both distributions.\n Then we compute the qFGW distance between these attributed point clouds.\n\n\n[68] Chowdhury, S., Miller, D., & Needham, T. (2021). Quantized gromov-wasserstein.\nECML PKDD 2021. Springer International Publishing.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: C\u00e9dric Vincent-Cuaz \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport matplotlib.pyplot as plt\nimport networkx\nfrom networkx.generators.community import stochastic_block_model as sbm\nfrom scipy.sparse.csgraph import shortest_path\n\nfrom ot.gromov import (\n quantized_fused_gromov_wasserstein_partitioned,\n quantized_fused_gromov_wasserstein,\n get_graph_partition,\n get_graph_representants,\n format_partitioned_graph,\n quantized_fused_gromov_wasserstein_samples,\n get_partition_and_representants_samples,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate graphs\n\nCreate two graphs following Stochastic Block models of 2 and 3 clusters.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "N1 = 30 # 2 communities\nN2 = 45 # 3 communities\np1 = [[0.8, 0.1], [0.1, 0.7]]\np2 = [[0.8, 0.1, 0.0], [0.1, 0.75, 0.1], [0.0, 0.1, 0.7]]\nG1 = sbm(seed=0, sizes=[N1 // 2, N1 // 2], p=p1)\nG2 = sbm(seed=0, sizes=[N2 // 3, N2 // 3, N2 // 3], p=p2)\n\n\nC1 = networkx.to_numpy_array(G1)\nC2 = networkx.to_numpy_array(G2)\n\nspC1 = shortest_path(C1)\nspC2 = shortest_path(C2)\n\nh1 = np.ones(C1.shape[0]) / C1.shape[0]\nh2 = np.ones(C2.shape[0]) / C2.shape[0]\n\n# Add weights on the edges for visualization later on\nweight_intra_G1 = 5\nweight_inter_G1 = 0.5\nweight_intra_G2 = 1.0\nweight_inter_G2 = 1.5\n\nweightedG1 = networkx.Graph()\npart_G1 = [G1.nodes[i][\"block\"] for i in range(N1)]\n\nfor node in G1.nodes():\n weightedG1.add_node(node)\nfor i, j in G1.edges():\n if part_G1[i] == part_G1[j]:\n weightedG1.add_edge(i, j, weight=weight_intra_G1)\n else:\n weightedG1.add_edge(i, j, weight=weight_inter_G1)\n\nweightedG2 = networkx.Graph()\npart_G2 = [G2.nodes[i][\"block\"] for i in range(N2)]\n\nfor node in G2.nodes():\n weightedG2.add_node(node)\nfor i, j in G2.edges():\n if part_G2[i] == part_G2[j]:\n weightedG2.add_edge(i, j, weight=weight_intra_G2)\n else:\n weightedG2.add_edge(i, j, weight=weight_inter_G2)\n\n\n# setup for graph visualization\n\n\ndef node_coloring(part, starting_color=0):\n # get graphs partition and their coloring\n unique_colors = [\"C%s\" % (starting_color + i) for i in np.unique(part)]\n nodes_color_part = []\n for cluster in part:\n nodes_color_part.append(unique_colors[cluster])\n\n return nodes_color_part\n\n\ndef draw_graph(\n G,\n C,\n nodes_color_part,\n rep_indices,\n node_alphas=None,\n pos=None,\n edge_color=\"black\",\n alpha_edge=0.7,\n node_size=None,\n shiftx=0,\n seed=0,\n highlight_rep=False,\n):\n if pos is None:\n pos = networkx.spring_layout(G, scale=1.0, seed=seed)\n\n if shiftx != 0:\n for k, v in pos.items():\n v[0] = v[0] + shiftx\n\n width_edge = 1.5\n\n if not highlight_rep:\n networkx.draw_networkx_edges(\n G, pos, width=width_edge, alpha=alpha_edge, edge_color=edge_color\n )\n else:\n for edge in G.edges:\n if (edge[0] in rep_indices) and (edge[1] in rep_indices):\n networkx.draw_networkx_edges(\n G,\n pos,\n edgelist=[edge],\n width=width_edge,\n alpha=alpha_edge,\n edge_color=edge_color,\n )\n else:\n networkx.draw_networkx_edges(\n G,\n pos,\n edgelist=[edge],\n width=width_edge,\n alpha=0.2,\n edge_color=edge_color,\n )\n\n for node, node_color in enumerate(nodes_color_part):\n local_node_shape, local_node_size = \"o\", node_size\n\n if highlight_rep:\n if node in rep_indices:\n local_node_shape, local_node_size = \"*\", 6 * node_size\n\n if node_alphas is None:\n alpha = 0.9\n if highlight_rep:\n alpha = 0.9 if node in rep_indices else 0.1\n\n else:\n alpha = node_alphas[node]\n\n networkx.draw_networkx_nodes(\n G,\n pos,\n nodelist=[node],\n alpha=alpha,\n node_shape=local_node_shape,\n node_size=local_node_size,\n node_color=node_color,\n )\n\n return pos" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute their quantized Gromov-Wasserstein distance without using the wrapper\n\nWe detail next the steps implemented within the wrapper that preprocess graphs\nto form partitioned graphs, which are then passed as input to the generic qFGW solver.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# 1-a) Partition C1 and C2 in 2 and 3 clusters respectively using Louvain\n# algorithm from Networkx. Then encode these partitions via vectors of assignments.\n\npart_method = \"louvain\"\nrep_method = \"pagerank\"\n\nnpart_1 = 2 # 2 clusters used to describe C1\nnpart_2 = 3 # 3 clusters used to describe C2\n\npart1 = get_graph_partition(\n C1, npart=npart_1, part_method=part_method, F=None, alpha=1.0\n)\npart2 = get_graph_partition(\n C2, npart=npart_2, part_method=part_method, F=None, alpha=1.0\n)\n\n# 1-b) Select representant in each partition using the Pagerank algorithm\n# implementation from networkx.\n\nrep_indices1 = get_graph_representants(C1, part1, rep_method=rep_method)\nrep_indices2 = get_graph_representants(C2, part2, rep_method=rep_method)\n\n# 1-c) Format partitions such that:\n# CR contains relations between representants in each space.\n# list_R contains relations between samples and representants within each partition.\n# list_h contains samples relative importance within each partition.\n\nCR1, list_R1, list_h1 = format_partitioned_graph(\n spC1, h1, part1, rep_indices1, F=None, M=None, alpha=1.0\n)\n\nCR2, list_R2, list_h2 = format_partitioned_graph(\n spC2, h2, part2, rep_indices2, F=None, M=None, alpha=1.0\n)\n\n# 1-d) call to partitioned quantized gromov-wasserstein solver\n\nOT_global_, OTs_local_, OT_, log_ = quantized_fused_gromov_wasserstein_partitioned(\n CR1,\n CR2,\n list_R1,\n list_R2,\n list_h1,\n list_h2,\n MR=None,\n alpha=1.0,\n build_OT=True,\n log=True,\n)\n\n\n# Visualization of the graph pre-processing\n\nnode_size = 40\nfontsize = 10\nseed_G1 = 0\nseed_G2 = 3\n\npart1_ = part1.astype(np.int32)\npart2_ = part2.astype(np.int32)\n\n\nnodes_color_part1 = node_coloring(part1_, starting_color=0)\nnodes_color_part2 = node_coloring(\n part2_, starting_color=np.unique(nodes_color_part1).shape[0]\n)\n\n\npl.figure(1, figsize=(6, 5))\npl.clf()\npl.axis(\"off\")\npl.subplot(2, 3, 1)\npl.title(r\"Input graph: $\\mathbf{spC_1}$\", fontsize=fontsize)\n\npos1 = draw_graph(\n G1, C1, [\"C0\" for _ in part1_], rep_indices1, node_size=node_size, seed=seed_G1\n)\n\npl.subplot(2, 3, 2)\npl.title(\"Partitioning\", fontsize=fontsize)\n\n_ = draw_graph(\n G1, C1, nodes_color_part1, rep_indices1, pos=pos1, node_size=node_size, seed=seed_G1\n)\n\npl.subplot(2, 3, 3)\npl.title(\"Representant selection\", fontsize=fontsize)\n\n_ = draw_graph(\n G1,\n C1,\n nodes_color_part1,\n rep_indices1,\n pos=pos1,\n node_size=node_size,\n seed=seed_G1,\n highlight_rep=True,\n)\n\npl.subplot(2, 3, 4)\npl.title(r\"Input graph: $\\mathbf{spC_2}$\", fontsize=fontsize)\n\npos2 = draw_graph(\n G2, C2, [\"C0\" for _ in part2_], rep_indices2, node_size=node_size, seed=seed_G2\n)\n\npl.subplot(2, 3, 5)\npl.title(r\"Partitioning\", fontsize=fontsize)\n\n_ = draw_graph(\n G2, C2, nodes_color_part2, rep_indices2, pos=pos2, node_size=node_size, seed=seed_G2\n)\n\npl.subplot(2, 3, 6)\npl.title(r\"Representant selection\", fontsize=fontsize)\n\n_ = draw_graph(\n G2,\n C2,\n nodes_color_part2,\n rep_indices2,\n pos=pos2,\n node_size=node_size,\n seed=seed_G2,\n highlight_rep=True,\n)\npl.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute the quantized Gromov-Wasserstein distance using the wrapper\n\nCompute qGW(spC1, h1, spC2, h2). We also illustrate the use of auxiliary matrices\nsuch that the adjacency matrices `C1_aux=C1` and `C2_aux=C2` to partition the graph using\nLouvain algorithm, and the Pagerank algorithm for selecting representant within\neach partition. Notice that `C1_aux` and `C2_aux` are optional, if they are not\nspecified these pre-processing algorithms will be applied to spC2 and spC3.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# no node features are considered on this synthetic dataset. Hence we simply\n# let F1, F2 = None and set alpha = 1.\nOT_global, OTs_local, OT, log = quantized_fused_gromov_wasserstein(\n spC1,\n spC2,\n npart_1,\n npart_2,\n h1,\n h2,\n C1_aux=C1,\n C2_aux=C2,\n F1=None,\n F2=None,\n alpha=1.0,\n part_method=part_method,\n rep_method=rep_method,\n log=True,\n)\n\nqGW_dist = log[\"qFGW_dist\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization of the quantized Gromov-Wasserstein matching\n\nWe color nodes of the graph based on the respective partition of each graph.\nOn the first plot we illustrate the qGW matching between both shortest path matrices.\nWhile the GW matching across representants of each space is illustrated on the right.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def draw_transp_colored_qGW(\n G1,\n C1,\n G2,\n C2,\n part1,\n part2,\n rep_indices1,\n rep_indices2,\n T,\n pos1=None,\n pos2=None,\n shiftx=4,\n switchx=False,\n node_size=70,\n seed_G1=0,\n seed_G2=0,\n highlight_rep=False,\n):\n starting_color = 0\n # get graphs partition and their coloring\n unique_colors1 = [\"C%s\" % (starting_color + i) for i in np.unique(part1)]\n nodes_color_part1 = []\n for cluster in part1:\n nodes_color_part1.append(unique_colors1[cluster])\n\n starting_color = len(unique_colors1) + 1\n unique_colors2 = [\"C%s\" % (starting_color + i) for i in np.unique(part2)]\n nodes_color_part2 = []\n for cluster in part2:\n nodes_color_part2.append(unique_colors2[cluster])\n\n pos1 = draw_graph(\n G1,\n C1,\n nodes_color_part1,\n rep_indices1,\n pos=pos1,\n node_size=node_size,\n shiftx=0,\n seed=seed_G1,\n highlight_rep=highlight_rep,\n )\n pos2 = draw_graph(\n G2,\n C2,\n nodes_color_part2,\n rep_indices2,\n pos=pos2,\n node_size=node_size,\n shiftx=shiftx,\n seed=seed_G1,\n highlight_rep=highlight_rep,\n )\n\n if not highlight_rep:\n for k1, v1 in pos1.items():\n max_Tk1 = np.max(T[k1, :])\n for k2, v2 in pos2.items():\n if T[k1, k2] > 0:\n pl.plot(\n [pos1[k1][0], pos2[k2][0]],\n [pos1[k1][1], pos2[k2][1]],\n \"-\",\n lw=0.7,\n alpha=T[k1, k2] / max_Tk1,\n color=nodes_color_part1[k1],\n )\n\n else: # OT is only between representants\n for id1, node_id1 in enumerate(rep_indices1):\n max_Tk1 = np.max(T[id1, :])\n for id2, node_id2 in enumerate(rep_indices2):\n if T[id1, id2] > 0:\n pl.plot(\n [pos1[node_id1][0], pos2[node_id2][0]],\n [pos1[node_id1][1], pos2[node_id2][1]],\n \"-\",\n lw=0.8,\n alpha=T[id1, id2] / max_Tk1,\n color=nodes_color_part1[node_id1],\n )\n return pos1, pos2\n\n\npl.figure(2, figsize=(5, 2.5))\npl.clf()\npl.axis(\"off\")\npl.subplot(1, 2, 1)\npl.title(\n r\"qGW$(\\mathbf{spC_1}, \\mathbf{spC_1}) =%s$\" % (np.round(qGW_dist, 3)),\n fontsize=fontsize,\n)\n\npos1, pos2 = draw_transp_colored_qGW(\n weightedG1,\n C1,\n weightedG2,\n C2,\n part1_,\n part2_,\n rep_indices1,\n rep_indices2,\n T=OT_,\n shiftx=1.5,\n node_size=node_size,\n seed_G1=seed_G1,\n seed_G2=seed_G2,\n)\n\npl.tight_layout()\n\npl.subplot(1, 2, 2)\npl.title(\n r\" GW$(\\mathbf{CR_1}, \\mathbf{CR_2}) =%s$\" % (np.round(log_[\"global dist\"], 3)),\n fontsize=fontsize,\n)\n\npos1, pos2 = draw_transp_colored_qGW(\n weightedG1,\n C1,\n weightedG2,\n C2,\n part1_,\n part2_,\n rep_indices1,\n rep_indices2,\n T=OT_global,\n shiftx=1.5,\n node_size=node_size,\n seed_G1=seed_G1,\n seed_G2=seed_G2,\n highlight_rep=True,\n)\n\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate attributed point clouds\n\nCreate two attributed point clouds representing curves in 2D and 3D respectively,\nwhose samples are further associated to various color intensities.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_samples = 100\n\n# Generate 2D and 3D curves\ntheta = np.linspace(-4 * np.pi, 4 * np.pi, n_samples)\nz = np.linspace(1, 2, n_samples)\nr = z**2 + 1\nx = r * np.sin(theta)\ny = r * np.cos(theta)\n\n# Source and target distribution across spaces encoded respectively via their\n# squared euclidean distance matrices.\n\nX = np.concatenate([x.reshape(-1, 1), z.reshape(-1, 1)], axis=1)\nY = np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)], axis=1)\n\n# Further associated to color intensity features derived from z\n\nFX = z - z.min() / (z.max() - z.min())\nFX = np.clip(0.8 * FX + 0.2, a_min=0.2, a_max=1.0) # for numerical issues\nFY = FX" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize partitioned attributed point clouds\n\nCompute the partitioning and representant selection further used within\nqFGW wrapper, both provided by a K-means algorithm. Then visualize partitioned spaces.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "part1, rep_indices1 = get_partition_and_representants_samples(X, 4, \"kmeans\", 0)\npart2, rep_indices2 = get_partition_and_representants_samples(Y, 4, \"kmeans\", 0)\n\nupart1 = np.unique(part1)\nupart2 = np.unique(part2)\n\n# Plot the source and target samples as distributions\ns = 20\nfig = plt.figure(3, figsize=(6, 3))\n\nax1 = fig.add_subplot(1, 3, 1)\nax1.set_title(\"2D curve\")\nax1.scatter(X[:, 0], X[:, 1], color=\"C0\", alpha=FX, s=s)\nplt.axis(\"off\")\n\n\nax2 = fig.add_subplot(1, 3, 2)\nax2.set_title(\"Partitioning\")\nfor i, elem in enumerate(upart1):\n idx = np.argwhere(part1 == elem)[:, 0]\n ax2.scatter(X[idx, 0], X[idx, 1], color=\"C%s\" % i, alpha=FX[idx], s=s)\nplt.axis(\"off\")\n\nax3 = fig.add_subplot(1, 3, 3)\nax3.set_title(\"Representant selection\")\nfor i, elem in enumerate(upart1):\n idx = np.argwhere(part1 == elem)[:, 0]\n ax3.scatter(X[idx, 0], X[idx, 1], color=\"C%s\" % i, alpha=FX[idx], s=10)\n rep_idx = rep_indices1[i]\n ax3.scatter(\n [X[rep_idx, 0]], [X[rep_idx, 1]], color=\"C%s\" % i, alpha=1, s=6 * s, marker=\"*\"\n )\nplt.axis(\"off\")\nplt.tight_layout()\nplt.show()\n\nstart_color = upart1.shape[0] + 1\n\nfig = plt.figure(4, figsize=(6, 5))\n\nax4 = fig.add_subplot(1, 3, 1, projection=\"3d\")\nax4.set_title(\"3D curve\")\nax4.scatter(Y[:, 0], Y[:, 1], Y[:, 2], c=\"C0\", alpha=FY, s=s)\nplt.axis(\"off\")\n\nax5 = fig.add_subplot(1, 3, 2, projection=\"3d\")\nax5.set_title(\"Partitioning\")\nfor i, elem in enumerate(upart2):\n idx = np.argwhere(part2 == elem)[:, 0]\n color = \"C%s\" % (start_color + i)\n ax5.scatter(Y[idx, 0], Y[idx, 1], Y[idx, 2], c=color, alpha=FY[idx], s=s)\nplt.axis(\"off\")\n\nax6 = fig.add_subplot(1, 3, 3, projection=\"3d\")\nax6.set_title(\"Representant selection\")\nfor i, elem in enumerate(upart2):\n idx = np.argwhere(part2 == elem)[:, 0]\n color = \"C%s\" % (start_color + i)\n rep_idx = rep_indices2[i]\n ax6.scatter(Y[idx, 0], Y[idx, 1], Y[idx, 2], c=color, alpha=FY[idx], s=s)\n ax6.scatter(\n [Y[rep_idx, 0]],\n [Y[rep_idx, 1]],\n [Y[rep_idx, 2]],\n c=color,\n alpha=1,\n s=6 * s,\n marker=\"*\",\n )\nplt.axis(\"off\")\nplt.tight_layout()\nplt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute the quantized Fused Gromov-Wasserstein distance between samples using the wrapper\n\nCompute qFGW(X, FX, hX, Y, FY, HY), setting the trade-off parameter between\nstructures and features `alpha=0.5`. This solver considers a squared euclidean structure\nfor each distribution X and Y, and partition each of them into 4 clusters using\nthe K-means algorithm before computing qFGW.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "T_global, Ts_local, T, log = quantized_fused_gromov_wasserstein_samples(\n X,\n Y,\n 4,\n 4,\n p=None,\n q=None,\n F1=FX[:, None],\n F2=FY[:, None],\n alpha=0.5,\n method=\"kmeans\",\n log=True,\n)\n\n# Plot low rank GW with different ranks\npl.figure(5, figsize=(6, 3))\npl.subplot(1, 2, 1)\npl.title(\"OT between distributions\")\npl.imshow(T, interpolation=\"nearest\", aspect=\"auto\")\npl.colorbar()\npl.axis(\"off\")\n\npl.subplot(1, 2, 2)\npl.title(\"OT between representants\")\npl.imshow(T_global, interpolation=\"nearest\", aspect=\"auto\")\npl.axis(\"off\")\npl.colorbar()\n\npl.tight_layout()\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }