Source code for opengt.encoder.linear_node_encoder

import torch
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import register_node_encoder


[docs] @register_node_encoder('LinearNode') class LinearNodeEncoder(torch.nn.Module): """ Linear node encoder that applies a linear transformation to the input features. Parameters: emb_dim (int): The dimension of the output node features. """ def __init__(self, emb_dim): super().__init__() self.encoder = torch.nn.Linear(cfg.share.dim_in, emb_dim) def forward(self, batch): batch.x = self.encoder(batch.x) return batch