Source code for opengt.network.grit_model

import torch
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import GNNPreMP
from torch_geometric.graphgym.models.layer import (new_layer_config,
                                                   BatchNorm1dNode)
from torch_geometric.graphgym.register import register_network

from opengt.encoder.feature_encoder import FeatureEncoder

[docs] @register_network('GritTransformer') class GritTransformer(torch.nn.Module): ''' Graph Inductive Bias Transformer (GRIT) model. Adapted from https://github.com/LiamMa/GRIT Parameters: dim_in (int): Number of input features. dim_out (int): Number of output features. cfg (dict): configuration dictionary containing model parameters from GraphGym. - cfg.gt.layers (int): Number of GRIT layers. - cfg.gt.n_heads (int): Number of attention heads. - cfg.gt.dropout (float): Dropout rate for the GRIT layers. - cfg.gt.dim_hidden (int): Hidden dimension for GNN layers and GRIT layers. Need to match cfg.gnn.dim_inner. - cfg.gt.layer_type (str): Type of layer to use for the GRIT layers. - cfg.gt.attn_dropout (float): Dropout rate for the attention mechanism. - cfg.gt.layer_norm (bool): Whether to use layer normalization. - cfg.gt.batch_norm (bool): Whether to use batch normalization. - cfg.gnn.head (str): Type of head to use for the final output layer. - cfg.gnn.layers_pre_mp (int): Number of pre-message-passing layers. - cfg.gnn.dim_inner (int): Inner dimension for GNN layers. Need to match cfg.gt.dim_hidden. ''' def __init__(self, dim_in, dim_out): super().__init__() self.encoder = FeatureEncoder(dim_in) dim_in = self.encoder.dim_in self.ablation = True self.ablation = False if cfg.posenc_RRWP.enable: self.rrwp_abs_encoder = register.node_encoder_dict["rrwp_linear"]\ (cfg.posenc_RRWP.ksteps, cfg.gnn.dim_inner) rel_pe_dim = cfg.posenc_RRWP.ksteps self.rrwp_rel_encoder = register.edge_encoder_dict["rrwp_linear"] \ (rel_pe_dim, cfg.gnn.dim_edge, pad_to_full_graph=cfg.gt.attn.full_attn, add_node_attr_as_self_loop=False, fill_value=0. ) if cfg.gnn.layers_pre_mp > 0: self.pre_mp = GNNPreMP( dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) dim_in = cfg.gnn.dim_inner assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ "The inner and hidden dims must match." global_model_type = cfg.gt.get('layer_type', "GritTransformer") # global_model_type = "GritTransformer" TransformerLayer = register.layer_dict.get(global_model_type) layers = [] for l in range(cfg.gt.layers): layers.append(TransformerLayer( in_dim=cfg.gt.dim_hidden, out_dim=cfg.gt.dim_hidden, num_heads=cfg.gt.n_heads, dropout=cfg.gt.dropout, act=cfg.gnn.act, attn_dropout=cfg.gt.attn_dropout, layer_norm=cfg.gt.layer_norm, batch_norm=cfg.gt.batch_norm, residual=True, norm_e=cfg.gt.attn.norm_e, O_e=cfg.gt.attn.O_e, cfg=cfg.gt, )) # layers = [] self.layers = torch.nn.Sequential(*layers) GNNHead = register.head_dict[cfg.gnn.head] self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) def forward(self, batch): for module in self.children(): batch = module(batch) return batch