Source code for opengt.encoder.kernel_pos_encoder

import torch
import torch.nn as nn
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_node_encoder


[docs] class KernelPENodeEncoder(torch.nn.Module): """Configurable kernel-based Positional Encoding node encoder. The choice of which kernel-based statistics to use is configurable through setting of `kernel_type`. Based on this, the appropriate config is selected, and also the appropriate variable with precomputed kernel stats is then selected from PyG Data graphs in `forward` function. E.g., supported are 'RWSE', 'HKdiagSE', 'ElstaticSE'. PE of size `dim_pe` will get appended to each node feature vector. If `expand_x` set True, original node features will be first linearly projected to (dim_emb - dim_pe) size and the concatenated with PE. Parameters: dim_emb (int): Size of final node embedding expand_x (bool): Expand node features `x` from dim_in to (dim_emb - dim_pe) """ kernel_type = None # Instantiated type of the KernelPE, e.g. RWSE def __init__(self, dim_emb, expand_x=True): super().__init__() if self.kernel_type is None: raise ValueError(f"{self.__class__.__name__} has to be " f"preconfigured by setting 'kernel_type' class" f"variable before calling the constructor.") dim_in = cfg.share.dim_in # Expected original input node features dim pecfg = getattr(cfg, f"posenc_{self.kernel_type}") dim_pe = pecfg.dim_pe # Size of the kernel-based PE embedding num_rw_steps = len(pecfg.kernel.times) model_type = pecfg.model.lower() # Encoder NN model type for PEs n_layers = pecfg.layers # Num. layers in PE encoder model norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable if dim_emb - dim_pe < 0: # formerly 1, but you could have zero feature size raise ValueError(f"PE dim size {dim_pe} is too large for " f"desired embedding size of {dim_emb}.") if expand_x and dim_emb - dim_pe > 0: self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe) self.expand_x = expand_x and dim_emb - dim_pe > 0 if norm_type == 'batchnorm': self.raw_norm = nn.BatchNorm1d(num_rw_steps) else: self.raw_norm = None activation = nn.ReLU # register.act_dict[cfg.gnn.act] if model_type == 'mlp': layers = [] if n_layers == 1: layers.append(nn.Linear(num_rw_steps, dim_pe)) layers.append(activation()) else: layers.append(nn.Linear(num_rw_steps, 2 * dim_pe)) layers.append(activation()) for _ in range(n_layers - 2): layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe)) layers.append(activation()) layers.append(nn.Linear(2 * dim_pe, dim_pe)) layers.append(activation()) self.pe_encoder = nn.Sequential(*layers) elif model_type == 'linear': self.pe_encoder = nn.Linear(num_rw_steps, dim_pe) else: raise ValueError(f"{self.__class__.__name__}: Does not support " f"'{model_type}' encoder model.") def forward(self, batch): pestat_var = f"pestat_{self.kernel_type}" if not hasattr(batch, pestat_var): raise ValueError(f"Precomputed '{pestat_var}' variable is " f"required for {self.__class__.__name__}; set " f"config 'posenc_{self.kernel_type}.enable' to " f"True, and also set 'posenc.kernel.times' values") pos_enc = getattr(batch, pestat_var) # (Num nodes) x (Num kernel times) # pos_enc = batch.rw_landing # (Num nodes) x (Num kernel times) if self.raw_norm: pos_enc = self.raw_norm(pos_enc) pos_enc = self.pe_encoder(pos_enc) # (Num nodes) x dim_pe # Expand node features if needed if self.expand_x: h = self.linear_x(batch.x) else: h = batch.x # Concatenate final PEs to input embedding batch.x = torch.cat((h, pos_enc), 1) # Keep PE also separate in a variable (e.g. for skip connections to input) if self.pass_as_var: setattr(batch, f'pe_{self.kernel_type}', pos_enc) return batch
[docs] @register_node_encoder('RWSE') class RWSENodeEncoder(KernelPENodeEncoder): """Random Walk Structural Encoding node encoder. """ kernel_type = 'RWSE'
[docs] @register_node_encoder('HKdiagSE') class HKdiagSENodeEncoder(KernelPENodeEncoder): """Heat kernel (diagonal) Structural Encoding node encoder. """ kernel_type = 'HKdiagSE'
[docs] @register_node_encoder('ElstaticSE') class ElstaticSENodeEncoder(KernelPENodeEncoder): """Electrostatic interactions Structural Encoding node encoder. """ kernel_type = 'ElstaticSE'