{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Node representations with GraphWave\n" ] }, { "cell_type": "markdown", "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "source": [ "
Run the latest release of this notebook:
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This demo features the algorithm GraphWave published in \"Learning Structural Node Embeddings via Diffusion Wavelets\" [https://arxiv.org/pdf/1710.10321.pdf]. GraphWave embeds the structural features of a node in a dense embeddings. We will demonstrate the use of GraphWave on a barbell graph demonstrating that structurally equivalent nodes have similar embeddings." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we load the required libraries." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "outputs": [], "source": [ "# install StellarGraph if running on Google Colab\n", "import sys\n", "if 'google.colab' in sys.modules:\n", " %pip install -q stellargraph[demos]==1.2.1" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "nbsphinx": "hidden", "tags": [ "VersionCheck" ] }, "outputs": [], "source": [ "# verify that we're using the correct version of StellarGraph for this notebook\n", "import stellargraph as sg\n", "\n", "try:\n", " sg.utils.validate_notebook_version(\"1.2.1\")\n", "except AttributeError:\n", " raise ValueError(\n", " f\"This notebook requires StellarGraph version 1.2.1, but a different version {sg.__version__} is installed. Please see .\"\n", " ) from None" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import networkx as nx\n", "from stellargraph.mapper import GraphWaveGenerator\n", "from stellargraph import StellarGraph\n", "from sklearn.decomposition import PCA\n", "import numpy as np\n", "from matplotlib import pyplot as plt\n", "from scipy.sparse.linalg import eigs\n", "import tensorflow as tf\n", "from tensorflow.keras import backend as K" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Graph construction\n", "\n", "Next, we construct the barbell graph, shown below. It consists of 2 fully connected graphs (at the 'ends' of the graph) connected by a chain of nodes. All nodes in the fully connected ends are structurally equivalent, as are the opposite nodes in the chain. A good structural embedding algorithm should embed equivalent nodes close together in the embedding space." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "m1 = 9\n", "m2 = 11\n", "gnx = nx.barbell_graph(m1=m1, m2=m2)\n", "\n", "classes = [0,] * len(gnx.nodes)\n", "# number of nodes with a non-zero class (the path, plus the nodes it connects to on either end)\n", "nonzero = m2 + 2\n", "# count up to the halfway point (rounded up)\n", "first = range(1, (nonzero + 1) // 2 + 1)\n", "# and down for the rest\n", "second = reversed(range(1, nonzero - len(first) + 1))\n", "classes[m1 - 1 : (m1 + m2) + 1] = list(first) + list(second)\n", "\n", "nx.draw(gnx, node_color=classes, cmap=\"jet\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## GraphWave embedding calculation\n", "\n", "Now, we're ready to calculate the GraphWave embeddings. We need to specify some information about the approximation to use: \n", "\n", "- an iterable of wavelet `scales` to use. This is a graph and task dependent hyperparameter. Larger scales extract larger scale features and smaller scales extract more local structural features. Experiment with different values. \n", "- the `sample_points` at which to sample the characteristic function. This should be of the form `sample_points=np.linspace(0, max_val, number_of_samples)`. The best value depends on the graph.\n", "- the `degree` of Chebyshev poly\n", "\n", "The dimension of the embeddings are `2 * len(scales) * len(sample_points)`" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "G = StellarGraph.from_networkx(gnx)\n", "sample_points = np.linspace(0, 100, 50).astype(np.float32)\n", "degree = 20\n", "scales = [5, 10]\n", "\n", "generator = GraphWaveGenerator(G, scales=scales, degree=degree)\n", "\n", "embeddings_dataset = generator.flow(\n", " node_ids=G.nodes(), sample_points=sample_points, batch_size=1, repeat=False\n", ")\n", "\n", "embeddings = [x.numpy() for x in embeddings_dataset]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualisation\n", "\n", "The nodes are coloured by their structural role, e.g. in the fully connected sections, first node in the chain, second node in the chain etc. We can see that all nodes of the same colour completely overlap in this visualisation, indicating that structurally equivalent nodes are very close in the embedding space.\n", "\n", "The plot here doesn't exactly match the one in the paper, which we think is because the details of approximating the wavelet diffusion differ, the paper uses `pygsp` to calculate the Chebyshev coefficient while `StellarGraph` uses `numpy` to calculate the coefficients. Some brief experiments have shown that the `numpy` Chebyshev coefficients are more accurate than the `pygsp` coefficients." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "trans_emb = PCA(n_components=2).fit_transform(np.vstack(embeddings))\n", "\n", "plt.scatter(\n", " trans_emb[:, 0], trans_emb[:, 1], c=classes, cmap=\"jet\", alpha=0.7,\n", ")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "source": [ "
Run the latest release of this notebook:
" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }