Source code for opengt.head.infer_links

import torch
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import register_head


[docs] @register_head('infer_links') class InferLinksHead(torch.nn.Module): """ InferLinks prediction head for graph prediction tasks. Parameters: dim_in (int): Input dimension. dim_out (int): Output dimension. For binary prediction, dim_out=1. Input: batch.x (torch.Tensor): Node features. batch.y (torch.Tensor): Edge labels. batch.complete_edge_index (torch.Tensor): Edge indices for the complete graph. Output: pred (torch.Tensor): Predicted edge labels. true (torch.Tensor): True edge labels. """ def __init__(self, dim_in, dim_out): super().__init__() if cfg.dataset.infer_link_label == "edge": dim_out = 2 else: raise ValueError(f"Infer-link task {cfg.dataset.infer_link_label} not available.") self.predictor = torch.nn.Linear(1, dim_out) def forward(self, batch): x = batch.x[batch.complete_edge_index] x = (x[0] * x[1]).sum(1) y = self.predictor(x.unsqueeze(1)) return y, batch.y