{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Supervised graph classification with Deep Graph CNN" ] }, { "cell_type": "markdown", "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "source": [ "
Run the latest release of this notebook:
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook demonstrates how to train a graph classification model in a supervised setting using the Deep Graph Convolutional Neural Network (DGCNN) [1] algorithm.\n", "\n", "In supervised graph classification, we are given a collection of graphs each with an attached categorical label. For example, the PROTEINS dataset we use for this demo is a collection of graphs each representing a chemical compound and labelled as either an enzyme or not. Our goal is to train a machine learning model that uses the graph structure of the data together with any information available for the graph's nodes, e.g., chemical properties for the compounds in PROTEINS, to predict the correct label for a previously unseen graph; a previously unseen graph is one that was not used for training and validating the model.\n", "\n", "The DGCNN architecture was proposed in [1] (see Figure 5 in [1]) using the graph convolutional layers from [2] but with a modified propagation rule (see [1] for details). DGCNN introduces a new `SortPooling` layer to generate a representation (also know as embedding) for each given graph using as input the representations learned for each node via a stack of graph convolutional layers. The output of the `SortPooling` layer is then used as input to one-dimensional convolutional, max pooling, and dense layers that learn graph-level features suitable for predicting graph labels.\n", "\n", "**References**\n", "\n", "[1] An End-to-End Deep Learning Architecture for Graph Classification, M. Zhang, Z. Cui, M. Neumann, Y. Chen, AAAI-18. ([link](https://www.aaai.org/ocs/index.php/AAAI/AAAI18/paper/viewPaper/17146))\n", "\n", "[2] Semi-supervised Classification with Graph Convolutional Networks, T. N. Kipf and M. Welling, ICLR 2017. ([link](https://arxiv.org/abs/1609.02907))\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "outputs": [], "source": [ "# install StellarGraph if running on Google Colab\n", "import sys\n", "if 'google.colab' in sys.modules:\n", " %pip install -q stellargraph[demos]==1.2.1" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "nbsphinx": "hidden", "tags": [ "VersionCheck" ] }, "outputs": [], "source": [ "# verify that we're using the correct version of StellarGraph for this notebook\n", "import stellargraph as sg\n", "\n", "try:\n", " sg.utils.validate_notebook_version(\"1.2.1\")\n", "except AttributeError:\n", " raise ValueError(\n", " f\"This notebook requires StellarGraph version 1.2.1, but a different version {sg.__version__} is installed. Please see .\"\n", " ) from None" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "\n", "import stellargraph as sg\n", "from stellargraph.mapper import PaddedGraphGenerator\n", "from stellargraph.layer import DeepGraphCNN\n", "from stellargraph import StellarGraph\n", "\n", "from stellargraph import datasets\n", "\n", "from sklearn import model_selection\n", "from IPython.display import display, HTML\n", "\n", "from tensorflow.keras import Model\n", "from tensorflow.keras.optimizers import Adam\n", "from tensorflow.keras.layers import Dense, Conv1D, MaxPool1D, Dropout, Flatten\n", "from tensorflow.keras.losses import binary_crossentropy\n", "import tensorflow as tf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import the data" ] }, { "cell_type": "markdown", "metadata": { "tags": [ "DataLoadingLinks" ] }, "source": [ "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [ "DataLoading" ] }, "outputs": [ { "data": { "text/html": [ "Each graph represents a protein and graph labels represent whether they are are enzymes or non-enzymes. The dataset includes 1113 graphs with 39 nodes and 73 edges on average for each graph. Graph nodes have 4 attributes (including a one-hot encoding of their label), and each graph is labelled as belonging to 1 of 2 classes." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dataset = datasets.PROTEINS()\n", "display(HTML(dataset.description))\n", "graphs, graph_labels = dataset.load()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `graphs` value is a list of many `StellarGraph` instances, each of which has a few node features:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StellarGraph: Undirected multigraph\n", " Nodes: 42, Edges: 162\n", "\n", " Node types:\n", " default: [42]\n", " Features: float32 vector, length 4\n", " Edge types: default-default->default\n", "\n", " Edge types:\n", " default-default->default: [162]\n", " Weights: all 1 (default)\n", " Features: none\n" ] } ], "source": [ "print(graphs[0].info())" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StellarGraph: Undirected multigraph\n", " Nodes: 27, Edges: 92\n", "\n", " Node types:\n", " default: [27]\n", " Features: float32 vector, length 4\n", " Edge types: default-default->default\n", "\n", " Edge types:\n", " default-default->default: [92]\n", " Weights: all 1 (default)\n", " Features: none\n" ] } ], "source": [ "print(graphs[1].info())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Summary statistics of the sizes of the graphs:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
nodesedges
count1113.01113.0
mean39.1145.6
std45.8169.3
min4.010.0
25%15.056.0
50%26.098.0
75%45.0174.0
max620.02098.0
\n", "
" ], "text/plain": [ " nodes edges\n", "count 1113.0 1113.0\n", "mean 39.1 145.6\n", "std 45.8 169.3\n", "min 4.0 10.0\n", "25% 15.0 56.0\n", "50% 26.0 98.0\n", "75% 45.0 174.0\n", "max 620.0 2098.0" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "summary = pd.DataFrame(\n", " [(g.number_of_nodes(), g.number_of_edges()) for g in graphs],\n", " columns=[\"nodes\", \"edges\"],\n", ")\n", "summary.describe().round(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The labels are `1` or `2`:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
label
1663
2450
\n", "
" ], "text/plain": [ " label\n", "1 663\n", "2 450" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "graph_labels.value_counts().to_frame()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "graph_labels = pd.get_dummies(graph_labels, drop_first=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepare graph generator\n", "\n", "To feed data to the `tf.Keras` model that we will create later, we need a data generator. For supervised graph classification, we create an instance of `StellarGraph`'s `PaddedGraphGenerator` class." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "generator = PaddedGraphGenerator(graphs=graphs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create the Keras graph classification model\n", "\n", "We are now ready to create a `tf.Keras` graph classification model using `StellarGraph`'s `DeepGraphCNN` class together with standard `tf.Keras` layers `Conv1D`, `MapPool1D`, `Dropout`, and `Dense`. \n", "\n", "The model's input is the graph represented by its adjacency and node features matrices. The first four layers are Graph Convolutional as in [2] but using the adjacency normalisation from [1], $D^{-1}A$ where $A$ is the adjacency matrix with self loops and $D$ is the corresponding degree matrix. The graph convolutional layers each have 32, 32, 32, 1 units and `tanh` activations. \n", "\n", "The next layer is a one dimensional convolutional layer, `Conv1D`, followed by a max pooling, `MaxPool1D`, layer. Next is a second `Conv1D` layer that is followed by two `Dense` layers the second used for binary classification. The convolutional and dense layers use `relu` activation except for the last dense layer that uses `sigmoid` for classification. As described in [1], we add a `Dropout` layer after the first `Dense` layer." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "![](dgcnn_architecture.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we create the base DGCNN model that includes the graph convolutional and `SortPooling` layers." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "k = 35 # the number of rows for the output tensor\n", "layer_sizes = [32, 32, 32, 1]\n", "\n", "dgcnn_model = DeepGraphCNN(\n", " layer_sizes=layer_sizes,\n", " activations=[\"tanh\", \"tanh\", \"tanh\", \"tanh\"],\n", " k=k,\n", " bias=False,\n", " generator=generator,\n", ")\n", "x_inp, x_out = dgcnn_model.in_out_tensors()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we add the convolutional, max pooling, and dense layers." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "x_out = Conv1D(filters=16, kernel_size=sum(layer_sizes), strides=sum(layer_sizes))(x_out)\n", "x_out = MaxPool1D(pool_size=2)(x_out)\n", "\n", "x_out = Conv1D(filters=32, kernel_size=5, strides=1)(x_out)\n", "\n", "x_out = Flatten()(x_out)\n", "\n", "x_out = Dense(units=128, activation=\"relu\")(x_out)\n", "x_out = Dropout(rate=0.5)(x_out)\n", "\n", "predictions = Dense(units=1, activation=\"sigmoid\")(x_out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we create the `Keras` model and prepare it for training by specifying the loss and optimisation algorithm." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "model = Model(inputs=x_inp, outputs=predictions)\n", "\n", "model.compile(\n", " optimizer=Adam(lr=0.0001), loss=binary_crossentropy, metrics=[\"acc\"],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train the model\n", "\n", "We can now train the model using the model's `fit` method.\n", "\n", "But first we need to split our data to training and test sets. We are going to use 90% of the data for training and the remaining 10% for testing. This 90/10 split is the equivalent of a single fold in the 10-fold cross validation scheme used in [1]." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "train_graphs, test_graphs = model_selection.train_test_split(\n", " graph_labels, train_size=0.9, test_size=None, stratify=graph_labels,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given the data split into train and test sets, we create a `StellarGraph.PaddedGenerator` generator object that prepares the data for training. We create data generators suitable for training at `tf.keras` model by calling the latter generator's `flow` method specifying the train and test data." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "gen = PaddedGraphGenerator(graphs=graphs)\n", "\n", "train_gen = gen.flow(\n", " list(train_graphs.index - 1),\n", " targets=train_graphs.values,\n", " batch_size=50,\n", " symmetric_normalization=False,\n", ")\n", "\n", "test_gen = gen.flow(\n", " list(test_graphs.index - 1),\n", " targets=test_graphs.values,\n", " batch_size=1,\n", " symmetric_normalization=False,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Note**: We set the number of epochs to a large value so the call to `model.fit(...)` later might take a long time to complete. For faster performance set `epochs` to a smaller value; but if you do accuracy of the model found may be low." ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "epochs = 100" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now train the model by calling it's `fit` method." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " ['...']\n", " ['...']\n", "Train for 21 steps, validate for 112 steps\n", "Epoch 1/100\n", "21/21 [==============================] - 3s 139ms/step - loss: 0.6640 - acc: 0.5824 - val_loss: 0.6188 - val_acc: 0.5982\n", "Epoch 2/100\n", "21/21 [==============================] - 2s 74ms/step - loss: 0.6526 - acc: 0.6234 - val_loss: 0.6003 - val_acc: 0.6429\n", "Epoch 3/100\n", "21/21 [==============================] - 2s 86ms/step - loss: 0.6468 - acc: 0.6643 - val_loss: 0.5987 - val_acc: 0.7411\n", "Epoch 4/100\n", "21/21 [==============================] - 2s 76ms/step - loss: 0.6361 - acc: 0.7123 - val_loss: 0.5843 - val_acc: 0.7321\n", "Epoch 5/100\n", "21/21 [==============================] - 2s 83ms/step - loss: 0.6301 - acc: 0.7143 - val_loss: 0.5786 - val_acc: 0.7500\n", "Epoch 6/100\n", "21/21 [==============================] - 2s 86ms/step - loss: 0.6061 - acc: 0.7073 - val_loss: 0.5716 - val_acc: 0.7500\n", "Epoch 7/100\n", "21/21 [==============================] - 2s 81ms/step - loss: 0.6129 - acc: 0.7173 - val_loss: 0.5626 - val_acc: 0.7500\n", "Epoch 8/100\n", "21/21 [==============================] - 2s 82ms/step - loss: 0.6274 - acc: 0.7163 - val_loss: 0.5637 - val_acc: 0.7411\n", "Epoch 9/100\n", "21/21 [==============================] - 2s 84ms/step - loss: 0.5985 - acc: 0.7243 - val_loss: 0.5606 - val_acc: 0.7411\n", "Epoch 10/100\n", "21/21 [==============================] - 2s 86ms/step - loss: 0.6066 - acc: 0.7223 - val_loss: 0.5568 - val_acc: 0.7411\n", "Epoch 11/100\n", "21/21 [==============================] - 2s 82ms/step - loss: 0.5956 - acc: 0.7273 - val_loss: 0.5530 - val_acc: 0.7411\n", "Epoch 12/100\n", "21/21 [==============================] - 2s 75ms/step - loss: 0.5852 - acc: 0.7203 - val_loss: 0.5493 - val_acc: 0.7500\n", "Epoch 13/100\n", "21/21 [==============================] - 2s 81ms/step - loss: 0.5995 - acc: 0.7233 - val_loss: 0.5482 - val_acc: 0.7500\n", "Epoch 14/100\n", "21/21 [==============================] - 2s 89ms/step - loss: 0.5898 - acc: 0.7303 - val_loss: 0.5452 - val_acc: 0.7411\n", "Epoch 15/100\n", "21/21 [==============================] - 2s 88ms/step - loss: 0.6028 - acc: 0.7233 - val_loss: 0.5467 - val_acc: 0.7589\n", "Epoch 16/100\n", "21/21 [==============================] - 2s 84ms/step - loss: 0.5850 - acc: 0.7223 - val_loss: 0.5444 - val_acc: 0.7500\n", "Epoch 17/100\n", "21/21 [==============================] - 2s 80ms/step - loss: 0.5793 - acc: 0.7243 - val_loss: 0.5436 - val_acc: 0.7589\n", "Epoch 18/100\n", "21/21 [==============================] - 2s 87ms/step - loss: 0.5705 - acc: 0.7133 - val_loss: 0.5413 - val_acc: 0.7500\n", "Epoch 19/100\n", "21/21 [==============================] - 2s 78ms/step - loss: 0.5829 - acc: 0.7263 - val_loss: 0.5426 - val_acc: 0.7411\n", "Epoch 20/100\n", "21/21 [==============================] - 2s 88ms/step - loss: 0.5796 - acc: 0.7133 - val_loss: 0.5423 - val_acc: 0.7411\n", "Epoch 21/100\n", "21/21 [==============================] - 2s 93ms/step - loss: 0.5772 - acc: 0.7053 - val_loss: 0.5397 - val_acc: 0.7321\n", "Epoch 22/100\n", "21/21 [==============================] - 2s 79ms/step - loss: 0.5818 - acc: 0.7143 - val_loss: 0.5378 - val_acc: 0.7500\n", "Epoch 23/100\n", "21/21 [==============================] - 2s 86ms/step - loss: 0.5733 - acc: 0.7133 - val_loss: 0.5381 - val_acc: 0.7321\n", "Epoch 24/100\n", "21/21 [==============================] - 2s 85ms/step - loss: 0.5670 - acc: 0.7143 - val_loss: 0.5390 - val_acc: 0.7321\n", "Epoch 25/100\n", "21/21 [==============================] - 2s 81ms/step - loss: 0.5688 - acc: 0.7143 - val_loss: 0.5374 - val_acc: 0.7321\n", "Epoch 26/100\n", "21/21 [==============================] - 2s 86ms/step - loss: 0.5671 - acc: 0.7103 - val_loss: 0.5372 - val_acc: 0.7232\n", "Epoch 27/100\n", "21/21 [==============================] - 2s 89ms/step - loss: 0.5639 - acc: 0.7103 - val_loss: 0.5362 - val_acc: 0.7232\n", "Epoch 28/100\n", "21/21 [==============================] - 2s 96ms/step - loss: 0.5732 - acc: 0.7143 - val_loss: 0.5377 - val_acc: 0.7321\n", "Epoch 29/100\n", "21/21 [==============================] - 2s 86ms/step - loss: 0.5655 - acc: 0.7073 - val_loss: 0.5363 - val_acc: 0.7232\n", "Epoch 30/100\n", "21/21 [==============================] - 2s 82ms/step - loss: 0.5683 - acc: 0.7153 - val_loss: 0.5366 - val_acc: 0.7321\n", "Epoch 31/100\n", "21/21 [==============================] - 2s 84ms/step - loss: 0.5752 - acc: 0.7203 - val_loss: 0.5345 - val_acc: 0.7232\n", "Epoch 32/100\n", "21/21 [==============================] - 2s 96ms/step - loss: 0.5778 - acc: 0.7183 - val_loss: 0.5392 - val_acc: 0.7321\n", "Epoch 33/100\n", "21/21 [==============================] - 2s 90ms/step - loss: 0.5649 - acc: 0.7253 - val_loss: 0.5352 - val_acc: 0.7500\n", "Epoch 34/100\n", "21/21 [==============================] - 2s 87ms/step - loss: 0.5700 - acc: 0.7153 - val_loss: 0.5337 - val_acc: 0.7321\n", "Epoch 35/100\n", "21/21 [==============================] - 2s 74ms/step - loss: 0.5621 - acc: 0.7083 - val_loss: 0.5358 - val_acc: 0.7411\n", "Epoch 36/100\n", "21/21 [==============================] - 2s 83ms/step - loss: 0.5729 - acc: 0.7273 - val_loss: 0.5371 - val_acc: 0.7232\n", "Epoch 37/100\n", "21/21 [==============================] - 2s 84ms/step - loss: 0.5735 - acc: 0.7153 - val_loss: 0.5316 - val_acc: 0.7321\n", "Epoch 38/100\n", "21/21 [==============================] - 2s 92ms/step - loss: 0.5694 - acc: 0.7043 - val_loss: 0.5309 - val_acc: 0.7411\n", "Epoch 39/100\n", "21/21 [==============================] - 2s 88ms/step - loss: 0.5589 - acc: 0.7173 - val_loss: 0.5315 - val_acc: 0.7411\n", "Epoch 40/100\n", "21/21 [==============================] - 2s 89ms/step - loss: 0.5687 - acc: 0.7163 - val_loss: 0.5314 - val_acc: 0.7321\n", "Epoch 41/100\n", "21/21 [==============================] - ETA: 0s - loss: 0.5534 - acc: 0.728 - 2s 86ms/step - loss: 0.5523 - acc: 0.7283 - val_loss: 0.5301 - val_acc: 0.7411\n", "Epoch 42/100\n", "21/21 [==============================] - 2s 93ms/step - loss: 0.5596 - acc: 0.7113 - val_loss: 0.5306 - val_acc: 0.7411\n", "Epoch 43/100\n", "21/21 [==============================] - 2s 90ms/step - loss: 0.5518 - acc: 0.7193 - val_loss: 0.5293 - val_acc: 0.7500\n", "Epoch 44/100\n", "21/21 [==============================] - 2s 86ms/step - loss: 0.5579 - acc: 0.7153 - val_loss: 0.5299 - val_acc: 0.7500\n", "Epoch 45/100\n", "21/21 [==============================] - 2s 82ms/step - loss: 0.5565 - acc: 0.7253 - val_loss: 0.5276 - val_acc: 0.7500\n", "Epoch 46/100\n", "21/21 [==============================] - 2s 83ms/step - loss: 0.5576 - acc: 0.7113 - val_loss: 0.5294 - val_acc: 0.7500\n", "Epoch 47/100\n", "21/21 [==============================] - 2s 83ms/step - loss: 0.5624 - acc: 0.7203 - val_loss: 0.5291 - val_acc: 0.7500\n", "Epoch 48/100\n", "21/21 [==============================] - 2s 89ms/step - loss: 0.5552 - acc: 0.7223 - val_loss: 0.5268 - val_acc: 0.7500\n", "Epoch 49/100\n", "21/21 [==============================] - 2s 90ms/step - loss: 0.5536 - acc: 0.7223 - val_loss: 0.5250 - val_acc: 0.7589\n", "Epoch 50/100\n", "21/21 [==============================] - 2s 98ms/step - loss: 0.5693 - acc: 0.7153 - val_loss: 0.5281 - val_acc: 0.7589\n", "Epoch 51/100\n", "21/21 [==============================] - 2s 90ms/step - loss: 0.5521 - acc: 0.7243 - val_loss: 0.5256 - val_acc: 0.7589\n", "Epoch 52/100\n", "21/21 [==============================] - 2s 89ms/step - loss: 0.5536 - acc: 0.7203 - val_loss: 0.5217 - val_acc: 0.7589\n", "Epoch 53/100\n", "21/21 [==============================] - 2s 93ms/step - loss: 0.5489 - acc: 0.7143 - val_loss: 0.5197 - val_acc: 0.7679\n", "Epoch 54/100\n", "21/21 [==============================] - 2s 88ms/step - loss: 0.5478 - acc: 0.7283 - val_loss: 0.5211 - val_acc: 0.7679\n", "Epoch 55/100\n", "21/21 [==============================] - 2s 90ms/step - loss: 0.5569 - acc: 0.7263 - val_loss: 0.5201 - val_acc: 0.7589\n", "Epoch 56/100\n", "21/21 [==============================] - 2s 101ms/step - loss: 0.5530 - acc: 0.7183 - val_loss: 0.5204 - val_acc: 0.7857\n", "Epoch 57/100\n", "21/21 [==============================] - 2s 91ms/step - loss: 0.5453 - acc: 0.7183 - val_loss: 0.5171 - val_acc: 0.7768\n", "Epoch 58/100\n", "21/21 [==============================] - 2s 88ms/step - loss: 0.5390 - acc: 0.7303 - val_loss: 0.5161 - val_acc: 0.7857\n", "Epoch 59/100\n", "21/21 [==============================] - 2s 90ms/step - loss: 0.5410 - acc: 0.7283 - val_loss: 0.5128 - val_acc: 0.7857\n", "Epoch 60/100\n", "21/21 [==============================] - 2s 97ms/step - loss: 0.5602 - acc: 0.7213 - val_loss: 0.5173 - val_acc: 0.7679\n", "Epoch 61/100\n", "21/21 [==============================] - 2s 90ms/step - loss: 0.5449 - acc: 0.7243 - val_loss: 0.5138 - val_acc: 0.7768\n", "Epoch 62/100\n", "21/21 [==============================] - 2s 89ms/step - loss: 0.5492 - acc: 0.7243 - val_loss: 0.5125 - val_acc: 0.7768\n", "Epoch 63/100\n", "21/21 [==============================] - 2s 84ms/step - loss: 0.5466 - acc: 0.7213 - val_loss: 0.5161 - val_acc: 0.7768\n", "Epoch 64/100\n", "21/21 [==============================] - 2s 83ms/step - loss: 0.5475 - acc: 0.7213 - val_loss: 0.5135 - val_acc: 0.7768\n", "Epoch 65/100\n", "21/21 [==============================] - 2s 86ms/step - loss: 0.5409 - acc: 0.7243 - val_loss: 0.5125 - val_acc: 0.7857\n", "Epoch 66/100\n", "21/21 [==============================] - 2s 95ms/step - loss: 0.5404 - acc: 0.7303 - val_loss: 0.5095 - val_acc: 0.7857\n", "Epoch 67/100\n", "21/21 [==============================] - 2s 85ms/step - loss: 0.5453 - acc: 0.7213 - val_loss: 0.5029 - val_acc: 0.7857\n", "Epoch 68/100\n", "21/21 [==============================] - 2s 88ms/step - loss: 0.5374 - acc: 0.7293 - val_loss: 0.5086 - val_acc: 0.7768\n", "Epoch 69/100\n", "21/21 [==============================] - 2s 97ms/step - loss: 0.5409 - acc: 0.7353 - val_loss: 0.5077 - val_acc: 0.7768\n", "Epoch 70/100\n", "21/21 [==============================] - 2s 92ms/step - loss: 0.5439 - acc: 0.7293 - val_loss: 0.5043 - val_acc: 0.7857\n", "Epoch 71/100\n", "21/21 [==============================] - 2s 90ms/step - loss: 0.5330 - acc: 0.7313 - val_loss: 0.5090 - val_acc: 0.7768\n", "Epoch 72/100\n", "21/21 [==============================] - 2s 82ms/step - loss: 0.5328 - acc: 0.7303 - val_loss: 0.5092 - val_acc: 0.7768\n", "Epoch 73/100\n", "21/21 [==============================] - 2s 84ms/step - loss: 0.5333 - acc: 0.7273 - val_loss: 0.5098 - val_acc: 0.7857\n", "Epoch 74/100\n", "21/21 [==============================] - 2s 96ms/step - loss: 0.5384 - acc: 0.7313 - val_loss: 0.5049 - val_acc: 0.7679\n", "Epoch 75/100\n", "21/21 [==============================] - 2s 83ms/step - loss: 0.5417 - acc: 0.7233 - val_loss: 0.5086 - val_acc: 0.7768\n", "Epoch 76/100\n", "21/21 [==============================] - 2s 81ms/step - loss: 0.5364 - acc: 0.7253 - val_loss: 0.5088 - val_acc: 0.7589\n", "Epoch 77/100\n", "21/21 [==============================] - 2s 89ms/step - loss: 0.5365 - acc: 0.7313 - val_loss: 0.5083 - val_acc: 0.7768\n", "Epoch 78/100\n", "21/21 [==============================] - 2s 86ms/step - loss: 0.5378 - acc: 0.7363 - val_loss: 0.5084 - val_acc: 0.7679\n", "Epoch 79/100\n", "21/21 [==============================] - 2s 86ms/step - loss: 0.5373 - acc: 0.7293 - val_loss: 0.5049 - val_acc: 0.7768\n", "Epoch 80/100\n", "21/21 [==============================] - 2s 87ms/step - loss: 0.5344 - acc: 0.7373 - val_loss: 0.5063 - val_acc: 0.7679\n", "Epoch 81/100\n", "21/21 [==============================] - 2s 87ms/step - loss: 0.5344 - acc: 0.7313 - val_loss: 0.5039 - val_acc: 0.7679\n", "Epoch 82/100\n", "21/21 [==============================] - 2s 90ms/step - loss: 0.5304 - acc: 0.7363 - val_loss: 0.5078 - val_acc: 0.7589\n", "Epoch 83/100\n", "21/21 [==============================] - 2s 93ms/step - loss: 0.5382 - acc: 0.7303 - val_loss: 0.5116 - val_acc: 0.7589\n", "Epoch 84/100\n", "21/21 [==============================] - 2s 79ms/step - loss: 0.5315 - acc: 0.7293 - val_loss: 0.4988 - val_acc: 0.7500\n", "Epoch 85/100\n", "21/21 [==============================] - 2s 91ms/step - loss: 0.5358 - acc: 0.7293 - val_loss: 0.4974 - val_acc: 0.7679\n", "Epoch 86/100\n", "21/21 [==============================] - 2s 77ms/step - loss: 0.5424 - acc: 0.7283 - val_loss: 0.5009 - val_acc: 0.7679\n", "Epoch 87/100\n", "21/21 [==============================] - 2s 88ms/step - loss: 0.5300 - acc: 0.7403 - val_loss: 0.5085 - val_acc: 0.7768\n", "Epoch 88/100\n", "21/21 [==============================] - 2s 82ms/step - loss: 0.5436 - acc: 0.7253 - val_loss: 0.5046 - val_acc: 0.7500\n", "Epoch 89/100\n", "21/21 [==============================] - 2s 90ms/step - loss: 0.5346 - acc: 0.7323 - val_loss: 0.5002 - val_acc: 0.7589\n", "Epoch 90/100\n", "21/21 [==============================] - 2s 91ms/step - loss: 0.5323 - acc: 0.7373 - val_loss: 0.5056 - val_acc: 0.7679\n", "Epoch 91/100\n", "21/21 [==============================] - 2s 93ms/step - loss: 0.5290 - acc: 0.7313 - val_loss: 0.5071 - val_acc: 0.7589\n", "Epoch 92/100\n", "21/21 [==============================] - 2s 86ms/step - loss: 0.5340 - acc: 0.7313 - val_loss: 0.5086 - val_acc: 0.7679\n", "Epoch 93/100\n", "21/21 [==============================] - 2s 98ms/step - loss: 0.5271 - acc: 0.7313 - val_loss: 0.5063 - val_acc: 0.7679\n", "Epoch 94/100\n", "21/21 [==============================] - 2s 83ms/step - loss: 0.5236 - acc: 0.7413 - val_loss: 0.5102 - val_acc: 0.7679\n", "Epoch 95/100\n", "21/21 [==============================] - 2s 86ms/step - loss: 0.5237 - acc: 0.7333 - val_loss: 0.5103 - val_acc: 0.7411\n", "Epoch 96/100\n", "21/21 [==============================] - 2s 95ms/step - loss: 0.5196 - acc: 0.7353 - val_loss: 0.5110 - val_acc: 0.7768\n", "Epoch 97/100\n", "21/21 [==============================] - 2s 94ms/step - loss: 0.5250 - acc: 0.7293 - val_loss: 0.5076 - val_acc: 0.7411\n", "Epoch 98/100\n", "21/21 [==============================] - 2s 87ms/step - loss: 0.5259 - acc: 0.7403 - val_loss: 0.5087 - val_acc: 0.7679\n", "Epoch 99/100\n", "21/21 [==============================] - 2s 99ms/step - loss: 0.5315 - acc: 0.7413 - val_loss: 0.5080 - val_acc: 0.7679\n", "Epoch 100/100\n", "21/21 [==============================] - 2s 93ms/step - loss: 0.5292 - acc: 0.7313 - val_loss: 0.5223 - val_acc: 0.7589\n" ] } ], "source": [ "history = model.fit(\n", " train_gen, epochs=epochs, verbose=1, validation_data=test_gen, shuffle=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us plot the training history (losses and accuracies for the train and test data)." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sg.utils.plot_history(history)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let us calculate the performance of the trained model on the test data." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " ['...']\n", "112/112 [==============================] - 0s 1ms/step - loss: 0.5223 - acc: 0.7589\n", "\n", "Test Set Metrics:\n", "\tloss: 0.5223\n", "\tacc: 0.7589\n" ] } ], "source": [ "test_metrics = model.evaluate(test_gen)\n", "print(\"\\nTest Set Metrics:\")\n", "for name, val in zip(model.metrics_names, test_metrics):\n", " print(\"\\t{}: {:0.4f}\".format(name, val))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Conclusion\n", "\n", "We demonstrated the use of `StellarGraph`'s `DeepGraphCNN` implementation for supervised graph classification algorithm. More specifically we showed how to predict whether a chemical compound represented as a graph is an enzyme or not.\n", "\n", "Performance is similar to that reported in [1] but a small difference does exist. This difference can be attributed to a small number of factors listed below, \n", "- We use a different training scheme, that is a single 90/10 split of the data as opposed to the repeated 10-fold cross validation scheme used in [1]. We use a single fold for ease of exposition. \n", "- The experimental evaluation scheme in [1] does not specify some important details such as: the regularisation used for the neural network layers; if a bias term is included; the weight initialization method used; and the batch size." ] }, { "cell_type": "markdown", "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "source": [ "
Run the latest release of this notebook:
" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }