from copy import deepcopy
import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric.utils import (get_laplacian, to_scipy_sparse_matrix,
to_undirected, to_dense_adj, scatter, to_networkx)
from torch_geometric.utils.num_nodes import maybe_num_nodes
from functools import partial
from opengt.encoder.graphormer_encoder import graphormer_pre_processing
from .rrwp import add_full_rrwp
[docs]
def custom_eigh(L):
"""Compute eigenvalues and eigenvectors of a Laplacian matrix.
Due to a bug in PyTorch, we use scipy's eigh instead of torch.linalg.eigh when matrix size is large.
Args:
L: Laplacian matrix (Tensor)
Returns:
EigVals: Eigenvalues
EigVecs: Eigenvectors
"""
if L.shape[0] > 1000:
# Use scipy's eigh for large matrices
EigVals, EigVecs = np.linalg.eigh(L.cpu().numpy())
return torch.from_numpy(EigVals).to(L.device), torch.from_numpy(EigVecs).to(L.device)
else:
# Use PyTorch's eigh for small matrices
return torch.linalg.eigh(L)
[docs]
def compute_posenc_stats(data, pe_types, is_undirected, cfg):
"""Precompute positional encodings for the given graph.
Supported PE statistics to precompute, selected by `pe_types`:
'LapPE': Laplacian eigen-decomposition.
'RWSE': Random walk landing probabilities (diagonals of RW matrices).
'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED)
'HKdiagSE': Diagonals of heat kernel diffusion.
'ElstaticSE': Kernel based on the electrostatic interaction between nodes.
'Graphormer': Computes spatial types and optionally edges along shortest paths.
'LapRaw': Laplacian eigen-decomposition without further processing.
'RRWP': Relative Random Walk Probabilities PE (for GRIT)
'WLSE': Weisfeiler-Lehman encoding.
Args:
data: PyG graph
pe_types: Positional encoding types to precompute statistics for.
This can also be a combination, e.g. 'eigen+rw_landing'
is_undirected: True if the graph is expected to be undirected
cfg: Main configuration node
Returns:
Extended PyG Data object.
"""
# Verify PE types.
for t in pe_types:
if t not in ['LapPE', 'EquivStableLapPE', 'SignNet', 'RWSE', 'HKdiagSE',
'HKfullPE', 'ElstaticSE', 'GraphormerBias', 'LapRaw', 'RRWP', 'WLSE']:
raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}")
# Basic preprocessing of the input graph.
if hasattr(data, 'num_nodes'):
N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa
else:
N = data.x.shape[0] # Number of nodes, including disconnected nodes.
laplacian_norm_type = cfg.posenc_LapPE.eigen.laplacian_norm.lower()
if laplacian_norm_type == 'none':
laplacian_norm_type = None
if is_undirected:
undir_edge_index = data.edge_index
else:
undir_edge_index = to_undirected(data.edge_index)
# Eigen values and vectors.
evals, evects = None, None
if 'LapPE' in pe_types or 'EquivStableLapPE' in pe_types:
# Eigen-decomposition with numpy, can be reused for Heat kernels.
L = to_scipy_sparse_matrix(
*get_laplacian(undir_edge_index, normalization=laplacian_norm_type,
num_nodes=N)
)
evals, evects = np.linalg.eigh(L.toarray())
if 'LapPE' in pe_types:
max_freqs=cfg.posenc_LapPE.eigen.max_freqs
eigvec_norm=cfg.posenc_LapPE.eigen.eigvec_norm
elif 'EquivStableLapPE' in pe_types:
max_freqs=cfg.posenc_EquivStableLapPE.eigen.max_freqs
eigvec_norm=cfg.posenc_EquivStableLapPE.eigen.eigvec_norm
data.EigVals, data.EigVecs = get_lap_decomp_stats(
evals=evals, evects=evects,
max_freqs=max_freqs,
eigvec_norm=eigvec_norm)
if 'SignNet' in pe_types:
# Eigen-decomposition with numpy for SignNet.
norm_type = cfg.posenc_SignNet.eigen.laplacian_norm.lower()
if norm_type == 'none':
norm_type = None
L = to_scipy_sparse_matrix(
*get_laplacian(undir_edge_index, normalization=norm_type,
num_nodes=N)
)
evals_sn, evects_sn = np.linalg.eigh(L.toarray())
data.eigvals_sn, data.eigvecs_sn = get_lap_decomp_stats(
evals=evals_sn, evects=evects_sn,
max_freqs=cfg.posenc_SignNet.eigen.max_freqs,
eigvec_norm=cfg.posenc_SignNet.eigen.eigvec_norm)
# Random Walks.
if 'RWSE' in pe_types:
kernel_param = cfg.posenc_RWSE.kernel
if len(kernel_param.times) == 0:
raise ValueError("List of kernel times required for RWSE")
rw_landing = get_rw_landing_probs(ksteps=kernel_param.times,
edge_index=data.edge_index,
num_nodes=N)
data.pestat_RWSE = rw_landing
# Heat Kernels.
if 'HKdiagSE' in pe_types or 'HKfullPE' in pe_types:
# Get the eigenvalues and eigenvectors of the regular Laplacian,
# if they have not yet been computed for 'eigen'.
if laplacian_norm_type is not None or evals is None or evects is None:
L_heat = to_scipy_sparse_matrix(
*get_laplacian(undir_edge_index, normalization=None, num_nodes=N)
)
evals_heat, evects_heat = np.linalg.eigh(L_heat.toarray())
else:
evals_heat, evects_heat = evals, evects
evals_heat = torch.from_numpy(evals_heat)
evects_heat = torch.from_numpy(evects_heat)
# Get the full heat kernels.
if 'HKfullPE' in pe_types:
# The heat kernels can't be stored in the Data object without
# additional padding because in PyG's collation of the graphs the
# sizes of tensors must match except in dimension 0. Do this when
# the full heat kernels are actually used downstream by an Encoder.
raise NotImplementedError()
# heat_kernels, hk_diag = get_heat_kernels(evects_heat, evals_heat,
# kernel_times=kernel_param.times)
# data.pestat_HKdiagSE = hk_diag
# Get heat kernel diagonals in more efficient way.
if 'HKdiagSE' in pe_types:
kernel_param = cfg.posenc_HKdiagSE.kernel
if len(kernel_param.times) == 0:
raise ValueError("Diffusion times are required for heat kernel")
hk_diag = get_heat_kernels_diag(evects_heat, evals_heat,
kernel_times=kernel_param.times,
space_dim=0)
data.pestat_HKdiagSE = hk_diag
# Electrostatic interaction inspired kernel.
if 'ElstaticSE' in pe_types:
elstatic = get_electrostatic_function_encoding(undir_edge_index, N)
data.pestat_ElstaticSE = elstatic
if 'GraphormerBias' in pe_types:
data = graphormer_pre_processing(
data,
cfg.posenc_GraphormerBias.num_spatial_types
)
if 'LapRaw' in pe_types:
def normalize_graph(g):
g = g + g.T
g[g > 0.] = 1.0
deg = g.sum(axis=1).reshape(-1)
deg[deg == 0.] = 1.0
deg = torch.diag(deg ** -0.5)
adj = deg @ g @ deg
L = torch.eye(g.shape[0],device=g.device) - adj
return L
def eigen_decompositon(g):
"The normalized (unit “length”) eigenvectors, "
"such that the column v[:,i] is the eigenvector corresponding to the eigenvalue w[i]."
g = normalize_graph(g)
e, u = custom_eigh(g)
return e, u
def feature_normalize(x):
rowsum = x.sum(axis=1, keepdims=True)
rowsum = torch.clip(rowsum, 1, 1e10)
return x / rowsum
data.EigVals, data.EigVecs = eigen_decompositon (to_dense_adj(undir_edge_index)[0])
# dataset.graph['node_feat'] = feature_normalize(dataset.graph['node_feat']).to(device)
if 'RRWP' in pe_types:
param = cfg.posenc_RRWP
transform = partial(add_full_rrwp,
walk_length=param.ksteps,
attr_name_abs="rrwp",
attr_name_rel="rrwp",
add_identity=True,
spd=param.spd, # by default False
)
data = transform(data)
if 'WLSE' in pe_types:
# Add the WLSE encoding to the graph
pecfg = cfg.posenc_WLSE
G = to_networkx(data, to_undirected=True)
edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None
hashlist = nx.weisfeiler_lehman_subgraph_hashes(G, edge_attr=edge_attr, iterations = pecfg.iterations)
# Create a mapping from the hashes to the node types
hash_to_type = {}
for i, h in enumerate(hashlist):
if h not in hash_to_type:
hash_to_type[h] = len(hash_to_type)
data.WLTag = torch.tensor([hash_to_type[h] for h in hashlist], dtype=torch.long)
data.WLTag = data.WLTag.view(-1, 1)
pecfg.num_types = len(hash_to_type)
return data
[docs]
def get_lap_decomp_stats(evals, evects, max_freqs, eigvec_norm='L2'):
"""Compute Laplacian eigen-decomposition-based PE stats of the given graph.
Args:
evals, evects: Precomputed eigen-decomposition
max_freqs: Maximum number of top smallest frequencies / eigenvecs to use
eigvec_norm: Normalization for the eigen vectors of the Laplacian
Returns:
Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node
Tensor (num_nodes, max_freqs) of eigenvector values per node
"""
N = len(evals) # Number of nodes, including disconnected nodes.
# Keep up to the maximum desired number of frequencies.
idx = evals.argsort()[:max_freqs]
evals, evects = evals[idx], np.real(evects[:, idx])
evals = torch.from_numpy(np.real(evals)).clamp_min(0)
# Normalize and pad eigen vectors.
evects = torch.from_numpy(evects).float()
evects = eigvec_normalizer(evects, evals, normalization=eigvec_norm)
if N < max_freqs:
EigVecs = F.pad(evects, (0, max_freqs - N), value=float('nan'))
else:
EigVecs = evects
# Pad and save eigenvalues.
if N < max_freqs:
EigVals = F.pad(evals, (0, max_freqs - N), value=float('nan')).unsqueeze(0)
else:
EigVals = evals.unsqueeze(0)
EigVals = EigVals.repeat(N, 1).unsqueeze(2)
return EigVals, EigVecs
[docs]
def get_rw_landing_probs(ksteps, edge_index, edge_weight=None,
num_nodes=None, space_dim=0):
"""Compute Random Walk landing probabilities for given list of K steps.
Args:
ksteps: List of k-steps for which to compute the RW landings
edge_index: PyG sparse representation of the graph
edge_weight: (optional) Edge weights
num_nodes: (optional) Number of nodes in the graph
space_dim: (optional) Estimated dimensionality of the space. Used to
correct the random-walk diagonal by a factor `k^(space_dim/2)`.
In euclidean space, this correction means that the height of
the gaussian distribution stays almost constant across the number of
steps, if `space_dim` is the dimension of the euclidean space.
Returns:
2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs
"""
if edge_weight is None:
edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)
num_nodes = maybe_num_nodes(edge_index, num_nodes)
source, dest = edge_index[0], edge_index[1]
deg = scatter(edge_weight, source, dim=0, dim_size=num_nodes, reduce='sum') # Out degrees.
deg_inv = deg.pow(-1.)
deg_inv.masked_fill_(deg_inv == float('inf'), 0)
if edge_index.numel() == 0:
P = edge_index.new_zeros((1, num_nodes, num_nodes))
else:
# P = D^-1 * A
P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes)
rws = []
if ksteps == list(range(min(ksteps), max(ksteps) + 1)):
# Efficient way if ksteps are a consecutive sequence (most of the time the case)
Pk = P.clone().detach().matrix_power(min(ksteps))
for k in range(min(ksteps), max(ksteps) + 1):
rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \
(k ** (space_dim / 2)))
Pk = Pk @ P
else:
# Explicitly raising P to power k for each k \in ksteps.
for k in ksteps:
rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \
(k ** (space_dim / 2)))
rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps)
return rw_landing
[docs]
def get_heat_kernels_diag(evects, evals, kernel_times=[], space_dim=0):
"""Compute Heat kernel diagonal.
This is a continuous function that represents a Gaussian in the Euclidean
space, and is the solution to the diffusion equation.
The random-walk diagonal should converge to this.
Args:
evects: Eigenvectors of the Laplacian matrix
evals: Eigenvalues of the Laplacian matrix
kernel_times: Time for the diffusion. Analogous to the k-steps in random
walk. The time is equivalent to the variance of the kernel.
space_dim: (optional) Estimated dimensionality of the space. Used to
correct the diffusion diagonal by a factor `t^(space_dim/2)`. In
euclidean space, this correction means that the height of the
gaussian stays constant across time, if `space_dim` is the dimension
of the euclidean space.
Returns:
2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs
"""
heat_kernels_diag = []
if len(kernel_times) > 0:
evects = F.normalize(evects, p=2., dim=0)
# Remove eigenvalues == 0 from the computation of the heat kernel
idx_remove = evals < 1e-8
evals = evals[~idx_remove]
evects = evects[:, ~idx_remove]
# Change the shapes for the computations
evals = evals.unsqueeze(-1) # lambda_{i, ..., ...}
evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node
# Compute the heat kernels diagonal only for each time
eigvec_mul = evects ** 2
for t in kernel_times:
# sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j} * phi_{i, j})
this_kernel = torch.sum(torch.exp(-t * evals) * eigvec_mul,
dim=0, keepdim=False)
# Multiply by `t` to stabilize the values, since the gaussian height
# is proportional to `1/t`
heat_kernels_diag.append(this_kernel * (t ** (space_dim / 2)))
heat_kernels_diag = torch.stack(heat_kernels_diag, dim=0).transpose(0, 1)
return heat_kernels_diag
[docs]
def get_heat_kernels(evects, evals, kernel_times=[]):
"""Compute full Heat diffusion kernels.
Args:
evects: Eigenvectors of the Laplacian matrix
evals: Eigenvalues of the Laplacian matrix
kernel_times: Time for the diffusion. Analogous to the k-steps in random
walk. The time is equivalent to the variance of the kernel.
"""
heat_kernels, rw_landing = [], []
if len(kernel_times) > 0:
evects = F.normalize(evects, p=2., dim=0)
# Remove eigenvalues == 0 from the computation of the heat kernel
idx_remove = evals < 1e-8
evals = evals[~idx_remove]
evects = evects[:, ~idx_remove]
# Change the shapes for the computations
evals = evals.unsqueeze(-1).unsqueeze(-1) # lambda_{i, ..., ...}
evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node
# Compute the heat kernels for each time
eigvec_mul = (evects.unsqueeze(2) * evects.unsqueeze(1)) # (phi_{i, j1, ...} * phi_{i, ..., j2})
for t in kernel_times:
# sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j1, ...} * phi_{i, ..., j2})
heat_kernels.append(
torch.sum(torch.exp(-t * evals) * eigvec_mul,
dim=0, keepdim=False)
)
heat_kernels = torch.stack(heat_kernels, dim=0) # (Num kernel times) x (Num nodes) x (Num nodes)
# Take the diagonal of each heat kernel,
# i.e. the landing probability of each of the random walks
rw_landing = torch.diagonal(heat_kernels, dim1=-2, dim2=-1).transpose(0, 1) # (Num nodes) x (Num kernel times)
return heat_kernels, rw_landing
[docs]
def get_electrostatic_function_encoding(edge_index, num_nodes):
"""Kernel based on the electrostatic interaction between nodes.
"""
L = to_scipy_sparse_matrix(
*get_laplacian(edge_index, normalization=None, num_nodes=num_nodes)
).todense()
L = torch.as_tensor(L)
Dinv = torch.eye(L.shape[0]) * (L.diag() ** -1)
A = deepcopy(L).abs()
A.fill_diagonal_(0)
DinvA = Dinv.matmul(A)
electrostatic = torch.pinverse(L)
electrostatic = electrostatic - electrostatic.diag()
green_encoding = torch.stack([
electrostatic.min(dim=0)[0], # Min of Vi -> j
electrostatic.max(dim=0)[0], # Max of Vi -> j
electrostatic.mean(dim=0), # Mean of Vi -> j
electrostatic.std(dim=0), # Std of Vi -> j
electrostatic.min(dim=1)[0], # Min of Vj -> i
electrostatic.max(dim=0)[0], # Max of Vj -> i
electrostatic.mean(dim=1), # Mean of Vj -> i
electrostatic.std(dim=1), # Std of Vj -> i
(DinvA * electrostatic).sum(dim=0), # Mean of interaction on direct neighbour
(DinvA * electrostatic).sum(dim=1), # Mean of interaction from direct neighbour
], dim=1)
return green_encoding
[docs]
def eigvec_normalizer(EigVecs, EigVals, normalization="L2", eps=1e-12):
"""
Implement different eigenvector normalizations.
"""
EigVals = EigVals.unsqueeze(0)
if normalization == "L1":
# L1 normalization: eigvec / sum(abs(eigvec))
denom = EigVecs.norm(p=1, dim=0, keepdim=True)
elif normalization == "L2":
# L2 normalization: eigvec / sqrt(sum(eigvec^2))
denom = EigVecs.norm(p=2, dim=0, keepdim=True)
elif normalization == "abs-max":
# AbsMax normalization: eigvec / max|eigvec|
denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values
elif normalization == "wavelength":
# AbsMax normalization, followed by wavelength multiplication:
# eigvec * pi / (2 * max|eigvec| * sqrt(eigval))
denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values
eigval_denom = torch.sqrt(EigVals)
eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0
denom = denom * eigval_denom * 2 / np.pi
elif normalization == "wavelength-asin":
# AbsMax normalization, followed by arcsin and wavelength multiplication:
# arcsin(eigvec / max|eigvec|) / sqrt(eigval)
denom_temp = torch.max(EigVecs.abs(), dim=0, keepdim=True).values.clamp_min(eps).expand_as(EigVecs)
EigVecs = torch.asin(EigVecs / denom_temp)
eigval_denom = torch.sqrt(EigVals)
eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0
denom = eigval_denom
elif normalization == "wavelength-soft":
# AbsSoftmax normalization, followed by wavelength multiplication:
# eigvec / (softmax|eigvec| * sqrt(eigval))
denom = (F.softmax(EigVecs.abs(), dim=0) * EigVecs.abs()).sum(dim=0, keepdim=True)
eigval_denom = torch.sqrt(EigVals)
eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0
denom = denom * eigval_denom
else:
raise ValueError(f"Unsupported normalization `{normalization}`")
denom = denom.clamp_min(eps).expand_as(EigVecs)
EigVecs = EigVecs / denom
return EigVecs