Source code for opengt.network.graphormer

import torch
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import GNNPreMP
from torch_geometric.graphgym.register import register_network

from opengt.layer.graphormer_layer import GraphormerLayer
from opengt.encoder.feature_encoder import FeatureEncoder



[docs] @register_network('Graphormer') class GraphormerModel(torch.nn.Module): """Graphormer port to GraphGPS. https://arxiv.org/abs/2106.05234 Ying, C., Cai, T., Luo, S., Zheng, S., Ke, G., He, D., ... & Liu, T. Y. Do transformers really perform badly for graph representation? (NeurIPS2021) Adapted from https://github.com/rampasek/GraphGPS Parameters: dim_in (int): Number of input features. dim_out (int): Number of output features. cfg (dict): Configuration dictionary containing model parameters from GraphGym. - cfg.graphormer.num_layers (int): Number of Graphormer layers. - cfg.graphormer.embed_dim (int): Embedding dimension for Graphormer layers. Need to match cfg.gnn.dim_inner. - cfg.graphormer.num_heads (int): Number of attention heads. - cfg.graphormer.dropout (float): Dropout rate. - cfg.graphormer.attention_dropout (float): Attention dropout rate. - cfg.graphormer.mlp_dropout (float): MLP dropout rate. - cfg.gnn.head (str): Type of head to use for the final output layer. - cfg.gnn.layers_pre_mp (int): Number of pre-message-passing layers. - cfg.gnn.dim_inner (int): Inner dimension for GNN layers. Need to match cfg.graphormer.embed_dim. Input: batch (torch_geometric.data.Batch): Input batch containing node features and graph structure. - batch.x (torch.Tensor): Input node features. - batch.edge_index (torch.Tensor): Edge indices of the graph. Output: batch (task dependent type, see output head): Output after model processing. """ def __init__(self, dim_in, dim_out): super().__init__() self.encoder = FeatureEncoder(dim_in) dim_in = self.encoder.dim_in if cfg.gnn.layers_pre_mp > 0: self.pre_mp = GNNPreMP( dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) dim_in = cfg.gnn.dim_inner if not cfg.graphormer.embed_dim == cfg.gnn.dim_inner == dim_in: raise ValueError( f"The inner and embed dims must match: " f"embed_dim={cfg.graphormer.embed_dim} " f"dim_inner={cfg.gnn.dim_inner} dim_in={dim_in}" ) layers = [] for _ in range(cfg.graphormer.num_layers): layers.append(GraphormerLayer( embed_dim=cfg.graphormer.embed_dim, num_heads=cfg.graphormer.num_heads, dropout=cfg.graphormer.dropout, attention_dropout=cfg.graphormer.attention_dropout, mlp_dropout=cfg.graphormer.mlp_dropout )) self.layers = torch.nn.Sequential(*layers) GNNHead = register.head_dict[cfg.gnn.head] self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) def forward(self, batch): for module in self.children(): batch = module(batch) return batch