import torch
import torch_geometric.graphgym.models.head # noqa, register module
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP
from torch_geometric.graphgym.register import register_network
from opengt.layer.gatedgcn_layer import GatedGCNLayer
from opengt.layer.gine_conv_layer import GINEConvLayer
[docs]
@register_network('custom_gnn')
class CustomGNN(torch.nn.Module):
"""
GNN model that customizes the torch_geometric.graphgym.models.gnn.GNN
to support specific handling of new conv layers.
Parameters:
dim_in (int): Number of input features.
dim_out (int): Number of output features.
cfg (dict): Configuration dictionary containing model parameters from GraphGym.
- cfg.gnn.layers_pre_mp: Number of pre-message-passing layers.
- cfg.gnn.dim_inner: Inner dimension for GNN layers.
- cfg.gnn.layers_mp: Number of message-passing layers.
- cfg.gnn.dropout: Dropout rate for GNN layers.
- cfg.gnn.residual: Whether to use residual connections in GNN layers.
- cfg.gnn.layer_type: Type of GNN layer to use ('gatedgcnconv' or 'gineconv').
- cfg.gnn.head: Type of head to use for the final output layer.
Input:
batch (torch_geometric.data.Batch): Input batch containing node features and graph structure.
- batch.x (torch.Tensor): Input node features.
- batch.edge_index (torch.Tensor): Edge indices of the graph.
Output:
batch (task dependent type, see output head): Output after model processing.
"""
def __init__(self, dim_in, dim_out):
super().__init__()
self.encoder = FeatureEncoder(dim_in)
dim_in = self.encoder.dim_in
if cfg.gnn.layers_pre_mp > 0:
self.pre_mp = GNNPreMP(
dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
dim_in = cfg.gnn.dim_inner
assert cfg.gnn.dim_inner == dim_in, \
"The inner and hidden dims must match."
conv_model = self.build_conv_model(cfg.gnn.layer_type)
layers = []
for _ in range(cfg.gnn.layers_mp):
layers.append(conv_model(dim_in,
dim_in,
dropout=cfg.gnn.dropout,
residual=cfg.gnn.residual))
self.gnn_layers = torch.nn.Sequential(*layers)
GNNHead = register.head_dict[cfg.gnn.head]
self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)
def build_conv_model(self, model_type):
if model_type == 'gatedgcnconv':
return GatedGCNLayer
elif model_type == 'gineconv':
return GINEConvLayer
else:
raise ValueError("Model {} unavailable".format(model_type))
def forward(self, batch):
for module in self.children():
batch = module(batch)
return batch