Source code for opengt.network.specformer

import torch
import torch.nn as nn
from opengt.layer.spec_layer import SpecLayer
from opengt.layer.multi_head_attention import MultiHeadAttention
from opengt.encoder.sine_encoder import SineEncoder
from opengt.layer.layer_norm import LayerNorm
from opengt.encoder.feature_encoder import FeatureEncoder

from torch_geometric.nn import Sequential
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.models.layer import new_layer_config , MLP, GCNConv, Linear
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_network


class swapex(nn.Module):
    """
    Swaps the x and EigVals attributes of the input batch. 

    Input: 
        batch (torch_geometric.data.Batch): Input batch containing node features and graph structure.
            - batch.x (torch.Tensor): Input node features.
            - batch.EigVals (torch.Tensor): Eigenvalues of the graph Laplacian.

    Output:
        batch (torch_geometric.data.Batch): Output batch with swapped x and EigVals.
            
    """
    def __init__(self):
        super().__init__()
    def forward(self, batch):
        batch.x, batch.EigVals = batch.EigVals, batch.x
        return batch
    

[docs] @register_network("SpecFormer") class SpecFormer(nn.Module): """ SpecFormer model. Adapted from https://github.com/DSL-Lab/Specformer Only supports the case where the input is a batch of graphs with the same number of nodes. Needs preprocessing for LapRaw positional encoding. Parameters: dim_in (int): Number of input features. dim_out (int): Number of output features. cfg (dict): Configuration dictionary containing model parameters from GraphGym. - cfg.gt.dim_hidden: Hidden dimension for GNN layers and SpecFormer layers. - cfg.gt.n_heads: Number of attention heads. - cfg.gt.dropout: Dropout rate for the model. - cfg.gt.attn_dropout: Dropout rate for the attention mechanism. - cfg.gnn.head: Type of head to use for the final output layer. Input: batch (torch_geometric.data.Batch): Input batch containing node features and graph structure. - batch.x (torch.Tensor): Input node features. - batch.edge_index (torch.Tensor): Edge indices of the graph. - batch.EigVals (torch.Tensor): Eigenvalues of the graph Laplacian. - batch.EigVecs (torch.Tensor): Eigenvectors of the graph Laplacian. Output: batch (task dependent type, see output head): Output after model processing. """ def __init__(self, dim_in, dim_out): super().__init__() self.encoder = FeatureEncoder(dim_in) dim_in = self.encoder.dim_in self.pre_mp = MLP(new_layer_config(dim_in = dim_in, dim_out = cfg.gt.dim_hidden, num_layers = 2, has_act = True, has_bias = True, cfg = cfg)) ### dirty hack self.swap1 = swapex() ### eig v self.eig_encoder = SineEncoder(cfg.gt.dim_hidden) self.mha_eig=Sequential('x',[ (LayerNorm(cfg.gt.dim_hidden), 'x -> x1'), (MultiHeadAttention(dim_hidden = cfg.gt.dim_hidden, n_heads = cfg.gt.n_heads, dropout = cfg.gt.attn_dropout), 'x1 -> x1'), (lambda x1, x2: self.aggregate_batches_add(x1, x2), 'x, x1 -> x') ]) self.ffn_eig=Sequential('x',[ (LayerNorm(cfg.gt.dim_hidden), 'x -> x1'), (MLP(new_layer_config(dim_in = cfg.gt.dim_hidden, dim_out = cfg.gt.dim_hidden, num_layers = 2, has_act = True, has_bias = True, cfg = cfg)), 'x1-> x1'), (lambda x1, x2: self.aggregate_batches_add(x1, x2), 'x, x1 -> x') ]) self.decoder=Linear(new_layer_config(dim_in = cfg.gt.dim_hidden, dim_out = cfg.gt.n_heads, num_layers = 0, has_act = True, has_bias = True, cfg = cfg)) ### eig ^ ### dirty hack self.swap2 = swapex() if cfg.gt.layer_norm: norm = 'layer' elif cfg.gt.batch_norm: norm = 'batch' else: norm = 'none' self.spec_layers = SpecLayer(dim_out = cfg.gt.dim_hidden, n_heads = cfg.gt.n_heads, dropout = cfg.gt.dropout, norm = norm) GNNHead = register.head_dict[cfg.gnn.head] self.post_mp = GNNHead(dim_in=cfg.gt.dim_hidden, dim_out=dim_out) def aggregate_batches_add(self, x1, x2): new_batch = x1.clone() new_batch.x = x1.x + x2.x return new_batch def forward(self, batch): for module in self.children(): batch = module(batch) return batch