Source code for opengt.layer.ETransformer
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter
from torch_geometric.graphgym.models.layer import LayerConfig
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_layer
[docs]
class ETransformer(nn.Module):
"""Mostly Multi-Head Graph Attention Layer.
Ported to PyG from original repo:
https://github.com/DevinKreuzer/SAN/blob/main/layers/graph_transformer_layer.py
"""
def __init__(self, in_dim, out_dim, num_heads, use_bias, edge_index='edge_index', use_edge_attr=False, edge_attr='edge_attr'):
super().__init__()
if out_dim % num_heads != 0:
raise ValueError('hidden dimension is not dividable by the number of heads')
self.out_dim = out_dim // num_heads
self.num_heads = num_heads
self.edge_index = edge_index
self.edge_attr = edge_attr
self.use_edge_attr = use_edge_attr
self.Q = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias)
self.K = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias)
if self.use_edge_attr:
self.E = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias)
self.V = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias)
def propagate_attention(self, batch):
edge_index = getattr(batch, self.edge_index)
src = batch.K_h[edge_index[0]] # (num real edges) x num_heads x out_dim
dest = batch.Q_h[edge_index[1]] # (num real edges) x num_heads x out_dim
score = torch.mul(src, dest) # element-wise multiplication
# Scale scores by sqrt(d)
score = score / np.sqrt(self.out_dim)
# Use available edge features to modify the scores for edges
if self.use_edge_attr:
score = torch.mul(score, batch.E) # (num real edges) x num_heads x out_dim
score = torch.exp(score.sum(-1, keepdim=True).clamp(-5, 5)) # (num real edges) x num_heads x 1
# Apply attention score to each source node to create edge messages
msg = batch.V_h[edge_index[0]] * score # (num real edges) x num_heads x out_dim
# Add-up real msgs in destination nodes as given by batch.edge_index[1]
batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim
scatter(msg, edge_index[1], dim=0, out=batch.wV, reduce='add')
# Compute attention normalization coefficient
batch.Z = score.new_zeros(batch.size(0), self.num_heads, 1) # (num nodes in batch) x num_heads x 1
scatter(score, edge_index[1], dim=0, out=batch.Z, reduce='add')
def forward(self, batch):
edge_index = getattr(batch, self.edge_index)
if edge_index is None:
raise ValueError(f'edge index: f{self.edge_index} not found')
if edge_index.shape[0] != 2 and edge_index.shape[1] == 2:
edge_index = torch.t(edge_index)
setattr(batch, self.edge_index, edge_index)
if self.use_edge_attr:
edge_attr = getattr(batch, self.edge_attr)
if edge_attr is None or edge_attr.shape[0] != edge_index.shape[1]:
print('edge_attr shape does not match edge_index shape, ignoring edge_attr')
self.use_edge_attr = False
Q_h = self.Q(batch.x)
K_h = self.K(batch.x)
if self.use_edge_attr:
E = self.E(edge_attr)
V_h = self.V(batch.x)
# Reshaping into [num_nodes, num_heads, feat_dim] to
# get projections for multi-head attention
batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim)
batch.K_h = K_h.view(-1, self.num_heads, self.out_dim)
if self.use_edge_attr:
batch.E = E.view(-1, self.num_heads, self.out_dim)
batch.V_h = V_h.view(-1, self.num_heads, self.out_dim)
self.propagate_attention(batch)
h_out = batch.wV / (batch.Z + 1e-6)
h_out = h_out.view(-1, self.out_dim * self.num_heads)
return h_out
register_layer('etransformer', ETransformer)