Source code for opengt.encoder.graphormer_encoder

import networkx as nx
import torch
import torch.nn.functional as F
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_node_encoder
from torch_geometric.utils import to_dense_adj, to_networkx

# Permutes from (batch, node, node, head) to (batch, head, node, node)
BATCH_HEAD_NODE_NODE = (0, 3, 1, 2)

# Inserts a leading 0 row and a leading 0 column with F.pad
INSERT_GRAPH_TOKEN = (1, 0, 1, 0)


[docs] def graphormer_pre_processing(data, distance): """Implementation of Graphormer pre-processing. Computes in- and out-degrees for node encodings, as well as spatial types (via shortest-path lengths) and prepares edge encodings along shortest paths. The function adds the following properties to the data object: - spatial_types - graph_index: An edge_index type tensor that contains all possible directed edges (see more below) - shortest_path_types: Populates edge attributes along all shortest paths between two nodes Similar to the adjacency matrix, any matrix can be batched in PyG by decomposing it into a 1D tensor of values and a 2D tensor of indices. Once batched, the graph-specific matrix can be recovered (while appropriately padded) via ``to_dense_adj``. We use this concept to decompose the spatial type matrix and the shortest path edge type tensor via the ``graph_index`` tensor. Parameters: data (torch_geometric.data.Data): A PyG data object holding a single graph distance (int): The distance up to which types are calculated Returns: data (torch_geometric.data.Data): The augmented data object. """ graph: nx.DiGraph = to_networkx(data) data.in_degrees = torch.tensor([d for _, d in graph.in_degree()]) data.out_degrees = torch.tensor([d for _, d in graph.out_degree()]) max_in_degree = torch.max(data.in_degrees).item() max_out_degree = torch.max(data.out_degrees).item() if max_in_degree >= cfg.posenc_GraphormerBias.num_in_degrees: print( f"Encountered in_degree: {max_in_degree}, setting posenc_" f"GraphormerBias.num_in_degrees to at least {max_in_degree + 1}" ) if max_out_degree >= cfg.posenc_GraphormerBias.num_out_degrees: print( f"Encountered out_degree: {max_out_degree}, setting posenc_" f"GraphormerBias.num_out_degrees to at least {max_out_degree + 1}" ) cfg.posenc_GraphormerBias.num_in_degrees = max(cfg.posenc_GraphormerBias.num_in_degrees, max_in_degree + 1) cfg.posenc_GraphormerBias.num_out_degrees = max(cfg.posenc_GraphormerBias.num_out_degrees, max_out_degree + 1) if cfg.posenc_GraphormerBias.node_degrees_only: return data N = len(graph.nodes) shortest_paths = dict(nx.shortest_path(graph)) # returns dict{S:dict{T:list[nodes on ST]}}, SSSP for each starting node spatial_types = torch.empty(N ** 2, dtype=torch.long).fill_(distance) graph_index = torch.empty(2, N ** 2, dtype=torch.long) if hasattr(data, "edge_attr") and data.edge_attr is not None: shortest_path_types = torch.zeros(N ** 2, distance, dtype=torch.long) edge_attr = torch.zeros(N, N, dtype=torch.long) if len(data.edge_attr.shape) == 1: edge_attr[data.edge_index[0], data.edge_index[1]] = data.edge_attr else: edge_attr[data.edge_index[0], data.edge_index[1]] = data.edge_attr[:, 0] for i in range(N): for j in range(N): graph_index[0, i * N + j] = i graph_index[1, i * N + j] = j for i, paths in shortest_paths.items(): for j, path in paths.items(): if len(path) > distance: path = path[:distance] assert len(path) >= 1 spatial_types[i * N + j] = len(path) - 1 if len(path) > 1 and hasattr(data, "edge_attr") and data.edge_attr is not None: path_attr = [ edge_attr[path[k], path[k + 1]] for k in range(len(path) - 1) # len(path) * (num_edge_types) ] # list of edge_attr for each edge in the path # We map each edge-encoding-distance pair to a distinct value # and so obtain dist * num_edge_features many encodings shortest_path_types[i * N + j, :len(path) - 1] = torch.tensor( path_attr, dtype=torch.long) data.spatial_types = spatial_types data.graph_index = graph_index if hasattr(data, "edge_attr") and data.edge_attr is not None: data.shortest_path_types = shortest_path_types return data
class BiasEncoder(torch.nn.Module): def __init__(self, num_heads: int, num_spatial_types: int, num_edge_types: int, use_graph_token: bool = True): """Implementation of the bias encoder of Graphormer. This encoder is based on the implementation at: https://github.com/microsoft/Graphormer/tree/v1.0 Note that this refers to v1 of Graphormer. Parameters: num_heads (int): The number of heads of the Graphormer model num_spatial_types (int): The total number of different spatial types num_edge_types (int): The total number of different edge types use_graph_token (bool): If True, pads the attn_bias to account for the additional graph token that can be added by the ``NodeEncoder``. """ super().__init__() self.num_heads = num_heads # Takes into account disconnected nodes self.spatial_encoder = torch.nn.Embedding( num_spatial_types + 1, num_heads) self.edge_dis_encoder = torch.nn.Embedding( num_spatial_types * num_heads * num_heads, 1) self.edge_encoder = torch.nn.Embedding(num_edge_types, num_heads) self.use_graph_token = use_graph_token if self.use_graph_token: self.graph_token = torch.nn.Parameter(torch.zeros(1, num_heads, 1)) self.reset_parameters() def reset_parameters(self): self.spatial_encoder.weight.data.normal_(std=0.02) self.edge_encoder.weight.data.normal_(std=0.02) self.edge_dis_encoder.weight.data.normal_(std=0.02) if self.use_graph_token: self.graph_token.data.normal_(std=0.02) def forward(self, data): """Computes the bias matrix that can be induced into multi-head attention via the attention mask. Adds the tensor ``attn_bias`` to the data object, optionally accounting for the graph token. """ # To convert 2D matrices to dense-batch mode, one needs to decompose # them into index and value. One example is the adjacency matrix # but this generalizes actually to any 2D matrix spatial_types: torch.Tensor = self.spatial_encoder(data.spatial_types) spatial_encodings = to_dense_adj(data.graph_index, data.batch, spatial_types) bias = spatial_encodings.permute(BATCH_HEAD_NODE_NODE) if hasattr(data, "shortest_path_types"): edge_types: torch.Tensor = self.edge_encoder( data.shortest_path_types) edge_encodings = to_dense_adj(data.graph_index, data.batch, edge_types) spatial_distances = to_dense_adj(data.graph_index, data.batch, data.spatial_types) spatial_distances = spatial_distances.float().clamp(min=1.0).unsqueeze(1) B, N, _, max_dist, H = edge_encodings.shape edge_encodings = edge_encodings.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.num_heads) edge_encodings = torch.bmm(edge_encodings, self.edge_dis_encoder.weight.reshape(-1, self.num_heads, self.num_heads)) edge_encodings = edge_encodings.reshape(max_dist, B, N, N, self.num_heads).permute(1, 2, 3, 0, 4) edge_encodings = edge_encodings.sum(-2).permute(BATCH_HEAD_NODE_NODE) / spatial_distances bias += edge_encodings if self.use_graph_token: bias = F.pad(bias, INSERT_GRAPH_TOKEN) bias[:, :, 1:, 0] = self.graph_token bias[:, :, 0, :] = self.graph_token B, H, N, _ = bias.shape data.attn_bias = bias.reshape(B * H, N, N) return data
[docs] def add_graph_token(data, token): """Helper function to augment a batch of PyG graphs with a graph token each. Note that the token is automatically replicated to fit the batch. Parameters: data (torch_geometric.data.Data): A PyG data object holding a single graph token (torch.Tensor): A tensor containing the graph token values Returns: data (torch_geometric.data.Data): The augmented data object. """ B = len(data.batch.unique()) tokens = torch.repeat_interleave(token, B, 0) data.x = torch.cat([tokens, data.x], 0) data.batch = torch.cat( [torch.arange(0, B, device=data.x.device, dtype=torch.long), data.batch] ) data.batch, sort_idx = torch.sort(data.batch) data.x = data.x[sort_idx] return data
class NodeEncoder(torch.nn.Module): def __init__(self, embed_dim, num_in_degree, num_out_degree, input_dropout=0.0, use_graph_token: bool = True): """Implementation of the node encoder of Graphormer. This encoder is based on the implementation at: https://github.com/microsoft/Graphormer/tree/v1.0 Note that this refers to v1 of Graphormer. Parameters: embed_dim (int): The number of hidden dimensions of the model num_in_degree (int): Maximum size of in-degree to encode num_out_degree (int): Maximum size of out-degree to encode input_dropout (float): Dropout applied to the input features use_graph_token (bool): If True, adds the graph token to the incoming batch. """ super().__init__() self.in_degree_encoder = torch.nn.Embedding(num_in_degree, embed_dim) self.out_degree_encoder = torch.nn.Embedding(num_out_degree, embed_dim) self.use_graph_token = use_graph_token if self.use_graph_token: self.graph_token = torch.nn.Parameter(torch.zeros(1, embed_dim)) self.input_dropout = torch.nn.Dropout(input_dropout) self.reset_parameters() def forward(self, data): in_degree_encoding = self.in_degree_encoder(data.in_degrees) out_degree_encoding = self.out_degree_encoder(data.out_degrees) if data.x.size(1) > 0: data.x = data.x + in_degree_encoding + out_degree_encoding else: data.x = in_degree_encoding + out_degree_encoding if self.use_graph_token: data = add_graph_token(data, self.graph_token) data.x = self.input_dropout(data.x) return data def reset_parameters(self): self.in_degree_encoder.weight.data.normal_(std=0.02) self.out_degree_encoder.weight.data.normal_(std=0.02) if self.use_graph_token: self.graph_token.data.normal_(std=0.02) @register_node_encoder("GraphormerBias") class GraphormerEncoder(torch.nn.Sequential): def __init__(self, dim_emb, *args, **kwargs): encoders = [ BiasEncoder( cfg.graphormer.num_heads, cfg.posenc_GraphormerBias.num_spatial_types, cfg.dataset.edge_encoder_num_types, cfg.graphormer.use_graph_token ), NodeEncoder( dim_emb, cfg.posenc_GraphormerBias.num_in_degrees, cfg.posenc_GraphormerBias.num_out_degrees, cfg.graphormer.input_dropout, cfg.graphormer.use_graph_token ), ] if cfg.posenc_GraphormerBias.node_degrees_only: # No attn. bias encoder encoders = encoders[1:] super().__init__(*encoders)