{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fUjv-SxfWgnY"
},
"source": [
"# Exercise 16.1 - Solution\n",
"## Zachary’s karate club - semi-supervised node classification\n",
"In this exercise, we investigate semi-supervised node classification using Graph Convolutional Networks on Zachary’s Karate Club dataset introduced in Example 10.2.\n",
"Sometime ago there was a dispute between the manager and the coach of the karate club which led to a split of the club into four groups.\n",
"\n",
"Can we use Graph Convolutional Networks to predict the affiliation of each member given the social network of the community and the memberships of only four people?"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FiBWdBxBWgnb"
},
"source": [
"The exercise uses spektral and networkx. If you don't have yet installed both packages, do so by executing:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "uej1SzZXWgnb"
},
"outputs": [],
"source": [
"import sys\n",
"!{sys.executable} -m pip install spektral\n",
"!{sys.executable} -m pip install networkx"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YKarg1plWgnd"
},
"source": [
"### Download data: Zachary’s karate club\n",
"You can find the original data set [here](http://vlado.fmf.uni-lj.si/pub/networks/data/ucinet/ucidata.htm#zachary)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gt5X2R30Wgnc",
"outputId": "6120bc4d-b196-4f4c-f344-55b45fcd7026"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"spektral 1.3.0\n",
"keras 2.13.1\n"
]
}
],
"source": [
"import keras\n",
"import matplotlib.pyplot as plt\n",
"import networkx as nx\n",
"import numpy as np\n",
"import spektral\n",
"\n",
"print(\"spektral\", spektral.__version__)\n",
"print(\"keras\", keras.__version__)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "k-ZB4SkCWgne"
},
"outputs": [],
"source": [
"import gdown\n",
"import os\n",
"\n",
"url = \"https://drive.google.com/u/0/uc?export=download&confirm=HgGH&id=1OugMZz6VVBjWy0uxsG_rrPdrYdPLklzD\"\n",
"output = 'karate_club.npz'\n",
"\n",
"if os.path.exists(output) == False:\n",
" gdown.download(url, output, quiet=True)\n",
"\n",
"f = np.load(output)\n",
"\n",
"adj, features = f[\"adj\"], f[\"features\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WqmOQfS7Wgnf",
"outputId": "fe893e51-794c-4ba8-f71e-84e6520a4017"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"adjacency matrix\n",
" [[0. 1. 1. ... 1. 0. 0.]\n",
" [1. 0. 1. ... 0. 0. 0.]\n",
" [1. 1. 0. ... 0. 1. 0.]\n",
" ...\n",
" [1. 0. 0. ... 0. 1. 1.]\n",
" [0. 0. 1. ... 1. 0. 1.]\n",
" [0. 0. 0. ... 1. 1. 0.]]\n"
]
}
],
"source": [
"print(\"adjacency matrix\\n\", adj)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OKsPVMUJWgng",
"outputId": "f03de89b-4abf-4061-aa99-119aea8e7a64"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"features\n",
" [[1. 0. 0. ... 0. 0. 0.]\n",
" [0. 1. 0. ... 0. 0. 0.]\n",
" [0. 0. 1. ... 0. 0. 0.]\n",
" ...\n",
" [0. 0. 0. ... 1. 0. 0.]\n",
" [0. 0. 0. ... 0. 1. 0.]\n",
" [0. 0. 0. ... 0. 0. 1.]]\n"
]
}
],
"source": [
"print(\"features\\n\", features)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YPAB-kFnWgnh",
"outputId": "9bd4d3c0-7fb3-48bd-fee1-5fe5136230e8"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"labels: [2 2 3 2 1 1 1 2 4 3 1 2 2 2 4 4 1 2 4 2 4 2 4 4 3 3 4 3 3 4 4 3 4 4]\n"
]
}
],
"source": [
"labels_one_hot = f[\"labels_one_hot\"]\n",
"\n",
"def one_hot_to_labels(labels_one_hot):\n",
" return np.sum([(labels_one_hot[:, i] == 1) * (i + 1) for i in range(4)], axis=0)\n",
"\n",
"labels = one_hot_to_labels(labels_one_hot)\n",
"\n",
"print(\"labels:\", labels)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dj6uvZ4-Wgnj"
},
"source": [
"### Plot data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 487
},
"id": "BINTXRNlWgnj",
"outputId": "efa4e4e5-ab69-48b0-b50f-ea728d0b7bbe"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"