Source code for opengt.transform.posenc_stats

from copy import deepcopy

import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric.utils import (get_laplacian, to_scipy_sparse_matrix,
                                   to_undirected, to_dense_adj, scatter, to_networkx)
from torch_geometric.utils.num_nodes import maybe_num_nodes
from functools import partial
from opengt.encoder.graphormer_encoder import graphormer_pre_processing
from .rrwp import add_full_rrwp

[docs] def custom_eigh(L): """Compute eigenvalues and eigenvectors of a Laplacian matrix. Due to a bug in PyTorch, we use scipy's eigh instead of torch.linalg.eigh when matrix size is large. Args: L: Laplacian matrix (Tensor) Returns: EigVals: Eigenvalues EigVecs: Eigenvectors """ if L.shape[0] > 1000: # Use scipy's eigh for large matrices EigVals, EigVecs = np.linalg.eigh(L.cpu().numpy()) return torch.from_numpy(EigVals).to(L.device), torch.from_numpy(EigVecs).to(L.device) else: # Use PyTorch's eigh for small matrices return torch.linalg.eigh(L)
[docs] def compute_posenc_stats(data, pe_types, is_undirected, cfg): """Precompute positional encodings for the given graph. Supported PE statistics to precompute, selected by `pe_types`: 'LapPE': Laplacian eigen-decomposition. 'RWSE': Random walk landing probabilities (diagonals of RW matrices). 'HKfullPE': Full heat kernels and their diagonals. (NOT IMPLEMENTED) 'HKdiagSE': Diagonals of heat kernel diffusion. 'ElstaticSE': Kernel based on the electrostatic interaction between nodes. 'Graphormer': Computes spatial types and optionally edges along shortest paths. 'LapRaw': Laplacian eigen-decomposition without further processing. 'RRWP': Relative Random Walk Probabilities PE (for GRIT) 'WLSE': Weisfeiler-Lehman encoding. Args: data: PyG graph pe_types: Positional encoding types to precompute statistics for. This can also be a combination, e.g. 'eigen+rw_landing' is_undirected: True if the graph is expected to be undirected cfg: Main configuration node Returns: Extended PyG Data object. """ # Verify PE types. for t in pe_types: if t not in ['LapPE', 'EquivStableLapPE', 'SignNet', 'RWSE', 'HKdiagSE', 'HKfullPE', 'ElstaticSE', 'GraphormerBias', 'LapRaw', 'RRWP', 'WLSE']: raise ValueError(f"Unexpected PE stats selection {t} in {pe_types}") # Basic preprocessing of the input graph. if hasattr(data, 'num_nodes'): N = data.num_nodes # Explicitly given number of nodes, e.g. ogbg-ppa else: N = data.x.shape[0] # Number of nodes, including disconnected nodes. laplacian_norm_type = cfg.posenc_LapPE.eigen.laplacian_norm.lower() if laplacian_norm_type == 'none': laplacian_norm_type = None if is_undirected: undir_edge_index = data.edge_index else: undir_edge_index = to_undirected(data.edge_index) # Eigen values and vectors. evals, evects = None, None if 'LapPE' in pe_types or 'EquivStableLapPE' in pe_types: # Eigen-decomposition with numpy, can be reused for Heat kernels. L = to_scipy_sparse_matrix( *get_laplacian(undir_edge_index, normalization=laplacian_norm_type, num_nodes=N) ) evals, evects = np.linalg.eigh(L.toarray()) if 'LapPE' in pe_types: max_freqs=cfg.posenc_LapPE.eigen.max_freqs eigvec_norm=cfg.posenc_LapPE.eigen.eigvec_norm elif 'EquivStableLapPE' in pe_types: max_freqs=cfg.posenc_EquivStableLapPE.eigen.max_freqs eigvec_norm=cfg.posenc_EquivStableLapPE.eigen.eigvec_norm data.EigVals, data.EigVecs = get_lap_decomp_stats( evals=evals, evects=evects, max_freqs=max_freqs, eigvec_norm=eigvec_norm) if 'SignNet' in pe_types: # Eigen-decomposition with numpy for SignNet. norm_type = cfg.posenc_SignNet.eigen.laplacian_norm.lower() if norm_type == 'none': norm_type = None L = to_scipy_sparse_matrix( *get_laplacian(undir_edge_index, normalization=norm_type, num_nodes=N) ) evals_sn, evects_sn = np.linalg.eigh(L.toarray()) data.eigvals_sn, data.eigvecs_sn = get_lap_decomp_stats( evals=evals_sn, evects=evects_sn, max_freqs=cfg.posenc_SignNet.eigen.max_freqs, eigvec_norm=cfg.posenc_SignNet.eigen.eigvec_norm) # Random Walks. if 'RWSE' in pe_types: kernel_param = cfg.posenc_RWSE.kernel if len(kernel_param.times) == 0: raise ValueError("List of kernel times required for RWSE") rw_landing = get_rw_landing_probs(ksteps=kernel_param.times, edge_index=data.edge_index, num_nodes=N) data.pestat_RWSE = rw_landing # Heat Kernels. if 'HKdiagSE' in pe_types or 'HKfullPE' in pe_types: # Get the eigenvalues and eigenvectors of the regular Laplacian, # if they have not yet been computed for 'eigen'. if laplacian_norm_type is not None or evals is None or evects is None: L_heat = to_scipy_sparse_matrix( *get_laplacian(undir_edge_index, normalization=None, num_nodes=N) ) evals_heat, evects_heat = np.linalg.eigh(L_heat.toarray()) else: evals_heat, evects_heat = evals, evects evals_heat = torch.from_numpy(evals_heat) evects_heat = torch.from_numpy(evects_heat) # Get the full heat kernels. if 'HKfullPE' in pe_types: # The heat kernels can't be stored in the Data object without # additional padding because in PyG's collation of the graphs the # sizes of tensors must match except in dimension 0. Do this when # the full heat kernels are actually used downstream by an Encoder. raise NotImplementedError() # heat_kernels, hk_diag = get_heat_kernels(evects_heat, evals_heat, # kernel_times=kernel_param.times) # data.pestat_HKdiagSE = hk_diag # Get heat kernel diagonals in more efficient way. if 'HKdiagSE' in pe_types: kernel_param = cfg.posenc_HKdiagSE.kernel if len(kernel_param.times) == 0: raise ValueError("Diffusion times are required for heat kernel") hk_diag = get_heat_kernels_diag(evects_heat, evals_heat, kernel_times=kernel_param.times, space_dim=0) data.pestat_HKdiagSE = hk_diag # Electrostatic interaction inspired kernel. if 'ElstaticSE' in pe_types: elstatic = get_electrostatic_function_encoding(undir_edge_index, N) data.pestat_ElstaticSE = elstatic if 'GraphormerBias' in pe_types: data = graphormer_pre_processing( data, cfg.posenc_GraphormerBias.num_spatial_types ) if 'LapRaw' in pe_types: def normalize_graph(g): g = g + g.T g[g > 0.] = 1.0 deg = g.sum(axis=1).reshape(-1) deg[deg == 0.] = 1.0 deg = torch.diag(deg ** -0.5) adj = deg @ g @ deg L = torch.eye(g.shape[0],device=g.device) - adj return L def eigen_decompositon(g): "The normalized (unit “length”) eigenvectors, " "such that the column v[:,i] is the eigenvector corresponding to the eigenvalue w[i]." g = normalize_graph(g) e, u = custom_eigh(g) return e, u def feature_normalize(x): rowsum = x.sum(axis=1, keepdims=True) rowsum = torch.clip(rowsum, 1, 1e10) return x / rowsum data.EigVals, data.EigVecs = eigen_decompositon (to_dense_adj(undir_edge_index)[0]) # dataset.graph['node_feat'] = feature_normalize(dataset.graph['node_feat']).to(device) if 'RRWP' in pe_types: param = cfg.posenc_RRWP transform = partial(add_full_rrwp, walk_length=param.ksteps, attr_name_abs="rrwp", attr_name_rel="rrwp", add_identity=True, spd=param.spd, # by default False ) data = transform(data) if 'WLSE' in pe_types: # Add the WLSE encoding to the graph pecfg = cfg.posenc_WLSE G = to_networkx(data, to_undirected=True) edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None hashlist = nx.weisfeiler_lehman_subgraph_hashes(G, edge_attr=edge_attr, iterations = pecfg.iterations) # Create a mapping from the hashes to the node types hash_to_type = {} for i, h in enumerate(hashlist): if h not in hash_to_type: hash_to_type[h] = len(hash_to_type) data.WLTag = torch.tensor([hash_to_type[h] for h in hashlist], dtype=torch.long) data.WLTag = data.WLTag.view(-1, 1) pecfg.num_types = len(hash_to_type) return data
[docs] def get_lap_decomp_stats(evals, evects, max_freqs, eigvec_norm='L2'): """Compute Laplacian eigen-decomposition-based PE stats of the given graph. Args: evals, evects: Precomputed eigen-decomposition max_freqs: Maximum number of top smallest frequencies / eigenvecs to use eigvec_norm: Normalization for the eigen vectors of the Laplacian Returns: Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node Tensor (num_nodes, max_freqs) of eigenvector values per node """ N = len(evals) # Number of nodes, including disconnected nodes. # Keep up to the maximum desired number of frequencies. idx = evals.argsort()[:max_freqs] evals, evects = evals[idx], np.real(evects[:, idx]) evals = torch.from_numpy(np.real(evals)).clamp_min(0) # Normalize and pad eigen vectors. evects = torch.from_numpy(evects).float() evects = eigvec_normalizer(evects, evals, normalization=eigvec_norm) if N < max_freqs: EigVecs = F.pad(evects, (0, max_freqs - N), value=float('nan')) else: EigVecs = evects # Pad and save eigenvalues. if N < max_freqs: EigVals = F.pad(evals, (0, max_freqs - N), value=float('nan')).unsqueeze(0) else: EigVals = evals.unsqueeze(0) EigVals = EigVals.repeat(N, 1).unsqueeze(2) return EigVals, EigVecs
[docs] def get_rw_landing_probs(ksteps, edge_index, edge_weight=None, num_nodes=None, space_dim=0): """Compute Random Walk landing probabilities for given list of K steps. Args: ksteps: List of k-steps for which to compute the RW landings edge_index: PyG sparse representation of the graph edge_weight: (optional) Edge weights num_nodes: (optional) Number of nodes in the graph space_dim: (optional) Estimated dimensionality of the space. Used to correct the random-walk diagonal by a factor `k^(space_dim/2)`. In euclidean space, this correction means that the height of the gaussian distribution stays almost constant across the number of steps, if `space_dim` is the dimension of the euclidean space. Returns: 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs """ if edge_weight is None: edge_weight = torch.ones(edge_index.size(1), device=edge_index.device) num_nodes = maybe_num_nodes(edge_index, num_nodes) source, dest = edge_index[0], edge_index[1] deg = scatter(edge_weight, source, dim=0, dim_size=num_nodes, reduce='sum') # Out degrees. deg_inv = deg.pow(-1.) deg_inv.masked_fill_(deg_inv == float('inf'), 0) if edge_index.numel() == 0: P = edge_index.new_zeros((1, num_nodes, num_nodes)) else: # P = D^-1 * A P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes) # 1 x (Num nodes) x (Num nodes) rws = [] if ksteps == list(range(min(ksteps), max(ksteps) + 1)): # Efficient way if ksteps are a consecutive sequence (most of the time the case) Pk = P.clone().detach().matrix_power(min(ksteps)) for k in range(min(ksteps), max(ksteps) + 1): rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * \ (k ** (space_dim / 2))) Pk = Pk @ P else: # Explicitly raising P to power k for each k \in ksteps. for k in ksteps: rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \ (k ** (space_dim / 2))) rw_landing = torch.cat(rws, dim=0).transpose(0, 1) # (Num nodes) x (K steps) return rw_landing
[docs] def get_heat_kernels_diag(evects, evals, kernel_times=[], space_dim=0): """Compute Heat kernel diagonal. This is a continuous function that represents a Gaussian in the Euclidean space, and is the solution to the diffusion equation. The random-walk diagonal should converge to this. Args: evects: Eigenvectors of the Laplacian matrix evals: Eigenvalues of the Laplacian matrix kernel_times: Time for the diffusion. Analogous to the k-steps in random walk. The time is equivalent to the variance of the kernel. space_dim: (optional) Estimated dimensionality of the space. Used to correct the diffusion diagonal by a factor `t^(space_dim/2)`. In euclidean space, this correction means that the height of the gaussian stays constant across time, if `space_dim` is the dimension of the euclidean space. Returns: 2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs """ heat_kernels_diag = [] if len(kernel_times) > 0: evects = F.normalize(evects, p=2., dim=0) # Remove eigenvalues == 0 from the computation of the heat kernel idx_remove = evals < 1e-8 evals = evals[~idx_remove] evects = evects[:, ~idx_remove] # Change the shapes for the computations evals = evals.unsqueeze(-1) # lambda_{i, ..., ...} evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node # Compute the heat kernels diagonal only for each time eigvec_mul = evects ** 2 for t in kernel_times: # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j} * phi_{i, j}) this_kernel = torch.sum(torch.exp(-t * evals) * eigvec_mul, dim=0, keepdim=False) # Multiply by `t` to stabilize the values, since the gaussian height # is proportional to `1/t` heat_kernels_diag.append(this_kernel * (t ** (space_dim / 2))) heat_kernels_diag = torch.stack(heat_kernels_diag, dim=0).transpose(0, 1) return heat_kernels_diag
[docs] def get_heat_kernels(evects, evals, kernel_times=[]): """Compute full Heat diffusion kernels. Args: evects: Eigenvectors of the Laplacian matrix evals: Eigenvalues of the Laplacian matrix kernel_times: Time for the diffusion. Analogous to the k-steps in random walk. The time is equivalent to the variance of the kernel. """ heat_kernels, rw_landing = [], [] if len(kernel_times) > 0: evects = F.normalize(evects, p=2., dim=0) # Remove eigenvalues == 0 from the computation of the heat kernel idx_remove = evals < 1e-8 evals = evals[~idx_remove] evects = evects[:, ~idx_remove] # Change the shapes for the computations evals = evals.unsqueeze(-1).unsqueeze(-1) # lambda_{i, ..., ...} evects = evects.transpose(0, 1) # phi_{i,j}: i-th eigvec X j-th node # Compute the heat kernels for each time eigvec_mul = (evects.unsqueeze(2) * evects.unsqueeze(1)) # (phi_{i, j1, ...} * phi_{i, ..., j2}) for t in kernel_times: # sum_{i>0}(exp(-2 t lambda_i) * phi_{i, j1, ...} * phi_{i, ..., j2}) heat_kernels.append( torch.sum(torch.exp(-t * evals) * eigvec_mul, dim=0, keepdim=False) ) heat_kernels = torch.stack(heat_kernels, dim=0) # (Num kernel times) x (Num nodes) x (Num nodes) # Take the diagonal of each heat kernel, # i.e. the landing probability of each of the random walks rw_landing = torch.diagonal(heat_kernels, dim1=-2, dim2=-1).transpose(0, 1) # (Num nodes) x (Num kernel times) return heat_kernels, rw_landing
[docs] def get_electrostatic_function_encoding(edge_index, num_nodes): """Kernel based on the electrostatic interaction between nodes. """ L = to_scipy_sparse_matrix( *get_laplacian(edge_index, normalization=None, num_nodes=num_nodes) ).todense() L = torch.as_tensor(L) Dinv = torch.eye(L.shape[0]) * (L.diag() ** -1) A = deepcopy(L).abs() A.fill_diagonal_(0) DinvA = Dinv.matmul(A) electrostatic = torch.pinverse(L) electrostatic = electrostatic - electrostatic.diag() green_encoding = torch.stack([ electrostatic.min(dim=0)[0], # Min of Vi -> j electrostatic.max(dim=0)[0], # Max of Vi -> j electrostatic.mean(dim=0), # Mean of Vi -> j electrostatic.std(dim=0), # Std of Vi -> j electrostatic.min(dim=1)[0], # Min of Vj -> i electrostatic.max(dim=0)[0], # Max of Vj -> i electrostatic.mean(dim=1), # Mean of Vj -> i electrostatic.std(dim=1), # Std of Vj -> i (DinvA * electrostatic).sum(dim=0), # Mean of interaction on direct neighbour (DinvA * electrostatic).sum(dim=1), # Mean of interaction from direct neighbour ], dim=1) return green_encoding
[docs] def eigvec_normalizer(EigVecs, EigVals, normalization="L2", eps=1e-12): """ Implement different eigenvector normalizations. """ EigVals = EigVals.unsqueeze(0) if normalization == "L1": # L1 normalization: eigvec / sum(abs(eigvec)) denom = EigVecs.norm(p=1, dim=0, keepdim=True) elif normalization == "L2": # L2 normalization: eigvec / sqrt(sum(eigvec^2)) denom = EigVecs.norm(p=2, dim=0, keepdim=True) elif normalization == "abs-max": # AbsMax normalization: eigvec / max|eigvec| denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values elif normalization == "wavelength": # AbsMax normalization, followed by wavelength multiplication: # eigvec * pi / (2 * max|eigvec| * sqrt(eigval)) denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values eigval_denom = torch.sqrt(EigVals) eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 denom = denom * eigval_denom * 2 / np.pi elif normalization == "wavelength-asin": # AbsMax normalization, followed by arcsin and wavelength multiplication: # arcsin(eigvec / max|eigvec|) / sqrt(eigval) denom_temp = torch.max(EigVecs.abs(), dim=0, keepdim=True).values.clamp_min(eps).expand_as(EigVecs) EigVecs = torch.asin(EigVecs / denom_temp) eigval_denom = torch.sqrt(EigVals) eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 denom = eigval_denom elif normalization == "wavelength-soft": # AbsSoftmax normalization, followed by wavelength multiplication: # eigvec / (softmax|eigvec| * sqrt(eigval)) denom = (F.softmax(EigVecs.abs(), dim=0) * EigVecs.abs()).sum(dim=0, keepdim=True) eigval_denom = torch.sqrt(EigVals) eigval_denom[EigVals < eps] = 1 # Problem with eigval = 0 denom = denom * eigval_denom else: raise ValueError(f"Unsupported normalization `{normalization}`") denom = denom.clamp_min(eps).expand_as(EigVecs) EigVecs = EigVecs / denom return EigVecs