import torch
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import register_edge_encoder
[docs]
@register_edge_encoder('LinearEdge')
class LinearEdgeEncoder(torch.nn.Module):
"""
Linear edge encoder that applies a linear transformation to the edge features.
Parameters:
emb_dim (int): The dimension of the output edge features.
"""
def __init__(self, emb_dim):
super().__init__()
if cfg.dataset.name in ['MNIST', 'CIFAR10']:
self.in_dim = 1
else:
raise ValueError("Input edge feature dim is required to be hardset "
"or refactored to use a cfg option.")
self.encoder = torch.nn.Linear(self.in_dim, emb_dim)
def forward(self, batch):
batch.edge_attr = self.encoder(batch.edge_attr.view(-1, self.in_dim))
return batch