Source code for opengt.network.san_transformer

import torch
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP
from torch_geometric.graphgym.register import register_network

from opengt.layer.san_layer import SANLayer
from opengt.layer.san2_layer import SAN2Layer


[docs] @register_network('SANTransformer') class SANTransformer(torch.nn.Module): """Spectral Attention Network (SAN) Graph Transformer. https://arxiv.org/abs/2106.03893 Adapted from https://github.com/rampasek/GraphGPS Parameters: dim_in (int): Number of input features. dim_out (int): Number of output features. cfg (dict): Configuration dictionary containing model parameters from GraphGym. - cfg.gt.layers (int): Number of SAN layers. - cfg.gt.dim_hidden (int): Hidden dimension for GNN layers and SAN layers. Need to match cfg.gnn.dim_inner. - cfg.gt.gamma (float): Gamma parameter for SAN layers. - cfg.gt.n_heads (int): Number of attention heads. - cfg.gt.full_graph (bool): Whether to use full graph attention. - cfg.gt.dropout (float): Dropout rate for the SAN layers. - cfg.gt.layer_norm (bool): Whether to use layer normalization. - cfg.gt.batch_norm (bool): Whether to use batch normalization. - cfg.gt.residual (bool): Whether to use residual connections. - cfg.gnn.head (str): Type of head to use for the final output layer. - cfg.gnn.layers_pre_mp (int): Number of pre-message-passing layers. - cfg.gnn.dim_inner (int): Inner dimension for GNN layers. Need to match cfg.gt.dim_hidden. Input: batch (torch_geometric.data.Batch): Input batch containing node features and graph structure. - batch.x (torch.Tensor): Input node features. - batch.edge_index (torch.Tensor): Edge indices of the graph. Output: batch (task dependent type, see output head): Output after model processing. """ def __init__(self, dim_in, dim_out): super().__init__() self.encoder = FeatureEncoder(dim_in) dim_in = self.encoder.dim_in if cfg.gnn.layers_pre_mp > 0: self.pre_mp = GNNPreMP( dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) dim_in = cfg.gnn.dim_inner assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ "The inner and hidden dims must match." fake_edge_emb = torch.nn.Embedding(1, cfg.gt.dim_hidden) # torch.nn.init.xavier_uniform_(fake_edge_emb.weight.data) Layer = { 'SANLayer': SANLayer, 'SAN2Layer': SAN2Layer, }.get(cfg.gt.layer_type) layers = [] for _ in range(cfg.gt.layers): layers.append(Layer(gamma=cfg.gt.gamma, in_dim=cfg.gt.dim_hidden, out_dim=cfg.gt.dim_hidden, num_heads=cfg.gt.n_heads, full_graph=cfg.gt.full_graph, fake_edge_emb=fake_edge_emb, dropout=cfg.gt.dropout, layer_norm=cfg.gt.layer_norm, batch_norm=cfg.gt.batch_norm, residual=cfg.gt.residual)) self.trf_layers = torch.nn.Sequential(*layers) GNNHead = register.head_dict[cfg.gnn.head] self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) def forward(self, batch): for module in self.children(): batch = module(batch) return batch