{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Optimizing the Gromov-Wasserstein distance with PyTorch\n\n

Note

Example added in release: 0.8.0.

\n\nIn this example, we use the pytorch backend to optimize the Gromov-Wasserstein\n(GW) loss between two graphs expressed as empirical distribution.\n\nIn the first part, we optimize the weights on the node of a simple template\ngraph so that it minimizes the GW with a given Stochastic Block Model graph.\nWe can see that this actually recovers the proportion of classes in the SBM\nand allows for an accurate clustering of the nodes using the GW optimal plan.\n\nIn the second part, we optimize simultaneously the weights and the structure of\nthe template graph which allows us to perform graph compression and to recover\nother properties of the SBM.\n\nThe backend actually uses the gradients expressed in [38] to optimize the\nweights.\n\n[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph\nDictionary Learning, International Conference on Machine Learning (ICML), 2021.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: R\u00e9mi Flamary \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 3\n\nfrom sklearn.manifold import MDS\nimport numpy as np\nimport matplotlib.pylab as pl\nimport torch\n\nimport ot\nfrom ot.gromov import gromov_wasserstein2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Graph generation\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "rng = np.random.RandomState(42)\n\n\ndef get_sbm(n, nc, ratio, P):\n nbpc = np.round(n * ratio).astype(int)\n n = np.sum(nbpc)\n C = np.zeros((n, n))\n for c1 in range(nc):\n for c2 in range(c1 + 1):\n if c1 == c2:\n for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[: c1 + 1])):\n for j in range(np.sum(nbpc[:c2]), i):\n if rng.rand() <= P[c1, c2]:\n C[i, j] = 1\n else:\n for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[: c1 + 1])):\n for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[: c2 + 1])):\n if rng.rand() <= P[c1, c2]:\n C[i, j] = 1\n\n return C + C.T\n\n\nn = 100\nnc = 3\nratio = np.array([0.5, 0.3, 0.2])\nP = np.array(0.6 * np.eye(3) + 0.05 * np.ones((3, 3)))\nC1 = get_sbm(n, nc, ratio, P)\n\n# get 2d position for nodes\nx1 = MDS(dissimilarity=\"precomputed\", random_state=0).fit_transform(1 - C1)\n\n\ndef plot_graph(x, C, color=\"C0\", s=None):\n for j in range(C.shape[0]):\n for i in range(j):\n if C[i, j] > 0:\n pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color=\"k\")\n pl.scatter(\n x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors=\"k\", cmap=\"tab10\", vmax=9\n )\n\n\npl.figure(1, (10, 5))\npl.clf()\npl.subplot(1, 2, 1)\nplot_graph(x1, C1, color=\"C0\")\npl.title(\"SBM Graph\")\npl.axis(\"off\")\npl.subplot(1, 2, 2)\npl.imshow(C1, interpolation=\"nearest\")\npl.title(\"Adjacency matrix\")\npl.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Optimizing GW w.r.t. the weights on a template structure\nThe adjacency matrix C1 is block diagonal with 3 blocks. We want to\noptimize the weights of a simple template C0=eye(3) and see if we can\nrecover the proportion of classes from the SBM (up to a permutation).\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "C0 = np.eye(3)\n\n\ndef min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2):\n \"\"\"solve min_a GW(C1,C2,a, a2) by gradient descent\"\"\"\n\n # use pyTorch for our data\n C1_torch = torch.tensor(C1)\n C2_torch = torch.tensor(C2)\n\n a0 = rng.rand(C1.shape[0]) # random_init\n a0 /= a0.sum() # on simplex\n a1_torch = torch.tensor(a0).requires_grad_(True)\n a2_torch = torch.tensor(a2)\n\n loss_iter = []\n\n for i in range(nb_iter_max):\n loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)\n\n loss_iter.append(loss.clone().detach().cpu().numpy())\n loss.backward()\n\n # print(\"{:03d} | {}\".format(i, loss_iter[-1]))\n\n # performs a step of projected gradient descent\n with torch.no_grad():\n grad = a1_torch.grad\n a1_torch -= grad * lr # step\n a1_torch.grad.zero_()\n a1_torch.data = ot.utils.proj_simplex(a1_torch)\n\n a1 = a1_torch.clone().detach().cpu().numpy()\n\n return a1, loss_iter\n\n\na0_est, loss_iter0 = min_weight_gw(C0, C1, ot.unif(n), nb_iter_max=100, lr=1e-2)\n\npl.figure(2)\npl.plot(loss_iter0)\npl.title(\"Loss along iterations\")\n\nprint(\"Estimated weights : \", a0_est)\nprint(\"True proportions : \", ratio)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is clear that the optimization has converged and that we recover the\nratio of the different classes in the SBM graph up to a permutation.\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Community clustering with uniform and estimated weights\nThe GW OT plan can be used to perform a clustering of the nodes of a graph\nwhen computing the GW with a simple template like C0 by labeling nodes in\nthe original graph using by the index of the node in the template receiving\nthe most mass.\n\nWe show here the result of such a clustering when using uniform weights on\nthe template C0 and when using the optimal weights previously estimated.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "T_unif = ot.gromov_wasserstein(C1, C0, ot.unif(n), ot.unif(3))\nlabel_unif = T_unif.argmax(1)\n\nT_est = ot.gromov_wasserstein(C1, C0, ot.unif(n), a0_est)\nlabel_est = T_est.argmax(1)\n\npl.figure(3, (10, 5))\npl.clf()\npl.subplot(1, 2, 1)\nplot_graph(x1, C1, color=label_unif)\npl.title(\"Graph clustering unif. weights\")\npl.axis(\"off\")\npl.subplot(1, 2, 2)\nplot_graph(x1, C1, color=label_est)\npl.title(\"Graph clustering est. weights\")\npl.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Graph compression with GW\nNow we optimize both the weights and structure of a small graph that\nminimize the GW distance wrt our data graph. This can be seen as graph\ncompression but can also recover important properties of an SBM such\nas its class proportion but also its matrix of probability of links between\nclasses\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def graph_compression_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2):\n \"\"\"solve min_a GW(C1,C2,a, a2) by gradient descent\"\"\"\n\n # use pyTorch for our data\n\n C2_torch = torch.tensor(C2)\n a2_torch = torch.tensor(a2)\n\n a0 = rng.rand(nb_nodes) # random_init\n a0 /= a0.sum() # on simplex\n a1_torch = torch.tensor(a0).requires_grad_(True)\n C0 = np.eye(nb_nodes)\n C1_torch = torch.tensor(C0).requires_grad_(True)\n\n loss_iter = []\n\n for i in range(nb_iter_max):\n loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch)\n\n loss_iter.append(loss.clone().detach().cpu().numpy())\n loss.backward()\n\n # print(\"{:03d} | {}\".format(i, loss_iter[-1]))\n\n # performs a step of projected gradient descent\n with torch.no_grad():\n grad = a1_torch.grad\n a1_torch -= grad * lr # step\n a1_torch.grad.zero_()\n a1_torch.data = ot.utils.proj_simplex(a1_torch)\n\n grad = C1_torch.grad\n C1_torch -= grad * lr # step\n C1_torch.grad.zero_()\n C1_torch.data = torch.clamp(C1_torch, 0, 1)\n\n a1 = a1_torch.clone().detach().cpu().numpy()\n C1 = C1_torch.clone().detach().cpu().numpy()\n\n return a1, C1, loss_iter\n\n\nnb_nodes = 3\na0_est2, C0_est2, loss_iter2 = graph_compression_gw(\n nb_nodes, C1, ot.unif(n), nb_iter_max=100, lr=5e-2\n)\n\npl.figure(4)\npl.plot(loss_iter2)\npl.title(\"Loss along iterations\")\n\n\nprint(\"Estimated weights : \", a0_est2)\nprint(\"True proportions : \", ratio)\n\npl.figure(6, (10, 3.5))\npl.clf()\npl.subplot(1, 2, 1)\npl.imshow(P, vmin=0, vmax=1)\npl.title(\"True SBM P matrix\")\npl.subplot(1, 2, 2)\npl.imshow(C0_est2, vmin=0, vmax=1)\npl.title(\"Estimated C0 matrix\")\npl.colorbar()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }