Source code for opengt.layer.spec_layer

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.graphgym.register as register
import torch_geometric.nn as pyg_nn
from torch_geometric.graphgym.models.layer import LayerConfig
from torch_geometric.graphgym.register import register_layer
from torch_scatter import scatter




[docs] @register_layer('SpecLayer') class SpecLayer(nn.Module): """ SpecFormer Layer. Adapted from https://github.com/DSL-Lab/Specformer Parameters: dim_out (int): Number of output features. n_heads (int): Number of attention heads. dropout (float): Dropout rate. Default: 0.0. norm (str): Normalization type. Options are 'none', 'layer', 'batch'. Default: 'none'. Input: batch.x (Tensor): Input node features. batch.EigVecs (Tensor): Eigenvectors of the graph Laplacian. batch.EigVals (Tensor): Eigenvalues of the graph Laplacian. Output: ret.x (Tensor): Output node features after applying the SpecLayer. """ def __init__(self, dim_out, n_heads, dropout=0.0, norm='none'): # nbases=n_heads+1, ncombines=dim_out super(SpecLayer, self).__init__() self.prop_dropout = nn.Dropout(dropout) if norm == 'none': self.weight = nn.Parameter(torch.ones((1, n_heads+1, dim_out))) else: self.weight = nn.Parameter(torch.empty((1, n_heads+1, dim_out))) nn.init.normal_(self.weight, mean=0.0, std=0.01) if norm == 'layer': # Arxiv self.norm = nn.LayerNorm(dim_out) elif norm == 'batch': # Penn self.norm = nn.BatchNorm1d(dim_out) else: # Others self.norm = None self.n_heads = n_heads def forward(self, batch): basic_feats = [batch.x] utx = batch.EigVecs.permute(1,0) @ batch.x for i in range(self.n_heads): basic_feats.append(batch.EigVecs @ (batch.EigVals[:, i].unsqueeze(1) * utx)) # [N, d] basic_feats = torch.stack(basic_feats, axis=1) # [N, m, d] # h = conv(basic_feats) # [SpecLayer(nheads+1, nclass, prop_dropout, norm=norm) for i in range(nlayer)] ######## done: rewrite logic basic_feats = self.prop_dropout(basic_feats) * self.weight # [N, m, d] * [1, m, d] basic_feats = torch.sum(basic_feats, dim=1) if self.norm is not None: basic_feats = self.norm(basic_feats) basic_feats = F.relu(basic_feats) ret = batch.clone() ret.x = basic_feats return ret