Source code for opengt.encoder.rrwp_encoder

'''
    The RRWP encoder for GRIT
    Adapted from https://github.com/LiamMa/GRIT
'''
import torch
from torch import nn
from torch.nn import functional as F
from ogb.utils.features import get_bond_feature_dims
import torch_sparse

import torch_geometric as pyg
from torch_geometric.graphgym.register import (
    register_edge_encoder,
    register_node_encoder,
)

from torch_geometric.utils import remove_self_loops, add_remaining_self_loops, add_self_loops
from torch_scatter import scatter
import warnings

[docs] def full_edge_index(edge_index, batch=None): """ Returns the Full batched sparse adjacency matrices given by edge indices. Returns batched sparse adjacency matrices with exactly those edges that are not in the input `edge_index` while ignoring self-loops. Implementation inspired by `torch_geometric.utils.to_dense_adj` Parameters: edge_index (torch.Tensor): The edge indices. batch: Batch vector, which assigns each node to a specific example. Returns: Complementary edge index. """ if batch is None: batch = edge_index.new_zeros(edge_index.max().item() + 1) batch_size = batch.max().item() + 1 one = batch.new_ones(batch.size(0)) num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='add') cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) negative_index_list = [] for i in range(batch_size): n = num_nodes[i].item() size = [n, n] adj = torch.ones(size, dtype=torch.short, device=edge_index.device) adj = adj.view(size) _edge_index = adj.nonzero(as_tuple=False).t().contiguous() # _edge_index, _ = remove_self_loops(_edge_index) negative_index_list.append(_edge_index + cum_nodes[i]) edge_index_full = torch.cat(negative_index_list, dim=1).contiguous() return edge_index_full
[docs] @register_node_encoder('rrwp_linear') class RRWPLinearNodeEncoder(torch.nn.Module): """ FC_1(RRWP) + FC_2 (Node-attr) note: FC_2 is given by the Typedict encoder of node-attr in some cases Parameters: num_classes (int) - the number of classes for the embedding mapping to learn """ def __init__(self, emb_dim, out_dim, use_bias=False, batchnorm=False, layernorm=False, pe_name="rrwp"): super().__init__() self.batchnorm = batchnorm self.layernorm = layernorm self.name = pe_name self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) torch.nn.init.xavier_uniform_(self.fc.weight) if self.batchnorm: self.bn = nn.BatchNorm1d(out_dim) if self.layernorm: self.ln = nn.LayerNorm(out_dim) def forward(self, batch): # Encode just the first dimension if more exist rrwp = batch[f"{self.name}"] rrwp = self.fc(rrwp) if self.batchnorm: rrwp = self.bn(rrwp) if self.layernorm: rrwp = self.ln(rrwp) if "x" in batch: batch.x = batch.x + rrwp else: batch.x = rrwp return batch
[docs] @register_edge_encoder('rrwp_linear') class RRWPLinearEdgeEncoder(torch.nn.Module): ''' Merge RRWP with given edge-attr and Zero-padding to all pairs of node FC_1(RRWP) + FC_2(edge-attr) - FC_2 given by the TypedictEncoder in same cases - Zero-padding for non-existing edges in fully-connected graph - (optional) add node-attr as the E_{i,i}'s attr note: assuming node-attr and edge-attr is with the same dimension after Encoders ''' def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias=False, pad_to_full_graph=True, fill_value=0., add_node_attr_as_self_loop=False, overwrite_old_attr=False): super().__init__() # note: batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info self.emb_dim = emb_dim self.out_dim = out_dim self.add_node_attr_as_self_loop = add_node_attr_as_self_loop self.overwrite_old_attr=overwrite_old_attr # remove the old edge-attr self.batchnorm = batchnorm self.layernorm = layernorm if self.batchnorm or self.layernorm: warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) torch.nn.init.xavier_uniform_(self.fc.weight) self.pad_to_full_graph = pad_to_full_graph self.fill_value = 0. padding = torch.ones(1, out_dim, dtype=torch.float) * fill_value self.register_buffer("padding", padding) if self.batchnorm: self.bn = nn.BatchNorm1d(out_dim) if self.layernorm: self.ln = nn.LayerNorm(out_dim) def forward(self, batch): rrwp_idx = batch.rrwp_index rrwp_val = batch.rrwp_val edge_index = batch.edge_index edge_attr = batch.edge_attr rrwp_val = self.fc(rrwp_val) if edge_attr is None: edge_attr = edge_index.new_zeros(edge_index.size(1), rrwp_val.size(1)) # zero padding for non-existing edges if self.overwrite_old_attr: out_idx, out_val = rrwp_idx, rrwp_val else: # edge_index, edge_attr = add_remaining_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) edge_index, edge_attr = add_self_loops(edge_index, edge_attr, num_nodes=batch.num_nodes, fill_value=0.) out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), batch.num_nodes, batch.num_nodes, op="add" ) if self.pad_to_full_graph: edge_index_full = full_edge_index(out_idx, batch=batch.batch) edge_attr_pad = self.padding.repeat(edge_index_full.size(1), 1) # zero padding to fully-connected graphs out_idx = torch.cat([out_idx, edge_index_full], dim=1) out_val = torch.cat([out_val, edge_attr_pad], dim=0) out_idx, out_val = torch_sparse.coalesce( out_idx, out_val, batch.num_nodes, batch.num_nodes, op="add" ) if self.batchnorm: out_val = self.bn(out_val) if self.layernorm: out_val = self.ln(out_val) batch.edge_index, batch.edge_attr = out_idx, out_val return batch def __repr__(self): return f"{self.__class__.__name__}" \ f"(pad_to_full_graph={self.pad_to_full_graph}," \ f"fill_value={self.fill_value}," \ f"{self.fc.__repr__()})"
[docs] @register_edge_encoder('masked_rrwp_linear') class RRWPLinearEdgeMaskedEncoder(torch.nn.Module): ''' RRWP Linear + Sparse-Attention Masking ''' def __init__(self, emb_dim, out_dim, batchnorm=False, layernorm=False, use_bias=False, fill_value=0., add_node_attr_as_self_loop=False, overwrite_old_attr=False, mask_index_name="edge_index", ): super().__init__() # note: batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info self.emb_dim = emb_dim self.out_dim = out_dim self.add_node_attr_as_self_loop = add_node_attr_as_self_loop self.overwrite_old_attr=overwrite_old_attr # remove the old edge-attr self.mask_index_name = mask_index_name self.batchnorm = batchnorm self.layernorm = layernorm if self.batchnorm or self.layernorm: warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ") self.fc = nn.Linear(emb_dim, out_dim, bias=use_bias) torch.nn.init.xavier_uniform_(self.fc.weight) self.fill_value = 0. padding = torch.ones(1, out_dim, dtype=torch.float) * fill_value self.register_buffer("padding", padding) if self.batchnorm: self.bn = nn.BatchNorm1d(out_dim) if self.layernorm: self.ln = nn.LayerNorm(out_dim) def forward(self, batch): rrwp_idx = batch.rrwp_index rrwp_val = batch.rrwp_val edge_index = batch.edge_index edge_attr = batch.edge_attr rrwp_val = self.fc(rrwp_val) mask_index = batch.get(self.mask_index_name, None) num_nodes = batch.num_nodes if edge_attr is None: edge_attr = edge_index.new_zeros(edge_index.size(1), rrwp_val.size(1)) # zero padding for non-existing edges if self.overwrite_old_attr: out_idx, out_val = rrwp_idx, rrwp_val else: out_idx, out_val = torch_sparse.coalesce( torch.cat([edge_index, rrwp_idx], dim=1), torch.cat([edge_attr, rrwp_val], dim=0), batch.num_nodes, batch.num_nodes, op="add" ) if mask_index is not None: mask_index, _ = add_remaining_self_loops(mask_index, None, num_nodes=batch.num_nodes) mask_val = mask_index.new_full((mask_index.size(1), ), 1) mask_comp = mask_index.new_full((out_idx.size(1), ), 0) mask_pad = mask_index.new_full((mask_index.size(1), out_val.size(1)), 0) _, masking = torch_sparse.coalesce( torch.cat([mask_index, out_idx], dim=1), torch.cat([mask_val, mask_comp], dim=0), m=num_nodes, n=num_nodes, op="max", ) out_idx, out_val = torch_sparse.coalesce( torch.cat([mask_index, out_idx], dim=1), torch.cat([mask_pad, out_val], dim=0), batch.num_nodes, batch.num_nodes, op="add" ) masking = masking.type(torch.bool) out_idx, out_val = out_idx[:, masking], out_val[masking] if self.batchnorm: out_val = self.bn(out_val) if self.layernorm: out_val = self.ln(out_val) batch.edge_index, batch.edge_attr = out_idx, out_val return batch def __repr__(self): return f"{self.__class__.__name__}" \ f"(pad_to_full_graph={self.pad_to_full_graph}," \ f"fill_value={self.fill_value}," \ f"{self.fc.__repr__()})"
[docs] @register_edge_encoder('pad_to_full_graph') class PadToFullGraphEdgeEncoder(torch.nn.Module): ''' Padding to Full Attention ''' def __init__(self,**kwargs): super().__init__() # note: batchnorm/layernorm might damage some properties of pe on providing shortest-path distance info self.pad_to_full_graph = True def forward(self, batch): edge_index = batch.edge_index edge_attr = batch.edge_attr out_idx, out_val = edge_index, edge_attr if self.pad_to_full_graph: edge_index_full = full_edge_index(out_idx, batch=batch.batch) # edge_attr_pad = self.padding.repeat(edge_index_full.size(1), 1) edge_attr_pad = edge_attr.new_zeros(edge_index_full.size(1), edge_attr.size(1)) # zero padding to fully-connected graphs out_idx = torch.cat([out_idx, edge_index_full], dim=1) out_val = torch.cat([out_val, edge_attr_pad], dim=0) out_idx, out_val = torch_sparse.coalesce( out_idx, out_val, batch.num_nodes, batch.num_nodes, op="add" ) batch.edge_index, batch.edge_attr = out_idx, out_val return batch def __repr__(self): return f"{self.__class__.__name__}" \ f"(pad_to_full_graph={self.pad_to_full_graph}," \ f"fill_value={self.fill_value}," \ f"{self.fc.__repr__()})"