'''
The SPD encoder for GRIT
Adapted from https://github.com/LiamMa/GRIT
'''
import torch
from torch import nn
from torch.nn import functional as F
from ogb.utils.features import get_bond_feature_dims
import torch_sparse
from torch_geometric.graphgym.register import (
register_edge_encoder,
register_node_encoder,
)
from torch_geometric.utils import remove_self_loops
from torch_scatter import scatter
import warnings
def full_edge_index(edge_index, batch=None, total_nodes=None):
"""
Retunr the Full batched sparse adjacency matrices given by edge indices.
Returns batched sparse adjacency matrices with exactly those edges that
are not in the input `edge_index` while ignoring self-loops.
Implementation inspired by `torch_geometric.utils.to_dense_adj`
Parameters:
edge_index (torch.Tensor): The edge indices.
batch: Batch vector, which assigns each node to a specific example.
Returns:
Complementary edge index.
"""
if batch is None:
if total_nodes is None: total_nodes = edge_index.max().item() + 1
batch = edge_index.new_zeros(total_nodes)
batch_size = batch.max().item() + 1
one = batch.new_ones(batch.size(0))
num_nodes = scatter(one, batch,
dim=0, dim_size=batch_size, reduce='add')
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])
negative_index_list = []
for i in range(batch_size):
n = num_nodes[i].item()
size = [n, n]
adj = torch.ones(size, dtype=torch.short,
device=edge_index.device)
adj = adj.view(size)
_edge_index = adj.nonzero(as_tuple=False).t().contiguous()
# _edge_index, _ = remove_self_loops(_edge_index)
negative_index_list.append(_edge_index + cum_nodes[i])
edge_index_full = torch.cat(negative_index_list, dim=1).contiguous()
return edge_index_full
[docs]
@register_edge_encoder('spd_emb')
class SPDEdgeEncoder(torch.nn.Module):
'''
Shortest-path-distance (SPD) Embedding Encoder
Args:
in_dim (int): The input dimension of the edge features.
out_dim (int): The output dimension of the edge features.
batchnorm (bool): Whether to apply batch normalization. Default: False.
layernorm (bool): Whether to apply layer normalization. Default: False.
use_bias (bool): Whether to use bias in the linear layer. Default: False.
pad_to_full_graph (bool): Whether to pad to a full graph. Default: True.
pe_name (str): The name of the positional encoding. Default: "spd".
pad_0th (bool): Whether to pad the 0-th embedding with a zero vector. Default: False.
overwrite_old_attr (bool): Whether to overwrite old attributes. Default: False.
'''
def __init__(self, in_dim, out_dim,
batchnorm=False, layernorm=False, use_bias=False,
pad_to_full_graph=True,
pe_name="spd",
pad_0th=False, # 0-th embdding is a fixed zero-vector; For the case that 0 indicate the padding.
overwrite_old_attr=False,
):
super().__init__()
# note: batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info
self.in_dim = in_dim
self.out_dim = out_dim
self.pe_name = pe_name
if pad_0th:
self.spd_emb = nn.Embedding(in_dim+1, out_dim, padding_idx=0)
else:
self.spd_emb = nn.Embedding(in_dim, out_dim)
self.batchnorm = batchnorm
self.layernorm = layernorm
if self.batchnorm or self.layernorm:
warnings.warn("batchnorm/layernorm might ruin some properties of pe on providing shortest-path distance info ")
self.pad_to_full_graph = pad_to_full_graph
self.overwrite_old_attr = overwrite_old_attr
if self.batchnorm:
self.bn = nn.BatchNorm1d(out_dim)
if self.layernorm:
self.ln = nn.LayerNorm(out_dim)
def forward(self, batch):
k_idx, k_val = f'{self.pe_name}_index', f'{self.pe_name}_val'
rel_idx = batch[k_idx].type(torch.long)
rel_val = batch[k_val].type(torch.long)
edge_index = batch.edge_index.type(torch.long)
edge_attr = batch.get('edge_attr', None)
rel_val = self.spd_emb(rel_val)
if edge_attr is None:
edge_attr = rel_val.new_full((edge_index.size(1), rel_val.size(1)), 0)
# zero padding for non-existing edges
# if self.overwrite_old_attr:
out_idx, out_val = rel_idx, rel_val
# else:
# out_idx, out_val = torch_sparse.coalesce(
# torch.cat([edge_index, rel_idx], dim=1),
# torch.cat([edge_attr, rel_val], dim=0),
# batch.num_nodes, batch.num_nodes,
# op="add"
# )
if self.pad_to_full_graph:
edge_index_full = full_edge_index(out_idx, batch=batch.batch, total_nodes=batch.num_nodes)
edge_attr_pad = rel_val.new_full((edge_index_full.size(1), rel_val.size(1)), 0)
# zero padding to fully-connected graphs
out_idx = torch.cat([out_idx, edge_index_full], dim=1)
out_val = torch.cat([out_val, edge_attr_pad], dim=0)
out_idx, out_val = torch_sparse.coalesce(
out_idx, out_val, batch.num_nodes, batch.num_nodes,
op="add"
)
if self.batchnorm:
out_val = self.bn(out_val)
if self.layernorm:
out_val = self.ln(out_val)
batch.edge_index, batch.edge_attr = out_idx, out_val
return batch
def __repr__(self):
return f"{self.__class__.__name__}" \
f"(pad_to_full_graph={self.pad_to_full_graph}," \
f"{self.fc.__repr__()})"