Source code for opengt.layer.graphormer_layer

import torch
from torch_geometric.utils import to_dense_batch


[docs] class GraphormerLayer(torch.nn.Module): def __init__(self, embed_dim: int, num_heads: int, dropout: float, attention_dropout: float, mlp_dropout: float): """Implementation of the Graphormer layer. This layer is based on the implementation at: https://github.com/microsoft/Graphormer/tree/v1.0 Note that this refers to v1 of Graphormer. Parameters: embed_dim: The number of hidden dimensions of the model num_heads: The number of heads of the Graphormer model dropout: Dropout applied after the attention and after the MLP attention_dropout: Dropout applied within the attention input_dropout: Dropout applied within the MLP Input: data.x (Tensor): Input node features of shape (N, D) data.batch (Tensor): Batch vector of shape (N,) data.attn_bias (Tensor): Attention bias of shape (N, N) Output: ret.x (Tensor): Output node features after applying the Graphormer layer. """ super().__init__() self.attention = torch.nn.MultiheadAttention(embed_dim, num_heads, attention_dropout, batch_first=True) self.input_norm = torch.nn.LayerNorm(embed_dim) self.dropout = torch.nn.Dropout(dropout) # We follow the paper in that all hidden dims are # equal to the embedding dim self.mlp = torch.nn.Sequential( torch.nn.LayerNorm(embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.GELU(), torch.nn.Dropout(mlp_dropout), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Dropout(dropout), ) def forward(self, data): x = self.input_norm(data.x) x, real_nodes = to_dense_batch(x, data.batch) if hasattr(data, "attn_bias"): x = self.attention(x, x, x, ~real_nodes, attn_mask=data.attn_bias)[0][real_nodes] else: x = self.attention(x, x, x, ~real_nodes)[0][real_nodes] x = self.dropout(x) + data.x ret = data.clone() ret.x = self.mlp(x) + x return ret