Source code for opengt.encoder.ER_node_encoder

import torch
import torch.nn as nn
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_node_encoder


[docs] @register_node_encoder('ERN') class ERNodeEncoder(torch.nn.Module): """Effective Resistance Node Encoder ER of size dim_pe will get appended to each node feature vector. If `expand_x` set True, original node features will be first linearly projected to (dim_emb - dim_pe) size and the concatenated with ER. Parameters: dim_emb (int): Size of final node embedding expand_x (bool): Expand node features `x` from dim_in to (dim_emb - dim_pe) """ def __init__(self, dim_emb, expand_x=True): super().__init__() dim_in = cfg.share.dim_in # Expected original input node features dim pecfg = cfg.posenc_ERN dim_pe = pecfg.dim_pe # Size of Laplace PE embedding model_type = pecfg.model # Encoder NN model type for DEs if model_type not in ['Transformer', 'DeepSet', 'Linear']: raise ValueError(f"Unexpected PE model {model_type}") self.model_type = model_type n_layers = pecfg.layers # Num. layers in PE encoder model n_heads = pecfg.n_heads # Num. attention heads in Trf PE encoder post_n_layers = pecfg.post_layers # Num. layers to apply after pooling self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable er_dim = pecfg.er_dim if dim_emb - dim_pe < 1: raise ValueError(f"ER_Node size {dim_pe} is too large for " f"desired embedding size of {dim_emb}.") if expand_x: self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe) self.expand_x = expand_x if model_type == 'Linear': self.pe_encoder = nn.Linear(er_dim, dim_pe) else: if model_type == 'Transformer': # Initial projection of each value of ER embedding self.linear_A = nn.Linear(1, dim_pe) # Transformer model for ER_Node encoder_layer = nn.TransformerEncoderLayer(d_model=dim_pe, nhead=n_heads, batch_first=True) self.pe_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) else: # DeepSet model for ER_Node layers = [] if n_layers == 1: layers.append(nn.ReLU()) else: self.linear_A = nn.Linear(1, dim_pe) layers.append(nn.ReLU()) for _ in range(n_layers - 1): layers.append(nn.Linear(dim_pe, dim_pe)) layers.append(nn.ReLU()) self.pe_encoder = nn.Sequential(*layers) self.post_mlp = None if post_n_layers > 0: # MLP to apply post pooling layers = [] if post_n_layers == 1: layers.append(nn.Linear(dim_pe, dim_pe)) layers.append(nn.ReLU()) else: layers.append(nn.Linear(dim_pe, 2 * dim_pe)) layers.append(nn.ReLU()) for _ in range(post_n_layers - 2): layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe)) layers.append(nn.ReLU()) layers.append(nn.Linear(2 * dim_pe, dim_pe)) layers.append(nn.ReLU()) self.post_mlp = nn.Sequential(*layers) def forward(self, batch): if not hasattr(batch, 'er_emb'): raise ValueError("Precomputed ER embeddings required for calculating ER Node Encodings") pos_enc = batch.er_emb # N * er_dim if self.training: pos_enc = pos_enc[:, torch.randperm(pos_enc.size()[1])] if self.model_type == 'Linear': pos_enc = self.pe_encoder(pos_enc) # N * er_dim -> N * dim_pe else: pos_enc = torch.unsqueeze(pos_enc, 2) pos_enc = self.linear_A(pos_enc) # (Num nodes) x (er_dim) x dim_pe # PE encoder: a Transformer or DeepSet model if self.model_type == 'Transformer': pos_enc = self.pe_encoder(src=pos_enc) else: pos_enc = self.pe_encoder(pos_enc) # Sum pooling pos_enc = torch.sum(pos_enc, 1, keepdim=False) # (Num nodes) x dim_pe # MLP post pooling if self.post_mlp is not None: pos_enc = self.post_mlp(pos_enc) # (Num nodes) x dim_pe # Expand node features if needed if self.expand_x: h = self.linear_x(batch.x) else: h = batch.x # Concatenate final PEs to input embedding batch.x = torch.cat((h, pos_enc), 1) # Keep PE also separate in a variable (e.g. for skip connections to input) if self.pass_as_var: batch.pe_ern = pos_enc return batch