{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Correct & Smooth Example (DGL)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import argparse\n", "\n", "from ogb.nodeproppred import DglNodePropPredDataset, Evaluator\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "\n", "from node_pred import train, test\n", "\n", "from utils import EarlyStopping, seed_everything\n", "from model import GNN\n", "\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# import CorrectAndSmooth\n", "from gtrick.dgl import CorrectAndSmooth" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define train process" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def run_node_pred(args, model, dataset):\n", " device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'\n", " device = torch.device(device)\n", "\n", " model.to(device)\n", "\n", " # dataset = DglNodePropPredDataset(name=args.dataset, root=args.dataset_path)\n", " evaluator = Evaluator(name=args.dataset)\n", "\n", " g, y = dataset[0]\n", "\n", " # add reverse edges\n", " srcs, dsts = g.all_edges()\n", " g.add_edges(dsts, srcs)\n", "\n", " # add self-loop\n", " print(f'Total edges before adding self-loop {g.number_of_edges()}')\n", " g = g.remove_self_loop().add_self_loop()\n", " print(f'Total edges after adding self-loop {g.number_of_edges()}')\n", "\n", " g, y = g.to(device), y.to(device)\n", "\n", " if args.dataset == 'ogbn-proteins':\n", " x = g.ndata['species']\n", " else:\n", " x = g.ndata['feat']\n", "\n", " split_idx = dataset.get_idx_split()\n", "\n", " train_idx, valid_idx = split_idx['train'], split_idx['valid']\n", "\n", " final_test_acc, final_test_acc_cs = [], []\n", " for run in range(args.runs):\n", " model.reset_parameters()\n", " optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n", " early_stopping = EarlyStopping(\n", " patience=args.patience, verbose=True, mode='max')\n", "\n", " best_test_acc, best_val_acc = 0, 0\n", " best_out = None\n", "\n", " test_acc_cs = 0\n", " for epoch in range(1, 1 + args.epochs):\n", " loss = train(model, g, x, y, train_idx,\n", " optimizer, dataset.task_type)\n", " result = test(model, g, x, y, split_idx,\n", " evaluator, dataset.eval_metric)\n", "\n", " train_acc, valid_acc, test_acc, out = result\n", "\n", " if epoch % args.log_steps == 0:\n", " print(f'Run: {run + 1:02d}, '\n", " f'Epoch: {epoch:02d}, '\n", " f'Loss: {loss:.4f}, '\n", " f'Train: {100 * train_acc:.2f}%, '\n", " f'Valid: {100 * valid_acc:.2f}% '\n", " f'Test: {100 * test_acc:.2f}%')\n", " \n", " if valid_acc > best_val_acc:\n", " best_val_acc = valid_acc\n", " best_test_acc = test_acc\n", " best_out = out\n", "\n", " if early_stopping(valid_acc, model):\n", " break\n", " \n", " # define c & s\n", " cs = CorrectAndSmooth(num_correction_layers=args.num_correction_layers,\n", " correction_alpha=args.correction_alpha,\n", " num_smoothing_layers=args.num_smoothing_layers,\n", " smoothing_alpha=args.smoothing_alpha,\n", " autoscale=args.autoscale)\n", "\n", " # use labels of train and valid set to propagate\n", " mask_idx = torch.cat([train_idx, valid_idx])\n", " y_soft = cs.correct(g, best_out, y[mask_idx], mask_idx)\n", " y_soft = cs.smooth(g, y_soft, y[mask_idx], mask_idx)\n", " y_pred = y_soft.argmax(dim=-1, keepdim=True)\n", "\n", " test_acc_cs = evaluator.eval({\n", " 'y_true': y[split_idx['test']],\n", " 'y_pred': y_pred[split_idx['test']],\n", " })[dataset.eval_metric]\n", "\n", " print('Best Test Acc: {:.4f}, Best Test Acc with C & S: {:.4f}'.format(best_test_acc, test_acc_cs))\n", " \n", " final_test_acc.append(best_test_acc)\n", " final_test_acc_cs.append(test_acc_cs)\n", " \n", " print('Test Acc: {:.4f} ± {:.4f}, Test Acc with C & S: {:.4f} ± {:.4f}'.format(np.mean(final_test_acc), np.std(final_test_acc), np.mean(final_test_acc_cs), np.std(final_test_acc_cs)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run Experiment" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "parser = argparse.ArgumentParser(\n", " description='train node property prediction')\n", "parser.add_argument('--dataset', type=str, default='ogbn-arxiv',\n", " choices=['ogbn-arxiv'])\n", "parser.add_argument('--dataset_path', type=str, default='/dev/dataset',\n", " help='path to dataset')\n", "parser.add_argument('--device', type=int, default=1)\n", "parser.add_argument('--log_steps', type=int, default=1)\n", "parser.add_argument('--lr', type=float, default=0.01)\n", "parser.add_argument('--epochs', type=int, default=500)\n", "parser.add_argument('--runs', type=int, default=3)\n", "parser.add_argument('--patience', type=int, default=30)\n", "\n", "# params for GNN\n", "parser.add_argument('--model', type=str, default='gcn')\n", "parser.add_argument('--num_layers', type=int, default=3)\n", "parser.add_argument('--hidden_channels', type=int, default=256)\n", "parser.add_argument('--dropout', type=float, default=0.5)\n", "\n", "# params for C & S\n", "parser.add_argument('--num-correction-layers', type=int, default=50)\n", "parser.add_argument('--correction-alpha', type=float, default=0.979)\n", "parser.add_argument('--num-smoothing-layers', type=int, default=50)\n", "parser.add_argument('--smoothing-alpha', type=float, default=0.756)\n", "parser.add_argument('--autoscale', action='store_true', default=True)\n", "args = parser.parse_args(args=[])\n", "print(args)\n", "\n", "seed_everything(3042)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dataset = DglNodePropPredDataset(name=args.dataset, root=args.dataset_path)\n", "g, _ = dataset[0]\n", "\n", "num_features = g.ndata['feat'].shape[1]\n", "\n", "model = GNN(num_features, args.hidden_channels,\n", " dataset.num_classes, args.num_layers,\n", " args.dropout, args.model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "run_node_pred(args, model, dataset)" ] } ], "metadata": { "language_info": { "name": "python" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }