--- name: torch-geometric-graph-neural-networks description: "PyTorch Geometric (PyG) for graph neural networks: node/graph classification, link prediction with GCN, GAT, GraphSAGE, GIN. Message passing, mini-batches, heterogeneous graphs, neighbor sampling, explainability. Supports molecules (QM9, MoleculeNet), social/knowledge graphs, 3D point clouds. For non-graph DL use PyTorch; for classical graph algorithms use NetworkX." license: MIT --- # PyTorch Geometric (PyG) — Graph Neural Networks ## Overview PyTorch Geometric is a library built on PyTorch for developing and training Graph Neural Networks (GNNs). It provides 40+ convolutional layers, mini-batch processing via block-diagonal adjacency matrices, neighbor sampling for large-scale graphs, and heterogeneous graph support for multi-type node/edge networks. ## When to Use - Node classification on citation, social, or biological networks - Graph-level classification (molecular activity, protein function) - Link prediction (knowledge graphs, recommendation systems) - Molecular property prediction (drug discovery, quantum chemistry) - 3D point cloud processing and mesh analysis - Large-scale graph learning with neighbor sampling (>100K nodes) - Heterogeneous graphs with multiple node/edge types - **For non-graph deep learning** → use PyTorch directly - **For traditional graph algorithms (shortest path, centrality)** → use NetworkX ## Prerequisites ```bash pip install torch torch_geometric # Optional sparse operations (recommended): # pip install pyg_lib torch_scatter torch_sparse torch_cluster ``` ```python import torch import torch.nn.functional as F from torch_geometric.data import Data from torch_geometric.nn import GCNConv ``` ## Quick Start ```python from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv import torch, torch.nn.functional as F dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0] class GCN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, data): x = F.relu(self.conv1(data.x, data.edge_index)) return self.conv2(x, data.edge_index) model = GCN() optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) for epoch in range(200): model.train(); optimizer.zero_grad() F.cross_entropy(model(data)[data.train_mask], data.y[data.train_mask]).backward() optimizer.step() model.eval() pred = model(data).argmax(dim=1) acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean() print(f'Test Accuracy: {acc:.4f}') # ~0.81 ``` ## Core API ### 1. Data Representation ```python import torch from torch_geometric.data import Data # Create a graph: 3 nodes, 4 edges (undirected) edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) x = torch.randn(3, 16) # Node features [num_nodes, features] y = torch.tensor([0, 1, 0]) # Node labels data = Data(x=x, edge_index=edge_index, y=y) print(f'Nodes: {data.num_nodes}, Edges: {data.num_edges}') print(f'Features: {data.num_node_features}') print(f'Has self-loops: {data.has_self_loops()}') print(f'Is undirected: {data.is_undirected()}') # Optional attributes data.edge_attr = torch.randn(4, 8) # Edge features [num_edges, features] data.pos = torch.randn(3, 3) # Node positions (3D) data.train_mask = torch.tensor([True, True, False]) # Custom masks ``` ```python # Mini-batch processing — graphs concatenated as block-diagonal from torch_geometric.loader import DataLoader loader = DataLoader(dataset, batch_size=32, shuffle=True) for batch in loader: print(f'Graphs: {batch.num_graphs}, Nodes: {batch.num_nodes}') # batch.batch maps each node → its source graph index # No padding needed — computationally efficient ``` ### 2. Convolutional Layers ```python from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv import torch.nn as nn # GCNConv — spectral graph convolution (baseline) conv = GCNConv(in_channels=16, out_channels=32) # Supports: edge_weight, SparseTensor, Bipartite, Lazy init # GATConv — attention-based neighbor weighting conv = GATConv(16, 32, heads=8, dropout=0.6) # Output: [N, heads * out_channels] (concat) or [N, out_channels] (concat=False) # SAGEConv — inductive learning via sampling conv = SAGEConv(16, 32, aggr='mean') # 'mean', 'max', 'lstm' # GINConv — maximally powerful for graph isomorphism nn_module = nn.Sequential(nn.Linear(16, 32), nn.ReLU(), nn.Linear(32, 32)) conv = GINConv(nn_module) # TransformerConv — graph transformer from torch_geometric.nn import TransformerConv conv = TransformerConv(16, 32, heads=8, beta=True) # All layers: x_out = conv(x, edge_index) x_out = conv(x, edge_index) print(f'Output shape: {x_out.shape}') # [num_nodes, out_channels] ``` ### 3. Custom Message Passing ```python from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degree class CustomConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='add') # 'add', 'mean', 'max' self.lin = torch.nn.Linear(in_channels, out_channels) def forward(self, x, edge_index): edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) x = self.lin(x) # Degree-based normalization row, col = edge_index deg = degree(col, x.size(0), dtype=x.dtype) norm = deg.pow(-0.5) norm = norm[row] * norm[col] return self.propagate(edge_index, x=x, norm=norm) def message(self, x_j, norm): # x_j: source node features (automatic via _j suffix) return norm.view(-1, 1) * x_j # Key methods: forward(), message(), aggregate(), update() # _i suffix → target node, _j suffix → source node ``` ### 4. Pooling & Graph-Level Readout ```python from torch_geometric.nn import ( global_mean_pool, global_max_pool, global_add_pool, TopKPooling, SAGPooling ) # Global pooling: node features → graph-level representation x_graph = global_mean_pool(x, batch) # [num_graphs, features] # Hierarchical pooling: coarsen graph pool = TopKPooling(64, ratio=0.8) # Keep top 80% nodes x, edge_index, _, batch, _, _ = pool(x, edge_index, None, batch) # Graph classification model class GraphClassifier(torch.nn.Module): def __init__(self, num_features, num_classes): super().__init__() self.conv1 = GCNConv(num_features, 64) self.conv2 = GCNConv(64, 64) self.pool = TopKPooling(64, ratio=0.8) self.lin = torch.nn.Linear(64, num_classes) def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x, edge_index, _, batch, _, _ = self.pool(x, edge_index, None, batch) x = F.relu(self.conv2(x, edge_index)) x = global_mean_pool(x, batch) return self.lin(x) ``` ### 5. Heterogeneous Graphs ```python from torch_geometric.data import HeteroData from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, to_hetero # Create heterogeneous graph data = HeteroData() data['paper'].x = torch.randn(100, 128) data['author'].x = torch.randn(200, 64) data['author', 'writes', 'paper'].edge_index = torch.randint(0, 200, (2, 500)) data['paper', 'cites', 'paper'].edge_index = torch.randint(0, 100, (2, 300)) print(data) # Shows all node/edge types # Method 1: Auto-convert homogeneous model model = GCN(...) model = to_hetero(model, data.metadata(), aggr='sum') out = model(data.x_dict, data.edge_index_dict) ``` ```python # Method 2: Custom per-edge-type convolutions class HeteroGNN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = HeteroConv({ ('paper', 'cites', 'paper'): GCNConv(-1, 64), ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), }, aggr='sum') def forward(self, x_dict, edge_index_dict): x_dict = self.conv1(x_dict, edge_index_dict) return {k: F.relu(v) for k, v in x_dict.items()} ``` ### 6. Transforms & Preprocessing ```python from torch_geometric.transforms import ( NormalizeFeatures, AddSelfLoops, ToUndirected, RandomNodeSplit, RandomLinkSplit, Compose, KNNGraph, RadiusGraph, AddLaplacianEigenvectorPE ) # Single transform dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures()) # Compose multiple transforms transform = Compose([ ToUndirected(), AddSelfLoops(), NormalizeFeatures(), ]) # Data splitting node_split = RandomNodeSplit(num_val=0.1, num_test=0.2) link_split = RandomLinkSplit(num_val=0.1, num_test=0.2, is_undirected=True) # Point cloud → graph pc_transform = Compose([KNNGraph(k=6), NormalizeFeatures()]) # Positional encodings (for Graph Transformers) pe_transform = AddLaplacianEigenvectorPE(k=10) ``` ## Key Concepts ### Layer Selection Guide | Task | Layer | Key Feature | |------|-------|-------------| | Baseline / general | `GCNConv` | Spectral, cached, edge_weight | | Variable neighbor importance | `GATConv` / `GATv2Conv` | Multi-head attention | | Large-scale inductive | `SAGEConv` | Sampling-friendly, mean/max/lstm aggr | | Graph classification | `GINConv` | Maximally powerful WL-test | | Long-range dependencies | `TransformerConv` | Graph transformer | | Spectral filtering | `ChebConv` | Chebyshev polynomials, K hops | | Rich edge features | `NNConv` | Edge NN processes edge_attr | | Molecular / 3D structures | `SchNet`, `DimeNet` | Continuous filters, angles | | Heterogeneous / multi-relation | `RGCNConv`, `HGTConv` | Multiple edge types | | Point clouds | `EdgeConv`, `PointNetConv` | Dynamic graphs, local features | | Deep GNNs (avoid oversmoothing) | `APPNP` + `PairNorm` | Separated propagation | ### Data Flow Architecture - **edge_index**: `[2, num_edges]` COO format. Row 0 = source, Row 1 = target - **Mini-batch**: Block-diagonal adjacency + `batch` vector mapping nodes → graphs. No padding - **Neighbor sampling**: `NeighborLoader` samples K-hop subgraphs per seed node. Output is directed, relabeled - **Heterogeneous**: `x_dict` (per-type features), `edge_index_dict` (per-relation edges), `metadata()` for schema ### Aggregation Options | Aggregation | Class | Use Case | |-------------|-------|----------| | Sum | `SumAggregation` | Counting-sensitive tasks | | Mean | `MeanAggregation` | Degree-invariant | | Max | `MaxAggregation` | Salient feature detection | | Softmax | `SoftmaxAggregation(learn=True)` | Learnable attention | | Multi | `MultiAggregation(['mean','max','std'])` | Combined signals | ## Common Workflows ### 1. Node Classification (Full Graph) ```python import torch import torch.nn.functional as F from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv dataset = Planetoid(root='/tmp/Cora', name='Cora') data = dataset[0] class GCN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, data): x = F.dropout(F.relu(self.conv1(data.x, data.edge_index)), p=0.5, training=self.training) return self.conv2(x, data.edge_index) model = GCN() optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) # Training for epoch in range(200): model.train(); optimizer.zero_grad() out = model(data) F.cross_entropy(out[data.train_mask], data.y[data.train_mask]).backward() optimizer.step() # Evaluation model.eval() pred = model(data).argmax(dim=1) acc = (pred[data.test_mask] == data.y[data.test_mask]).float().mean() print(f'Test Accuracy: {acc:.4f}') ``` ### 2. Graph Classification (Mini-Batch) ```python from torch_geometric.datasets import TUDataset from torch_geometric.loader import DataLoader from torch_geometric.nn import GCNConv, global_mean_pool dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES') train_dataset = dataset[:int(0.8 * len(dataset))] test_dataset = dataset[int(0.8 * len(dataset)):] train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) class GraphNet(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 64) self.conv2 = GCNConv(64, 64) self.lin = torch.nn.Linear(64, dataset.num_classes) def forward(self, data): x = F.relu(self.conv1(data.x, data.edge_index)) x = F.relu(self.conv2(x, data.edge_index)) x = global_mean_pool(x, data.batch) return self.lin(x) model = GraphNet() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) for epoch in range(100): model.train() for batch in train_loader: optimizer.zero_grad() F.cross_entropy(model(batch), batch.y).backward() optimizer.step() ``` ### 3. Large-Scale with Neighbor Sampling ```python from torch_geometric.loader import NeighborLoader # Sample 25 1-hop and 10 2-hop neighbors per seed node train_loader = NeighborLoader( data, num_neighbors=[25, 10], batch_size=128, input_nodes=data.train_mask, ) model.train() for batch in train_loader: optimizer.zero_grad() out = model(batch) # Only compute loss on seed nodes (first batch_size nodes) loss = F.cross_entropy(out[:batch.batch_size], batch.y[:batch.batch_size]) loss.backward() optimizer.step() # Note: output subgraphs are directed, indices relabeled 0..N-1 ``` ## Key Parameters | Parameter | Module | Default | Range | Effect | |-----------|--------|---------|-------|--------| | `in_channels` | All Conv layers | — | int | Input feature dimension | | `out_channels` | All Conv layers | — | int | Output feature dimension | | `heads` | GATConv | 1 | 1-16 | Number of attention heads | | `dropout` | GATConv | 0.0 | 0-0.8 | Attention weight dropout | | `aggr` | MessagePassing | 'add' | add/mean/max | Neighbor aggregation | | `K` | ChebConv | — | 2-5 | Chebyshev polynomial order | | `num_neighbors` | NeighborLoader | — | list[int] | Neighbors per hop (e.g., [25,10]) | | `batch_size` | DataLoader | — | 16-512 | Graphs or seed nodes per batch | | `ratio` | TopKPooling | 0.5 | 0.1-0.9 | Fraction of nodes to keep | | `lr` | Adam | — | 1e-4 to 0.01 | Learning rate | | `weight_decay` | Adam | 0 | 0 to 5e-3 | L2 regularization | ## Best Practices 1. **Start with GCNConv**: Use 2-layer GCN as baseline before trying complex architectures 2. **Use lazy initialization**: Pass `-1` as `in_channels` to infer dimensions automatically: `GCNConv(-1, 64)` 3. **Normalize features**: Apply `NormalizeFeatures()` transform for citation/social networks 4. **Anti-pattern — too many layers**: GNNs typically need only 2-3 layers. Deeper causes oversmoothing. Use `JumpingKnowledge` or `PairNorm` if you need depth 5. **GPU transfer**: Move both model AND data to GPU: `model.to(device)`, `data.to(device)` 6. **Anti-pattern — ignoring batch vector**: In graph classification, always use `global_mean_pool(x, batch)` — forgetting `batch` pools across all graphs ## Common Recipes ### Recipe: Model Explainability (GNNExplainer) ```python from torch_geometric.explain import Explainer, GNNExplainer explainer = Explainer( model=model, algorithm=GNNExplainer(epochs=200), explanation_type='model', node_mask_type='attributes', edge_mask_type='object', model_config=dict(mode='multiclass_classification', task_level='node', return_type='log_probs'), ) explanation = explainer(data.x, data.edge_index, index=10) print(f'Important edges: {explanation.edge_mask.topk(5).indices}') print(f'Important features: {explanation.node_mask[10].topk(5).indices}') ``` ### Recipe: Custom InMemoryDataset ```python from torch_geometric.data import InMemoryDataset, Data class MyDataset(InMemoryDataset): def __init__(self, root, transform=None, pre_transform=None): super().__init__(root, transform, pre_transform) self.load(self.processed_paths[0]) @property def raw_file_names(self): return ['data.csv'] @property def processed_file_names(self): return ['data.pt'] def process(self): data_list = [] # Build Data objects from raw files edge_index = torch.tensor([[0, 1], [1, 0]], dtype=torch.long) x = torch.randn(2, 16) data_list.append(Data(x=x, edge_index=edge_index, y=torch.tensor([0]))) if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[0]) ``` ### Recipe: Deep GNN with JumpingKnowledge ```python from torch_geometric.nn import GCNConv, JumpingKnowledge, LayerNorm class DeepGNN(torch.nn.Module): def __init__(self, in_ch, hidden, num_layers, out_ch): super().__init__() self.convs = torch.nn.ModuleList() self.norms = torch.nn.ModuleList() self.convs.append(GCNConv(in_ch, hidden)) self.norms.append(LayerNorm(hidden)) for _ in range(num_layers - 2): self.convs.append(GCNConv(hidden, hidden)) self.norms.append(LayerNorm(hidden)) self.convs.append(GCNConv(hidden, hidden)) self.jk = JumpingKnowledge(mode='cat') self.lin = torch.nn.Linear(hidden * num_layers, out_ch) def forward(self, x, edge_index, batch): xs = [] for conv, norm in zip(self.convs[:-1], self.norms): x = F.relu(norm(conv(x, edge_index))) xs.append(x) xs.append(self.convs[-1](x, edge_index)) return self.lin(global_mean_pool(self.jk(xs), batch)) ``` ## Troubleshooting | Problem | Cause | Solution | |---------|-------|---------| | `edge_index` shape error | Wrong format (should be [2, E]) | Ensure COO format: `torch.tensor([[src...],[dst...]], dtype=torch.long)` | | OOM on large graph | Full-graph forward pass | Use `NeighborLoader` for mini-batch training | | Low accuracy | Oversmoothing (too many layers) | Reduce to 2-3 layers, add `JumpingKnowledge` or `PairNorm` | | NaN in training | Exploding gradients | Add gradient clipping, reduce learning rate, check feature scale | | Wrong graph-level output | Missing `batch` in pooling | Pass `batch` tensor to `global_mean_pool(x, batch)` | | Heterogeneous type error | Mismatched node/edge types | Check `data.metadata()` matches model definition | | Slow DataLoader | Large graph, no sampling | Use `NeighborLoader` with reasonable `num_neighbors` (e.g., [25,10]) | | `x` dimension mismatch | Multi-head attention output | For GATConv: output is `heads*out_channels` unless `concat=False` | | Import error for sparse ops | Missing optional dependencies | Install `torch_scatter`, `torch_sparse` from PyG wheels | | Pre-transform not applied | Dataset already processed | Delete `processed/` directory and reload | ## Bundled Resources - **`references/layers_transforms_reference.md`** — Complete catalog of 40+ convolutional layers (GCN, GAT, SAGE, GIN, molecular layers, hypergraph), aggregation operators, pooling (global + hierarchical), normalization layers, pre-built models, auto-encoders, knowledge graph embeddings, utility layers. Transform catalog: structure, feature, spatial, augmentation, mesh, specialized. Consolidated from original layers_reference.md (486 lines) + transforms_reference.md (680 lines). Script functionality (benchmark_model.py, create_gnn_template.py, visualize_graph.py) covered by Core API code blocks and Common Recipes - **`references/datasets_catalog.md`** — Comprehensive dataset catalog organized by domain: citation networks (Planetoid, Coauthor, Amazon), graph classification (TUDataset 120+ benchmarks), molecular (QM9, ZINC, MoleculeNet), social (Reddit, Twitch), knowledge graphs (WordNet, FB15k), heterogeneous (OGB_MAG, MovieLens, DBLP), temporal (JODIE), 3D meshes (ShapeNet, ModelNet), OGB integration. Consolidated from original datasets_reference.md (575 lines) ## Related Skills - **matplotlib-scientific-plotting** — Visualize graph structures, training curves, attention weights ## References - PyTorch Geometric Documentation: https://pytorch-geometric.readthedocs.io/ - PyG GitHub: https://github.com/pyg-team/pytorch_geometric - Fey & Lenssen (2019), "Fast Graph Representation Learning with PyTorch Geometric", ICLR Workshop