Source code for opengt.encoder.composed_encoders

import torch
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.encoder import AtomEncoder
from torch_geometric.graphgym.register import register_node_encoder

from opengt.encoder.ast_encoder import ASTNodeEncoder
from opengt.encoder.kernel_pos_encoder import RWSENodeEncoder, \
    HKdiagSENodeEncoder, ElstaticSENodeEncoder
from opengt.encoder.laplace_pos_encoder import LapPENodeEncoder
from opengt.encoder.ppa_encoder import PPANodeEncoder
from opengt.encoder.signnet_pos_encoder import SignNetNodeEncoder
from opengt.encoder.voc_superpixels_encoder import VOCNodeEncoder
from opengt.encoder.type_dict_encoder import TypeDictNodeEncoder
from opengt.encoder.linear_node_encoder import LinearNodeEncoder
from opengt.encoder.equivstable_laplace_pos_encoder import EquivStableLapPENodeEncoder
from opengt.encoder.graphormer_encoder import GraphormerEncoder
from opengt.encoder.wl_encoder import WLSENodeEncoder


[docs] def concat_node_encoders(encoder_classes, pe_enc_names): """ A factory that creates a new Encoder class that concatenates functionality of the given list of two or three Encoder classes. First Encoder is expected to be a dataset-specific encoder, and the rest PE Encoders. Parameters: encoder_classes: List of node encoder classes pe_enc_names: List of PE embedding Encoder names, used to query a dict with their desired PE embedding dims. That dict can only be created during the runtime, once the config is loaded. Returns: new node encoder class """ class Concat2NodeEncoder(torch.nn.Module): """ Encoder that concatenates two node encoders. """ enc1_cls = None enc2_cls = None enc2_name = None def __init__(self, dim_emb): super().__init__() if cfg.posenc_EquivStableLapPE.enable: # Special handling for Equiv_Stable LapPE where node feats and PE are not concat self.encoder1 = self.enc1_cls(dim_emb) self.encoder2 = self.enc2_cls(dim_emb) else: # PE dims can only be gathered once the cfg is loaded. enc2_dim_pe = getattr(cfg, f"posenc_{self.enc2_name}").dim_pe self.encoder1 = self.enc1_cls(dim_emb - enc2_dim_pe) self.encoder2 = self.enc2_cls(dim_emb, expand_x=False) def forward(self, batch): batch = self.encoder1(batch) batch = self.encoder2(batch) return batch class Concat3NodeEncoder(torch.nn.Module): """ Encoder that concatenates three node encoders. """ enc1_cls = None enc2_cls = None enc2_name = None enc3_cls = None enc3_name = None def __init__(self, dim_emb): super().__init__() # PE dims can only be gathered once the cfg is loaded. enc2_dim_pe = getattr(cfg, f"posenc_{self.enc2_name}").dim_pe enc3_dim_pe = getattr(cfg, f"posenc_{self.enc3_name}").dim_pe self.encoder1 = self.enc1_cls(dim_emb - enc2_dim_pe - enc3_dim_pe) self.encoder2 = self.enc2_cls(dim_emb - enc3_dim_pe, expand_x=False) self.encoder3 = self.enc3_cls(dim_emb, expand_x=False) def forward(self, batch): batch = self.encoder1(batch) batch = self.encoder2(batch) batch = self.encoder3(batch) return batch # Configure the correct concatenation class and return it. if len(encoder_classes) == 2: Concat2NodeEncoder.enc1_cls = encoder_classes[0] Concat2NodeEncoder.enc2_cls = encoder_classes[1] Concat2NodeEncoder.enc2_name = pe_enc_names[0] return Concat2NodeEncoder elif len(encoder_classes) == 3: Concat3NodeEncoder.enc1_cls = encoder_classes[0] Concat3NodeEncoder.enc2_cls = encoder_classes[1] Concat3NodeEncoder.enc3_cls = encoder_classes[2] Concat3NodeEncoder.enc2_name = pe_enc_names[0] Concat3NodeEncoder.enc3_name = pe_enc_names[1] return Concat3NodeEncoder else: raise ValueError(f"Does not support concatenation of " f"{len(encoder_classes)} encoder classes.")
# Dataset-specific node encoders. ds_encs = {'Atom': AtomEncoder, 'ASTNode': ASTNodeEncoder, 'PPANode': PPANodeEncoder, 'TypeDictNode': TypeDictNodeEncoder, 'VOCNode': VOCNodeEncoder, 'LinearNode': LinearNodeEncoder} # Positional Encoding node encoders. pe_encs = {'LapPE': LapPENodeEncoder, 'RWSE': RWSENodeEncoder, 'HKdiagSE': HKdiagSENodeEncoder, 'ElstaticSE': ElstaticSENodeEncoder, 'SignNet': SignNetNodeEncoder, 'EquivStableLapPE': EquivStableLapPENodeEncoder, 'GraphormerBias': GraphormerEncoder, 'WLSE': WLSENodeEncoder,} # Concat dataset-specific and PE encoders. for ds_enc_name, ds_enc_cls in ds_encs.items(): for pe_enc_name, pe_enc_cls in pe_encs.items(): register_node_encoder( f"{ds_enc_name}+{pe_enc_name}", concat_node_encoders([ds_enc_cls, pe_enc_cls], [pe_enc_name]) ) # Combine both LapPE and RWSE positional encodings. for ds_enc_name, ds_enc_cls in ds_encs.items(): register_node_encoder( f"{ds_enc_name}+LapPE+RWSE", concat_node_encoders([ds_enc_cls, LapPENodeEncoder, RWSENodeEncoder], ['LapPE', 'RWSE']) ) # Combine both SignNet and RWSE positional encodings. for ds_enc_name, ds_enc_cls in ds_encs.items(): register_node_encoder( f"{ds_enc_name}+SignNet+RWSE", concat_node_encoders([ds_enc_cls, SignNetNodeEncoder, RWSENodeEncoder], ['SignNet', 'RWSE']) ) # Combine GraphormerBias with LapPE or RWSE positional encodings. for ds_enc_name, ds_enc_cls in ds_encs.items(): register_node_encoder( f"{ds_enc_name}+GraphormerBias+LapPE", concat_node_encoders([ds_enc_cls, GraphormerEncoder, LapPENodeEncoder], ['GraphormerBias', 'LapPE']) ) register_node_encoder( f"{ds_enc_name}+GraphormerBias+RWSE", concat_node_encoders([ds_enc_cls, GraphormerEncoder, RWSENodeEncoder], ['GraphormerBias', 'RWSE']) )