{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Semi-relaxed (Fused) Gromov-Wasserstein Barycenter as Dictionary Learning\n\n

Note

Example added in release: 0.9.5.

\n\nIn this example, we illustrate how to learn a semi-relaxed Gromov-Wasserstein\n(srGW) barycenter using a Block-Coordinate Descent algorithm, on a dataset of\nstructured data such as graphs, denoted $\\{ \\mathbf{C_s} \\}_{s \\in [S]}$\nwhere every nodes have uniform weights $\\{ \\mathbf{p_s} \\}_{s \\in [S]}$.\nGiven a barycenter structure matrix $\\mathbf{C}$ with N nodes,\neach graph $(\\mathbf{C_s}, \\mathbf{p_s})$ is modeled as a reweighed subgraph\nwith structure $\\mathbf{C}$ and weights $\\mathbf{w_s} \\in \\Sigma_N$\nwhere each $\\mathbf{w_s}$ corresponds to the second marginal of the OT\n$\\mathbf{T_s}$ (s.t $\\mathbf{w_s} = \\mathbf{T_s}^\\top \\mathbf{1}$)\nminimizing the srGW loss between the s^{th} input and the barycenter.\n\n\nFirst, we consider a dataset composed of graphs generated by Stochastic Block models\nwith variable sizes taken in $\\{30, ... , 50\\}$ and number of clusters\nvarying in $\\{ 1, 2, 3\\}$ with random proportions. We learn a srGW barycenter\nwith 3 nodes and visualize the learned structure and the embeddings for some inputs.\n\nSecond, we illustrate the extension of this framework to graphs endowed\nwith node features by using the semi-relaxed Fused Gromov-Wasserstein\ndivergence (srFGW). Starting from the aforementioned dataset of unattributed graphs, we\nadd discrete labels uniformly depending on the number of clusters. Then conduct\nthe analog analysis.\n\n\n[48] C\u00e9dric Vincent-Cuaz, R\u00e9mi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.\n\"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs\".\nInternational Conference on Learning Representations (ICLR), 2022.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: C\u00e9dric Vincent-Cuaz \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pylab as pl\nfrom sklearn.manifold import MDS\nfrom ot.gromov import semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters\nimport ot\nimport networkx\nfrom networkx.generators.community import stochastic_block_model as sbm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "np.random.seed(42)\n\nn_samples = 60 # number of graphs in the dataset\n# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability,\n# and variable cluster proportions.\nclusters = [1, 2, 3]\nNc = n_samples // len(clusters) # number of graphs by cluster\nnlabels = len(clusters)\ndataset = []\nnode_labels = []\nlabels = []\n\np_inter = 0.1\np_intra = 0.9\nfor n_cluster in clusters:\n for i in range(Nc):\n n_nodes = int(np.random.uniform(low=30, high=50))\n\n if n_cluster > 1:\n P = p_inter * np.ones((n_cluster, n_cluster))\n np.fill_diagonal(P, p_intra)\n props = np.random.uniform(0.2, 1, size=(n_cluster,))\n props /= props.sum()\n sizes = np.round(n_nodes * props).astype(np.int32)\n else:\n P = p_intra * np.eye(1)\n sizes = [n_nodes]\n\n G = sbm(sizes, P, seed=i, directed=False)\n part = np.array([G.nodes[i][\"block\"] for i in range(np.sum(sizes))])\n C = networkx.to_numpy_array(G)\n dataset.append(C)\n node_labels.append(part)\n labels.append(n_cluster)\n\n\n# Visualize samples\n\n\ndef plot_graph(x, C, binary=True, color=\"C0\", s=None):\n for j in range(C.shape[0]):\n for i in range(j):\n if binary:\n if C[i, j] > 0:\n pl.plot(\n [x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color=\"k\"\n )\n else: # connection intensity proportional to C[i,j]\n pl.plot(\n [x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color=\"k\"\n )\n\n pl.scatter(\n x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors=\"k\", cmap=\"tab10\", vmax=9\n )\n\n\npl.figure(1, (12, 8))\npl.clf()\nfor idx_c, c in enumerate(clusters):\n C = dataset[(c - 1) * Nc] # sample with c clusters\n # get 2d position for nodes\n x = MDS(dissimilarity=\"precomputed\", random_state=0).fit_transform(1 - C)\n pl.subplot(2, nlabels, c)\n pl.title(\"(graph) sample from label \" + str(c), fontsize=14)\n plot_graph(x, C, binary=True, color=\"C0\", s=50.0)\n pl.axis(\"off\")\n pl.subplot(2, nlabels, nlabels + c)\n pl.title(\"(matrix) sample from label %s \\n\" % c, fontsize=14)\n pl.imshow(C, interpolation=\"nearest\")\n pl.axis(\"off\")\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimate the srGW barycenter from the dataset and visualize embeddings\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "np.random.seed(0)\nps = [ot.unif(C.shape[0]) for C in dataset] # uniform weights on input nodes\nlambdas = [1.0 / n_samples for _ in range(n_samples)] # uniform barycenter\nN = 3 # 3 nodes in the barycenter\n\n# Here we use the Fluid partitioning method to deduce initial transport plans\n# for the barycenter problem. An initlal structure is also deduced from these\n# initial transport plans. Then a warmstart strategy is used iteratively to\n# init each individual srGW problem within the BCD algorithm.\n\ninit_plan = \"fluid\" # notice that several init options are implemented in `ot.gromov.semirelaxed_init_plan`\nwarmstartT = True\n\nC, log = semirelaxed_gromov_barycenters(\n N=N,\n Cs=dataset,\n ps=ps,\n lambdas=lambdas,\n loss_fun=\"square_loss\",\n tol=1e-6,\n stop_criterion=\"loss\",\n warmstartT=warmstartT,\n log=True,\n G0=init_plan,\n verbose=False,\n)\n\nprint(\"barycenter structure:\", C)\n\nunmixings = log[\"p\"]\n# Compute the 2D representation of the embeddings living in the 2-simplex of probability\nunmixings2D = np.zeros(shape=(n_samples, 2))\nfor i, w in enumerate(unmixings):\n unmixings2D[i, 0] = (2.0 * w[1] + w[2]) / 2.0\n unmixings2D[i, 1] = (np.sqrt(3.0) * w[2]) / 2.0\nx = [0.0, 0.0]\ny = [1.0, 0.0]\nz = [0.5, np.sqrt(3) / 2.0]\nextremities = np.stack([x, y, z])\n\npl.figure(2, (4, 4))\npl.clf()\npl.title(\"Embedding space\", fontsize=14)\nfor cluster in range(nlabels):\n start, end = Nc * cluster, Nc * (cluster + 1)\n if cluster == 0:\n pl.scatter(\n unmixings2D[start:end, 0],\n unmixings2D[start:end, 1],\n c=\"C\" + str(cluster),\n marker=\"o\",\n s=80.0,\n label=\"1 cluster\",\n )\n else:\n pl.scatter(\n unmixings2D[start:end, 0],\n unmixings2D[start:end, 1],\n c=\"C\" + str(cluster),\n marker=\"o\",\n s=80.0,\n label=\"%s clusters\" % (cluster + 1),\n )\npl.scatter(\n extremities[:, 0],\n extremities[:, 1],\n c=\"black\",\n marker=\"x\",\n s=100.0,\n label=\"bary. nodes\",\n)\npl.plot([x[0], y[0]], [x[1], y[1]], color=\"black\", linewidth=2.0)\npl.plot([x[0], z[0]], [x[1], z[1]], color=\"black\", linewidth=2.0)\npl.plot([y[0], z[0]], [y[1], z[1]], color=\"black\", linewidth=2.0)\npl.axis(\"off\")\npl.legend(fontsize=11)\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Endow the dataset with node features\nnode labels, corresponding to the true SBM cluster assignments,\nare set for each graph as one-hot encoded node features.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "dataset_features = []\nfor i in range(len(dataset)):\n n = dataset[i].shape[0]\n F = np.zeros((n, 3))\n F[np.arange(n), node_labels[i]] = 1.0\n dataset_features.append(F)\n\npl.figure(3, (12, 8))\npl.clf()\nfor idx_c, c in enumerate(clusters):\n C = dataset[(c - 1) * Nc] # sample with c clusters\n F = dataset_features[(c - 1) * Nc]\n colors = [f\"C{labels[i]}\" for i in range(F.shape[0])]\n # get 2d position for nodes\n x = MDS(dissimilarity=\"precomputed\", random_state=0).fit_transform(1 - C)\n pl.subplot(2, nlabels, c)\n pl.title(\"(graph) sample from label \" + str(c), fontsize=14)\n plot_graph(x, C, binary=True, color=colors, s=50)\n pl.axis(\"off\")\n pl.subplot(2, nlabels, nlabels + c)\n pl.title(\"(matrix) sample from label %s \\n\" % c, fontsize=14)\n pl.imshow(C, interpolation=\"nearest\")\n pl.axis(\"off\")\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimate the srFGW barycenter from the attributed graphs and visualize embeddings\nWe emphasize the dependence to the trade-off parameter alpha that weights the\nrelative importance between structures (alpha=1) and features (alpha=0),\nknowing that embeddings that perfectly cluster graphs w.r.t their features\nshould ease the identification of the number of clusters in the graphs.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "list_alphas = [0.0001, 0.5, 0.9999]\nlist_unmixings2D = []\n\nfor ialpha, alpha in enumerate(list_alphas):\n print(\"--- alpha:\", alpha)\n C, F, log = semirelaxed_fgw_barycenters(\n N=N,\n Ys=dataset_features,\n Cs=dataset,\n ps=ps,\n lambdas=lambdas,\n alpha=alpha,\n loss_fun=\"square_loss\",\n tol=1e-6,\n stop_criterion=\"loss\",\n warmstartT=warmstartT,\n log=True,\n G0=init_plan,\n )\n\n print(\"barycenter structure:\", C)\n print(\"barycenter features:\", F)\n\n unmixings = log[\"p\"]\n # Compute the 2D representation of the embeddings living in the 2-simplex of probability\n unmixings2D = np.zeros(shape=(n_samples, 2))\n for i, w in enumerate(unmixings):\n unmixings2D[i, 0] = (2.0 * w[1] + w[2]) / 2.0\n unmixings2D[i, 1] = (np.sqrt(3.0) * w[2]) / 2.0\n list_unmixings2D.append(unmixings2D.copy())\n\nx = [0.0, 0.0]\ny = [1.0, 0.0]\nz = [0.5, np.sqrt(3) / 2.0]\nextremities = np.stack([x, y, z])\n\npl.figure(4, (12, 4))\npl.clf()\npl.suptitle(\"Embedding spaces\", fontsize=14)\nfor ialpha, alpha in enumerate(list_alphas):\n pl.subplot(1, len(list_alphas), ialpha + 1)\n pl.title(f\"alpha = {alpha}\", fontsize=14)\n for cluster in range(nlabels):\n start, end = Nc * cluster, Nc * (cluster + 1)\n if cluster == 0:\n pl.scatter(\n list_unmixings2D[ialpha][start:end, 0],\n list_unmixings2D[ialpha][start:end, 1],\n c=\"C\" + str(cluster),\n marker=\"o\",\n s=80.0,\n label=\"1 cluster\",\n )\n else:\n pl.scatter(\n list_unmixings2D[ialpha][start:end, 0],\n list_unmixings2D[ialpha][start:end, 1],\n c=\"C\" + str(cluster),\n marker=\"o\",\n s=80.0,\n label=\"%s clusters\" % (cluster + 1),\n )\n pl.scatter(\n extremities[:, 0],\n extremities[:, 1],\n c=\"black\",\n marker=\"x\",\n s=100.0,\n label=\"bary. nodes\",\n )\n pl.plot([x[0], y[0]], [x[1], y[1]], color=\"black\", linewidth=2.0)\n pl.plot([x[0], z[0]], [x[1], z[1]], color=\"black\", linewidth=2.0)\n pl.plot([y[0], z[0]], [y[1], z[1]], color=\"black\", linewidth=2.0)\n pl.axis(\"off\")\n pl.legend(fontsize=11)\npl.tight_layout()\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }