import torch
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import new_layer_config, BatchNorm1dNode
from torch_geometric.graphgym.register import register_layer
[docs]
@register_layer("feature_encoder")
class FeatureEncoder(torch.nn.Module):
"""
Encodes node and edge features.
Receives the encoder type from the config file.
Parameters:
dim_in (int): Input feature dimension
"""
def __init__(self, dim_in):
super(FeatureEncoder, self).__init__()
self.dim_in = dim_in
if cfg.dataset.node_encoder:
# Encode integer node features via nn.Embeddings
NodeEncoder = register.node_encoder_dict[
cfg.dataset.node_encoder_name]
self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)
if cfg.dataset.node_encoder_bn:
self.node_encoder_bn = BatchNorm1dNode(
new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False,
has_bias=False, cfg=cfg))
# Update dim_in to reflect the new dimension of the node features
self.dim_in = cfg.gnn.dim_inner
if cfg.dataset.edge_encoder:
# Hard-limit max edge dim for PNA.
if 'PNA' in cfg.gt.layer_type:
cfg.gnn.dim_edge = min(128, cfg.gnn.dim_inner)
else:
cfg.gnn.dim_edge = cfg.gnn.dim_inner
# Encode integer edge features via nn.Embeddings
EdgeEncoder = register.edge_encoder_dict[
cfg.dataset.edge_encoder_name]
self.edge_encoder = EdgeEncoder(cfg.gnn.dim_edge)
if cfg.dataset.edge_encoder_bn:
self.edge_encoder_bn = BatchNorm1dNode(
new_layer_config(cfg.gnn.dim_edge, -1, -1, has_act=False,
has_bias=False, cfg=cfg))
def forward(self, batch):
for module in self.children():
batch = module(batch)
return batch