{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Graph classification with Template Based Fused Gromov Wasserstein\n\n

Note

Example added in release: 0.9.1.

\n\nThis example first illustrates how to train a graph classification gnn based on the Template Fused Gromov Wasserstein layer as proposed in [52] .\n\n[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022).Template based graph neural network with optimal transport distances. Advances in Neural Information Processing Systems, 35.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Author: Sonia Mazelet \n# R\u00e9mi Flamary \n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as pl\nimport torch\nimport networkx as nx\nfrom torch.utils.data import random_split\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.utils import to_networkx, one_hot\nfrom torch_geometric.utils import stochastic_blockmodel_graph as sbm\nfrom torch_geometric.data import Data as GraphData\nimport torch.nn as nn\nfrom torch_geometric.nn import Linear, GCNConv\nfrom ot.gnn import TFGWPooling\nfrom sklearn.manifold import TSNE" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# parameters\n\n# We create 2 classes of stochastic block models (SBM) graphs with 1 block and 2 blocks respectively.\n\ntorch.manual_seed(0)\n\nn_graphs = 50\nn_nodes = 10\nn_node_classes = 2\n\n# edge probabilities for the SBMs\nP1 = [[0.8]]\nP2 = [[0.9, 0.1], [0.1, 0.9]]\n\n# block sizes\nblock_sizes1 = [n_nodes]\nblock_sizes2 = [n_nodes // 2, n_nodes // 2]\n\n# node features\nx1 = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])\nx1 = one_hot(x1, num_classes=n_node_classes)\nx1 = torch.reshape(x1, (n_nodes, n_node_classes))\n\nx2 = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])\nx2 = one_hot(x2, num_classes=n_node_classes)\nx2 = torch.reshape(x2, (n_nodes, n_node_classes))\n\ngraphs1 = [\n GraphData(x=x1, edge_index=sbm(block_sizes1, P1), y=torch.tensor([0]))\n for i in range(n_graphs)\n]\ngraphs2 = [\n GraphData(x=x2, edge_index=sbm(block_sizes2, P2), y=torch.tensor([1]))\n for i in range(n_graphs)\n]\n\ngraphs = graphs1 + graphs2\n\n# split the data into train and test sets\ntrain_graphs, test_graphs = random_split(graphs, [n_graphs, n_graphs])\n\ntrain_loader = DataLoader(train_graphs, batch_size=10, shuffle=True)\ntest_loader = DataLoader(test_graphs, batch_size=10, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot data\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# plot one graph of each class\n\nfontsize = 10\n\npl.figure(0, figsize=(8, 2.5))\npl.clf()\npl.subplot(121)\npl.axis(\"off\")\npl.title(\"Graph of class 1\", fontsize=fontsize)\nG = to_networkx(graphs1[0], to_undirected=True)\npos = nx.spring_layout(G, seed=0)\nnx.draw_networkx(G, pos, with_labels=False, node_color=\"tab:blue\")\n\npl.subplot(122)\npl.axis(\"off\")\npl.title(\"Graph of class 2\", fontsize=fontsize)\nG = to_networkx(graphs2[0], to_undirected=True)\npos = nx.spring_layout(G, seed=0)\nnx.draw_networkx(\n G, pos, with_labels=False, nodelist=[0, 1, 2, 3, 4], node_color=\"tab:blue\"\n)\nnx.draw_networkx(\n G, pos, with_labels=False, nodelist=[5, 6, 7, 8, 9], node_color=\"tab:red\"\n)\n\npl.tight_layout()\npl.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pooling architecture using the TFGW layer\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class pooling_TFGW(nn.Module):\n \"\"\"\n Pooling architecture using the TFGW layer.\n \"\"\"\n\n def __init__(\n self,\n n_features,\n n_templates,\n n_template_nodes,\n n_classes,\n n_hidden_layers,\n feature_init_mean=0.0,\n feature_init_std=1.0,\n ):\n \"\"\"\n Pooling architecture using the TFGW layer.\n \"\"\"\n super().__init__()\n\n self.n_templates = n_templates\n self.n_template_nodes = n_template_nodes\n self.n_hidden_layers = n_hidden_layers\n self.n_features = n_features\n\n self.conv = GCNConv(self.n_features, self.n_hidden_layers)\n\n self.TFGW = TFGWPooling(\n self.n_hidden_layers,\n self.n_templates,\n self.n_template_nodes,\n feature_init_mean=feature_init_mean,\n feature_init_std=feature_init_std,\n )\n\n self.linear = Linear(self.n_templates, n_classes)\n\n def forward(self, x, edge_index, batch=None):\n x = self.conv(x, edge_index)\n\n x = self.TFGW(x, edge_index, batch)\n\n x_latent = x # save latent embeddings for visualization\n\n x = self.linear(x)\n\n return x, x_latent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Graph classification training\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_epochs = 25\n\n# store latent embeddings and classes for TSNE visualization\nembeddings_for_TSNE = []\nclasses = []\n\nmodel = pooling_TFGW(\n n_features=2,\n n_templates=2,\n n_template_nodes=2,\n n_classes=2,\n n_hidden_layers=2,\n feature_init_mean=0.5,\n feature_init_std=0.5,\n)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0005)\ncriterion = torch.nn.CrossEntropyLoss()\n\nall_accuracy = []\nall_loss = []\n\nfor epoch in range(n_epochs):\n losses = []\n accs = []\n\n for data in train_loader:\n out, latent_embedding = model(data.x, data.edge_index, data.batch)\n loss = criterion(out, data.y)\n loss.backward()\n optimizer.step()\n\n pred = out.argmax(dim=1)\n train_correct = pred == data.y\n train_acc = int(train_correct.sum()) / len(data)\n\n accs.append(train_acc)\n losses.append(loss.item())\n\n # store last classes and embeddings for TSNE visualization\n if epoch == n_epochs - 1:\n embeddings_for_TSNE.append(latent_embedding)\n classes.append(data.y)\n\n print(\n f\"Epoch: {epoch:03d}, Loss: {torch.mean(torch.tensor(losses)):.4f},Train Accuracy: {torch.mean(torch.tensor(accs)):.4f}\"\n )\n\n all_accuracy.append(torch.mean(torch.tensor(accs)))\n all_loss.append(torch.mean(torch.tensor(losses)))\n\n\npl.figure(1, figsize=(8, 2.5))\npl.clf()\npl.subplot(121)\npl.plot(all_loss)\npl.xlabel(\"epochs\")\npl.title(\"Loss\")\n\npl.subplot(122)\npl.plot(all_accuracy)\npl.xlabel(\"epochs\")\npl.title(\"Accuracy\")\n\npl.tight_layout()\npl.show()\n\n# Test\n\ntest_accs = []\n\nfor data in test_loader:\n out, latent_embedding = model(data.x, data.edge_index, data.batch)\n pred = out.argmax(dim=1)\n test_correct = pred == data.y\n test_acc = int(test_correct.sum()) / len(data)\n test_accs.append(test_acc)\n embeddings_for_TSNE.append(latent_embedding)\n classes.append(data.y)\n\nclasses = torch.hstack(classes)\n\nprint(f\"Test Accuracy: {torch.mean(torch.tensor(test_acc)):.4f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#############################################################################\n TSNE visualization of graph classification\n ---------\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "indices = torch.randint(\n 2 * n_graphs, (60,)\n) # select a subset of embeddings for TSNE visualization\nlatent_embeddings = torch.vstack(embeddings_for_TSNE).detach().numpy()[indices, :]\n\nTSNE_embeddings = TSNE(n_components=2, perplexity=20, random_state=1).fit_transform(\n latent_embeddings\n)\n\nclass_0 = classes[indices] == 0\nclass_1 = classes[indices] == 1\n\nTSNE_embeddings_0 = TSNE_embeddings[class_0, :]\nTSNE_embeddings_1 = TSNE_embeddings[class_1, :]\n\npl.figure(2, figsize=(6, 2.5))\npl.scatter(\n TSNE_embeddings_0[:, 0],\n TSNE_embeddings_0[:, 1],\n alpha=0.5,\n marker=\"o\",\n label=\"class 1\",\n)\npl.scatter(\n TSNE_embeddings_1[:, 0],\n TSNE_embeddings_1[:, 1],\n alpha=0.5,\n marker=\"o\",\n label=\"class 2\",\n)\npl.legend()\npl.title(\"TSNE in the latent space after training\")\npl.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }