{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# (Fused) Gromov-Wasserstein Linear Dictionary Learning\n\n

Note

Example added in release: 0.8.2.

\n\nIn this example, we illustrate how to learn a Gromov-Wasserstein dictionary on\na dataset of structured data such as graphs, denoted\n$\\{ \\mathbf{C_s} \\}_{s \\in [S]}$ where every nodes have uniform weights.\nGiven a dictionary $\\mathbf{C_{dict}}$ composed of D structures of a fixed\nsize nt, each graph $(\\mathbf{C_s}, \\mathbf{p_s})$\nis modeled as a convex combination $\\mathbf{w_s} \\in \\Sigma_D$ of these\ndictionary atoms as $\\sum_d w_{s,d} \\mathbf{C_{dict}[d]}$.\n\n\nFirst, we consider a dataset composed of graphs generated by Stochastic Block models\nwith variable sizes taken in $\\{30, ... , 50\\}$ and quantities of clusters\nvarying in $\\{ 1, 2, 3\\}$. We learn a dictionary of 3 atoms, by minimizing\nthe Gromov-Wasserstein distance from all samples to its model in the dictionary\nwith respect to the dictionary atoms.\n\nSecond, we illustrate the extension of this dictionary learning framework to\nstructured data endowed with node features by using the Fused Gromov-Wasserstein\ndistance. Starting from the aforementioned dataset of unattributed graphs, we\nadd discrete labels uniformly depending on the number of clusters. Then we learn\nand visualize attributed graph atoms where each sample is modeled as a joint convex\ncombination between atom structures and features.\n\n\n[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph\nDictionary Learning, International Conference on Machine Learning (ICML), 2021.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: C\u00e9dric Vincent-Cuaz \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 4\n\nimport numpy as np\nimport matplotlib.pylab as pl\nfrom sklearn.manifold import MDS\nfrom ot.gromov import (\n gromov_wasserstein_linear_unmixing,\n gromov_wasserstein_dictionary_learning,\n fused_gromov_wasserstein_linear_unmixing,\n fused_gromov_wasserstein_dictionary_learning,\n)\nimport ot\nimport networkx\nfrom networkx.generators.community import stochastic_block_model as sbm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "np.random.seed(42)\n\nN = 60 # number of graphs in the dataset\n# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability.\nclusters = [1, 2, 3]\nNc = N // len(clusters) # number of graphs by cluster\nnlabels = len(clusters)\ndataset = []\nlabels = []\n\np_inter = 0.1\np_intra = 0.9\nfor n_cluster in clusters:\n for i in range(Nc):\n n_nodes = int(np.random.uniform(low=30, high=50))\n\n if n_cluster > 1:\n P = p_inter * np.ones((n_cluster, n_cluster))\n np.fill_diagonal(P, p_intra)\n else:\n P = p_intra * np.eye(1)\n sizes = np.round(n_nodes * np.ones(n_cluster) / n_cluster).astype(np.int32)\n G = sbm(sizes, P, seed=i, directed=False)\n C = networkx.to_numpy_array(G)\n dataset.append(C)\n labels.append(n_cluster)\n\n\n# Visualize samples\n\n\ndef plot_graph(x, C, binary=True, color=\"C0\", s=None):\n for j in range(C.shape[0]):\n for i in range(j):\n if binary:\n if C[i, j] > 0:\n pl.plot(\n [x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color=\"k\"\n )\n else: # connection intensity proportional to C[i,j]\n pl.plot(\n [x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color=\"k\"\n )\n\n pl.scatter(\n x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors=\"k\", cmap=\"tab10\", vmax=9\n )\n\n\npl.figure(1, (12, 8))\npl.clf()\nfor idx_c, c in enumerate(clusters):\n C = dataset[(c - 1) * Nc] # sample with c clusters\n # get 2d position for nodes\n x = MDS(dissimilarity=\"precomputed\", random_state=0).fit_transform(1 - C)\n pl.subplot(2, nlabels, c)\n pl.title(\"(graph) sample from label \" + str(c), fontsize=14)\n plot_graph(x, C, binary=True, color=\"C0\", s=50.0)\n pl.axis(\"off\")\n pl.subplot(2, nlabels, nlabels + c)\n pl.title(\"(matrix) sample from label %s \\n\" % c, fontsize=14)\n pl.imshow(C, interpolation=\"nearest\")\n pl.axis(\"off\")\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimate the Gromov-Wasserstein dictionary from the dataset\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "np.random.seed(0)\nps = [ot.unif(C.shape[0]) for C in dataset]\n\nD = 3 # 3 atoms in the dictionary\nnt = 6 # of 6 nodes each\n\nq = ot.unif(nt)\nreg = 0.0 # regularization coefficient to promote sparsity of unmixings {w_s}\n\nCdict_GW, log = gromov_wasserstein_dictionary_learning(\n Cs=dataset,\n D=D,\n nt=nt,\n ps=ps,\n q=q,\n epochs=10,\n batch_size=16,\n learning_rate=0.1,\n reg=reg,\n projection=\"nonnegative_symmetric\",\n tol_outer=10 ** (-5),\n tol_inner=10 ** (-5),\n max_iter_outer=30,\n max_iter_inner=300,\n use_log=True,\n use_adam_optimizer=True,\n verbose=True,\n)\n# visualize loss evolution over epochs\npl.figure(2, (4, 3))\npl.clf()\npl.title(\"loss evolution by epoch\", fontsize=14)\npl.plot(log[\"loss_epochs\"])\npl.xlabel(\"epochs\", fontsize=12)\npl.ylabel(\"loss\", fontsize=12)\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization of the estimated dictionary atoms\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Continuous connections between nodes of the atoms are colored in shades of grey (1: dark / 2: white)\n\npl.figure(3, (12, 8))\npl.clf()\nfor idx_atom, atom in enumerate(Cdict_GW):\n scaled_atom = (atom - atom.min()) / (atom.max() - atom.min())\n x = MDS(dissimilarity=\"precomputed\", random_state=0).fit_transform(1 - scaled_atom)\n pl.subplot(2, D, idx_atom + 1)\n pl.title(\"(graph) atom \" + str(idx_atom + 1), fontsize=14)\n plot_graph(x, atom / atom.max(), binary=False, color=\"C0\", s=100.0)\n pl.axis(\"off\")\n pl.subplot(2, D, D + idx_atom + 1)\n pl.title(\"(matrix) atom %s \\n\" % (idx_atom + 1), fontsize=14)\n pl.imshow(scaled_atom, interpolation=\"nearest\")\n pl.colorbar()\n pl.axis(\"off\")\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization of the embedding space\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "unmixings = []\nreconstruction_errors = []\nfor C in dataset:\n p = ot.unif(C.shape[0])\n unmixing, Cembedded, OT, reconstruction_error = gromov_wasserstein_linear_unmixing(\n C,\n Cdict_GW,\n p=p,\n q=q,\n reg=reg,\n tol_outer=10 ** (-5),\n tol_inner=10 ** (-5),\n max_iter_outer=30,\n max_iter_inner=300,\n )\n unmixings.append(unmixing)\n reconstruction_errors.append(reconstruction_error)\nunmixings = np.array(unmixings)\nprint(\"cumulated reconstruction error:\", np.array(reconstruction_errors).sum())\n\n\n# Compute the 2D representation of the unmixing living in the 2-simplex of probability\nunmixings2D = np.zeros(shape=(N, 2))\nfor i, w in enumerate(unmixings):\n unmixings2D[i, 0] = (2.0 * w[1] + w[2]) / 2.0\n unmixings2D[i, 1] = (np.sqrt(3.0) * w[2]) / 2.0\nx = [0.0, 0.0]\ny = [1.0, 0.0]\nz = [0.5, np.sqrt(3) / 2.0]\nextremities = np.stack([x, y, z])\n\npl.figure(4, (4, 4))\npl.clf()\npl.title(\"Embedding space\", fontsize=14)\nfor cluster in range(nlabels):\n start, end = Nc * cluster, Nc * (cluster + 1)\n if cluster == 0:\n pl.scatter(\n unmixings2D[start:end, 0],\n unmixings2D[start:end, 1],\n c=\"C\" + str(cluster),\n marker=\"o\",\n s=40.0,\n label=\"1 cluster\",\n )\n else:\n pl.scatter(\n unmixings2D[start:end, 0],\n unmixings2D[start:end, 1],\n c=\"C\" + str(cluster),\n marker=\"o\",\n s=40.0,\n label=\"%s clusters\" % (cluster + 1),\n )\npl.scatter(\n extremities[:, 0], extremities[:, 1], c=\"black\", marker=\"x\", s=80.0, label=\"atoms\"\n)\npl.plot([x[0], y[0]], [x[1], y[1]], color=\"black\", linewidth=2.0)\npl.plot([x[0], z[0]], [x[1], z[1]], color=\"black\", linewidth=2.0)\npl.plot([y[0], z[0]], [y[1], z[1]], color=\"black\", linewidth=2.0)\npl.axis(\"off\")\npl.legend(fontsize=11)\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Endow the dataset with node features\nWe follow this feature assignment on all nodes of a graph depending on its label/number of clusters\n1 cluster --> 0 as nodes feature\n2 clusters --> 1 as nodes feature\n3 clusters --> 2 as nodes feature\nfeatures are one-hot encoded following these assignments\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "dataset_features = []\nfor i in range(len(dataset)):\n n = dataset[i].shape[0]\n F = np.zeros((n, 3))\n if i < Nc: # graph with 1 cluster\n F[:, 0] = 1.0\n elif i < 2 * Nc: # graph with 2 clusters\n F[:, 1] = 1.0\n else: # graph with 3 clusters\n F[:, 2] = 1.0\n dataset_features.append(F)\n\npl.figure(5, (12, 8))\npl.clf()\nfor idx_c, c in enumerate(clusters):\n C = dataset[(c - 1) * Nc] # sample with c clusters\n F = dataset_features[(c - 1) * Nc]\n colors = [\"C\" + str(np.argmax(F[i])) for i in range(F.shape[0])]\n # get 2d position for nodes\n x = MDS(dissimilarity=\"precomputed\", random_state=0).fit_transform(1 - C)\n pl.subplot(2, nlabels, c)\n pl.title(\"(graph) sample from label \" + str(c), fontsize=14)\n plot_graph(x, C, binary=True, color=colors, s=50)\n pl.axis(\"off\")\n pl.subplot(2, nlabels, nlabels + c)\n pl.title(\"(matrix) sample from label %s \\n\" % c, fontsize=14)\n pl.imshow(C, interpolation=\"nearest\")\n pl.axis(\"off\")\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimate a Fused Gromov-Wasserstein dictionary from the dataset of attributed graphs\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "np.random.seed(0)\nps = [ot.unif(C.shape[0]) for C in dataset]\nD = 3 # 6 atoms instead of 3\nnt = 6\nq = ot.unif(nt)\nreg = 0.001\nalpha = 0.5 # trade-off parameter between structure and feature information of Fused Gromov-Wasserstein\n\n\nCdict_FGW, Ydict_FGW, log = fused_gromov_wasserstein_dictionary_learning(\n Cs=dataset,\n Ys=dataset_features,\n D=D,\n nt=nt,\n ps=ps,\n q=q,\n alpha=alpha,\n epochs=10,\n batch_size=16,\n learning_rate_C=0.1,\n learning_rate_Y=0.1,\n reg=reg,\n tol_outer=10 ** (-5),\n tol_inner=10 ** (-5),\n max_iter_outer=30,\n max_iter_inner=300,\n projection=\"nonnegative_symmetric\",\n use_log=True,\n use_adam_optimizer=True,\n verbose=True,\n)\n# visualize loss evolution\npl.figure(6, (4, 3))\npl.clf()\npl.title(\"loss evolution by epoch\", fontsize=14)\npl.plot(log[\"loss_epochs\"])\npl.xlabel(\"epochs\", fontsize=12)\npl.ylabel(\"loss\", fontsize=12)\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization of the estimated dictionary atoms\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "pl.figure(7, (12, 8))\npl.clf()\nmax_features = Ydict_FGW.max()\nmin_features = Ydict_FGW.min()\n\nfor idx_atom, (Catom, Fatom) in enumerate(zip(Cdict_FGW, Ydict_FGW)):\n scaled_atom = (Catom - Catom.min()) / (Catom.max() - Catom.min())\n # scaled_F = 2 * (Fatom - min_features) / (max_features - min_features)\n colors = [\"C%s\" % np.argmax(Fatom[i]) for i in range(Fatom.shape[0])]\n x = MDS(dissimilarity=\"precomputed\", random_state=0).fit_transform(1 - scaled_atom)\n pl.subplot(2, D, idx_atom + 1)\n pl.title(\"(attributed graph) atom \" + str(idx_atom + 1), fontsize=14)\n plot_graph(x, Catom / Catom.max(), binary=False, color=colors, s=100)\n pl.axis(\"off\")\n pl.subplot(2, D, D + idx_atom + 1)\n pl.title(\"(matrix) atom %s \\n\" % (idx_atom + 1), fontsize=14)\n pl.imshow(scaled_atom, interpolation=\"nearest\")\n pl.colorbar()\n pl.axis(\"off\")\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization of the embedding space\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "unmixings = []\nreconstruction_errors = []\nfor i in range(len(dataset)):\n C = dataset[i]\n Y = dataset_features[i]\n p = ot.unif(C.shape[0])\n unmixing, Cembedded, Yembedded, OT, reconstruction_error = (\n fused_gromov_wasserstein_linear_unmixing(\n C,\n Y,\n Cdict_FGW,\n Ydict_FGW,\n p=p,\n q=q,\n alpha=alpha,\n reg=reg,\n tol_outer=10 ** (-6),\n tol_inner=10 ** (-6),\n max_iter_outer=30,\n max_iter_inner=300,\n )\n )\n unmixings.append(unmixing)\n reconstruction_errors.append(reconstruction_error)\nunmixings = np.array(unmixings)\nprint(\"cumulated reconstruction error:\", np.array(reconstruction_errors).sum())\n\n# Visualize unmixings in the 2-simplex of probability\nunmixings2D = np.zeros(shape=(N, 2))\nfor i, w in enumerate(unmixings):\n unmixings2D[i, 0] = (2.0 * w[1] + w[2]) / 2.0\n unmixings2D[i, 1] = (np.sqrt(3.0) * w[2]) / 2.0\nx = [0.0, 0.0]\ny = [1.0, 0.0]\nz = [0.5, np.sqrt(3) / 2.0]\nextremities = np.stack([x, y, z])\n\npl.figure(8, (4, 4))\npl.clf()\npl.title(\"Embedding space\", fontsize=14)\nfor cluster in range(nlabels):\n start, end = Nc * cluster, Nc * (cluster + 1)\n if cluster == 0:\n pl.scatter(\n unmixings2D[start:end, 0],\n unmixings2D[start:end, 1],\n c=\"C\" + str(cluster),\n marker=\"o\",\n s=40.0,\n label=\"1 cluster\",\n )\n else:\n pl.scatter(\n unmixings2D[start:end, 0],\n unmixings2D[start:end, 1],\n c=\"C\" + str(cluster),\n marker=\"o\",\n s=40.0,\n label=\"%s clusters\" % (cluster + 1),\n )\n\npl.scatter(\n extremities[:, 0], extremities[:, 1], c=\"black\", marker=\"x\", s=80.0, label=\"atoms\"\n)\npl.plot([x[0], y[0]], [x[1], y[1]], color=\"black\", linewidth=2.0)\npl.plot([x[0], z[0]], [x[1], z[1]], color=\"black\", linewidth=2.0)\npl.plot([y[0], z[0]], [y[1], z[1]], color=\"black\", linewidth=2.0)\npl.axis(\"off\")\npl.legend(fontsize=11)\npl.tight_layout()\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }