{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Supervised graph classification with GCN\n" ] }, { "cell_type": "markdown", "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "source": [ "
Run the latest release of this notebook:
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook demonstrates how to train a graph classification model in a supervised setting using graph convolutional layers followed by a mean pooling layer as well as any number of fully connected layers.\n", "\n", "The graph convolutional classification model architecture is based on the one proposed in [1] (see Figure 5 in [1]) using the graph convolutional layers from [2]. This demo differs from [1] in the dataset, MUTAG, used here; MUTAG is a collection of static graphs representing chemical compounds with each graph associated with a binary label. Furthermore, none of the graph convolutional layers in our model utilise an attention head as proposed in [1].\n", "\n", "Evaluation data for graph kernel-based approaches shown in the very last cell in this notebook are taken from [3].\n", "\n", "**References**\n", "\n", "[1] Fake News Detection on Social Media using Geometric Deep Learning, F. Monti, F. Frasca, D. Eynard, D. Mannion, and M. M. Bronstein, ICLR 2019. ([link](https://arxiv.org/abs/1902.06673))\n", "\n", "[2] Semi-supervised Classification with Graph Convolutional Networks, T. N. Kipf and M. Welling, ICLR 2017. ([link](https://arxiv.org/abs/1609.02907))\n", "\n", "[3] An End-to-End Deep Learning Architecture for Graph Classification, M. Zhang, Z. Cui, M. Neumann, Y. Chen, AAAI-18. ([link](https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/viewPaper/17146))" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "outputs": [], "source": [ "# install StellarGraph if running on Google Colab\n", "import sys\n", "if 'google.colab' in sys.modules:\n", " %pip install -q stellargraph[demos]==1.2.1" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "nbsphinx": "hidden", "tags": [ "VersionCheck" ] }, "outputs": [], "source": [ "# verify that we're using the correct version of StellarGraph for this notebook\n", "import stellargraph as sg\n", "\n", "try:\n", " sg.utils.validate_notebook_version(\"1.2.1\")\n", "except AttributeError:\n", " raise ValueError(\n", " f\"This notebook requires StellarGraph version 1.2.1, but a different version {sg.__version__} is installed. Please see .\"\n", " ) from None" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "\n", "import stellargraph as sg\n", "from stellargraph.mapper import PaddedGraphGenerator\n", "from stellargraph.layer import GCNSupervisedGraphClassification\n", "from stellargraph import StellarGraph\n", "\n", "from stellargraph import datasets\n", "\n", "from sklearn import model_selection\n", "from IPython.display import display, HTML\n", "\n", "from tensorflow.keras import Model\n", "from tensorflow.keras.optimizers import Adam\n", "from tensorflow.keras.layers import Dense\n", "from tensorflow.keras.losses import binary_crossentropy\n", "from tensorflow.keras.callbacks import EarlyStopping\n", "import tensorflow as tf\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import the data" ] }, { "cell_type": "markdown", "metadata": { "tags": [ "DataLoadingLinks" ] }, "source": [ "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [ "DataLoading" ] }, "outputs": [ { "data": { "text/html": [ "Each graph represents a chemical compound and graph labels represent 'their mutagenic effect on a specific gram negative bacterium.'The dataset includes 188 graphs with 18 nodes and 20 edges on average for each graph. Graph nodes have 7 labels and each graph is labelled as belonging to 1 of 2 classes." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dataset = datasets.MUTAG()\n", "display(HTML(dataset.description))\n", "graphs, graph_labels = dataset.load()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `graphs` value is a list of many `StellarGraph` instances, each of which has a few node features:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StellarGraph: Undirected multigraph\n", " Nodes: 17, Edges: 38\n", "\n", " Node types:\n", " default: [17]\n", " Features: float32 vector, length 7\n", " Edge types: default-default->default\n", "\n", " Edge types:\n", " default-default->default: [38]\n", " Weights: all 1 (default)\n", " Features: none\n" ] } ], "source": [ "print(graphs[0].info())" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StellarGraph: Undirected multigraph\n", " Nodes: 13, Edges: 28\n", "\n", " Node types:\n", " default: [13]\n", " Features: float32 vector, length 7\n", " Edge types: default-default->default\n", "\n", " Edge types:\n", " default-default->default: [28]\n", " Weights: all 1 (default)\n", " Features: none\n" ] } ], "source": [ "print(graphs[1].info())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Summary statistics of the sizes of the graphs:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
nodesedges
count188.0188.0
mean17.939.6
std4.611.4
min10.020.0
25%14.028.0
50%17.538.0
75%22.050.0
max28.066.0
\n", "
" ], "text/plain": [ " nodes edges\n", "count 188.0 188.0\n", "mean 17.9 39.6\n", "std 4.6 11.4\n", "min 10.0 20.0\n", "25% 14.0 28.0\n", "50% 17.5 38.0\n", "75% 22.0 50.0\n", "max 28.0 66.0" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "summary = pd.DataFrame(\n", " [(g.number_of_nodes(), g.number_of_edges()) for g in graphs],\n", " columns=[\"nodes\", \"edges\"],\n", ")\n", "summary.describe().round(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The labels are `1` or `-1`:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
label
1125
-163
\n", "
" ], "text/plain": [ " label\n", "1 125\n", "-1 63" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "graph_labels.value_counts().to_frame()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "graph_labels = pd.get_dummies(graph_labels, drop_first=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare graph generator\n", "\n", "To feed data to the `tf.Keras` model that we will create later, we need a data generator. For supervised graph classification, we create an instance of `StellarGraph`'s `PaddedGraphGenerator` class. Note that `graphs` is a list of `StellarGraph` graph objects." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "generator = PaddedGraphGenerator(graphs=graphs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create the Keras graph classification model\n", "\n", "We are now ready to create a `tf.Keras` graph classification model using `StellarGraph`'s `GraphClassification` class together with standard `tf.Keras` layers, e.g., `Dense`. \n", "\n", "The input is the graph represented by its adjacency and node features matrices. The first two layers are Graph Convolutional as in [2] with each layer having 64 units and `relu` activations. The next layer is a mean pooling layer where the learned node representation are summarized to create a graph representation. The graph representation is input to two fully connected layers with 32 and 16 units respectively and `relu` activations. The last layer is the output layer with a single unit and `sigmoid` activation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](graph_classification_architecture.png)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "def create_graph_classification_model(generator):\n", " gc_model = GCNSupervisedGraphClassification(\n", " layer_sizes=[64, 64],\n", " activations=[\"relu\", \"relu\"],\n", " generator=generator,\n", " dropout=0.5,\n", " )\n", " x_inp, x_out = gc_model.in_out_tensors()\n", " predictions = Dense(units=32, activation=\"relu\")(x_out)\n", " predictions = Dense(units=16, activation=\"relu\")(predictions)\n", " predictions = Dense(units=1, activation=\"sigmoid\")(predictions)\n", "\n", " # Let's create the Keras model and prepare it for training\n", " model = Model(inputs=x_inp, outputs=predictions)\n", " model.compile(optimizer=Adam(0.005), loss=binary_crossentropy, metrics=[\"acc\"])\n", "\n", " return model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train the model\n", "\n", "We can now train the model using the model's `fit` method. First, we specify some important training parameters such as the number of training epochs, number of fold for cross validation and the number of time to repeat cross validation." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "epochs = 200 # maximum number of training epochs\n", "folds = 10 # the number of folds for k-fold cross validation\n", "n_repeats = 5 # the number of repeats for repeated k-fold cross validation" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "es = EarlyStopping(\n", " monitor=\"val_loss\", min_delta=0, patience=25, restore_best_weights=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The method `train_fold` is used to train a graph classification model for a single fold of the data." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def train_fold(model, train_gen, test_gen, es, epochs):\n", " history = model.fit(\n", " train_gen, epochs=epochs, validation_data=test_gen, verbose=0, callbacks=[es],\n", " )\n", " # calculate performance on the test data and return along with history\n", " test_metrics = model.evaluate(test_gen, verbose=0)\n", " test_acc = test_metrics[model.metrics_names.index(\"acc\")]\n", "\n", " return history, test_acc" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def get_generators(train_index, test_index, graph_labels, batch_size):\n", " train_gen = generator.flow(\n", " train_index, targets=graph_labels.iloc[train_index].values, batch_size=batch_size\n", " )\n", " test_gen = generator.flow(\n", " test_index, targets=graph_labels.iloc[test_index].values, batch_size=batch_size\n", " )\n", "\n", " return train_gen, test_gen" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The code below puts all the above functionality together in a training loop for repeated k-fold cross-validation where the number of folds is 10, `folds=10`; that is we do 10-fold cross validation `n_repeats` times where `n_repeats=5`.\n", "\n", "**Note**: The below code may take a long time to run depending on the value set for `n_repeats`. The larger the latter, the longer it takes since for each repeat we train and evaluate 10 graph classification models, one for each fold of the data. For progress updates, we recommend that you set `verbose=2` in the call to the `fit` method is cell 10, line 3." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training and evaluating on fold 1 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 2 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 3 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 4 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 5 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 6 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 7 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 8 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 9 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 10 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 11 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 12 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 13 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 14 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 15 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 16 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 17 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 18 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 19 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 20 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 21 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 22 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 23 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 24 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 25 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 26 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 27 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 28 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 29 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 30 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 31 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 32 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 33 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 34 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 35 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 36 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 37 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 38 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 39 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 40 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 41 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 42 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 43 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 44 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 45 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 46 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 47 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 48 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 49 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n", "Training and evaluating on fold 50 out of 50...\n", " ['...']\n", " ['...']\n", " ['...']\n" ] } ], "source": [ "test_accs = []\n", "\n", "stratified_folds = model_selection.RepeatedStratifiedKFold(\n", " n_splits=folds, n_repeats=n_repeats\n", ").split(graph_labels, graph_labels)\n", "\n", "for i, (train_index, test_index) in enumerate(stratified_folds):\n", " print(f\"Training and evaluating on fold {i+1} out of {folds * n_repeats}...\")\n", " train_gen, test_gen = get_generators(\n", " train_index, test_index, graph_labels, batch_size=30\n", " )\n", "\n", " model = create_graph_classification_model(generator)\n", "\n", " history, acc = train_fold(model, train_gen, test_gen, es, epochs)\n", "\n", " test_accs.append(acc)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Accuracy over all folds mean: 76.4% and std: 6.7%\n" ] } ], "source": [ "print(\n", " f\"Accuracy over all folds mean: {np.mean(test_accs)*100:.3}% and std: {np.std(test_accs)*100:.2}%\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we plot a histogram of the accuracy of all `n_repeats x folds` models trained (50 in total)." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'Count')" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(8, 6))\n", "plt.hist(test_accs)\n", "plt.xlabel(\"Accuracy\")\n", "plt.ylabel(\"Count\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The histogram shown above indicates the difficulty of training a good model on the MUTAG dataset due to the following factors,\n", "- small amount of available data, i.e., only 188 graphs\n", "- small amount of validation data since for a single fold only 19 graphs are used for validation\n", "- the data are unbalanced since the majority class is twice as prevalent in the data\n", "\n", "Given the above, average performance as estimated using repeated 10-fold cross validation displays high variance but overall good performance for a straightforward application of graph convolutional neural networks to supervised graph classification. The high variance is likely the result of the small dataset size.\n", "\n", "Generally, performance is a bit lower than SOTA in recent literature. However, we have not tuned the model for the best performance possible so some improvement over the current baseline may be attainable.\n", "\n", "When comparing to graph kernel-based approaches, our straightforward GCN with mean pooling graph classification model is competitive with the WL kernel being the exception.\n", "\n", "For comparison, some performance numbers repeated from [3] for graph kernel-based approaches are, \n", "- Graphlet Kernel (GK): $81.39\\pm1.74$\n", "- Random Walk Kernel (RW): $79.17\\pm2.07$\n", "- Propagation Kernel (PK): $76.00\\pm2.69$\n", "- Weisfeiler-Lehman Subtree Kernel (WL): $84.11\\pm1.91$" ] }, { "cell_type": "markdown", "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "source": [ "
Run the latest release of this notebook:
" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }