from logging import root import os.path as osp import torch import torch.nn.functional as F import torch_geometric.transforms as T from sklearn.metrics import roc_auc_score from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv from torch_geometric.utils import negative_sampling, train_test_split_edges class Net(torch.nn.Module): def __init__(self, in_channels, out_channels): super(Net, self).__init__() self.conv1 = GCNConv(in_channels, 128) self.conv2 = GCNConv(128, out_channels) def encode(self, x, edge_index): x = self.conv1(x, edge_index) x = x.relu() return self.conv2(x, edge_index) def decode(self, z, pos_edge_index, neg_edge_index): edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1) def decode_all(self, z): prob_adj = z @ z.t() return (prob_adj > 0).nonzero(as_tuple=False).t() def get_link_labels(pos_edge_index, neg_edge_index): num_links = pos_edge_index.size(1) + neg_edge_index.size(1) link_labels = torch.zeros(num_links, dtype=torch.float) link_labels[:pos_edge_index.size(1)] = 1. return link_labels def train(data, model, optimizer): model.train() neg_edge_index = negative_sampling( edge_index=data.train_pos_edge_index, num_nodes=data.num_nodes, num_neg_samples=data.train_pos_edge_index.size(1)) train_neg_edge_set = set(map(tuple, neg_edge_index.T.tolist())) val_pos_edge_set = set(map(tuple, data.val_pos_edge_index.T.tolist())) test_pos_edge_set = set(map(tuple, data.test_pos_edge_index.T.tolist())) if (len(train_neg_edge_set & val_pos_edge_set) > 0) or (len(train_neg_edge_set & test_pos_edge_set) > 0): # 训练集负样本与验证集负样本存在交集,或训练集负样本与测试集负样本存在交集 print('wrong!') optimizer.zero_grad() z = model.encode(data.x, data.train_pos_edge_index) link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index) link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index).to(data.x.device) loss = F.binary_cross_entropy_with_logits(link_logits, link_labels) loss.backward() optimizer.step() return loss @torch.no_grad() def test(data, model): model.eval() z = model.encode(data.x, data.train_pos_edge_index) results = [] for prefix in ['val', 'test']: pos_edge_index = data[f'{prefix}_pos_edge_index'] neg_edge_index = data[f'{prefix}_neg_edge_index'] link_logits = model.decode(z, pos_edge_index, neg_edge_index) link_probs = link_logits.sigmoid() link_labels = get_link_labels(pos_edge_index, neg_edge_index) results.append(roc_auc_score(link_labels.cpu(), link_probs.cpu())) return results def main(): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = Planetoid(root='data', name='Cora', transform=T.NormalizeFeatures()) data = dataset[0] ground_truth_edge_index = data.edge_index.to(device) data.train_mask = data.val_mask = data.test_mask = data.y = None data = train_test_split_edges(data) data = data.to(device) model = Net(dataset.num_features, 64).to(device) optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) best_val_auc = test_auc = 0 for epoch in range(1, 101): loss = train(data, model, optimizer) val_auc, tmp_test_auc = test(data, model) if val_auc > best_val_auc: best_val_auc = val_auc test_auc = tmp_test_auc print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, ' f'Test: {test_auc:.4f}') z = model.encode(data.x, data.train_pos_edge_index) final_edge_index = model.decode_all(z) if __name__ == "__main__": main()