import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.graphgym.register import register_layer
from torch_sparse import SparseTensor, matmul
from torch_geometric.utils import degree
import math
import numpy as np
BIG_CONSTANT = 1e8
def create_projection_matrix(m, d, seed=0, scaling=0, struct_mode=False):
nb_full_blocks = int(m/d)
block_list = []
current_seed = seed
for _ in range(nb_full_blocks):
torch.manual_seed(current_seed)
if struct_mode:
q = create_products_of_givens_rotations(d, current_seed)
else:
unstructured_block = torch.randn((d, d))
q, _ = torch.qr(unstructured_block)
q = torch.t(q)
block_list.append(q)
current_seed += 1
remaining_rows = m - nb_full_blocks * d
if remaining_rows > 0:
torch.manual_seed(current_seed)
if struct_mode:
q = create_products_of_givens_rotations(d, current_seed)
else:
unstructured_block = torch.randn((d, d))
q, _ = torch.linalg.qr(unstructured_block)
q = torch.t(q)
block_list.append(q[0:remaining_rows])
final_matrix = torch.vstack(block_list)
current_seed += 1
torch.manual_seed(current_seed)
if scaling == 0:
multiplier = torch.norm(torch.randn((m, d)), dim=1)
elif scaling == 1:
multiplier = torch.sqrt(torch.tensor(float(d))) * torch.ones(m)
else:
raise ValueError("Scaling must be one of {0, 1}. Was %s" % scaling)
return torch.matmul(torch.diag(multiplier), final_matrix)
def create_products_of_givens_rotations(dim, seed):
nb_givens_rotations = dim * int(math.ceil(math.log(float(dim))))
q = np.eye(dim, dim)
np.random.seed(seed)
for _ in range(nb_givens_rotations):
random_angle = math.pi * np.random.uniform()
random_indices = np.random.choice(dim, 2)
index_i = min(random_indices[0], random_indices[1])
index_j = max(random_indices[0], random_indices[1])
slice_i = q[index_i]
slice_j = q[index_j]
new_slice_i = math.cos(random_angle) * slice_i + math.cos(random_angle) * slice_j
new_slice_j = -math.sin(random_angle) * slice_i + math.cos(random_angle) * slice_j
q[index_i] = new_slice_i
q[index_j] = new_slice_j
return torch.tensor(q, dtype=torch.float32)
def relu_kernel_transformation(data, is_query, projection_matrix=None, numerical_stabilizer=0.001):
del is_query
if projection_matrix is None:
return data.relu() + numerical_stabilizer
else:
ratio = 1.0 / torch.sqrt(
torch.tensor(projection_matrix.shape[0], torch.float32)
)
data_dash = ratio * torch.einsum("bnhd,md->bnhm", data, projection_matrix)
return data_dash.relu() + numerical_stabilizer
def softmax_kernel_transformation(data, is_query, projection_matrix=None, numerical_stabilizer=0.000001):
data_normalizer = 1.0 / torch.sqrt(torch.sqrt(torch.tensor(data.shape[-1], dtype=torch.float32)))
data = data_normalizer * data
ratio = 1.0 / torch.sqrt(torch.tensor(projection_matrix.shape[0], dtype=torch.float32))
data_dash = torch.einsum("bnhd,md->bnhm", data, projection_matrix)
diag_data = torch.square(data)
diag_data = torch.sum(diag_data, dim=len(data.shape)-1)
diag_data = diag_data / 2.0
diag_data = torch.unsqueeze(diag_data, dim=len(data.shape)-1)
last_dims_t = len(data_dash.shape) - 1
attention_dims_t = len(data_dash.shape) - 3
if is_query:
data_dash = ratio * (
torch.exp(data_dash - diag_data - torch.max(data_dash, dim=last_dims_t, keepdim=True)[0]) + numerical_stabilizer
)
else:
data_dash = ratio * (
torch.exp(data_dash - diag_data - torch.max(torch.max(data_dash, dim=last_dims_t, keepdim=True)[0],
dim=attention_dims_t, keepdim=True)[0]) + numerical_stabilizer
)
return data_dash
def numerator(qs, ks, vs):
kvs = torch.einsum("nbhm,nbhd->bhmd", ks, vs) # kvs refers to U_k in the paper
return torch.einsum("nbhm,bhmd->nbhd", qs, kvs)
def denominator(qs, ks):
all_ones = torch.ones([ks.shape[0]]).to(qs.device)
ks_sum = torch.einsum("nbhm,n->bhm", ks, all_ones) # ks_sum refers to O_k in the paper
return torch.einsum("nbhm,bhm->nbh", qs, ks_sum)
def numerator_gumbel(qs, ks, vs):
kvs = torch.einsum("nbhkm,nbhd->bhkmd", ks, vs) # kvs refers to U_k in the paper
return torch.einsum("nbhm,bhkmd->nbhkd", qs, kvs)
def denominator_gumbel(qs, ks):
all_ones = torch.ones([ks.shape[0]]).to(qs.device)
ks_sum = torch.einsum("nbhkm,n->bhkm", ks, all_ones) # ks_sum refers to O_k in the paper
return torch.einsum("nbhm,bhkm->nbhk", qs, ks_sum)
[docs]
def kernelized_softmax(query, key, value, kernel_transformation, projection_matrix=None, edge_index=None, tau=0.25, return_weight=True):
'''
fast computation of all-pair attentive aggregation with linear complexity
input: query/key/value [B, N, H, D]
return: updated node emb, attention weight (for computing edge loss)
B = graph number (always equal to 1 in Node Classification), N = node number, H = head number,
M = random feature dimension, D = hidden size
'''
query = query / math.sqrt(tau)
key = key / math.sqrt(tau)
query_prime = kernel_transformation(query, True, projection_matrix) # [B, N, H, M]
key_prime = kernel_transformation(key, False, projection_matrix) # [B, N, H, M]
query_prime = query_prime.permute(1, 0, 2, 3) # [N, B, H, M]
key_prime = key_prime.permute(1, 0, 2, 3) # [N, B, H, M]
value = value.permute(1, 0, 2, 3) # [N, B, H, D]
# compute updated node emb, this step requires O(N)
z_num = numerator(query_prime, key_prime, value)
z_den = denominator(query_prime, key_prime)
z_num = z_num.permute(1, 0, 2, 3) # [B, N, H, D]
z_den = z_den.permute(1, 0, 2)
z_den = torch.unsqueeze(z_den, len(z_den.shape))
z_output = z_num / z_den # [B, N, H, D]
if return_weight: # query edge prob for computing edge-level reg loss, this step requires O(E)
start, end = edge_index
query_end, key_start = query_prime[end], key_prime[start] # [E, B, H, M]
edge_attn_num = torch.einsum("ebhm,ebhm->ebh", query_end, key_start) # [E, B, H]
edge_attn_num = edge_attn_num.permute(1, 0, 2) # [B, E, H]
attn_normalizer = denominator(query_prime, key_prime) # [N, B, H]
edge_attn_dem = attn_normalizer[end] # [E, B, H]
edge_attn_dem = edge_attn_dem.permute(1, 0, 2) # [B, E, H]
A_weight = edge_attn_num / edge_attn_dem # [B, E, H]
return z_output, A_weight
else:
return z_output
[docs]
def kernelized_gumbel_softmax(query, key, value, kernel_transformation, projection_matrix=None, edge_index=None,
K=10, tau=0.25, return_weight=True):
'''
fast computation of all-pair attentive aggregation with linear complexity
input: query/key/value [B, N, H, D]
return: updated node emb, attention weight (for computing edge loss)
B = graph number (always equal to 1 in Node Classification), N = node number, H = head number,
M = random feature dimension, D = hidden size, K = number of Gumbel sampling
'''
query = query / math.sqrt(tau)
key = key / math.sqrt(tau)
query_prime = kernel_transformation(query, True, projection_matrix) # [B, N, H, M]
key_prime = kernel_transformation(key, False, projection_matrix) # [B, N, H, M]
query_prime = query_prime.permute(1, 0, 2, 3) # [N, B, H, M]
key_prime = key_prime.permute(1, 0, 2, 3) # [N, B, H, M]
value = value.permute(1, 0, 2, 3) # [N, B, H, D]
# compute updated node emb, this step requires O(N)
gumbels = (
-torch.empty(key_prime.shape[:-1]+(K, ), memory_format=torch.legacy_contiguous_format).exponential_().log()
).to(query.device) / tau # [N, B, H, K]
key_t_gumbel = key_prime.unsqueeze(3) * gumbels.exp().unsqueeze(4) # [N, B, H, K, M]
z_num = numerator_gumbel(query_prime, key_t_gumbel, value) # [N, B, H, K, D]
z_den = denominator_gumbel(query_prime, key_t_gumbel) # [N, B, H, K]
z_num = z_num.permute(1, 0, 2, 3, 4) # [B, N, H, K, D]
z_den = z_den.permute(1, 0, 2, 3) # [B, N, H, K]
z_den = torch.unsqueeze(z_den, len(z_den.shape))
z_output = torch.mean(z_num / z_den, dim=3) # [B, N, H, D]
if return_weight: # query edge prob for computing edge-level reg loss, this step requires O(E)
start, end = edge_index
query_end, key_start = query_prime[end], key_prime[start] # [E, B, H, M]
edge_attn_num = torch.einsum("ebhm,ebhm->ebh", query_end, key_start) # [E, B, H]
edge_attn_num = edge_attn_num.permute(1, 0, 2) # [B, E, H]
attn_normalizer = denominator(query_prime, key_prime) # [N, B, H]
edge_attn_dem = attn_normalizer[end] # [E, B, H]
edge_attn_dem = edge_attn_dem.permute(1, 0, 2) # [B, E, H]
A_weight = edge_attn_num / edge_attn_dem # [B, E, H]
return z_output, A_weight
else:
return z_output
[docs]
def add_conv_relational_bias(x, edge_index, b, trans='sigmoid'):
'''
compute updated result by the relational bias of input adjacency
the implementation is similar to the Graph Convolution Network with a (shared) scalar weight for each edge
'''
row, col = edge_index
d_in = degree(col, x.shape[1]).float()
d_norm_in = (1. / d_in[col]).sqrt()
d_out = degree(row, x.shape[1]).float()
d_norm_out = (1. / d_out[row]).sqrt()
conv_output = []
for i in range(x.shape[2]):
if trans == 'sigmoid':
b_i = b[i].sigmoid()
elif trans == 'identity':
b_i = b[i]
else:
raise NotImplementedError
value = torch.ones_like(row) * b_i * d_norm_in * d_norm_out
adj_i = SparseTensor(row=col, col=row, value=value, sparse_sizes=(x.shape[1], x.shape[1]))
conv_output.append( matmul(adj_i, x[:, :, i]) ) # [B, N, D]
conv_output = torch.stack(conv_output, dim=2) # [B, N, H, D]
return conv_output