{ "cells": [ { "cell_type": "markdown", "id": "93e10022", "metadata": { "id": "93e10022" }, "source": [ "# torch geometric + skorch @ CORA dataset" ] }, { "cell_type": "code", "execution_count": 1, "id": "f8b30bc0", "metadata": { "id": "f8b30bc0", "outputId": "20f01f79-f21b-4555-dad3-791d06ebf432", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Thu Jan 5 13:34:48 UTC 2023\n" ] } ], "source": [ "!date" ] }, { "cell_type": "markdown", "id": "4e144398", "metadata": { "id": "4e144398" }, "source": [ "This is an example for how to use skorch with [torch geometric](https://pytorch-geometric.readthedocs.io/). The code is based on the [introduction example](https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html) but modified to have a proper train/valid/test split. This example is showcasing a quite small data set that does not need to employ batching to be trained efficiently. How to do batching with skorch + torch geometric will not be handled here since it is non-trivial and quite dataset specific - if you need this and are stuck, feel free to open [an issue](https://github.com/skorch-dev/skorch/issues) so that we can support you the best we can.\n", "\n", "Dependencies of this notebook besides skorch base installation:" ] }, { "cell_type": "raw", "id": "3748b2a0", "metadata": { "id": "3748b2a0" }, "source": [ "!pip install torch_geometric==2.0.4" ] }, { "cell_type": "markdown", "id": "7ce31902", "metadata": { "id": "7ce31902" }, "source": [ "It is recommended to install the dependencies [as documented by pytorch geometric](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)." ] }, { "cell_type": "markdown", "id": "ff707ef1", "metadata": { "id": "ff707ef1" }, "source": [ "---" ] }, { "cell_type": "code", "execution_count": 2, "id": "jkLKYeyyiAMe", "metadata": { "id": "jkLKYeyyiAMe" }, "outputs": [], "source": [ "import subprocess\n", "\n", "# Installation on Google Colab\n", "try:\n", " import google.colab\n", " import torch\n", " subprocess.run(['python', '-m', 'pip', 'install', 'skorch' , 'torch_geometric'])\n", " subprocess.run(['python', '-m', 'pip', 'install', 'torch-sparse' , '-f', f'https://data.pyg.org/whl/torch-{torch.__version__}.html'])\n", " subprocess.run(['python', '-m', 'pip', 'install', 'torch-scatter' , '-f', f'https://data.pyg.org/whl/torch-{torch.__version__}.html'])\n", "except ImportError:\n", " pass" ] }, { "cell_type": "code", "execution_count": 3, "id": "efa2b5a5", "metadata": { "id": "efa2b5a5" }, "outputs": [], "source": [ "import skorch\n", "import torch" ] }, { "cell_type": "markdown", "id": "934fa438", "metadata": { "id": "934fa438" }, "source": [ "### Data Loading" ] }, { "cell_type": "code", "execution_count": 4, "id": "f1505f8f", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "f1505f8f", "outputId": "c5a9543b-e565-4d50-9a7c-4e37a4c5e9e7" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x\n", "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx\n", "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx\n", "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y\n", "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty\n", "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally\n", "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph\n", "Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index\n", "Processing...\n", "Done!\n" ] } ], "source": [ "from torch_geometric.datasets import Planetoid\n", "\n", "dataset = Planetoid(root='/tmp/Cora', name='Cora')" ] }, { "cell_type": "code", "execution_count": 5, "id": "4979ba6f", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4979ba6f", "outputId": "4b50548c-ec03-4495-ee2f-fc78301527cd" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708]),\n", " 7)" ] }, "metadata": {}, "execution_count": 5 } ], "source": [ "dataset.data, dataset.num_classes" ] }, { "cell_type": "markdown", "id": "9cd7e104", "metadata": { "id": "9cd7e104" }, "source": [ "In order to use pytorch geometric / the cora dataset with skorch\n", "we need to address the following things:\n", " \n", "1. graph convolutions cannot handle missing nodes (=> splitting node attributes but keeping edge_index intact will lead to errors)\n", "2. cora dataset has different attributes for the different split masks (i.e. `train_mask`, `val_mask`, `test_mask`)\n", "3. skorch expects to have (X, y) pairs for classification tasks\n", "\n", "To deal with (1) we will split the data into three datasets, creating three sub-graphs in the process; these complete sub-graphs can then be convolved over without errors. \n", "We use the masks mentioned in (2) to identify the nodes and edges of the subgraphs.\n", "\n", "(3) will be handled by specifying our own `XYDataset` which will just have length 1 and return the dataset and the respective y values. We will therefore basically simulate a `batch_size=1` scenario." ] }, { "cell_type": "code", "execution_count": 6, "id": "35aa1770", "metadata": { "id": "35aa1770" }, "outputs": [], "source": [ "from torch_geometric.data import Data\n", "\n", "# simulating batch_size=1 by returning the whole dataset and the\n", "# y-values. this way, the data loader can iterate over the 'batches'\n", "# and produce X/y values for us.\n", "class XYDataset(torch.utils.data.Dataset):\n", " def __init__(self, data: Data, y: torch.tensor):\n", " self.data = data\n", " self.y = y\n", " \n", " def __len__(self):\n", " return 1\n", " \n", " def __getitem__(self, i):\n", " return self.data, self.y" ] }, { "cell_type": "markdown", "id": "7cd78cd1", "metadata": { "id": "7cd78cd1" }, "source": [ "### Data Splitting" ] }, { "cell_type": "markdown", "id": "468d8d6f", "metadata": { "id": "468d8d6f" }, "source": [ "Split the graph into train, validation and test sub-graphs.\n", "This ensures that there will be no leakage between steps when we apply graph\n", "convolution operators on the graph since each split has its own sub-graph.\n", "\n", "We use `relabel_nodes=True` to make the node indices in the edge tensor \n", "zero-based for each sub-graph. If we would not do this the node subsets\n", "(now zero-based after applying the mask) would not match the indices in the\n", "edge tensor." ] }, { "cell_type": "code", "execution_count": 7, "id": "f3d908e4", "metadata": { "id": "f3d908e4" }, "outputs": [], "source": [ "from torch_geometric.utils import subgraph\n", "\n", "data = dataset[0]\n", "\n", "edge_index_train, _ = subgraph(\n", " subset=data.train_mask, \n", " edge_index=data.edge_index, \n", " relabel_nodes=True\n", ")\n", "ds_train = XYDataset(\n", " Data(x=data.x[data.train_mask], edge_index=edge_index_train),\n", " data.y[data.train_mask],\n", ")\n", "\n", "edge_index_valid, _ = subgraph(\n", " subset=data.val_mask, \n", " edge_index=data.edge_index, \n", " relabel_nodes=True\n", ")\n", "ds_valid = XYDataset(\n", " Data(x=data.x[data.val_mask], edge_index=edge_index_valid),\n", " data.y[data.val_mask],\n", ")\n", "\n", "edge_index_test, _ = subgraph(\n", " subset=data.test_mask, \n", " edge_index=data.edge_index, \n", " relabel_nodes=True\n", ")\n", "ds_test = XYDataset(\n", " Data(x=data.x[data.test_mask], edge_index=edge_index_test),\n", " data.y[data.test_mask],\n", ")" ] }, { "cell_type": "markdown", "id": "bf59f7a9", "metadata": { "id": "bf59f7a9" }, "source": [ "### Data Feeding" ] }, { "cell_type": "markdown", "id": "4d23039d", "metadata": { "id": "4d23039d" }, "source": [ "Our \"batch\" consists of the whole dataset so if we unpack the\n", "batch into `(X, y)` we will have `X = Data(...)` and `y = [y_true]`.\n", "The `DataLoader` does not modify `X` but `y` gets a new batch dimension.\n", "This will lead to a shape mismatch as `y.shape` would then be `(1, #num_samples)`. Therefore, we need our own loader that strips the first dimension to \n", "match the predicted `y` and the labelled `y` in length.\n", "\n", "Note: It is possible to avoid this by stripping this dimension by overriding `get_loss` in the `NeuralNet` class. For brevity we won't do this in this example. It is possible to use [one of the many `DataLoader` classes](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html) provided by torch geometric using the approach outlined below (just base the `RawDataloader` on one of the other classes) - chances are, though, that if you are doing this you need to deal with batching anyway which is a topic that is not handled here since it is not trivial." ] }, { "cell_type": "code", "execution_count": 8, "id": "83594638", "metadata": { "id": "83594638" }, "outputs": [], "source": [ "from torch_geometric.loader import DataLoader\n", "\n", "class RawLoader(DataLoader):\n", " def __iter__(self):\n", " it = super().__iter__()\n", " for X, y in it:\n", " yield X, y[0]" ] }, { "cell_type": "markdown", "id": "d51e7140", "metadata": { "id": "d51e7140" }, "source": [ "### Modelling" ] }, { "cell_type": "markdown", "id": "67e6ff9e", "metadata": { "id": "67e6ff9e" }, "source": [ "This is the CORA example module as seen in the [torch geometric introduction](https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html)." ] }, { "cell_type": "code", "execution_count": 9, "id": "10de0020", "metadata": { "id": "10de0020" }, "outputs": [], "source": [ "import torch\n", "import torch.nn.functional as F\n", "from torch_geometric.nn import GCNConv\n", "\n", "class GCN(torch.nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.conv1 = GCNConv(dataset.num_node_features, 16)\n", " self.conv2 = GCNConv(16, dataset.num_classes)\n", "\n", " def forward(self, data): \n", " x, edge_index = data.x, data.edge_index\n", "\n", " x = self.conv1(x, edge_index)\n", " x = F.relu(x)\n", " x = F.dropout(x, training=self.training)\n", " x = self.conv2(x, edge_index)\n", "\n", " return F.softmax(x, dim=1)" ] }, { "cell_type": "markdown", "id": "f58e93d8", "metadata": { "id": "f58e93d8" }, "source": [ "### Fitting" ] }, { "cell_type": "code", "execution_count": 10, "id": "b3f5ef3c", "metadata": { "id": "b3f5ef3c" }, "outputs": [], "source": [ "from skorch.helper import predefined_split\n", "\n", "torch.manual_seed(42)\n", "\n", "net = skorch.NeuralNetClassifier(\n", " module=GCN,\n", " lr=0.1,\n", " optimizer__weight_decay=5e-4,\n", " max_epochs=200,\n", " train_split=skorch.helper.predefined_split(ds_valid),\n", " batch_size=1,\n", " iterator_train=RawLoader,\n", " iterator_valid=RawLoader,\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "id": "c2cf8bbd", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "c2cf8bbd", "outputId": "57c6aba6-37a2-4c33-ee90-75ddd05ac33b", "scrolled": false }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " epoch train_loss valid_acc valid_loss dur\n", "------- ------------ ----------- ------------ ------\n", " 1 \u001b[36m1.9724\u001b[0m \u001b[32m0.1680\u001b[0m \u001b[35m1.9398\u001b[0m 0.0963\n", " 2 \u001b[36m1.9625\u001b[0m \u001b[32m0.1740\u001b[0m \u001b[35m1.9376\u001b[0m 0.0091\n", " 3 \u001b[36m1.9327\u001b[0m 0.1720 \u001b[35m1.9342\u001b[0m 0.0115\n", " 4 \u001b[36m1.9321\u001b[0m \u001b[32m0.1760\u001b[0m \u001b[35m1.9324\u001b[0m 0.0096\n", " 5 \u001b[36m1.9142\u001b[0m \u001b[32m0.1800\u001b[0m \u001b[35m1.9307\u001b[0m 0.0069\n", " 6 \u001b[36m1.8923\u001b[0m 0.1800 \u001b[35m1.9290\u001b[0m 0.0090\n", " 7 \u001b[36m1.8848\u001b[0m \u001b[32m0.1880\u001b[0m \u001b[35m1.9269\u001b[0m 0.0092\n", " 8 1.8936 \u001b[32m0.1920\u001b[0m \u001b[35m1.9247\u001b[0m 0.0152\n", " 9 \u001b[36m1.8783\u001b[0m \u001b[32m0.1960\u001b[0m \u001b[35m1.9219\u001b[0m 0.0082\n", " 10 \u001b[36m1.8737\u001b[0m \u001b[32m0.2040\u001b[0m \u001b[35m1.9192\u001b[0m 0.0082\n", " 11 \u001b[36m1.8542\u001b[0m \u001b[32m0.2060\u001b[0m \u001b[35m1.9176\u001b[0m 0.0070\n", " 12 \u001b[36m1.8489\u001b[0m \u001b[32m0.2100\u001b[0m \u001b[35m1.9156\u001b[0m 0.0078\n", " 13 \u001b[36m1.8314\u001b[0m \u001b[32m0.2120\u001b[0m \u001b[35m1.9121\u001b[0m 0.0061\n", " 14 1.8334 \u001b[32m0.2220\u001b[0m \u001b[35m1.9100\u001b[0m 0.0076\n", " 15 \u001b[36m1.8041\u001b[0m \u001b[32m0.2240\u001b[0m \u001b[35m1.9085\u001b[0m 0.0080\n", " 16 1.8089 0.2220 \u001b[35m1.9065\u001b[0m 0.0101\n", " 17 1.8082 0.2200 \u001b[35m1.9043\u001b[0m 0.0078\n", " 18 \u001b[36m1.7759\u001b[0m 0.2220 \u001b[35m1.9023\u001b[0m 0.0064\n", " 19 \u001b[36m1.7745\u001b[0m 0.2220 \u001b[35m1.8992\u001b[0m 0.0076\n", " 20 \u001b[36m1.7630\u001b[0m \u001b[32m0.2280\u001b[0m \u001b[35m1.8970\u001b[0m 0.0089\n", " 21 \u001b[36m1.7411\u001b[0m \u001b[32m0.2340\u001b[0m \u001b[35m1.8946\u001b[0m 0.0087\n", " 22 1.7732 \u001b[32m0.2360\u001b[0m \u001b[35m1.8921\u001b[0m 0.0086\n", " 23 \u001b[36m1.7407\u001b[0m \u001b[32m0.2420\u001b[0m \u001b[35m1.8893\u001b[0m 0.0077\n", " 24 \u001b[36m1.7259\u001b[0m 0.2420 \u001b[35m1.8857\u001b[0m 0.0092\n", " 25 \u001b[36m1.6920\u001b[0m \u001b[32m0.2520\u001b[0m \u001b[35m1.8836\u001b[0m 0.0069\n", " 26 1.7033 \u001b[32m0.2540\u001b[0m \u001b[35m1.8805\u001b[0m 0.0133\n", " 27 1.7080 \u001b[32m0.2580\u001b[0m \u001b[35m1.8767\u001b[0m 0.0090\n", " 28 1.6924 \u001b[32m0.2620\u001b[0m \u001b[35m1.8741\u001b[0m 0.0072\n", " 29 \u001b[36m1.6882\u001b[0m 0.2620 \u001b[35m1.8703\u001b[0m 0.0075\n", " 30 \u001b[36m1.6850\u001b[0m \u001b[32m0.2660\u001b[0m \u001b[35m1.8679\u001b[0m 0.0074\n", " 31 \u001b[36m1.6438\u001b[0m 0.2580 \u001b[35m1.8650\u001b[0m 0.0073\n", " 32 \u001b[36m1.6345\u001b[0m \u001b[32m0.2680\u001b[0m \u001b[35m1.8618\u001b[0m 0.0083\n", " 33 1.6816 0.2660 \u001b[35m1.8579\u001b[0m 0.0084\n", " 34 \u001b[36m1.6169\u001b[0m 0.2620 \u001b[35m1.8559\u001b[0m 0.0085\n", " 35 1.6373 \u001b[32m0.2720\u001b[0m \u001b[35m1.8522\u001b[0m 0.0093\n", " 36 \u001b[36m1.6107\u001b[0m 0.2700 \u001b[35m1.8491\u001b[0m 0.0084\n", " 37 \u001b[36m1.6035\u001b[0m \u001b[32m0.2800\u001b[0m \u001b[35m1.8449\u001b[0m 0.0080\n", " 38 1.6060 \u001b[32m0.2840\u001b[0m \u001b[35m1.8421\u001b[0m 0.0079\n", " 39 \u001b[36m1.5604\u001b[0m \u001b[32m0.2960\u001b[0m \u001b[35m1.8389\u001b[0m 0.0080\n", " 40 1.5724 \u001b[32m0.3060\u001b[0m \u001b[35m1.8354\u001b[0m 0.0102\n", " 41 \u001b[36m1.5371\u001b[0m \u001b[32m0.3160\u001b[0m \u001b[35m1.8319\u001b[0m 0.0071\n", " 42 \u001b[36m1.5246\u001b[0m \u001b[32m0.3240\u001b[0m \u001b[35m1.8281\u001b[0m 0.0073\n", " 43 1.5524 0.3200 \u001b[35m1.8241\u001b[0m 0.0076\n", " 44 1.5282 \u001b[32m0.3300\u001b[0m \u001b[35m1.8211\u001b[0m 0.0098\n", " 45 1.5356 \u001b[32m0.3380\u001b[0m \u001b[35m1.8169\u001b[0m 0.0109\n", " 46 \u001b[36m1.5079\u001b[0m \u001b[32m0.3440\u001b[0m \u001b[35m1.8137\u001b[0m 0.0082\n", " 47 1.5192 \u001b[32m0.3500\u001b[0m \u001b[35m1.8090\u001b[0m 0.0077\n", " 48 \u001b[36m1.4991\u001b[0m \u001b[32m0.3540\u001b[0m \u001b[35m1.8063\u001b[0m 0.0078\n", " 49 \u001b[36m1.4949\u001b[0m 0.3460 \u001b[35m1.8036\u001b[0m 0.0086\n", " 50 \u001b[36m1.4892\u001b[0m \u001b[32m0.3640\u001b[0m \u001b[35m1.8000\u001b[0m 0.0100\n", " 51 1.5165 \u001b[32m0.3760\u001b[0m \u001b[35m1.7968\u001b[0m 0.0082\n", " 52 \u001b[36m1.4367\u001b[0m 0.3740 \u001b[35m1.7931\u001b[0m 0.0081\n", " 53 1.4473 0.3700 \u001b[35m1.7894\u001b[0m 0.0081\n", " 54 1.4387 \u001b[32m0.3840\u001b[0m \u001b[35m1.7855\u001b[0m 0.0100\n", " 55 \u001b[36m1.4261\u001b[0m 0.3840 \u001b[35m1.7825\u001b[0m 0.0081\n", " 56 1.4355 \u001b[32m0.4040\u001b[0m \u001b[35m1.7768\u001b[0m 0.0085\n", " 57 1.4270 0.3900 \u001b[35m1.7749\u001b[0m 0.0082\n", " 58 \u001b[36m1.4029\u001b[0m 0.4000 \u001b[35m1.7714\u001b[0m 0.0085\n", " 59 \u001b[36m1.3793\u001b[0m 0.4040 \u001b[35m1.7679\u001b[0m 0.0085\n", " 60 \u001b[36m1.3493\u001b[0m 0.4020 \u001b[35m1.7629\u001b[0m 0.0086\n", " 61 1.3624 \u001b[32m0.4160\u001b[0m \u001b[35m1.7597\u001b[0m 0.0081\n", " 62 1.3970 \u001b[32m0.4180\u001b[0m \u001b[35m1.7562\u001b[0m 0.0085\n", " 63 1.3552 \u001b[32m0.4220\u001b[0m \u001b[35m1.7516\u001b[0m 0.0110\n", " 64 1.3745 \u001b[32m0.4240\u001b[0m \u001b[35m1.7480\u001b[0m 0.0064\n", " 65 1.4002 \u001b[32m0.4260\u001b[0m \u001b[35m1.7448\u001b[0m 0.0087\n", " 66 \u001b[36m1.2924\u001b[0m \u001b[32m0.4280\u001b[0m \u001b[35m1.7405\u001b[0m 0.0083\n", " 67 1.2954 \u001b[32m0.4300\u001b[0m \u001b[35m1.7375\u001b[0m 0.0106\n", " 68 \u001b[36m1.2785\u001b[0m \u001b[32m0.4320\u001b[0m \u001b[35m1.7319\u001b[0m 0.0074\n", " 69 1.3192 0.4300 \u001b[35m1.7290\u001b[0m 0.0089\n", " 70 1.3049 \u001b[32m0.4360\u001b[0m \u001b[35m1.7246\u001b[0m 0.0063\n", " 71 \u001b[36m1.2504\u001b[0m \u001b[32m0.4420\u001b[0m \u001b[35m1.7198\u001b[0m 0.0094\n", " 72 1.2841 0.4340 \u001b[35m1.7165\u001b[0m 0.0085\n", " 73 \u001b[36m1.2304\u001b[0m \u001b[32m0.4460\u001b[0m \u001b[35m1.7120\u001b[0m 0.0071\n", " 74 1.2414 \u001b[32m0.4540\u001b[0m \u001b[35m1.7070\u001b[0m 0.0062\n", " 75 \u001b[36m1.1753\u001b[0m 0.4520 \u001b[35m1.7020\u001b[0m 0.0089\n", " 76 1.2608 \u001b[32m0.4580\u001b[0m \u001b[35m1.6981\u001b[0m 0.0086\n", " 77 1.2053 0.4580 \u001b[35m1.6935\u001b[0m 0.0083\n", " 78 1.2640 \u001b[32m0.4600\u001b[0m \u001b[35m1.6910\u001b[0m 0.0077\n", " 79 1.2251 \u001b[32m0.4700\u001b[0m \u001b[35m1.6845\u001b[0m 0.0082\n", " 80 1.2221 \u001b[32m0.4780\u001b[0m \u001b[35m1.6801\u001b[0m 0.0063\n", " 81 \u001b[36m1.1499\u001b[0m 0.4760 \u001b[35m1.6761\u001b[0m 0.0076\n", " 82 1.1761 \u001b[32m0.4820\u001b[0m \u001b[35m1.6727\u001b[0m 0.0081\n", " 83 \u001b[36m1.1286\u001b[0m \u001b[32m0.4880\u001b[0m \u001b[35m1.6673\u001b[0m 0.0073\n", " 84 1.1338 \u001b[32m0.4920\u001b[0m \u001b[35m1.6634\u001b[0m 0.0074\n", " 85 \u001b[36m1.1273\u001b[0m \u001b[32m0.4940\u001b[0m \u001b[35m1.6593\u001b[0m 0.0076\n", " 86 1.1289 0.4900 \u001b[35m1.6548\u001b[0m 0.0085\n", " 87 1.1618 \u001b[32m0.4960\u001b[0m \u001b[35m1.6512\u001b[0m 0.0083\n", " 88 1.1306 \u001b[32m0.4980\u001b[0m \u001b[35m1.6474\u001b[0m 0.0085\n", " 89 1.1436 \u001b[32m0.5000\u001b[0m \u001b[35m1.6438\u001b[0m 0.0102\n", " 90 \u001b[36m1.0675\u001b[0m \u001b[32m0.5020\u001b[0m \u001b[35m1.6397\u001b[0m 0.0087\n", " 91 1.0798 0.5000 \u001b[35m1.6360\u001b[0m 0.0086\n", " 92 1.1148 0.4980 \u001b[35m1.6330\u001b[0m 0.0079\n", " 93 1.0830 \u001b[32m0.5040\u001b[0m \u001b[35m1.6276\u001b[0m 0.0077\n", " 94 1.1569 0.5020 \u001b[35m1.6246\u001b[0m 0.0076\n", " 95 \u001b[36m1.0338\u001b[0m 0.5020 \u001b[35m1.6197\u001b[0m 0.0075\n", " 96 1.0800 \u001b[32m0.5100\u001b[0m \u001b[35m1.6139\u001b[0m 0.0085\n", " 97 1.0869 0.5080 \u001b[35m1.6121\u001b[0m 0.0085\n", " 98 1.1144 0.5100 \u001b[35m1.6074\u001b[0m 0.0087\n", " 99 \u001b[36m1.0271\u001b[0m 0.5060 \u001b[35m1.6045\u001b[0m 0.0083\n", " 100 1.0465 0.5100 \u001b[35m1.6020\u001b[0m 0.0141\n", " 101 1.0348 \u001b[32m0.5200\u001b[0m \u001b[35m1.5963\u001b[0m 0.0105\n", " 102 \u001b[36m1.0045\u001b[0m \u001b[32m0.5220\u001b[0m \u001b[35m1.5936\u001b[0m 0.0086\n", " 103 1.0307 \u001b[32m0.5260\u001b[0m \u001b[35m1.5895\u001b[0m 0.0103\n", " 104 \u001b[36m0.9970\u001b[0m \u001b[32m0.5300\u001b[0m \u001b[35m1.5839\u001b[0m 0.0097\n", " 105 \u001b[36m0.9644\u001b[0m 0.5300 \u001b[35m1.5814\u001b[0m 0.0093\n", " 106 0.9879 \u001b[32m0.5320\u001b[0m \u001b[35m1.5770\u001b[0m 0.0094\n", " 107 0.9986 0.5320 \u001b[35m1.5732\u001b[0m 0.0094\n", " 108 \u001b[36m0.9234\u001b[0m 0.5320 \u001b[35m1.5692\u001b[0m 0.0096\n", " 109 0.9704 0.5300 \u001b[35m1.5642\u001b[0m 0.0064\n", " 110 1.0256 0.5300 \u001b[35m1.5621\u001b[0m 0.0060\n", " 111 0.9590 0.5240 \u001b[35m1.5585\u001b[0m 0.0078\n", " 112 1.0168 0.5300 \u001b[35m1.5565\u001b[0m 0.0087\n", " 113 0.9994 0.5320 \u001b[35m1.5534\u001b[0m 0.0078\n", " 114 0.9635 0.5320 \u001b[35m1.5492\u001b[0m 0.0113\n", " 115 0.9872 \u001b[32m0.5340\u001b[0m \u001b[35m1.5452\u001b[0m 0.0079\n", " 116 0.9749 0.5340 \u001b[35m1.5411\u001b[0m 0.0126\n", " 117 0.9667 0.5340 \u001b[35m1.5392\u001b[0m 0.0104\n", " 118 \u001b[36m0.8757\u001b[0m 0.5300 \u001b[35m1.5351\u001b[0m 0.0087\n", " 119 0.9306 0.5340 \u001b[35m1.5340\u001b[0m 0.0087\n", " 120 \u001b[36m0.8284\u001b[0m \u001b[32m0.5380\u001b[0m \u001b[35m1.5300\u001b[0m 0.0084\n", " 121 0.8389 \u001b[32m0.5400\u001b[0m \u001b[35m1.5254\u001b[0m 0.0127\n", " 122 0.9347 \u001b[32m0.5440\u001b[0m \u001b[35m1.5226\u001b[0m 0.0165\n", " 123 0.8502 0.5340 \u001b[35m1.5207\u001b[0m 0.0126\n", " 124 0.8519 \u001b[32m0.5480\u001b[0m \u001b[35m1.5163\u001b[0m 0.0096\n", " 125 0.8536 0.5460 \u001b[35m1.5127\u001b[0m 0.0094\n", " 126 0.8926 0.5480 \u001b[35m1.5082\u001b[0m 0.0112\n", " 127 0.8605 0.5480 \u001b[35m1.5050\u001b[0m 0.0105\n", " 128 0.8853 \u001b[32m0.5540\u001b[0m \u001b[35m1.5037\u001b[0m 0.0120\n", " 129 0.8483 0.5540 \u001b[35m1.4992\u001b[0m 0.0102\n", " 130 0.8745 0.5540 \u001b[35m1.4954\u001b[0m 0.0103\n", " 131 \u001b[36m0.7866\u001b[0m 0.5500 \u001b[35m1.4940\u001b[0m 0.0107\n", " 132 0.8322 0.5520 \u001b[35m1.4901\u001b[0m 0.0108\n", " 133 0.8019 0.5520 \u001b[35m1.4864\u001b[0m 0.0105\n", " 134 0.8829 0.5540 \u001b[35m1.4846\u001b[0m 0.0103\n", " 135 0.8545 \u001b[32m0.5560\u001b[0m \u001b[35m1.4829\u001b[0m 0.0088\n", " 136 0.9028 0.5560 \u001b[35m1.4802\u001b[0m 0.0090\n", " 137 0.8797 0.5540 \u001b[35m1.4794\u001b[0m 0.0102\n", " 138 0.7967 0.5560 \u001b[35m1.4744\u001b[0m 0.0087\n", " 139 \u001b[36m0.7614\u001b[0m 0.5560 \u001b[35m1.4723\u001b[0m 0.0088\n", " 140 0.8399 \u001b[32m0.5580\u001b[0m \u001b[35m1.4696\u001b[0m 0.0097\n", " 141 0.8502 0.5580 \u001b[35m1.4671\u001b[0m 0.0094\n", " 142 \u001b[36m0.7301\u001b[0m 0.5580 \u001b[35m1.4648\u001b[0m 0.0087\n", " 143 0.7543 \u001b[32m0.5640\u001b[0m \u001b[35m1.4617\u001b[0m 0.0102\n", " 144 \u001b[36m0.7023\u001b[0m 0.5620 \u001b[35m1.4580\u001b[0m 0.0090\n", " 145 0.7329 0.5640 \u001b[35m1.4561\u001b[0m 0.0090\n", " 146 0.7820 0.5640 \u001b[35m1.4526\u001b[0m 0.0100\n", " 147 0.8137 0.5640 \u001b[35m1.4521\u001b[0m 0.0079\n", " 148 0.7950 0.5600 \u001b[35m1.4489\u001b[0m 0.0083\n", " 149 0.7702 0.5600 \u001b[35m1.4468\u001b[0m 0.0082\n", " 150 0.7851 0.5580 1.4469 0.0079\n", " 151 0.7881 0.5600 \u001b[35m1.4456\u001b[0m 0.0078\n", " 152 0.7375 0.5600 \u001b[35m1.4405\u001b[0m 0.0102\n", " 153 0.7888 0.5580 \u001b[35m1.4401\u001b[0m 0.0079\n", " 154 0.8128 0.5580 \u001b[35m1.4376\u001b[0m 0.0082\n", " 155 \u001b[36m0.6960\u001b[0m 0.5600 \u001b[35m1.4345\u001b[0m 0.0085\n", " 156 0.7073 0.5600 \u001b[35m1.4328\u001b[0m 0.0081\n", " 157 0.7129 0.5620 \u001b[35m1.4301\u001b[0m 0.0083\n", " 158 0.7282 0.5620 \u001b[35m1.4283\u001b[0m 0.0078\n", " 159 0.7855 0.5580 \u001b[35m1.4263\u001b[0m 0.0083\n", " 160 0.7444 0.5620 \u001b[35m1.4237\u001b[0m 0.0081\n", " 161 0.7081 0.5620 \u001b[35m1.4204\u001b[0m 0.0095\n", " 162 \u001b[36m0.6947\u001b[0m 0.5620 \u001b[35m1.4188\u001b[0m 0.0085\n", " 163 0.7121 0.5620 \u001b[35m1.4152\u001b[0m 0.0067\n", " 164 0.7374 0.5620 \u001b[35m1.4139\u001b[0m 0.0095\n", " 165 \u001b[36m0.6866\u001b[0m 0.5600 \u001b[35m1.4113\u001b[0m 0.0090\n", " 166 0.7360 0.5640 \u001b[35m1.4102\u001b[0m 0.0083\n", " 167 \u001b[36m0.6596\u001b[0m 0.5600 \u001b[35m1.4073\u001b[0m 0.0085\n", " 168 \u001b[36m0.6245\u001b[0m 0.5600 \u001b[35m1.4048\u001b[0m 0.0145\n", " 169 0.6546 0.5600 \u001b[35m1.4032\u001b[0m 0.0069\n", " 170 0.6534 0.5640 \u001b[35m1.4001\u001b[0m 0.0074\n", " 171 0.7578 0.5640 \u001b[35m1.3992\u001b[0m 0.0082\n", " 172 \u001b[36m0.5842\u001b[0m 0.5640 \u001b[35m1.3976\u001b[0m 0.0075\n", " 173 0.5937 0.5600 \u001b[35m1.3950\u001b[0m 0.0068\n", " 174 0.6404 0.5620 \u001b[35m1.3937\u001b[0m 0.0088\n", " 175 0.6674 \u001b[32m0.5700\u001b[0m \u001b[35m1.3930\u001b[0m 0.0090\n", " 176 0.6230 0.5660 \u001b[35m1.3901\u001b[0m 0.0082\n", " 177 0.7345 0.5660 \u001b[35m1.3891\u001b[0m 0.0083\n", " 178 0.6696 0.5700 \u001b[35m1.3872\u001b[0m 0.0084\n", " 179 \u001b[36m0.5714\u001b[0m \u001b[32m0.5720\u001b[0m \u001b[35m1.3838\u001b[0m 0.0070\n", " 180 0.5873 0.5720 \u001b[35m1.3810\u001b[0m 0.0098\n", " 181 0.6660 \u001b[32m0.5780\u001b[0m \u001b[35m1.3787\u001b[0m 0.0095\n", " 182 0.6348 0.5760 \u001b[35m1.3771\u001b[0m 0.0083\n", " 183 0.6988 0.5780 \u001b[35m1.3746\u001b[0m 0.0076\n", " 184 0.5842 0.5760 \u001b[35m1.3719\u001b[0m 0.0074\n", " 185 0.5969 0.5760 \u001b[35m1.3698\u001b[0m 0.0075\n", " 186 0.6382 0.5780 \u001b[35m1.3682\u001b[0m 0.0077\n", " 187 0.5822 0.5780 1.3693 0.0063\n", " 188 0.6263 0.5760 \u001b[35m1.3677\u001b[0m 0.0077\n", " 189 0.6212 0.5740 \u001b[35m1.3666\u001b[0m 0.0077\n", " 190 0.6059 0.5740 \u001b[35m1.3641\u001b[0m 0.0115\n", " 191 0.5797 0.5760 \u001b[35m1.3610\u001b[0m 0.0080\n", " 192 0.6375 0.5780 \u001b[35m1.3598\u001b[0m 0.0088\n", " 193 0.6777 0.5740 1.3610 0.0103\n", " 194 0.6331 0.5700 1.3609 0.0083\n", " 195 0.6305 0.5720 1.3601 0.0075\n", " 196 \u001b[36m0.5502\u001b[0m 0.5700 \u001b[35m1.3568\u001b[0m 0.0080\n", " 197 0.7014 0.5720 \u001b[35m1.3562\u001b[0m 0.0079\n", " 198 0.5605 0.5720 \u001b[35m1.3551\u001b[0m 0.0119\n", " 199 0.5541 0.5760 1.3557 0.0071\n", " 200 \u001b[36m0.5496\u001b[0m 0.5720 \u001b[35m1.3506\u001b[0m 0.0074\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "[initialized](\n", " module_=GCN(\n", " (conv1): GCNConv(1433, 16)\n", " (conv2): GCNConv(16, 7)\n", " ),\n", ")" ] }, "metadata": {}, "execution_count": 11 } ], "source": [ "net.fit(ds_train, None)" ] }, { "cell_type": "markdown", "id": "fed65b00", "metadata": { "id": "fed65b00" }, "source": [ "### Evaluation" ] }, { "cell_type": "code", "execution_count": 12, "id": "ae562c9e", "metadata": { "id": "ae562c9e" }, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score" ] }, { "cell_type": "code", "execution_count": 13, "id": "ef5f2409", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ef5f2409", "outputId": "b2d30bee-eb51-452a-b450-3342b05cc858" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "0.682" ] }, "metadata": {}, "execution_count": 13 } ], "source": [ "accuracy_score(ds_test.y, net.predict(ds_test))" ] }, { "cell_type": "markdown", "id": "2fb7c99b", "metadata": { "id": "2fb7c99b" }, "source": [ "In conclusion this example showed you how to use a basic data graph dataset using pytorch geometric in conjunction with skorch. The final test score is lower than the ~80% accuracy in the [introduction example](https://pytorch-geometric.readthedocs.io/en/latest/notes/introduction.html) which can be explained by the reduced leakage between train and validation sets due to our splitting the data into subgraphs beforehand.\n", "\n", "The model is now incorporated into the sklearn world (as you could already see, you can simply use sklearn metrics to evaluate the model). Thus, tools like grid and random search are available to you and it is easily possible to include a graph neural net as a feature transformer in your next ML pipeline!" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13 (default, Mar 28 2022, 08:03:21) [MSC v.1916 64 bit (AMD64)]" }, "vscode": { "interpreter": { "hash": "bd97b8bffa4d3737e84826bc3d37be3046061822757ce35137ab82ad4c5a2016" } }, "accelerator": "GPU", "gpuClass": "standard" }, "nbformat": 4, "nbformat_minor": 5 }