Source code for opengt.encoder.ER_edge_encoder

import torch
import torch.nn as nn
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_edge_encoder


[docs] @register_edge_encoder('ERE') class EREdgeEncoder(torch.nn.Module): """ Effective Resistance Edge Encoder This encoder calculates the effective resistance edge embedding mentioned in https://arxiv.org/pdf/2206.11941 Parameters: emb_dim (int): Size of final edge embedding use_edge_attr (bool): If True, the edge attributes are prepended to the effective resistance embedding. expand_edge_attr (bool): If True, the edge attributes are linearly projected to (dim_in - emb_dim) size. """ def __init__(self, emb_dim, use_edge_attr=False, expand_edge_attr=False): super().__init__() dim_in = cfg.gt.dim_edge # Expected final edge_dim pecfg = cfg.posenc_ERE n_layers = pecfg.layers # Num. layers in PE encoder model self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable self.use_edge_attr = use_edge_attr self.expand_edge_attr = expand_edge_attr if expand_edge_attr: self.linear_x = nn.Linear(dim_in, dim_in - emb_dim) if not self.use_edge_attr: assert emb_dim == dim_in layers = [] layers.append(nn.Linear(1, emb_dim)) layers.append(nn.ReLU()) if n_layers > 1: for _ in range(n_layers - 1): layers.append(nn.Linear(emb_dim, emb_dim)) layers.append(nn.ReLU()) self.er_encoder = nn.Sequential(*layers) def forward(self, batch): ere = self.er_encoder(batch.er_edge) if self.expand_edge_attr: batch.edge_attr = self.linear_x(batch.edge_attr) if self.use_edge_attr: batch.edge_attr = torch.cat([batch.edge_attr, ere], dim=1) else: batch.edge_attr = ere if self.pass_as_var: batch.er_edge = ere return batch