{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"2022-01-05-information-aggregation.ipynb","provenance":[{"file_id":"https://github.com/recohut/nbs/blob/main/raw/A687546%20%7C%20Understanding%20Information%20Aggregation%20by%20applying%20SAGE%20convolution%20layer%20on%20Cora%20dataset.ipynb","timestamp":1644438961854},{"file_id":"1n8Qp9JVZ6p4MHhPc5XAErlo3otRhJedE","timestamp":1638731049131},{"file_id":"1noqvdndIt6Y6hwWizVg3PlYbOVul_jSw","timestamp":1628355127982}],"collapsed_sections":[],"toc_visible":true,"mount_file_id":"1n8Qp9JVZ6p4MHhPc5XAErlo3otRhJedE","authorship_tag":"ABX9TyOZqIvZIsiNn+RaGjsu08vl"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","metadata":{"id":"ij48zhee8Gfb"},"source":["# Information Aggregation\n","> Understanding Information Aggregation by applying SAGE convolution layer on Cora dataset"]},{"cell_type":"code","metadata":{"id":"33HhJOZ_6wxl"},"source":["# torch geometric\n","try: \n"," import torch_geometric\n","except ModuleNotFoundError:\n"," # Installing torch geometric packages with specific CUDA+PyTorch version. \n"," # See https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html for details \n"," import torch\n"," TORCH = torch.__version__.split('+')[0]\n"," CUDA = 'cu' + torch.version.cuda.replace('.','')\n","\n"," !pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html\n"," !pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html\n"," !pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html\n"," !pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html\n"," !pip install torch-geometric \n"," import torch_geometric\n","import torch_geometric.nn as geom_nn\n","import torch_geometric.data as geom_data"],"execution_count":null,"outputs":[]},{"cell_type":"markdown","metadata":{"id":"VHKkioYa8K5F"},"source":["### Prototype"]},{"cell_type":"code","metadata":{"id":"5qiU3KMV6zFr"},"source":["import os.path as osp\n","\n","import torch\n","import torch.nn.functional as F\n","\n","import torch_geometric\n","from torch_geometric.datasets import Planetoid\n","from torch_geometric.nn import SAGEConv"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"GBx6zeSa7B4P"},"source":["use_cuda_if_available = False"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"6FVfZweF7DT8","executionInfo":{"status":"ok","timestamp":1638730071488,"user_tz":-330,"elapsed":3985,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"14cb5517-dc09-43a7-b5ab-c8481df45c58"},"source":["dataset = Planetoid(root=\"/content/cora\", name= \"Cora\")"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stderr","text":["Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x\n","Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx\n","Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx\n","Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y\n","Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty\n","Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally\n","Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph\n","Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index\n","Processing...\n","Done!\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ljx3DrLS7GrR","executionInfo":{"status":"ok","timestamp":1638730071490,"user_tz":-330,"elapsed":48,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"1a129c18-9b06-4e8b-d5bc-d0bb02492e00"},"source":["print(dataset)\n","print(\"number of graphs:\\t\\t\",len(dataset))\n","print(\"number of classes:\\t\\t\",dataset.num_classes)\n","print(\"number of node features:\\t\",dataset.num_node_features)\n","print(\"number of edge features:\\t\",dataset.num_edge_features)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Cora()\n","number of graphs:\t\t 1\n","number of classes:\t\t 7\n","number of node features:\t 1433\n","number of edge features:\t 0\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"mSgjzQ1G7G9p","executionInfo":{"status":"ok","timestamp":1638730087701,"user_tz":-330,"elapsed":516,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"4876c74e-6b09-4887-9797-02182e59ab8c"},"source":["print(dataset.data)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"_2cP4-IH7Lfr","executionInfo":{"status":"ok","timestamp":1638730116565,"user_tz":-330,"elapsed":499,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"6dbae112-a82f-443a-c7d6-ce88e41cd031"},"source":["print(\"edge_index:\\t\\t\",dataset.data.edge_index.shape)\n","print(dataset.data.edge_index)\n","print(\"\\n\")\n","print(\"train_mask:\\t\\t\",dataset.data.train_mask.shape)\n","print(dataset.data.train_mask)\n","print(\"\\n\")\n","print(\"x:\\t\\t\",dataset.data.x.shape)\n","print(dataset.data.x)\n","print(\"\\n\")\n","print(\"y:\\t\\t\",dataset.data.y.shape)\n","print(dataset.data.y)"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["edge_index:\t\t torch.Size([2, 10556])\n","tensor([[ 0, 0, 0, ..., 2707, 2707, 2707],\n"," [ 633, 1862, 2582, ..., 598, 1473, 2706]])\n","\n","\n","train_mask:\t\t torch.Size([2708])\n","tensor([ True, True, True, ..., False, False, False])\n","\n","\n","x:\t\t torch.Size([2708, 1433])\n","tensor([[0., 0., 0., ..., 0., 0., 0.],\n"," [0., 0., 0., ..., 0., 0., 0.],\n"," [0., 0., 0., ..., 0., 0., 0.],\n"," ...,\n"," [0., 0., 0., ..., 0., 0., 0.],\n"," [0., 0., 0., ..., 0., 0., 0.],\n"," [0., 0., 0., ..., 0., 0., 0.]])\n","\n","\n","y:\t\t torch.Size([2708])\n","tensor([3, 4, 4, ..., 3, 3, 3])\n"]}]},{"cell_type":"code","metadata":{"id":"4lRviHFBbXIo"},"source":["data = dataset[0]"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"AIm93HpdbcTF"},"source":["class Net(torch.nn.Module):\n"," def __init__(self):\n"," super(Net, self).__init__()\n"," \n"," self.conv = SAGEConv(dataset.num_features,\n"," dataset.num_classes,\n"," aggr=\"max\") # max, mean, add ...)\n","\n"," def forward(self):\n"," x = self.conv(data.x, data.edge_index)\n"," return F.log_softmax(x, dim=1)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"id":"KZfAEl7ubcO1"},"source":["device = torch.device('cuda' if torch.cuda.is_available() and use_cuda_if_available else 'cpu')\n","model, data = Net().to(device), data.to(device)\n","optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"Bq6K-QAvbcH-","executionInfo":{"status":"ok","timestamp":1638730214359,"user_tz":-330,"elapsed":16,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"8d846ec9-8fbd-41ff-a012-a66ec56dabc1"},"source":["device"],"execution_count":null,"outputs":[{"output_type":"execute_result","data":{"text/plain":["device(type='cpu')"]},"metadata":{},"execution_count":13}]},{"cell_type":"code","metadata":{"id":"eXwCdpHycB-h"},"source":["def train():\n"," model.train()\n"," optimizer.zero_grad()\n"," F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()\n"," optimizer.step()\n","\n","\n","def test():\n"," model.eval()\n"," logits, accs = model(), []\n"," for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n"," pred = logits[mask].max(1)[1]\n"," acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n"," accs.append(acc)\n"," return accs"],"execution_count":null,"outputs":[]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"ISVIQ-qFcyoa","executionInfo":{"status":"ok","timestamp":1638730263431,"user_tz":-330,"elapsed":46858,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"9c8ebdca-d970-454b-f58b-cb1adf1d2f7e"},"source":["best_val_acc = test_acc = 0\n","\n","for epoch in range(1,100):\n"," train()\n"," _, val_acc, tmp_test_acc = test()\n"," if val_acc > best_val_acc:\n"," best_val_acc = val_acc\n"," test_acc = tmp_test_acc\n"," log = 'Epoch: {:03d}, Val: {:.4f}, Test: {:.4f}'\n"," \n"," if epoch % 10 == 0:\n"," print(log.format(epoch, best_val_acc, test_acc))"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch: 010, Val: 0.7220, Test: 0.7220\n","Epoch: 020, Val: 0.7220, Test: 0.7220\n","Epoch: 030, Val: 0.7220, Test: 0.7220\n","Epoch: 040, Val: 0.7220, Test: 0.7220\n","Epoch: 050, Val: 0.7220, Test: 0.7220\n","Epoch: 060, Val: 0.7220, Test: 0.7220\n","Epoch: 070, Val: 0.7280, Test: 0.7130\n","Epoch: 080, Val: 0.7300, Test: 0.7150\n","Epoch: 090, Val: 0.7300, Test: 0.7150\n"]}]},{"cell_type":"markdown","metadata":{"id":"Wwi-MOgkgS7U"},"source":["### Scripting"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"8MtmyZrVfTvA","executionInfo":{"status":"ok","timestamp":1631525453623,"user_tz":-330,"elapsed":1072,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"8398a4bb-2b4b-417c-fbe1-07b2ce90a1f3"},"source":["%%writefile src/datasets/vectorial.py\n","import torch.nn as nn\n","import torch\n","\n","\n","#%% Dataset to manage vector to vector data\n","class VectorialDataset(torch.utils.data.Dataset):\n"," def __init__(self, input_data, output_data):\n"," super(VectorialDataset, self).__init__()\n"," self.input_data = torch.tensor(input_data.astype('f'))\n"," self.output_data = torch.tensor(output_data.astype('f'))\n"," \n"," def __len__(self):\n"," return self.input_data.shape[0]\n"," \n"," def __getitem__(self, idx):\n"," if torch.is_tensor(idx):\n"," idx = idx.tolist()\n"," sample = (self.input_data[idx, :], \n"," self.output_data[idx, :]) \n"," return sample "],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Writing src/datasets/vectorial.py\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"P94yQLIDgA_S","executionInfo":{"status":"ok","timestamp":1631525593073,"user_tz":-330,"elapsed":1603,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"973f21cd-d82e-4c7d-e84d-6434549e11f1"},"source":["%%writefile src/datasets/__init__.py\n","from .vectorial import VectorialDataset"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Writing src/datasets/__init__.py\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"58I29mSRc-yv","executionInfo":{"status":"ok","timestamp":1631525536407,"user_tz":-330,"elapsed":1300,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"6929a0d0-1612-46bd-e7c1-e4f56b16978a"},"source":["%%writefile src/models/linear.py\n","import torch.nn as nn\n","import torch\n","\n","#%% Linear layer\n","class LinearModel(nn.Module):\n"," def __init__(self, input_dim, output_dim):\n"," super(LinearModel, self).__init__()\n","\n"," self.input_dim = input_dim\n"," self.output_dim = output_dim\n","\n"," self.linear = nn.Linear(self.input_dim, self.output_dim, bias=True)\n","\n"," def forward(self, x):\n"," out = self.linear(x)\n"," return out\n"," \n"," def reset(self):\n"," self.linear.reset_parameters()"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Writing src/models/linear.py\n"]}]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"uWGKn9Tjf9kW","executionInfo":{"status":"ok","timestamp":1631525608126,"user_tz":-330,"elapsed":8,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"78e1bac4-3209-43a3-e453-8fafef38d9dc"},"source":["%%writefile src/models/__init__.py\n","from .linear import LinearModel"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Writing src/models/__init__.py\n"]}]},{"cell_type":"markdown","metadata":{"id":"n83l4fRt9luz"},"source":["---"]},{"cell_type":"code","metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"trEfQuxo9lu5","executionInfo":{"status":"ok","timestamp":1638730732578,"user_tz":-330,"elapsed":3418,"user":{"displayName":"Sparsh Agarwal","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"13037694610922482904"}},"outputId":"86f4bbba-495f-4d90-b6f1-e5303bfd5d13"},"source":["!pip install -q watermark\n","%reload_ext watermark\n","%watermark -a \"Sparsh A.\" -m -iv -u -t -d"],"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Author: Sparsh A.\n","\n","Last updated: 2021-12-05 18:58:56\n","\n","Compiler : GCC 7.5.0\n","OS : Linux\n","Release : 5.4.104+\n","Machine : x86_64\n","Processor : x86_64\n","CPU cores : 2\n","Architecture: 64bit\n","\n","IPython : 5.5.0\n","torch_geometric: 2.0.2\n","torch : 1.10.0+cu111\n","\n"]}]},{"cell_type":"markdown","metadata":{"id":"xe6ycIM89lu6"},"source":["---"]},{"cell_type":"markdown","metadata":{"id":"Cy6OjZxd9lu6"},"source":["**END**"]}]}