import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_scatter import scatter, scatter_max, scatter_add
from opengt.utils import negate_edge_index
[docs]
def pyg_softmax(src, index, num_nodes=None):
r"""Computes a sparsely evaluated softmax.
Given a value tensor :attr:`src`, this function first groups the values
along the first dimension based on the indices specified in :attr:`index`,
and then proceeds to compute the softmax individually for each group.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements for applying the softmax.
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
num_nodes = maybe_num_nodes(index, num_nodes)
out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
out = out.exp()
out = out / (
scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16)
return out
[docs]
class MultiHeadAttention2Layer(nn.Module):
"""Multi-Head Graph Attention Layer.
Ported to PyG and modified compared to the original repo:
https://github.com/DevinKreuzer/SAN/blob/main/layers/graph_transformer_layer.py
"""
def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph,
fake_edge_emb, use_bias):
super().__init__()
self.out_dim = out_dim
self.num_heads = num_heads
self.gamma = nn.Parameter(torch.tensor(0.5, dtype=float),
requires_grad=True)
self.full_graph = full_graph
self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
self.E = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
if self.full_graph:
self.Q_2 = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
self.K_2 = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
self.E_2 = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
self.fake_edge_emb = fake_edge_emb
self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
def propagate_attention(self, batch):
src = batch.K_h[batch.edge_index[0]] # (num real edges) x num_heads x out_dim
dest = batch.Q_h[batch.edge_index[1]] # (num real edges) x num_heads x out_dim
score = torch.mul(src, dest) # element-wise multiplication
# Scale scores by sqrt(d)
score = score / np.sqrt(self.out_dim)
if self.full_graph:
fake_edge_index = negate_edge_index(batch.edge_index, batch.batch)
src_2 = batch.K_2h[fake_edge_index[0]] # (num fake edges) x num_heads x out_dim
dest_2 = batch.Q_2h[fake_edge_index[1]] # (num fake edges) x num_heads x out_dim
score_2 = torch.mul(src_2, dest_2)
# Scale scores by sqrt(d)
score_2 = score_2 / np.sqrt(self.out_dim)
# Use available edge features to modify the scores for edges
score = torch.mul(score, batch.E) # (num real edges) x num_heads x out_dim
if self.full_graph:
# E_2 is 1 x num_heads x out_dim and will be broadcast over dim=0
score_2 = torch.mul(score_2, batch.E_2)
if self.full_graph:
# softmax and scaling by gamma
score = pyg_softmax(score.sum(-1, keepdim=True), batch.edge_index[1]) # (num real edges) x num_heads x 1
score_2 = pyg_softmax(score_2.sum(-1, keepdim=True), fake_edge_index[1]) # (num fake edges) x num_heads x 1
score = score / (self.gamma + 1)
score_2 = self.gamma * score_2 / (self.gamma + 1)
else:
score = pyg_softmax(score.sum(-1, keepdim=True), batch.edge_index[1]) # (num real edges) x num_heads x 1
# Apply attention score to each source node to create edge messages
msg = batch.V_h[batch.edge_index[0]] * score # (num real edges) x num_heads x out_dim
# Add-up real msgs in destination nodes as given by batch.edge_index[1]
batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim
scatter(msg, batch.edge_index[1], dim=0, out=batch.wV, reduce='add')
if self.full_graph:
# Attention via fictional edges
msg_2 = batch.V_h[fake_edge_index[0]] * score_2
# Add messages along fake edges to destination nodes
scatter(msg_2, fake_edge_index[1], dim=0, out=batch.wV, reduce='add')
def forward(self, batch):
Q_h = self.Q(batch.x)
K_h = self.K(batch.x)
E = self.E(batch.edge_attr)
if self.full_graph:
Q_2h = self.Q_2(batch.x)
K_2h = self.K_2(batch.x)
# One embedding used for all fake edges; shape: 1 x emb_dim
dummy_edge = self.fake_edge_emb(batch.edge_index.new_zeros(1))
E_2 = self.E_2(dummy_edge)
V_h = self.V(batch.x)
# Reshaping into [num_nodes, num_heads, feat_dim] to
# get projections for multi-head attention
batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim)
batch.K_h = K_h.view(-1, self.num_heads, self.out_dim)
batch.E = E.view(-1, self.num_heads, self.out_dim)
if self.full_graph:
batch.Q_2h = Q_2h.view(-1, self.num_heads, self.out_dim)
batch.K_2h = K_2h.view(-1, self.num_heads, self.out_dim)
batch.E_2 = E_2.view(-1, self.num_heads, self.out_dim)
batch.V_h = V_h.view(-1, self.num_heads, self.out_dim)
self.propagate_attention(batch)
h_out = batch.wV
return h_out
[docs]
class SAN2Layer(nn.Module):
"""Modified GraphTransformerLayer from SAN.
Ported to PyG from original repo:
https://github.com/DevinKreuzer/SAN/blob/main/layers/graph_transformer_layer.py
"""
def __init__(self, gamma, in_dim, out_dim, num_heads, full_graph,
fake_edge_emb, dropout=0.0,
layer_norm=False, batch_norm=True,
residual=True, use_bias=False):
super().__init__()
self.in_channels = in_dim
self.out_channels = out_dim
self.num_heads = num_heads
self.dropout = dropout
self.residual = residual
self.layer_norm = layer_norm
self.batch_norm = batch_norm
self.attention = MultiHeadAttention2Layer(gamma=gamma,
in_dim=in_dim,
out_dim=out_dim // num_heads,
num_heads=num_heads,
full_graph=full_graph,
fake_edge_emb=fake_edge_emb,
use_bias=use_bias)
self.O_h = nn.Linear(out_dim, out_dim)
if self.layer_norm:
self.layer_norm1_h = nn.LayerNorm(out_dim)
if self.batch_norm:
self.batch_norm1_h = nn.BatchNorm1d(out_dim)
# FFN for h
self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2)
self.FFN_h_layer2 = nn.Linear(out_dim * 2, out_dim)
if self.layer_norm:
self.layer_norm2_h = nn.LayerNorm(out_dim)
if self.batch_norm:
self.batch_norm2_h = nn.BatchNorm1d(out_dim)
def forward(self, batch):
h = batch.x
h_in1 = h # for first residual connection
# multi-head attention out
h_attn_out = self.attention(batch)
# Concat multi-head outputs
h = h_attn_out.view(-1, self.out_channels)
h = F.dropout(h, self.dropout, training=self.training)
h = self.O_h(h)
if self.residual:
h = h_in1 + h # residual connection
if self.layer_norm:
h = self.layer_norm1_h(h)
if self.batch_norm:
h = self.batch_norm1_h(h)
h_in2 = h # for second residual connection
# FFN for h
h = self.FFN_h_layer1(h)
h = F.relu(h)
h = F.dropout(h, self.dropout, training=self.training)
h = self.FFN_h_layer2(h)
if self.residual:
h = h_in2 + h # residual connection
if self.layer_norm:
h = self.layer_norm2_h(h)
if self.batch_norm:
h = self.batch_norm2_h(h)
batch.x = h
return batch
def __repr__(self):
return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(
self.__class__.__name__,
self.in_channels,
self.out_channels, self.num_heads, self.residual)