Source code for opengt.network.custom_gnn

import torch
import torch_geometric.graphgym.models.head  # noqa, register module
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP
from torch_geometric.graphgym.register import register_network

from opengt.layer.gatedgcn_layer import GatedGCNLayer
from opengt.layer.gine_conv_layer import GINEConvLayer


[docs] @register_network('custom_gnn') class CustomGNN(torch.nn.Module): """ GNN model that customizes the torch_geometric.graphgym.models.gnn.GNN to support specific handling of new conv layers. Parameters: dim_in (int): Number of input features. dim_out (int): Number of output features. cfg (dict): Configuration dictionary containing model parameters from GraphGym. - cfg.gnn.layers_pre_mp: Number of pre-message-passing layers. - cfg.gnn.dim_inner: Inner dimension for GNN layers. - cfg.gnn.layers_mp: Number of message-passing layers. - cfg.gnn.dropout: Dropout rate for GNN layers. - cfg.gnn.residual: Whether to use residual connections in GNN layers. - cfg.gnn.layer_type: Type of GNN layer to use ('gatedgcnconv' or 'gineconv'). - cfg.gnn.head: Type of head to use for the final output layer. Input: batch (torch_geometric.data.Batch): Input batch containing node features and graph structure. - batch.x (torch.Tensor): Input node features. - batch.edge_index (torch.Tensor): Edge indices of the graph. Output: batch (task dependent type, see output head): Output after model processing. """ def __init__(self, dim_in, dim_out): super().__init__() self.encoder = FeatureEncoder(dim_in) dim_in = self.encoder.dim_in if cfg.gnn.layers_pre_mp > 0: self.pre_mp = GNNPreMP( dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) dim_in = cfg.gnn.dim_inner assert cfg.gnn.dim_inner == dim_in, \ "The inner and hidden dims must match." conv_model = self.build_conv_model(cfg.gnn.layer_type) layers = [] for _ in range(cfg.gnn.layers_mp): layers.append(conv_model(dim_in, dim_in, dropout=cfg.gnn.dropout, residual=cfg.gnn.residual)) self.gnn_layers = torch.nn.Sequential(*layers) GNNHead = register.head_dict[cfg.gnn.head] self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) def build_conv_model(self, model_type): if model_type == 'gatedgcnconv': return GatedGCNLayer elif model_type == 'gineconv': return GINEConvLayer else: raise ValueError("Model {} unavailable".format(model_type)) def forward(self, batch): for module in self.children(): batch = module(batch) return batch