{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Node classification with Graph Convolutional Network (GCN)\n", "\n", "> This demo explains how to do node classification using the StellarGraph library. [See all other demos](../README.md).\n" ] }, { "cell_type": "markdown", "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "source": [ "
Run the latest release of this notebook:
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[The StellarGraph library](https://github.com/stellargraph/stellargraph) supports many state-of-the-art machine learning (ML) algorithms on [graphs](https://en.wikipedia.org/wiki/Graph_%28discrete_mathematics%29). In this notebook, we'll be training a model to predict the class or label of a node, commonly known as node classification. We will also use the resulting model to compute vector embeddings for each node.\n", "\n", "There's two necessary parts to be able to do this task:\n", "\n", "- a graph: this notebook uses the Cora dataset from . The dataset consists of academic publications as the nodes and the citations between them as the links: if publication A cites publication B, then the graph has an edge from A to B. The nodes are classified into one of seven subjects, and our model will learn to predict this subject.\n", "- an algorithm: this notebook uses a Graph Convolution Network (GCN) [1]. The core of the GCN neural network model is a \"graph convolution\" layer. This layer is similar to a conventional dense layer, augmented by the graph adjacency matrix to use information about a node's connections. This algorithm is discussed in more detail in [\"Knowing Your Neighbours: Machine Learning on Graphs\"](https://medium.com/stellargraph/knowing-your-neighbours-machine-learning-on-graphs-9b7c3d0d5896).\n", "\n", "The notebook walks through three sections:\n", "\n", "1. **Data preparation** using [Pandas](https://pandas.pydata.org) and [scikit-learn](https://scikit-learn.org/): loading the graph from CSV files, doing some basic introspection, and splitting it into train, test and validation splits for ML\n", "2. **Creating the GCN layers** and data input using [StellarGraph](https://github.com/stellargraph/stellargraph)\n", "3. **Training and evaluating** the model using [TensorFlow Keras](https://www.tensorflow.org/guide/keras), Pandas and scikit-learn\n", "\n", "Notably, only section 2 needs StellarGraph: section 1 and section 3 are driven by the existing flexible functionality in common and popular data science libraries. Most of the algorithms supported by StellarGraph follow this pattern, where the custom StellarGraph functionality integrates smoothly with the conventional data science work-flow.\n", "\n", "> StellarGraph supports other algorithms for doing [node classification](README.md), as well as many [other tasks](../README.md) such as [link prediction](../link-prediction/README.md), and [representation learning](../embeddings/README.md).\n", "\n", "[1]: [Graph Convolutional Networks (GCN): Semi-Supervised Classification with Graph Convolutional Networks](https://github.com/tkipf/gcn). Thomas N. Kipf, Max Welling.\n", "International Conference on Learning Representations (ICLR), 2017" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first step is to import the Python libraries that we'll need. We import `stellargraph` under the `sg` name for convenience, similar to `pandas` often being imported as `pd`." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "outputs": [], "source": [ "# install StellarGraph if running on Google Colab\n", "import sys\n", "if 'google.colab' in sys.modules:\n", " %pip install -q stellargraph[demos]==1.2.1" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "nbsphinx": "hidden", "tags": [ "VersionCheck" ] }, "outputs": [], "source": [ "# verify that we're using the correct version of StellarGraph for this notebook\n", "import stellargraph as sg\n", "\n", "try:\n", " sg.utils.validate_notebook_version(\"1.2.1\")\n", "except AttributeError:\n", " raise ValueError(\n", " f\"This notebook requires StellarGraph version 1.2.1, but a different version {sg.__version__} is installed. Please see .\"\n", " ) from None" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import os\n", "\n", "import stellargraph as sg\n", "from stellargraph.mapper import FullBatchNodeGenerator\n", "from stellargraph.layer import GCN\n", "\n", "from tensorflow.keras import layers, optimizers, losses, metrics, Model\n", "from sklearn import preprocessing, model_selection\n", "from IPython.display import display, HTML\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Data Preparation\n", "\n", "### Loading the CORA network" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can retrieve a `StellarGraph` graph object holding this Cora dataset using the `Cora` loader ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.datasets.Cora)) from the `datasets` submodule ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#module-stellargraph.datasets)). It also provides us with the ground-truth node subject classes. This function is implemented using Pandas, see [the \"Loading data into StellarGraph from Pandas\" notebook](../basics/loading-pandas.ipynb) for details.\n", "\n", "(Note: Cora is a citation network, which is a directed graph, but, like most users of this graph, we ignore the edge direction and treat it as undirected.)" ] }, { "cell_type": "markdown", "metadata": { "tags": [ "DataLoadingLinks" ] }, "source": [ "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [ "DataLoading" ] }, "outputs": [ { "data": { "text/html": [ "The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dataset = sg.datasets.Cora()\n", "display(HTML(dataset.description))\n", "G, node_subjects = dataset.load()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `info` method can help us verify that our loaded graph matches the description:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StellarGraph: Undirected multigraph\n", " Nodes: 2708, Edges: 5429\n", "\n", " Node types:\n", " paper: [2708]\n", " Features: float32 vector, length 1433\n", " Edge types: paper-cites->paper\n", "\n", " Edge types:\n", " paper-cites->paper: [5429]\n" ] } ], "source": [ "print(G.info())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We aim to train a graph-ML model that will predict the \"subject\" attribute on the nodes. These subjects are one of 7 categories, with some categories more common than others:" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
subject
Neural_Networks818
Probabilistic_Methods426
Genetic_Algorithms418
Theory351
Case_Based298
Reinforcement_Learning217
Rule_Learning180
\n", "
" ], "text/plain": [ " subject\n", "Neural_Networks 818\n", "Probabilistic_Methods 426\n", "Genetic_Algorithms 418\n", "Theory 351\n", "Case_Based 298\n", "Reinforcement_Learning 217\n", "Rule_Learning 180" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "node_subjects.value_counts().to_frame()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Splitting the data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For machine learning we want to take a subset of the nodes for training, and use the rest for validation and testing. We'll use scikit-learn's `train_test_split` function ([docs](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html)) to do this.\n", "\n", "Here we're taking 140 node labels for training, 500 for validation, and the rest for testing." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "train_subjects, test_subjects = model_selection.train_test_split(\n", " node_subjects, train_size=140, test_size=None, stratify=node_subjects\n", ")\n", "val_subjects, test_subjects = model_selection.train_test_split(\n", " test_subjects, train_size=500, test_size=None, stratify=test_subjects\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note using stratified sampling gives the following counts:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
subject
Neural_Networks42
Genetic_Algorithms22
Probabilistic_Methods22
Theory18
Case_Based16
Reinforcement_Learning11
Rule_Learning9
\n", "
" ], "text/plain": [ " subject\n", "Neural_Networks 42\n", "Genetic_Algorithms 22\n", "Probabilistic_Methods 22\n", "Theory 18\n", "Case_Based 16\n", "Reinforcement_Learning 11\n", "Rule_Learning 9" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_subjects.value_counts().to_frame()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The training set has class imbalance that might need to be compensated, e.g., via using a weighted cross-entropy loss in model training, with class weights inversely proportional to class support. However, we will ignore the class imbalance in this example, for simplicity." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Converting to numeric arrays" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For our categorical target, we will use one-hot vectors that will be compared against the model's soft-max output. To do this conversion we can use the `LabelBinarizer` transform ([docs](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelBinarizer.html)) from scikit-learn. Another option would be the `pandas.get_dummies` function ([docs](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.get_dummies.html)), but the scikit-learn transform allows us to do the inverse transform easily later in the notebook, to interpret the predictions." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "target_encoding = preprocessing.LabelBinarizer()\n", "\n", "train_targets = target_encoding.fit_transform(train_subjects)\n", "val_targets = target_encoding.transform(val_subjects)\n", "test_targets = target_encoding.transform(test_subjects)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The CORA dataset contains attributes `w_x` that correspond to words found in that publication. If a word occurs more than once in a publication the relevant attribute will be set to one, otherwise it will be zero. These numeric attributes have been automatically included in the `StellarGraph` instance `G`, and so we do not have to do any further conversion.\n", "\n", "\"Each" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Creating the GCN layers\n", "\n", "A machine learning model in StellarGraph consists of a pair of items:\n", "\n", "- the layers themselves, such as graph convolution, [dropout](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dropout) and even [conventional dense layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense)\n", "- a data generator to convert the core graph structure and node features into a format that can be fed into the Keras model for training or prediction\n", "\n", "GCN is a full-batch model and we're doing node classification here, which means the `FullBatchNodeGenerator` class ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.mapper.FullBatchNodeGenerator)) is the appropriate generator for our task. StellarGraph has many generators in order to support all [its many models and tasks](../README.md).\n", "\n", "Specifying the `method='gcn'` argument to the `FullBatchNodeGenerator` means it will yield data appropriate for the GCN algorithm specifically, by using the [normalized graph Laplacian matrix](https://en.wikipedia.org/wiki/Laplacian_matrix#Symmetric_normalized_Laplacian) to capture the graph structure." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using GCN (local pooling) filters...\n" ] } ], "source": [ "generator = FullBatchNodeGenerator(G, method=\"gcn\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A generator just encodes the information required to produce the model inputs. Calling the `flow` method ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.mapper.FullBatchNodeGenerator.flow)) with a set of nodes and their true labels produces an object that can be used to train the model, on those nodes and labels that were specified. We created a training set above, so that's what we're going to use here." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "train_gen = generator.flow(train_subjects.index, train_targets)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can specify our machine learning model by building a stack of layers. We can use StellarGraph's `GCN` class ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.layer.GCN)), which packages up the creation of this stack of [graph convolution](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.layer.GraphConvolution) and [dropout](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dropout) layers. We can specify a few parameters to control this:\n", "\n", " * `layer_sizes`: the number of hidden GCN layers and their sizes. In this case, two GCN layers with 16 units each.\n", " * `activations`: the activation to apply to each GCN layer's output. In this case, [RelU](https://en.wikipedia.org/wiki/Rectifier_\\(neural_networks\\)) for both layers.\n", " * `dropout`: the rate of dropout for the input of each GCN layer. In this case, 50%." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "gcn = GCN(\n", " layer_sizes=[16, 16], activations=[\"relu\", \"relu\"], generator=generator, dropout=0.5\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To create a Keras model we now expose the input and output tensors of the GCN model for node prediction, via the `GCN.in_out_tensors` method:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_inp, x_out = gcn.in_out_tensors()\n", "\n", "x_out" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `x_out` value is a TensorFlow tensor that holds a 16-dimensional vector for the nodes requested when training or predicting. The actual predictions of each node's class/subject needs to be computed from this vector. StellarGraph is built using Keras functionality, so this can be done with a standard Keras functionality: an additional dense layer (with one unit per class) using a softmax activation. This activation function ensures that the final outputs for each input node will be a vector of \"probabilities\", where every value is between 0 and 1, and the whole vector sums to 1. The predicted class is the element with the highest value." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "predictions = layers.Dense(units=train_targets.shape[1], activation=\"softmax\")(x_out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Training and evaluating\n", "\n", "### Training the model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's create the actual [Keras model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) with the input tensors `x_inp` and output tensors being the predictions `predictions` from the final dense layer. Our task is a categorical prediction task, so a categorical cross-entropy loss function is appropriate. Even though we're doing graph ML with StellarGraph, we're still working with conventional Keras prediction values, so we can use [the loss function from Keras](https://www.tensorflow.org/api_docs/python/tf/keras/losses/categorical_crossentropy) directly." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "model = Model(inputs=x_inp, outputs=predictions)\n", "model.compile(\n", " optimizer=optimizers.Adam(lr=0.01),\n", " loss=losses.categorical_crossentropy,\n", " metrics=[\"acc\"],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we're training the model, we'll want to also keep track of its generalisation performance on the validation set, which means creating another data generator, using our `FullBatchNodeGenerator` we created above." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "val_gen = generator.flow(val_subjects.index, val_targets)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can directly use the `EarlyStopping` functionality ([docs](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/EarlyStopping)) offered by Keras to stop training if the validation accuracy stops improving." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.callbacks import EarlyStopping\n", "\n", "es_callback = EarlyStopping(monitor=\"val_acc\", patience=50, restore_best_weights=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've now set up our model layers, our training data, our validation data and even our training callbacks, so we can now train the model using the model's `fit` method ([docs](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit)). Like most things in this section, this is all built into Keras." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " ['...']\n", " ['...']\n", "Train for 1 steps, validate for 1 steps\n", "Epoch 1/200\n", "1/1 - 1s - loss: 1.9505 - acc: 0.1000 - val_loss: 1.9182 - val_acc: 0.2820\n", "Epoch 2/200\n", "1/1 - 0s - loss: 1.9004 - acc: 0.3143 - val_loss: 1.8831 - val_acc: 0.3560\n", "Epoch 3/200\n", "1/1 - 0s - loss: 1.8493 - acc: 0.3571 - val_loss: 1.8297 - val_acc: 0.3940\n", "Epoch 4/200\n", "1/1 - 0s - loss: 1.7679 - acc: 0.4500 - val_loss: 1.7643 - val_acc: 0.3700\n", "Epoch 5/200\n", "1/1 - 0s - loss: 1.6747 - acc: 0.4500 - val_loss: 1.7046 - val_acc: 0.3580\n", "Epoch 6/200\n", "1/1 - 0s - loss: 1.5794 - acc: 0.4643 - val_loss: 1.6489 - val_acc: 0.3780\n", "Epoch 7/200\n", "1/1 - 0s - loss: 1.5086 - acc: 0.4714 - val_loss: 1.5843 - val_acc: 0.4440\n", "Epoch 8/200\n", "1/1 - 0s - loss: 1.4128 - acc: 0.5071 - val_loss: 1.5189 - val_acc: 0.5180\n", "Epoch 9/200\n", "1/1 - 0s - loss: 1.2905 - acc: 0.5929 - val_loss: 1.4558 - val_acc: 0.5900\n", "Epoch 10/200\n", "1/1 - 0s - loss: 1.1587 - acc: 0.6714 - val_loss: 1.3988 - val_acc: 0.6320\n", "Epoch 11/200\n", "1/1 - 0s - loss: 1.1166 - acc: 0.7143 - val_loss: 1.3416 - val_acc: 0.6620\n", "Epoch 12/200\n", "1/1 - 0s - loss: 1.0452 - acc: 0.7500 - val_loss: 1.2856 - val_acc: 0.6740\n", "Epoch 13/200\n", "1/1 - 0s - loss: 1.0205 - acc: 0.7286 - val_loss: 1.2315 - val_acc: 0.6880\n", "Epoch 14/200\n", "1/1 - 0s - loss: 0.8734 - acc: 0.7786 - val_loss: 1.1815 - val_acc: 0.6880\n", "Epoch 15/200\n", "1/1 - 0s - loss: 0.7818 - acc: 0.7857 - val_loss: 1.1342 - val_acc: 0.6940\n", "Epoch 16/200\n", "1/1 - 0s - loss: 0.7580 - acc: 0.8143 - val_loss: 1.0892 - val_acc: 0.7020\n", "Epoch 17/200\n", "1/1 - 0s - loss: 0.6956 - acc: 0.8143 - val_loss: 1.0459 - val_acc: 0.7120\n", "Epoch 18/200\n", "1/1 - 0s - loss: 0.5902 - acc: 0.8214 - val_loss: 1.0059 - val_acc: 0.7180\n", "Epoch 19/200\n", "1/1 - 0s - loss: 0.5497 - acc: 0.8786 - val_loss: 0.9683 - val_acc: 0.7420\n", "Epoch 20/200\n", "1/1 - 0s - loss: 0.4658 - acc: 0.8929 - val_loss: 0.9342 - val_acc: 0.7520\n", "Epoch 21/200\n", "1/1 - 0s - loss: 0.4416 - acc: 0.8857 - val_loss: 0.9039 - val_acc: 0.7760\n", "Epoch 22/200\n", "1/1 - 0s - loss: 0.4374 - acc: 0.9071 - val_loss: 0.8786 - val_acc: 0.7860\n", "Epoch 23/200\n", "1/1 - 0s - loss: 0.3275 - acc: 0.9500 - val_loss: 0.8585 - val_acc: 0.7860\n", "Epoch 24/200\n", "1/1 - 0s - loss: 0.3131 - acc: 0.9429 - val_loss: 0.8451 - val_acc: 0.7920\n", "Epoch 25/200\n", "1/1 - 0s - loss: 0.3186 - acc: 0.9357 - val_loss: 0.8369 - val_acc: 0.8000\n", "Epoch 26/200\n", "1/1 - 0s - loss: 0.2150 - acc: 0.9786 - val_loss: 0.8352 - val_acc: 0.7940\n", "Epoch 27/200\n", "1/1 - 0s - loss: 0.2385 - acc: 0.9643 - val_loss: 0.8335 - val_acc: 0.7940\n", "Epoch 28/200\n", "1/1 - 0s - loss: 0.2191 - acc: 0.9500 - val_loss: 0.8330 - val_acc: 0.7940\n", "Epoch 29/200\n", "1/1 - 0s - loss: 0.1988 - acc: 0.9643 - val_loss: 0.8297 - val_acc: 0.7940\n", "Epoch 30/200\n", "1/1 - 0s - loss: 0.1957 - acc: 0.9500 - val_loss: 0.8282 - val_acc: 0.8040\n", "Epoch 31/200\n", "1/1 - 0s - loss: 0.1622 - acc: 0.9500 - val_loss: 0.8281 - val_acc: 0.8020\n", "Epoch 32/200\n", "1/1 - 0s - loss: 0.1748 - acc: 0.9571 - val_loss: 0.8307 - val_acc: 0.8100\n", "Epoch 33/200\n", "1/1 - 0s - loss: 0.1223 - acc: 0.9714 - val_loss: 0.8360 - val_acc: 0.8120\n", "Epoch 34/200\n", "1/1 - 0s - loss: 0.1208 - acc: 0.9857 - val_loss: 0.8433 - val_acc: 0.8160\n", "Epoch 35/200\n", "1/1 - 0s - loss: 0.1331 - acc: 0.9714 - val_loss: 0.8526 - val_acc: 0.8120\n", "Epoch 36/200\n", "1/1 - 0s - loss: 0.1015 - acc: 0.9714 - val_loss: 0.8610 - val_acc: 0.8140\n", "Epoch 37/200\n", "1/1 - 0s - loss: 0.1253 - acc: 0.9714 - val_loss: 0.8680 - val_acc: 0.8180\n", "Epoch 38/200\n", "1/1 - 0s - loss: 0.0815 - acc: 0.9857 - val_loss: 0.8766 - val_acc: 0.8240\n", "Epoch 39/200\n", "1/1 - 0s - loss: 0.0822 - acc: 0.9857 - val_loss: 0.8847 - val_acc: 0.8200\n", "Epoch 40/200\n", "1/1 - 0s - loss: 0.0677 - acc: 0.9857 - val_loss: 0.8942 - val_acc: 0.8160\n", "Epoch 41/200\n", "1/1 - 0s - loss: 0.0633 - acc: 0.9786 - val_loss: 0.9061 - val_acc: 0.8140\n", "Epoch 42/200\n", "1/1 - 0s - loss: 0.0767 - acc: 0.9857 - val_loss: 0.9204 - val_acc: 0.8140\n", "Epoch 43/200\n", "1/1 - 0s - loss: 0.0427 - acc: 0.9929 - val_loss: 0.9353 - val_acc: 0.8120\n", "Epoch 44/200\n", "1/1 - 0s - loss: 0.1346 - acc: 0.9429 - val_loss: 0.9500 - val_acc: 0.8080\n", "Epoch 45/200\n", "1/1 - 0s - loss: 0.0318 - acc: 1.0000 - val_loss: 0.9651 - val_acc: 0.8100\n", "Epoch 46/200\n", "1/1 - 0s - loss: 0.0409 - acc: 0.9929 - val_loss: 0.9797 - val_acc: 0.8020\n", "Epoch 47/200\n", "1/1 - 0s - loss: 0.0551 - acc: 0.9786 - val_loss: 0.9891 - val_acc: 0.8040\n", "Epoch 48/200\n", "1/1 - 0s - loss: 0.0645 - acc: 0.9714 - val_loss: 0.9956 - val_acc: 0.8040\n", "Epoch 49/200\n", "1/1 - 0s - loss: 0.0550 - acc: 0.9857 - val_loss: 0.9981 - val_acc: 0.8020\n", "Epoch 50/200\n", "1/1 - 0s - loss: 0.0223 - acc: 1.0000 - val_loss: 0.9984 - val_acc: 0.8020\n", "Epoch 51/200\n", "1/1 - 0s - loss: 0.0533 - acc: 0.9857 - val_loss: 0.9987 - val_acc: 0.8040\n", "Epoch 52/200\n", "1/1 - 0s - loss: 0.0389 - acc: 1.0000 - val_loss: 0.9986 - val_acc: 0.8060\n", "Epoch 53/200\n", "1/1 - 0s - loss: 0.0559 - acc: 0.9929 - val_loss: 0.9956 - val_acc: 0.8060\n", "Epoch 54/200\n", "1/1 - 0s - loss: 0.0316 - acc: 0.9929 - val_loss: 0.9950 - val_acc: 0.8080\n", "Epoch 55/200\n", "1/1 - 0s - loss: 0.0392 - acc: 0.9857 - val_loss: 0.9925 - val_acc: 0.8060\n", "Epoch 56/200\n", "1/1 - 0s - loss: 0.0476 - acc: 0.9857 - val_loss: 0.9934 - val_acc: 0.8060\n", "Epoch 57/200\n", "1/1 - 0s - loss: 0.0574 - acc: 0.9857 - val_loss: 0.9916 - val_acc: 0.8080\n", "Epoch 58/200\n", "1/1 - 0s - loss: 0.0727 - acc: 0.9714 - val_loss: 0.9905 - val_acc: 0.8120\n", "Epoch 59/200\n", "1/1 - 0s - loss: 0.0540 - acc: 0.9857 - val_loss: 0.9890 - val_acc: 0.8080\n", "Epoch 60/200\n", "1/1 - 0s - loss: 0.0544 - acc: 0.9786 - val_loss: 0.9886 - val_acc: 0.8100\n", "Epoch 61/200\n", "1/1 - 0s - loss: 0.0553 - acc: 0.9929 - val_loss: 0.9901 - val_acc: 0.8100\n", "Epoch 62/200\n", "1/1 - 0s - loss: 0.0402 - acc: 0.9929 - val_loss: 0.9908 - val_acc: 0.8080\n", "Epoch 63/200\n", "1/1 - 0s - loss: 0.0172 - acc: 1.0000 - val_loss: 0.9922 - val_acc: 0.8100\n", "Epoch 64/200\n", "1/1 - 0s - loss: 0.0376 - acc: 0.9929 - val_loss: 0.9929 - val_acc: 0.8080\n", "Epoch 65/200\n", "1/1 - 0s - loss: 0.0247 - acc: 0.9929 - val_loss: 0.9941 - val_acc: 0.8100\n", "Epoch 66/200\n", "1/1 - 0s - loss: 0.1193 - acc: 0.9571 - val_loss: 0.9894 - val_acc: 0.8100\n", "Epoch 67/200\n", "1/1 - 0s - loss: 0.0259 - acc: 0.9929 - val_loss: 0.9872 - val_acc: 0.8080\n", "Epoch 68/200\n", "1/1 - 0s - loss: 0.0136 - acc: 1.0000 - val_loss: 0.9872 - val_acc: 0.8140\n", "Epoch 69/200\n", "1/1 - 0s - loss: 0.0250 - acc: 1.0000 - val_loss: 0.9908 - val_acc: 0.8160\n", "Epoch 70/200\n", "1/1 - 0s - loss: 0.0392 - acc: 0.9929 - val_loss: 0.9970 - val_acc: 0.8220\n", "Epoch 71/200\n", "1/1 - 0s - loss: 0.0253 - acc: 1.0000 - val_loss: 1.0030 - val_acc: 0.8140\n", "Epoch 72/200\n", "1/1 - 0s - loss: 0.0219 - acc: 1.0000 - val_loss: 1.0105 - val_acc: 0.8140\n", "Epoch 73/200\n", "1/1 - 0s - loss: 0.0206 - acc: 0.9929 - val_loss: 1.0190 - val_acc: 0.8080\n", "Epoch 74/200\n", "1/1 - 0s - loss: 0.0228 - acc: 1.0000 - val_loss: 1.0272 - val_acc: 0.8060\n", "Epoch 75/200\n", "1/1 - 0s - loss: 0.0211 - acc: 0.9929 - val_loss: 1.0353 - val_acc: 0.8040\n", "Epoch 76/200\n", "1/1 - 0s - loss: 0.0355 - acc: 0.9857 - val_loss: 1.0439 - val_acc: 0.8020\n", "Epoch 77/200\n", "1/1 - 0s - loss: 0.0325 - acc: 0.9857 - val_loss: 1.0548 - val_acc: 0.7980\n", "Epoch 78/200\n", "1/1 - 0s - loss: 0.0235 - acc: 1.0000 - val_loss: 1.0655 - val_acc: 0.8000\n", "Epoch 79/200\n", "1/1 - 0s - loss: 0.0266 - acc: 0.9929 - val_loss: 1.0742 - val_acc: 0.8000\n", "Epoch 80/200\n", "1/1 - 0s - loss: 0.0585 - acc: 0.9857 - val_loss: 1.0839 - val_acc: 0.8040\n", "Epoch 81/200\n", "1/1 - 0s - loss: 0.0626 - acc: 0.9857 - val_loss: 1.0925 - val_acc: 0.7980\n", "Epoch 82/200\n", "1/1 - 0s - loss: 0.0198 - acc: 1.0000 - val_loss: 1.1006 - val_acc: 0.7980\n", "Epoch 83/200\n", "1/1 - 0s - loss: 0.0259 - acc: 0.9929 - val_loss: 1.1047 - val_acc: 0.8000\n", "Epoch 84/200\n", "1/1 - 0s - loss: 0.0296 - acc: 0.9929 - val_loss: 1.1079 - val_acc: 0.8020\n", "Epoch 85/200\n", "1/1 - 0s - loss: 0.0236 - acc: 0.9929 - val_loss: 1.1077 - val_acc: 0.8060\n", "Epoch 86/200\n", "1/1 - 0s - loss: 0.0440 - acc: 0.9714 - val_loss: 1.1033 - val_acc: 0.8040\n", "Epoch 87/200\n", "1/1 - 0s - loss: 0.0324 - acc: 0.9929 - val_loss: 1.0994 - val_acc: 0.8020\n", "Epoch 88/200\n", "1/1 - 0s - loss: 0.0359 - acc: 0.9857 - val_loss: 1.0955 - val_acc: 0.8040\n" ] } ], "source": [ "history = model.fit(\n", " train_gen,\n", " epochs=200,\n", " validation_data=val_gen,\n", " verbose=2,\n", " shuffle=False, # this should be False, since shuffling data means shuffling the whole graph\n", " callbacks=[es_callback],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once we've trained the model, we can view the behaviour loss function and any other metrics using the `plot_history` function ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.utils.plot_history)). In this case, we can see the loss and accuracy on both the training and validation sets." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sg.utils.plot_history(history)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As the final part of our evaluation, let's check the model against the test set. We again create the data required for this using the `flow` method on our `FullBatchNodeGenerator` from above, and can use the model's `evaluate` method ([docs](https://www.tensorflow.org/api_docs/python/tf/keras/Model#evaluate)) to compute the metric values for the trained model.\n", "\n", "As expected, the model performs similarly on the validation set during training and on the test set here." ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "test_gen = generator.flow(test_subjects.index, test_targets)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " ['...']\n", "1/1 [==============================] - 0s 11ms/step - loss: 0.6904 - acc: 0.8298\n", "\n", "Test Set Metrics:\n", "\tloss: 0.6904\n", "\tacc: 0.8298\n" ] } ], "source": [ "test_metrics = model.evaluate(test_gen)\n", "print(\"\\nTest Set Metrics:\")\n", "for name, val in zip(model.metrics_names, test_metrics):\n", " print(\"\\t{}: {:0.4f}\".format(name, val))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Making predictions with the model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's get the predictions for all nodes. You're probably getting used to it by now, but we use our `FullBatchNodeGenerator` to create the input required and then use one of the model's methods: `predict` ([docs](https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict)). This time we _don't_ provide the labels to `flow`, and instead just the nodes, because we're trying to predict these classes without knowing them." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "all_nodes = node_subjects.index\n", "all_gen = generator.flow(all_nodes)\n", "all_predictions = model.predict(all_gen)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These predictions will be the output of the softmax layer, so to get final categories we'll use the `inverse_transform` method of our target attribute specification to turn these values back to the original categories.\n", "\n", "Note that for full-batch methods the batch size is 1 and the predictions have shape $(1, N_{nodes}, N_{classes})$ so we we remove the batch dimension to obtain predictions of shape $(N_{nodes}, N_{classes})$ using the NumPy `squeeze` method." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "node_predictions = target_encoding.inverse_transform(all_predictions.squeeze())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's have a look at a few predictions after training the model:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PredictedTrue
31336Neural_NetworksNeural_Networks
1061127Rule_LearningRule_Learning
1106406Reinforcement_LearningReinforcement_Learning
13195Reinforcement_LearningReinforcement_Learning
37879Probabilistic_MethodsProbabilistic_Methods
1126012Probabilistic_MethodsProbabilistic_Methods
1107140Reinforcement_LearningTheory
1102850Neural_NetworksNeural_Networks
31349Neural_NetworksNeural_Networks
1106418TheoryTheory
1123188Probabilistic_MethodsNeural_Networks
1128990Reinforcement_LearningGenetic_Algorithms
109323Probabilistic_MethodsProbabilistic_Methods
217139Case_BasedCase_Based
31353Neural_NetworksNeural_Networks
32083Neural_NetworksNeural_Networks
1126029Reinforcement_LearningReinforcement_Learning
1118017Neural_NetworksNeural_Networks
49482Neural_NetworksNeural_Networks
753265TheoryNeural_Networks
\n", "
" ], "text/plain": [ " Predicted True\n", "31336 Neural_Networks Neural_Networks\n", "1061127 Rule_Learning Rule_Learning\n", "1106406 Reinforcement_Learning Reinforcement_Learning\n", "13195 Reinforcement_Learning Reinforcement_Learning\n", "37879 Probabilistic_Methods Probabilistic_Methods\n", "1126012 Probabilistic_Methods Probabilistic_Methods\n", "1107140 Reinforcement_Learning Theory\n", "1102850 Neural_Networks Neural_Networks\n", "31349 Neural_Networks Neural_Networks\n", "1106418 Theory Theory\n", "1123188 Probabilistic_Methods Neural_Networks\n", "1128990 Reinforcement_Learning Genetic_Algorithms\n", "109323 Probabilistic_Methods Probabilistic_Methods\n", "217139 Case_Based Case_Based\n", "31353 Neural_Networks Neural_Networks\n", "32083 Neural_Networks Neural_Networks\n", "1126029 Reinforcement_Learning Reinforcement_Learning\n", "1118017 Neural_Networks Neural_Networks\n", "49482 Neural_Networks Neural_Networks\n", "753265 Theory Neural_Networks" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.DataFrame({\"Predicted\": node_predictions, \"True\": node_subjects})\n", "df.head(20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Node embeddings\n", "\n", "In addition to just predicting the node class, it can be useful to get a more detailed picture of what information the model has learnt about the nodes and their neighbourhoods. In this case, this means an embedding of the node (also called a \"representation\") into a latent vector space that captures that information, and it comes in the form of either a look-up table mapping node to a vector of numbers, or a neural network that produces those vectors. For GCN, we're going to be using the second option, using the last graph convolution layer of the GCN model (called `x_out` above), before we applied the prediction layer.\n", "\n", "We can visualise these embeddings as points on a plot, colored by their true subject labels. If the model has learned useful information about the nodes based on their class, we expect to see nice clusters of papers in the node embedding space, with papers of the same subject belonging to the same cluster.\n", "\n", "To create a model that computes node embeddings, we use the same input tensors (`x_inp`) as the prediction model above, and just swap the output tensor to the GCN one (`x_out`) instead of the prediction layer. These tensors are connected to the same layers and weights that we trained when training the predictions above, and so we're only using this model to compute/\"predict\" the node embedding vectors. Similar to doing predictions for every node, we will compute embeddings for every node using the `all_gen` data." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "embedding_model = Model(inputs=x_inp, outputs=x_out)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(1, 2708, 16)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "emb = embedding_model.predict(all_gen)\n", "emb.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The last GCN layer had output dimension 16, meaning each embedding consists of 16 numbers. Plotting this directly would require a 16 dimensional plot, which is hard for humans to visualise. Instead, we can first project these vectors down to just 2 numbers, making vectors of dimension 2 that can be plotted on a normal 2D scatter plot.\n", "\n", "There are many tools for this [dimensionality reduction](https://en.wikipedia.org/wiki/Dimensionality_reduction) task, many of which are offered by scikit-learn. Two of the more common ones are [principal component analysis (PCA)](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html) (which is linear) and [t-distributed Stochastic Neighbor Embedding (t-SNE or TSNE)](https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html) (non-linear). t-SNE is slower but typically gives nicer results for plotting." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "from sklearn.decomposition import PCA\n", "from sklearn.manifold import TSNE\n", "\n", "transform = TSNE # or PCA" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the embeddings from the GCN model have a batch dimension of 1 so we `squeeze` this to get a matrix of $N_{nodes} \\times N_{emb}$." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2708, 16)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X = emb.squeeze(0)\n", "X.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We've thus prepared our high-dimension embeddings and chosen our dimension-reduction transform, so we now compute the reduced vectors, as two columns of the new values." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(2708, 2)" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trans = transform(n_components=2)\n", "X_reduced = trans.fit_transform(X)\n", "X_reduced.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `X_reduced` values contains a pair of numbers for each node, in the same order as the `node_subjects` Series of ground-truth labels (because that's how `all_gen` was created). This is enough to do a scatter plot of the nodes, with colors. We can let matplotlib compute the colors by mapping the subjects to integers 0, 1, ..., 6, using [Pandas's support for categorical data](https://pandas.pydata.org/pandas-docs/stable/user_guide/categorical.html).\n", "\n", "Qualitatively, the plot shows good clustering, where nodes of a single colour are mostly grouped together." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Text(0, 0.5, '$X_2$'),\n", " Text(0.5, 0, '$X_1$'),\n", " Text(0.5, 1.0, 'TSNE visualization of GCN embeddings for cora dataset'),\n", " None]" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, ax = plt.subplots(figsize=(7, 7))\n", "ax.scatter(\n", " X_reduced[:, 0],\n", " X_reduced[:, 1],\n", " c=node_subjects.astype(\"category\").cat.codes,\n", " cmap=\"jet\",\n", " alpha=0.7,\n", ")\n", "ax.set(\n", " aspect=\"equal\",\n", " xlabel=\"$X_1$\",\n", " ylabel=\"$X_2$\",\n", " title=f\"{transform.__name__} visualization of GCN embeddings for cora dataset\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "This notebook gave an example using the GCN algorithm to predict the class of nodes. Specifically, the subject of an academic paper in the Cora dataset. Our model used:\n", "\n", "- the graph structure of the dataset, in the form of citation links between papers\n", "- the 1433-dimensional feature vectors associated with each paper\n", "\n", "Once we trained a model for prediction, we could:\n", "\n", "- predict the classes of nodes\n", "- use the model's weights to compute vector embeddings for nodes\n", "\n", "This notebook ran through the following steps:\n", "\n", "1. prepared the data using common data science libraries\n", "2. built a TensorFlow Keras model and data generator with [the StellarGraph library](https://github.com/stellargraph/stellargraph) \n", "3. trained and evaluated it using TensorFlow and other libraries\n", "\n", "For problems with only small amounts of labelled data, model performance can be improved by semi-supervised training. See [the GCN + Deep Graph Infomax fine-tuning demo](gcn-deep-graph-infomax-fine-tuning-node-classification.ipynb) for more details on how to do this.\n", "\n", "StellarGraph includes [other algorithms for node classification](README.md) and [algorithms and demos for other tasks](../README.md). Most can be applied with the same basic structure as this GCN demo." ] }, { "cell_type": "markdown", "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "source": [ "
Run the latest release of this notebook:
" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 4 }