Source code for opengt.head.inductive_edge

import numpy as np
import torch
import torch.nn as nn
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import new_layer_config, MLP
from torch_geometric.graphgym.register import register_head


[docs] @register_head('inductive_edge') class GNNInductiveEdgeHead(nn.Module): """ GNN prediction head for inductive edge/link prediction tasks. Implementation adapted from the transductive GraphGym's GNNEdgeHead. Parameters: dim_in (int): Input dimension dim_out (int): Output dimension. For binary prediction, dim_out=1. Input: batch.x (torch.Tensor): Node features. batch.edge_label (torch.Tensor): Edge labels. Output: pred (torch.Tensor): Predicted edge labels. true (torch.Tensor): True edge labels. """ def __init__(self, dim_in, dim_out): super().__init__() # module to decode edges from node embeddings if cfg.model.edge_decoding == 'concat': self.layer_post_mp = MLP( new_layer_config(dim_in * 2, dim_out, cfg.gnn.layers_post_mp, has_act=False, has_bias=True, cfg=cfg)) # requires parameter self.decode_module = lambda v1, v2: \ self.layer_post_mp(torch.cat((v1, v2), dim=-1)) else: if dim_out > 1: raise ValueError( 'Binary edge decoding ({})is used for multi-class ' 'edge/link prediction.'.format(cfg.model.edge_decoding)) self.layer_post_mp = MLP( new_layer_config(dim_in, dim_in, cfg.gnn.layers_post_mp, has_act=False, has_bias=True, cfg=cfg)) if cfg.model.edge_decoding == 'dot': self.decode_module = lambda v1, v2: torch.sum(v1 * v2, dim=-1) elif cfg.model.edge_decoding == 'cosine_similarity': self.decode_module = nn.CosineSimilarity(dim=-1) else: raise ValueError( f'Unknown edge decoding {cfg.model.edge_decoding}.') def _apply_index(self, batch): return batch.x[batch.edge_index_labeled], batch.edge_label def forward(self, batch): if cfg.model.edge_decoding != 'concat': batch = self.layer_post_mp(batch) pred, true = self._apply_index(batch) nodes_first = pred[0] nodes_second = pred[1] pred = self.decode_module(nodes_first, nodes_second) if not self.training: # Compute extra stats when in evaluation mode. stats = self.compute_mrr(batch) return pred, true, stats else: return pred, true def compute_mrr(self, batch): if cfg.model.edge_decoding != 'dot': raise ValueError( f'Unsupported edge decoding {cfg.model.edge_decoding}.') stats = {} for data in batch.to_data_list(): # print(data.num_nodes) # print(data.edge_index_labeled) # print(data.edge_label) pred = data.x @ data.x.transpose(0, 1) # print(pred.shape) pos_edge_index = data.edge_index_labeled[:, data.edge_label == 1] num_pos_edges = pos_edge_index.shape[1] # print(pos_edge_index, num_pos_edges) pred_pos = pred[pos_edge_index[0], pos_edge_index[1]] # print(pred_pos) if num_pos_edges > 0: neg_mask = torch.ones([num_pos_edges, data.num_nodes], dtype=torch.bool) neg_mask[torch.arange(num_pos_edges), pos_edge_index[1]] = False pred_neg = pred[pos_edge_index[0]][neg_mask].view(num_pos_edges, -1) # print(pred_neg, pred_neg.shape) mrr_list = self._eval_mrr(pred_pos, pred_neg, 'torch') else: # Return empty stats. mrr_list = self._eval_mrr(pred_pos, pred_pos, 'torch') # print(mrr_list) for key, val in mrr_list.items(): if key.endswith('_list'): key = key[:-len('_list')] val = float(val.mean().item()) if np.isnan(val): val = 0. if key not in stats: stats[key] = [val] else: stats[key].append(val) # print(key, val) # print('-' * 80) # print('=' * 80, batch.split) batch_stats = {} for key, val in stats.items(): mean_val = sum(val) / len(val) batch_stats[key] = mean_val # print(f"{key}: {mean_val}") return batch_stats def _eval_mrr(self, y_pred_pos, y_pred_neg, type_info): """ Compute Hits@k and Mean Reciprocal Rank (MRR). Implementation from OGB: https://github.com/snap-stanford/ogb/blob/master/ogb/linkproppred/evaluate.py Args: y_pred_neg: array with shape (batch size, num_entities_neg). y_pred_pos: array with shape (batch size, ) """ if type_info == 'torch': y_pred = torch.cat([y_pred_pos.view(-1, 1), y_pred_neg], dim=1) argsort = torch.argsort(y_pred, dim=1, descending=True) ranking_list = torch.nonzero(argsort == 0, as_tuple=False) ranking_list = ranking_list[:, 1] + 1 hits1_list = (ranking_list <= 1).to(torch.float) hits3_list = (ranking_list <= 3).to(torch.float) hits10_list = (ranking_list <= 10).to(torch.float) mrr_list = 1. / ranking_list.to(torch.float) return {'hits@1_list': hits1_list, 'hits@3_list': hits3_list, 'hits@10_list': hits10_list, 'mrr_list': mrr_list} else: y_pred = np.concatenate([y_pred_pos.reshape(-1, 1), y_pred_neg], axis=1) argsort = np.argsort(-y_pred, axis=1) ranking_list = (argsort == 0).nonzero() ranking_list = ranking_list[1] + 1 hits1_list = (ranking_list <= 1).astype(np.float32) hits3_list = (ranking_list <= 3).astype(np.float32) hits10_list = (ranking_list <= 10).astype(np.float32) mrr_list = 1. / ranking_list.astype(np.float32) return {'hits@1_list': hits1_list, 'hits@3_list': hits3_list, 'hits@10_list': hits10_list, 'mrr_list': mrr_list}