Source code for opengt.transform.task_preprocessing

import torch


def shuffle(tensor):
    idx = torch.randperm(len(tensor))
    return tensor[idx]


[docs] def task_specific_preprocessing(data, cfg): """Task-specific preprocessing before the dataset is logged and finalized. Parameters: data: PyG graph cfg: Main configuration node Returns: Extended PyG Data object. """ if cfg.gnn.head == "infer_links": N = data.x.size(0) idx = torch.arange(N, dtype=torch.long) complete_index = torch.stack([idx.repeat_interleave(N), idx.repeat(N)], 0) data.edge_attr = None if cfg.dataset.infer_link_label == "edge": labels = torch.empty(N, N, dtype=torch.long) non_edge_index = (complete_index.T.unsqueeze(1) != data.edge_index.T).any(2).all(1).nonzero()[:, 0] non_edge_index = shuffle(non_edge_index)[:data.edge_index.size(1)] edge_index = (complete_index.T.unsqueeze(1) == data.edge_index.T).all(2).any(1).nonzero()[:, 0] final_index = shuffle(torch.cat([edge_index, non_edge_index])) data.complete_edge_index = complete_index[:, final_index] labels.fill_(0) labels[data.edge_index[0], data.edge_index[1]] = 1 assert labels.flatten()[final_index].mean(dtype=torch.float) == 0.5 else: raise ValueError(f"Infer-link task {cfg.dataset.infer_link_label} not available.") data.y = labels.flatten()[final_index] supported_encoding_available = ( cfg.posenc_LapPE.enable or cfg.posenc_RWSE.enable or cfg.posenc_GraphormerBias.enable ) if cfg.dataset.name == "TRIANGLES": # If encodings are present they can append to the empty data.x if not supported_encoding_available: data.x = torch.zeros((data.x.size(0), 1)) data.y = data.y.sub(1).to(torch.long) if cfg.dataset.name == "CSL": # If encodings are present they can append to the empty data.x if not supported_encoding_available: data.x = torch.zeros((data.num_nodes, 1)) else: data.x = torch.zeros((data.num_nodes, 0)) return data