import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_node_encoder
from torch_geometric.utils import to_dense_adj, to_networkx
# Permutes from (batch, node, node, head) to (batch, head, node, node)
BATCH_HEAD_NODE_NODE = (0, 3, 1, 2)
# Inserts a leading 0 row and a leading 0 column with F.pad
INSERT_GRAPH_TOKEN = (1, 0, 1, 0)
[docs]
def graphormer_pre_processing(data, distance):
"""Implementation of Graphormer pre-processing. Computes in- and out-degrees
for node encodings, as well as spatial types (via shortest-path lengths) and
prepares edge encodings along shortest paths. The function adds the following
properties to the data object:
- spatial_types
- graph_index: An edge_index type tensor that contains all possible directed edges
(see more below)
- shortest_path_types: Populates edge attributes along all shortest paths between two nodes
Similar to the adjacency matrix, any matrix can be batched in PyG by decomposing it
into a 1D tensor of values and a 2D tensor of indices. Once batched, the graph-specific
matrix can be recovered (while appropriately padded) via ``to_dense_adj``. We use this
concept to decompose the spatial type matrix and the shortest path edge type tensor
via the ``graph_index`` tensor.
Parameters:
data (torch_geometric.data.Data): A PyG data object holding a single graph
distance (int): The distance up to which types are calculated
Returns:
data (torch_geometric.data.Data): The augmented data object.
"""
graph: nx.DiGraph = to_networkx(data)
data.in_degrees = torch.tensor([d for _, d in graph.in_degree()])
data.out_degrees = torch.tensor([d for _, d in graph.out_degree()])
max_in_degree = torch.max(data.in_degrees).item()
max_out_degree = torch.max(data.out_degrees).item()
if max_in_degree >= cfg.posenc_GraphormerBias.num_in_degrees:
print(
f"Encountered in_degree: {max_in_degree}, setting posenc_"
f"GraphormerBias.num_in_degrees to at least {max_in_degree + 1}"
)
if max_out_degree >= cfg.posenc_GraphormerBias.num_out_degrees:
print(
f"Encountered out_degree: {max_out_degree}, setting posenc_"
f"GraphormerBias.num_out_degrees to at least {max_out_degree + 1}"
)
cfg.posenc_GraphormerBias.num_in_degrees = max(cfg.posenc_GraphormerBias.num_in_degrees, max_in_degree + 1)
cfg.posenc_GraphormerBias.num_out_degrees = max(cfg.posenc_GraphormerBias.num_out_degrees, max_out_degree + 1)
if cfg.posenc_GraphormerBias.node_degrees_only:
return data
N = len(graph.nodes)
shortest_paths = dict(nx.shortest_path(graph)) # returns dict{S:dict{T:list[nodes on ST]}}, SSSP for each starting node
spatial_types = torch.empty(N ** 2, dtype=torch.long).fill_(distance)
graph_index = torch.empty(2, N ** 2, dtype=torch.long)
if hasattr(data, "edge_attr") and data.edge_attr is not None:
shortest_path_types = torch.zeros(N ** 2, distance, dtype=torch.long)
edge_attr = torch.zeros(N, N, dtype=torch.long)
if len(data.edge_attr.shape) == 1:
edge_attr[data.edge_index[0], data.edge_index[1]] = data.edge_attr
else:
edge_attr[data.edge_index[0], data.edge_index[1]] = data.edge_attr[:, 0]
for i in range(N):
for j in range(N):
graph_index[0, i * N + j] = i
graph_index[1, i * N + j] = j
for i, paths in shortest_paths.items():
for j, path in paths.items():
if len(path) > distance:
path = path[:distance]
assert len(path) >= 1
spatial_types[i * N + j] = len(path) - 1
if len(path) > 1 and hasattr(data, "edge_attr") and data.edge_attr is not None:
path_attr = [
edge_attr[path[k], path[k + 1]] for k in
range(len(path) - 1) # len(path) * (num_edge_types)
]
# list of edge_attr for each edge in the path
# We map each edge-encoding-distance pair to a distinct value
# and so obtain dist * num_edge_features many encodings
shortest_path_types[i * N + j, :len(path) - 1] = torch.tensor(
path_attr, dtype=torch.long)
data.spatial_types = spatial_types
data.graph_index = graph_index
if hasattr(data, "edge_attr") and data.edge_attr is not None:
data.shortest_path_types = shortest_path_types
return data
class BiasEncoder(torch.nn.Module):
def __init__(self, num_heads: int, num_spatial_types: int,
num_edge_types: int, use_graph_token: bool = True):
"""Implementation of the bias encoder of Graphormer.
This encoder is based on the implementation at:
https://github.com/microsoft/Graphormer/tree/v1.0
Note that this refers to v1 of Graphormer.
Parameters:
num_heads (int): The number of heads of the Graphormer model
num_spatial_types (int): The total number of different spatial types
num_edge_types (int): The total number of different edge types
use_graph_token (bool): If True, pads the attn_bias to account for the
additional graph token that can be added by the ``NodeEncoder``.
"""
super().__init__()
self.num_heads = num_heads
# Takes into account disconnected nodes
self.spatial_encoder = torch.nn.Embedding(
num_spatial_types + 1, num_heads)
self.edge_dis_encoder = torch.nn.Embedding(
num_spatial_types * num_heads * num_heads, 1)
self.edge_encoder = torch.nn.Embedding(num_edge_types, num_heads)
self.use_graph_token = use_graph_token
if self.use_graph_token:
self.graph_token = torch.nn.Parameter(torch.zeros(1, num_heads, 1))
self.reset_parameters()
def reset_parameters(self):
self.spatial_encoder.weight.data.normal_(std=0.02)
self.edge_encoder.weight.data.normal_(std=0.02)
self.edge_dis_encoder.weight.data.normal_(std=0.02)
if self.use_graph_token:
self.graph_token.data.normal_(std=0.02)
def forward(self, data):
"""Computes the bias matrix that can be induced into multi-head attention
via the attention mask.
Adds the tensor ``attn_bias`` to the data object, optionally accounting
for the graph token.
"""
# To convert 2D matrices to dense-batch mode, one needs to decompose
# them into index and value. One example is the adjacency matrix
# but this generalizes actually to any 2D matrix
spatial_types: torch.Tensor = self.spatial_encoder(data.spatial_types)
spatial_encodings = to_dense_adj(data.graph_index,
data.batch,
spatial_types)
bias = spatial_encodings.permute(BATCH_HEAD_NODE_NODE)
if hasattr(data, "shortest_path_types"):
edge_types: torch.Tensor = self.edge_encoder(
data.shortest_path_types)
edge_encodings = to_dense_adj(data.graph_index,
data.batch,
edge_types)
spatial_distances = to_dense_adj(data.graph_index,
data.batch,
data.spatial_types)
spatial_distances = spatial_distances.float().clamp(min=1.0).unsqueeze(1)
B, N, _, max_dist, H = edge_encodings.shape
edge_encodings = edge_encodings.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.num_heads)
edge_encodings = torch.bmm(edge_encodings, self.edge_dis_encoder.weight.reshape(-1, self.num_heads, self.num_heads))
edge_encodings = edge_encodings.reshape(max_dist, B, N, N, self.num_heads).permute(1, 2, 3, 0, 4)
edge_encodings = edge_encodings.sum(-2).permute(BATCH_HEAD_NODE_NODE) / spatial_distances
bias += edge_encodings
if self.use_graph_token:
bias = F.pad(bias, INSERT_GRAPH_TOKEN)
bias[:, :, 1:, 0] = self.graph_token
bias[:, :, 0, :] = self.graph_token
B, H, N, _ = bias.shape
data.attn_bias = bias.reshape(B * H, N, N)
return data
[docs]
def add_graph_token(data, token):
"""Helper function to augment a batch of PyG graphs
with a graph token each. Note that the token is
automatically replicated to fit the batch.
Parameters:
data (torch_geometric.data.Data): A PyG data object holding a single graph
token (torch.Tensor): A tensor containing the graph token values
Returns:
data (torch_geometric.data.Data): The augmented data object.
"""
B = len(data.batch.unique())
tokens = torch.repeat_interleave(token, B, 0)
data.x = torch.cat([tokens, data.x], 0)
data.batch = torch.cat(
[torch.arange(0, B, device=data.x.device, dtype=torch.long), data.batch]
)
data.batch, sort_idx = torch.sort(data.batch)
data.x = data.x[sort_idx]
return data
class NodeEncoder(torch.nn.Module):
def __init__(self, embed_dim, num_in_degree, num_out_degree,
input_dropout=0.0, use_graph_token: bool = True):
"""Implementation of the node encoder of Graphormer.
This encoder is based on the implementation at:
https://github.com/microsoft/Graphormer/tree/v1.0
Note that this refers to v1 of Graphormer.
Parameters:
embed_dim (int): The number of hidden dimensions of the model
num_in_degree (int): Maximum size of in-degree to encode
num_out_degree (int): Maximum size of out-degree to encode
input_dropout (float): Dropout applied to the input features
use_graph_token (bool): If True, adds the graph token to the incoming batch.
"""
super().__init__()
self.in_degree_encoder = torch.nn.Embedding(num_in_degree, embed_dim)
self.out_degree_encoder = torch.nn.Embedding(num_out_degree, embed_dim)
self.use_graph_token = use_graph_token
if self.use_graph_token:
self.graph_token = torch.nn.Parameter(torch.zeros(1, embed_dim))
self.input_dropout = torch.nn.Dropout(input_dropout)
self.reset_parameters()
def forward(self, data):
in_degree_encoding = self.in_degree_encoder(data.in_degrees)
out_degree_encoding = self.out_degree_encoder(data.out_degrees)
if data.x.size(1) > 0:
data.x = data.x + in_degree_encoding + out_degree_encoding
else:
data.x = in_degree_encoding + out_degree_encoding
if self.use_graph_token:
data = add_graph_token(data, self.graph_token)
data.x = self.input_dropout(data.x)
return data
def reset_parameters(self):
self.in_degree_encoder.weight.data.normal_(std=0.02)
self.out_degree_encoder.weight.data.normal_(std=0.02)
if self.use_graph_token:
self.graph_token.data.normal_(std=0.02)
@register_node_encoder("GraphormerBias")
class GraphormerEncoder(torch.nn.Sequential):
def __init__(self, dim_emb, *args, **kwargs):
encoders = [
BiasEncoder(
cfg.graphormer.num_heads,
cfg.posenc_GraphormerBias.num_spatial_types,
cfg.dataset.edge_encoder_num_types,
cfg.graphormer.use_graph_token
),
NodeEncoder(
dim_emb,
cfg.posenc_GraphormerBias.num_in_degrees,
cfg.posenc_GraphormerBias.num_out_degrees,
cfg.graphormer.input_dropout,
cfg.graphormer.use_graph_token
),
]
if cfg.posenc_GraphormerBias.node_degrees_only: # No attn. bias encoder
encoders = encoders[1:]
super().__init__(*encoders)