import torch
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import register_head
[docs]
@register_head('infer_links')
class InferLinksHead(torch.nn.Module):
"""
InferLinks prediction head for graph prediction tasks.
Parameters:
dim_in (int): Input dimension.
dim_out (int): Output dimension. For binary prediction, dim_out=1.
Input:
batch.x (torch.Tensor): Node features.
batch.y (torch.Tensor): Edge labels.
batch.complete_edge_index (torch.Tensor): Edge indices for the complete graph.
Output:
pred (torch.Tensor): Predicted edge labels.
true (torch.Tensor): True edge labels.
"""
def __init__(self, dim_in, dim_out):
super().__init__()
if cfg.dataset.infer_link_label == "edge":
dim_out = 2
else:
raise ValueError(f"Infer-link task {cfg.dataset.infer_link_label} not available.")
self.predictor = torch.nn.Linear(1, dim_out)
def forward(self, batch):
x = batch.x[batch.complete_edge_index]
x = (x[0] * x[1]).sum(1)
y = self.predictor(x.unsqueeze(1))
return y, batch.y