{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Label propagation and label spreading on graphs\n", "\n", "## Base class" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from abc import abstractmethod\n", "import torch\n", "\n", "class BaseLabelPropagation:\n", " \"\"\"Base class for label propagation models.\n", " \n", " Parameters\n", " ----------\n", " adj_matrix: torch.FloatTensor\n", " Adjacency matrix of the graph.\n", " \"\"\"\n", " def __init__(self, adj_matrix):\n", " self.norm_adj_matrix = self._normalize(adj_matrix)\n", " self.n_nodes = adj_matrix.size(0)\n", " self.one_hot_labels = None \n", " self.n_classes = None\n", " self.labeled_mask = None\n", " self.predictions = None\n", "\n", " @staticmethod\n", " @abstractmethod\n", " def _normalize(adj_matrix):\n", " raise NotImplementedError(\"_normalize must be implemented\")\n", "\n", " @abstractmethod\n", " def _propagate(self):\n", " raise NotImplementedError(\"_propagate must be implemented\")\n", "\n", " def _one_hot_encode(self, labels):\n", " # Get the number of classes\n", " classes = torch.unique(labels)\n", " classes = classes[classes != -1]\n", " self.n_classes = classes.size(0)\n", "\n", " # One-hot encode labeled data instances and zero rows corresponding to unlabeled instances\n", " unlabeled_mask = (labels == -1)\n", " labels = labels.clone() # defensive copying\n", " labels[unlabeled_mask] = 0\n", " self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)\n", " self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)\n", " self.one_hot_labels[unlabeled_mask, 0] = 0\n", "\n", " self.labeled_mask = ~unlabeled_mask\n", "\n", " def fit(self, labels, max_iter, tol):\n", " \"\"\"Fits a semi-supervised learning label propagation model.\n", " \n", " labels: torch.LongTensor\n", " Tensor of size n_nodes indicating the class number of each node.\n", " Unlabeled nodes are denoted with -1.\n", " max_iter: int\n", " Maximum number of iterations allowed.\n", " tol: float\n", " Convergence tolerance: threshold to consider the system at steady state.\n", " \"\"\"\n", " self._one_hot_encode(labels)\n", "\n", " self.predictions = self.one_hot_labels.clone()\n", " prev_predictions = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)\n", "\n", " for i in range(max_iter):\n", " # Stop iterations if the system is considered at a steady state\n", " variation = torch.abs(self.predictions - prev_predictions).sum().item()\n", " \n", " if variation < tol:\n", " print(f\"The method stopped after {i} iterations, variation={variation:.4f}.\")\n", " break\n", "\n", " prev_predictions = self.predictions\n", " self._propagate()\n", "\n", " def predict(self):\n", " return self.predictions\n", "\n", " def predict_classes(self):\n", " return self.predictions.max(dim=1).indices" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Label propagation" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class LabelPropagation(BaseLabelPropagation):\n", " def __init__(self, adj_matrix):\n", " super().__init__(adj_matrix)\n", "\n", " @staticmethod\n", " def _normalize(adj_matrix):\n", " \"\"\"Computes D^-1 * W\"\"\"\n", " degs = adj_matrix.sum(dim=1)\n", " degs[degs == 0] = 1 # avoid division by 0 error\n", " return adj_matrix / degs[:, None]\n", "\n", " def _propagate(self):\n", " self.predictions = torch.matmul(self.norm_adj_matrix, self.predictions)\n", "\n", " # Put back already known labels\n", " self.predictions[self.labeled_mask] = self.one_hot_labels[self.labeled_mask]\n", "\n", " def fit(self, labels, max_iter=1000, tol=1e-3):\n", " super().fit(labels, max_iter, tol)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Label spreading" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class LabelSpreading(BaseLabelPropagation):\n", " def __init__(self, adj_matrix):\n", " super().__init__(adj_matrix)\n", " self.alpha = None\n", "\n", " @staticmethod\n", " def _normalize(adj_matrix):\n", " \"\"\"Computes D^-1/2 * W * D^-1/2\"\"\"\n", " degs = adj_matrix.sum(dim=1)\n", " norm = torch.pow(degs, -0.5)\n", " norm[torch.isinf(norm)] = 1\n", " return adj_matrix * norm[:, None] * norm[None, :]\n", "\n", " def _propagate(self):\n", " self.predictions = (\n", " self.alpha * torch.matmul(self.norm_adj_matrix, self.predictions)\n", " + (1 - self.alpha) * self.one_hot_labels\n", " )\n", " \n", " def fit(self, labels, max_iter=1000, tol=1e-3, alpha=0.5):\n", " \"\"\"\n", " Parameters\n", " ----------\n", " alpha: float\n", " Clamping factor.\n", " \"\"\"\n", " self.alpha = alpha\n", " super().fit(labels, max_iter, tol)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Testing models on synthetic data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Label Propagation: The method stopped after 73 iterations, variation=0.0010.\n", "Label Spreading: The method stopped after 20 iterations, variation=0.0009.\n" ] } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "import networkx as nx\n", "import matplotlib.pyplot as plt\n", "\n", "# Create caveman graph\n", "n_cliques = 4\n", "size_cliques = 10\n", "caveman_graph = nx.connected_caveman_graph(n_cliques, size_cliques)\n", "adj_matrix = nx.adjacency_matrix(caveman_graph).toarray()\n", "\n", "# Create labels\n", "labels = np.full(n_cliques * size_cliques, -1.)\n", "\n", "# Only one node per clique is labeled. Each clique belongs to a different class.\n", "labels[0] = 0\n", "labels[size_cliques] = 1\n", "labels[size_cliques * 2] = 2\n", "labels[size_cliques * 3] = 3\n", "\n", "# Create input tensors\n", "adj_matrix_t = torch.FloatTensor(adj_matrix)\n", "labels_t = torch.LongTensor(labels)\n", "\n", "# Learn with Label Propagation\n", "label_propagation = LabelPropagation(adj_matrix_t)\n", "print(\"Label Propagation: \", end=\"\")\n", "label_propagation.fit(labels_t)\n", "label_propagation_output_labels = label_propagation.predict_classes()\n", "\n", "# Learn with Label Spreading\n", "label_spreading = LabelSpreading(adj_matrix_t)\n", "print(\"Label Spreading: \", end=\"\")\n", "label_spreading.fit(labels_t, alpha=0.8)\n", "label_spreading_output_labels = label_spreading.predict_classes()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot graphs\n", "color_map = {-1: \"orange\", 0: \"blue\", 1: \"green\", 2: \"red\", 3: \"cyan\"}\n", "input_labels_colors = [color_map[l] for l in labels]\n", "lprop_labels_colors = [color_map[l] for l in label_propagation_output_labels.numpy()]\n", "lspread_labels_colors = [color_map[l] for l in label_spreading_output_labels.numpy()]\n", "\n", "plt.figure(figsize=(14, 6))\n", "ax1 = plt.subplot(1, 4, 1)\n", "ax2 = plt.subplot(1, 4, 2)\n", "ax3 = plt.subplot(1, 4, 3)\n", "\n", "ax1.title.set_text(\"Raw data (4 classes)\")\n", "ax2.title.set_text(\"Label Propagation\")\n", "ax3.title.set_text(\"Label Spreading\")\n", "\n", "pos = nx.spring_layout(caveman_graph)\n", "nx.draw(caveman_graph, ax=ax1, pos=pos, node_color=input_labels_colors, node_size=50)\n", "nx.draw(caveman_graph, ax=ax2, pos=pos, node_color=lprop_labels_colors, node_size=50)\n", "nx.draw(caveman_graph, ax=ax3, pos=pos, node_color=lspread_labels_colors, node_size=50)\n", "\n", "# Legend\n", "ax4 = plt.subplot(1, 4, 4)\n", "ax4.axis(\"off\")\n", "legend_colors = [\"orange\", \"blue\", \"green\", \"red\", \"cyan\"]\n", "legend_labels = [\"unlabeled\", \"class 0\", \"class 1\", \"class 2\", \"class 3\"]\n", "dummy_legend = [ax4.plot([], [], ls='-', c=c)[0] for c in legend_colors]\n", "plt.legend(dummy_legend, legend_labels)\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 4 }