Source code for opengt.encoder.feature_encoder

import torch
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import new_layer_config, BatchNorm1dNode
from torch_geometric.graphgym.register import register_layer


[docs] @register_layer("feature_encoder") class FeatureEncoder(torch.nn.Module): """ Encodes node and edge features. Receives the encoder type from the config file. Parameters: dim_in (int): Input feature dimension """ def __init__(self, dim_in): super(FeatureEncoder, self).__init__() self.dim_in = dim_in if cfg.dataset.node_encoder: # Encode integer node features via nn.Embeddings NodeEncoder = register.node_encoder_dict[ cfg.dataset.node_encoder_name] self.node_encoder = NodeEncoder(cfg.gnn.dim_inner) if cfg.dataset.node_encoder_bn: self.node_encoder_bn = BatchNorm1dNode( new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False, has_bias=False, cfg=cfg)) # Update dim_in to reflect the new dimension of the node features self.dim_in = cfg.gnn.dim_inner if cfg.dataset.edge_encoder: # Hard-limit max edge dim for PNA. if 'PNA' in cfg.gt.layer_type: cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner) else: cfg.gnn.dim_edge = cfg.gnn.dim_inner # Encode integer edge features via nn.Embeddings EdgeEncoder = register.edge_encoder_dict[ cfg.dataset.edge_encoder_name] self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge) if cfg.dataset.edge_encoder_bn: self.edge_encoder_bn = BatchNorm1dNode( new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False, has_bias=False, cfg=cfg)) def forward(self, batch): for module in self.children(): batch = module(batch) return batch