{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Ensemble models for link prediction" ] }, { "cell_type": "markdown", "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "source": [ "
Run the latest release of this notebook:
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this example, we use `stellargraph`s `BaggingEnsemble` class of [GraphSAGE](http://snap.stanford.edu/graphsage/) models to predict citation links in the Cora dataset (see below). The `BaggingEnsemble` class brings ensemble learning to `stellargraph`'s graph neural network models, e.g., `GraphSAGE`, quantifying prediction variance and potentially improving prediction accuracy. \n", "\n", "The problem is treated as a supervised link prediction problem on a homogeneous citation network with nodes representing papers (with attributes such as binary keyword indicators and categorical subject) and links corresponding to paper-paper citations. \n", "\n", "To address this problem, we build a a base `GraphSAGE` model with the following architecture. First we build a two-layer GraphSAGE model that takes labeled `(paper1, paper2)` node pairs corresponding to possible citation links, and outputs a pair of node embeddings for the `paper1` and `paper2` nodes of the pair. These embeddings are then fed into a link classification layer, which first applies a binary operator to those node embeddings (e.g., concatenating them) to construct the embedding of the potential link. Thus obtained link embeddings are passed through the dense link classification layer to obtain link predictions - probability for these candidate links to actually exist in the network. The entire model is trained end-to-end by minimizing the loss function of choice (e.g., binary cross-entropy between predicted link probabilities and true link labels, with true/false citation links having labels 1/0) using stochastic gradient descent (SGD) updates of the model parameters, with minibatches of 'training' links fed into the model.\n", "\n", "Finally, using our base model, we create an ensemble with each model in the ensemble trained on a bootstrapped sample of the training data. \n", "\n", "**References**\n", "\n", "1. Inductive Representation Learning on Large Graphs. W.L. Hamilton, R. Ying, and J. Leskovec arXiv:1706.02216 \n", "[cs.SI], 2017." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "outputs": [], "source": [ "# install StellarGraph if running on Google Colab\n", "import sys\n", "if 'google.colab' in sys.modules:\n", " %pip install -q stellargraph[demos]==1.1.0" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "nbsphinx": "hidden", "tags": [ "VersionCheck" ] }, "outputs": [], "source": [ "# verify that we're using the correct version of StellarGraph for this notebook\n", "import stellargraph as sg\n", "\n", "try:\n", " sg.utils.validate_notebook_version(\"1.1.0\")\n", "except AttributeError:\n", " raise ValueError(\n", " f\"This notebook requires StellarGraph version 1.1.0, but a different version {sg.__version__} is installed. Please see .\"\n", " ) from None" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import networkx as nx\n", "import pandas as pd\n", "import numpy as np\n", "from tensorflow import keras\n", "import os\n", "\n", "import stellargraph as sg\n", "from stellargraph.data import EdgeSplitter\n", "from stellargraph.mapper import GraphSAGELinkGenerator\n", "from stellargraph.layer import GraphSAGE, link_classification\n", "from stellargraph.ensemble import BaggingEnsemble\n", "\n", "from sklearn import preprocessing, feature_extraction, model_selection\n", "\n", "from stellargraph import globalvar\n", "from stellargraph import datasets\n", "from IPython.display import display, HTML\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Loading the CORA network data" ] }, { "cell_type": "markdown", "metadata": { "tags": [ "DataLoadingLinks" ] }, "source": [ "(See [the \"Loading from Pandas\" demo](../basics/loading-pandas.ipynb) for details on how data can be loaded.)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [ "DataLoading" ] }, "outputs": [ { "data": { "text/html": [ "The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "dataset = datasets.Cora()\n", "display(HTML(dataset.description))\n", "G, _subjects = dataset.load()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We aim to train a link prediction model, hence we need to prepare the train and test sets of links and the corresponding graphs with those links removed.\n", "\n", "We are going to split our input graph into train and test graphs using the `EdgeSplitter` class in `stellargraph.data`. We will use the train graph for training the model (a binary classifier that, given two nodes, predicts whether a link between these two nodes should exist or not) and the test graph for evaluating the model's performance on hold out data.\n", "\n", "Each of these graphs will have the same number of nodes as the input graph, but the number of links will differ (be reduced) as some of the links will be removed during each split and used as the positive samples for training/testing the link prediction classifier." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From the original graph G, extract a randomly sampled subset of test edges (true and false citation links) and the reduced graph G_test with the positive test edges removed:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "** Sampled 542 positive and 542 negative edges. **\n" ] } ], "source": [ "# Define an edge splitter on the original graph G:\n", "edge_splitter_test = EdgeSplitter(G)\n", "\n", "# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G, and obtain the\n", "# reduced graph G_test with the sampled links removed:\n", "G_test, edge_ids_test, edge_labels_test = edge_splitter_test.train_test_split(\n", " p=0.1, method=\"global\", keep_connected=True, seed=42\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The reduced graph G_test, together with the test ground truth set of links (edge_ids_test, edge_labels_test), will be used for testing the model.\n", "\n", "Now, repeat this procedure to obtain validation data that we are going to use for early stopping in order to prevent overfitting. From the reduced graph G_test, extract a randomly sampled subset of validation edges (true and false citation links) and the reduced graph G_val with the positive validation edges removed." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "** Sampled 488 positive and 488 negative edges. **\n" ] } ], "source": [ "# Define an edge splitter on the reduced graph G_test:\n", "edge_splitter_val = EdgeSplitter(G_test)\n", "\n", "# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G_test, and obtain the\n", "# reduced graph G_train with the sampled links removed:\n", "G_val, edge_ids_val, edge_labels_val = edge_splitter_val.train_test_split(\n", " p=0.1, method=\"global\", keep_connected=True, seed=100\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We repeat this procedure one last time in order to obtain the training data for the model.\n", "From the reduced graph G_val, extract a randomly sampled subset of train edges (true and false citation links) and the reduced graph G_train with the positive train edges removed:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "** Sampled 488 positive and 488 negative edges. **\n" ] } ], "source": [ "# Define an edge splitter on the reduced graph G_test:\n", "edge_splitter_train = EdgeSplitter(G_test)\n", "\n", "# Randomly sample a fraction p=0.1 of all positive links, and same number of negative links, from G_test, and obtain the\n", "# reduced graph G_train with the sampled links removed:\n", "G_train, edge_ids_train, edge_labels_train = edge_splitter_train.train_test_split(\n", " p=0.1, method=\"global\", keep_connected=True, seed=42\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "G_train, together with the train ground truth set of links (edge_ids_train, edge_labels_train), will be used for training the model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Summary of G_train and G_test - note that they have the same set of nodes, only differing in their edge sets:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StellarGraph: Undirected multigraph\n", " Nodes: 2708, Edges: 4399\n", "\n", " Node types:\n", " paper: [2708]\n", " Features: float32 vector, length 1433\n", " Edge types: paper-cites->paper\n", "\n", " Edge types:\n", " paper-cites->paper: [4399]\n" ] } ], "source": [ "print(G_train.info())" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StellarGraph: Undirected multigraph\n", " Nodes: 2708, Edges: 4887\n", "\n", " Node types:\n", " paper: [2708]\n", " Features: float32 vector, length 1433\n", " Edge types: paper-cites->paper\n", "\n", " Edge types:\n", " paper-cites->paper: [4887]\n" ] } ], "source": [ "print(G_test.info())" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "StellarGraph: Undirected multigraph\n", " Nodes: 2708, Edges: 4399\n", "\n", " Node types:\n", " paper: [2708]\n", " Features: float32 vector, length 1433\n", " Edge types: paper-cites->paper\n", "\n", " Edge types:\n", " paper-cites->paper: [4399]\n" ] } ], "source": [ "print(G_val.info())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Specify global parameters\n", "\n", "Here we specify some important parameters that control the type of ensemble model we are going to use. For example, we specify the number of models in the ensemble and the number of predictions per query point per model." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "n_estimators = 5 # Number of models in the ensemble\n", "n_predictions = 10 # Number of predictions per query point per model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we create link generators for sampling and streaming train and test link examples to the model. The link generators essentially \"map\" pairs of nodes `(paper1, paper2)` to the input of GraphSAGE: they take minibatches of node pairs, sample 2-hop subgraphs with `(paper1, paper2)` head nodes extracted from those pairs, and feed them, together with the corresponding binary labels indicating whether those pairs represent true or false citation links, to the input layer of the GraphSAGE model, for SGD updates of the model parameters.\n", "\n", "Specify the minibatch size (number of node pairs per minibatch) and the number of epochs for training the model:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "batch_size = 20\n", "epochs = 20" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Specify the sizes of 1- and 2-hop neighbour samples for GraphSAGE. Note that the length of `num_samples` list defines the number of layers/iterations in the GraphSAGE model. In this example, we are defining a 2-layer GraphSAGE model:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "num_samples = [20, 10]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create the generators for training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For training we create a generator on the `G_train` graph. The `shuffle=True` argument is given to the `flow` method to improve training." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "generator = GraphSAGELinkGenerator(G_train, batch_size, num_samples)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "train_gen = generator.flow(edge_ids_train, edge_labels_train, shuffle=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "At test time we use the `G_test` graph and don't specify the `shuffle` argument (it defaults to `False`)." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "test_gen = GraphSAGELinkGenerator(G_test, batch_size, num_samples).flow(\n", " edge_ids_test, edge_labels_test\n", ")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "val_gen = GraphSAGELinkGenerator(G_val, batch_size, num_samples).flow(\n", " edge_ids_val, edge_labels_val\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create the base GraphSAGE model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Build the model: a 2-layer GraphSAGE model acting as node representation learner, with a link classification layer on concatenated `(paper1, paper2)` node embeddings.\n", "\n", "GraphSAGE part of the model, with hidden layer sizes of 20 for both GraphSAGE layers, a bias term, and no dropout. (Dropout can be switched on by specifying a positive dropout rate, 0 < dropout < 1)\n", "\n", "Note that the length of layer_sizes list must be equal to the length of num_samples, as len(num_samples) defines the number of hops (layers) in the GraphSAGE model." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "layer_sizes = [20, 20]\n", "assert len(layer_sizes) == len(num_samples)\n", "\n", "graphsage = GraphSAGE(\n", " layer_sizes=layer_sizes, generator=generator, bias=True, dropout=0.5\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Build the model and expose the input and output tensors.\n", "x_inp, x_out = graphsage.in_out_tensors()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Final link classification layer that takes a pair of node embeddings produced by graphsage, applies a binary operator to them to produce the corresponding link embedding ('ip' for inner product; other options for the binary operator can be seen by running a cell with `?link_classification` in it), and passes it through a dense layer:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "link_classification: using 'ip' method to combine node embeddings into edge embeddings\n" ] } ], "source": [ "prediction = link_classification(\n", " output_dim=1, output_act=\"relu\", edge_embedding_method=\"ip\"\n", ")(x_out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Stack the GraphSAGE and prediction layers into a Keras model." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "base_model = keras.Model(inputs=x_inp, outputs=prediction)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we create the ensemble based on `base_model` we just created." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "model = BaggingEnsemble(\n", " model=base_model, n_estimators=n_estimators, n_predictions=n_predictions\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to `compile` the model specifying the optimiser, loss function, and metrics to use." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "model.compile(\n", " optimizer=keras.optimizers.Adam(lr=1e-3),\n", " loss=keras.losses.binary_crossentropy,\n", " weighted_metrics=[\"acc\"],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Evaluate the initial (untrained) ensemble of models on the train and test set:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", "\n", "Train Set Metrics of the initial (untrained) model:\n", "\tloss: 0.6999±0.0411\n", "\tacc: 0.6063±0.0191\n", "\n", "Test Set Metrics of the initial (untrained) model:\n", "\tloss: 0.7105±0.0467\n", "\tacc: 0.6152±0.0269\n" ] } ], "source": [ "init_train_metrics_mean, init_train_metrics_std = model.evaluate(train_gen)\n", "init_test_metrics_mean, init_test_metrics_std = model.evaluate(test_gen)\n", "\n", "print(\"\\nTrain Set Metrics of the initial (untrained) model:\")\n", "for name, m, s in zip(\n", " model.metrics_names, init_train_metrics_mean, init_train_metrics_std\n", "):\n", " print(\"\\t{}: {:0.4f}±{:0.4f}\".format(name, m, s))\n", "\n", "print(\"\\nTest Set Metrics of the initial (untrained) model:\")\n", "for name, m, s in zip(model.metrics_names, init_test_metrics_mean, init_test_metrics_std):\n", " print(\"\\t{}: {:0.4f}±{:0.4f}\".format(name, m, s))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Train the ensemble model\n", "\n", "We are going to use **bootstrap samples** of the training dataset to train each model in the ensemble. For this purpose, we need to pass `generator`, `edge_ids_train`, and `edge_labels_train` to the `fit` method.\n", "\n", "Note that training time will vary based on computer speed. Set `verbose=1` for reporting of training progress." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n" ] } ], "source": [ "history = model.fit(\n", " generator=generator,\n", " train_data=edge_ids_train,\n", " train_targets=edge_labels_train,\n", " epochs=epochs,\n", " validation_data=val_gen,\n", " verbose=0,\n", " use_early_stopping=True, # Enable early stopping\n", " early_stopping_monitor=\"val_acc\",\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot the training history:" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "sg.utils.plot_history(history)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Evaluate the trained model on test citation links. After training the model, performance should be better than before training (shown above):" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", " ['...']\n", "\n", "Train Set Metrics of the trained model:\n", "\tloss: 0.3113±0.0447\n", "\tacc: 0.9111±0.0099\n", "\n", "Test Set Metrics of the trained model:\n", "\tloss: 0.6903±0.1061\n", "\tacc: 0.7785±0.0141\n" ] } ], "source": [ "train_metrics_mean, train_metrics_std = model.evaluate(train_gen)\n", "test_metrics_mean, test_metrics_std = model.evaluate(test_gen)\n", "\n", "print(\"\\nTrain Set Metrics of the trained model:\")\n", "for name, m, s in zip(model.metrics_names, train_metrics_mean, train_metrics_std):\n", " print(\"\\t{}: {:0.4f}±{:0.4f}\".format(name, m, s))\n", "\n", "print(\"\\nTest Set Metrics of the trained model:\")\n", "for name, m, s in zip(model.metrics_names, test_metrics_mean, test_metrics_std):\n", " print(\"\\t{}: {:0.4f}±{:0.4f}\".format(name, m, s))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Make predictions with the model\n", "\n", "Now let's get the predictions for all the edges in the test set." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "test_predictions = model.predict(generator=test_gen)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These predictions will be the output of the last layer in the model with `sigmoid` activation.\n", "\n", "The array `test_predictions` has dimensionality $MxKxNxF$ where $M$ is the number of estimators in the ensemble (`n_estimators`); $K$ is the number of predictions per query point per estimator (`n_predictions`); $N$ is the number of query points (`len(test_predictions)`); and $F$ is the output dimensionality of the specified layer determined by the shape of the output layer (in this case it is equal to 1 since we are performing binary classification)." ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(numpy.ndarray, (5, 10, 1084, 1))" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(test_predictions), test_predictions.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For demonstration, we are going to select one of the edges in the test set, and plot the ensemble's predictions for that edge.\n", "\n", "Change the value of `selected_query_point` (valid values are in the range of `0` to `len(test_predictions)`) to visualise the results for another test point." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "selected_query_point = -10" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(5, 10, 1)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Select the predictios for the point specified by selected_query_point\n", "qp_predictions = test_predictions[:, :, selected_query_point, :]\n", "# The shape should be n_estimators x n_predictions x size_output_layer\n", "qp_predictions.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, to facilitate plotting the predictions using either a density plot or a box plot, we are going to reshape `qp_predictions` to $R\\times F$ where $R$ is equal to $M\\times K$ as above and $F$ is the output dimensionality of the output layer." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(50, 1)" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "qp_predictions = qp_predictions.reshape(\n", " np.product(qp_predictions.shape[0:-1]), qp_predictions.shape[-1]\n", ")\n", "qp_predictions.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model returns the probability of edge, the class to predict. The probability of no edge is just the complement of the latter. Let's calculate it so that we can plot the distribution of predictions for both outcomes." ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "qp_predictions = np.hstack((qp_predictions, 1.0 - qp_predictions,))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'd like to assess the ensemble's confidence in its predictions in order to decide if we can trust them or not. Utilising a box plot, we can visually inspect the ensemble's distribution of prediction probabilities for a point in the test set.\n", "\n", "If the spread of values for the predicted point class is well separated from those of the other class with little overlap then we can be confident that the prediction is correct." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0.5, 0, 'Class')" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "correct_label = \"Edge\"\n", "if edge_labels_test[selected_query_point] == 0:\n", " correct_label = \"No Edge\"\n", "\n", "fig, ax = plt.subplots(figsize=(12, 6))\n", "ax.boxplot(x=qp_predictions)\n", "ax.set_xticklabels([\"Edge\", \"No Edge\"])\n", "ax.tick_params(axis=\"x\", rotation=45)\n", "plt.title(\"Correct label is \" + correct_label)\n", "plt.ylabel(\"Predicted Probability\")\n", "plt.xlabel(\"Class\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the selected pair of nodes (query point), the ensemble is not certain as to whether an edge between these two nodes should exist. This can be inferred by the large spread of values as indicated in the above figure.\n", "\n", "(Note that due to the stochastic nature of training neural network algorithms, the above conclusion may not be valid if you re-run the notebook; however, the general conclusion that the use of ensemble learning can be used to quantify the model's uncertainty about its prediction still holds.)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The below image shows an example of the classifier making a correct prediction with higher confidence than the above example. The results is for the setting `selected_query_point=0`." ] }, { "attachments": { "image.png": { "image/png": "" } }, "cell_type": "markdown", "metadata": {}, "source": [ "![image.png](attachment:image.png)" ] }, { "cell_type": "markdown", "metadata": { "nbsphinx": "hidden", "tags": [ "CloudRunner" ] }, "source": [ "
Run the latest release of this notebook:
" ] } ], "metadata": { "file_extension": ".py", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.5" }, "mimetype": "text/x-python", "name": "python", "npconvert_exporter": "python", "pygments_lexer": "ipython3", "version": 3 }, "nbformat": 4, "nbformat_minor": 4 }