import numpy as np
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch_geometric.graphgym.models.layer import LayerConfig
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_layer
[docs]
class ExphormerAttention(nn.Module):
def __init__(self, in_dim, out_dim, num_heads, use_bias, dim_edge=None, use_virt_nodes=False):
super().__init__()
if out_dim % num_heads != 0:
raise ValueError('hidden dimension is not dividable by the number of heads')
self.out_dim = out_dim // num_heads
self.num_heads = num_heads
self.use_virt_nodes = use_virt_nodes
self.use_bias = use_bias
if dim_edge is None:
dim_edge = in_dim
self.Q = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias)
self.K = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias)
self.E = nn.Linear(dim_edge, self.out_dim * num_heads, bias=use_bias)
self.V = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias)
# self._reset_parameters()
# def _reset_parameters(self):
# xavier_uniform_(self.Q)
# xavier_uniform_(self.K)
# xavier_uniform_(self.V)
# xavier_uniform_(self.E)
def propagate_attention(self, batch, edge_index):
src = batch.K_h[edge_index[0].to(torch.long)] # (num edges) x num_heads x out_dim
dest = batch.Q_h[edge_index[1].to(torch.long)] # (num edges) x num_heads x out_dim
score = torch.mul(src, dest) # element-wise multiplication
# Scale scores by sqrt(d)
score = score / np.sqrt(self.out_dim)
# Use available edge features to modify the scores for edges
score = torch.mul(score, batch.E) # (num real edges) x num_heads x out_dim
score = torch.exp(score.sum(-1, keepdim=True).clamp(-5, 5)) # (num real edges) x num_heads x 1
# Apply attention score to each source node to create edge messages
msg = batch.V_h[edge_index[0].to(torch.long)] * score # (num real edges) x num_heads x out_dim
# Add-up real msgs in destination nodes as given by batch.edge_index[1]
batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim
scatter(msg, edge_index[1], dim=0, out=batch.wV, reduce='add')
# Compute attention normalization coefficient
batch.Z = score.new_zeros(batch.V_h.size(0), self.num_heads, 1) # (num nodes in batch) x num_heads x 1
scatter(score, edge_index[1], dim=0, out=batch.Z, reduce='add')
def forward(self, batch):
edge_attr = batch.expander_edge_attr
edge_index = batch.expander_edge_index
h = batch.x
num_node = batch.batch.shape[0]
if self.use_virt_nodes:
h = torch.cat([h, batch.virt_h], dim=0)
edge_index = torch.cat([edge_index, batch.virt_edge_index], dim=1)
edge_attr = torch.cat([edge_attr, batch.virt_edge_attr], dim=0)
Q_h = self.Q(h)
K_h = self.K(h)
E = self.E(edge_attr)
V_h = self.V(h)
# Reshaping into [num_nodes, num_heads, feat_dim] to
# get projections for multi-head attention
batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim)
batch.K_h = K_h.view(-1, self.num_heads, self.out_dim)
batch.E = E.view(-1, self.num_heads, self.out_dim)
batch.V_h = V_h.view(-1, self.num_heads, self.out_dim)
self.propagate_attention(batch, edge_index)
h_out = batch.wV / (batch.Z + 1e-6)
h_out = h_out.view(-1, self.out_dim * self.num_heads)
batch.virt_h = h_out[num_node:]
h_out = h_out[:num_node]
return h_out
register_layer('Exphormer', ExphormerAttention)
def get_activation(activation):
if activation == 'relu':
return 2, nn.ReLU()
elif activation == 'gelu':
return 2, nn.GELU()
elif activation == 'silu':
return 2, nn.SiLU()
elif activation == 'glu':
return 1, nn.GLU()
else:
raise ValueError(f'activation function {activation} is not valid!')
[docs]
class ExphormerFullLayer(nn.Module):
"""Exphormer attention + FFN
Adapted from https://github.com/hamed1375/Exphormer
Parameters:
in_dim (int): Number of input features.
out_dim (int): Number of output features.
num_heads (int): Number of attention heads.
dropout (float): Dropout rate.
dim_edge (int): Number of edge features. Default: None.
layer_norm (bool): Whether to use layer normalization. Default: False.
batch_norm (bool): Whether to use batch normalization. Default: True.
activation (str): Activation function. Default: 'relu'.
residual (bool): Whether to use residual connection. Default: True.
use_bias (bool): Whether to use bias in linear layers. Default: False.
use_virt_nodes (bool): Whether to use virtual nodes. Default: False.
Input:
batch.x (Tensor): Input node features.
batch.edge_index (Tensor): Edge indices of the graph.
batch.expander_edge_attr (Tensor): Edge features for attention.
batch.expander_edge_index (Tensor): Edge indices for attention.
batch.virt_h (Tensor): Virtual node features.
Output:
batch.x (Tensor): Output node features after applying the Exphormer layer.
"""
def __init__(self, in_dim, out_dim, num_heads,
dropout=0.0,
dim_edge=None,
layer_norm=False, batch_norm=True,
activation = 'relu',
residual=True, use_bias=False, use_virt_nodes=False):
super().__init__()
self.in_channels = in_dim
self.out_channels = out_dim
self.num_heads = num_heads
self.dropout = dropout
self.residual = residual
self.layer_norm = layer_norm
self.batch_norm = batch_norm
self.attention = ExphormerAttention(in_dim, out_dim, num_heads,
use_bias=use_bias,
dim_edge=dim_edge,
use_virt_nodes=use_virt_nodes)
self.O_h = nn.Linear(out_dim, out_dim)
if self.layer_norm:
self.layer_norm1_h = nn.LayerNorm(out_dim)
if self.batch_norm:
self.batch_norm1_h = nn.BatchNorm1d(out_dim)
# FFN for h
self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2)
factor, self.activation_fn = get_activation(activation=activation)
self.FFN_h_layer2 = nn.Linear(out_dim * factor, out_dim)
if self.layer_norm:
self.layer_norm2_h = nn.LayerNorm(out_dim)
if self.batch_norm:
self.batch_norm2_h = nn.BatchNorm1d(out_dim)
# self.reset_parameters()
# def reset_parameters(self):
# xavier_uniform_(self.attention.Q.weight, gain=1 / math.sqrt(2))
# xavier_uniform_(self.attention.K.weight, gain=1 / math.sqrt(2))
# xavier_uniform_(self.attention.V.weight, gain=1 / math.sqrt(2))
# xavier_uniform_(self.attention.E.weight, gain=1 / math.sqrt(2))
# xavier_uniform_(self.O_h.weight, gain=1 / math.sqrt(2))
# constant_(self.O_h.bias, 0.0)
def forward(self, batch):
h = batch.x
h_in1 = h # for first residual connection
# multi-head attention out
h_attn_out = self.attention(batch)
# Concat multi-head outputs
h = h_attn_out.view(-1, self.out_channels)
h = F.dropout(h, self.dropout, training=self.training)
# h = self.O_h(h)
if self.residual:
h = h_in1 + h # residual connection
if self.layer_norm:
h = self.layer_norm1_h(h)
if self.batch_norm:
h = self.batch_norm1_h(h)
h_in2 = h # for second residual connection
# FFN for h
h = self.FFN_h_layer1(h)
h = self.activation_fn(h)
h = F.dropout(h, self.dropout, training=self.training)
h = self.FFN_h_layer2(h)
if self.residual:
h = h_in2 + h # residual connection
if self.layer_norm:
h = self.layer_norm2_h(h)
if self.batch_norm:
h = self.batch_norm2_h(h)
batch.x = h
return batch
def __repr__(self):
return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format(
self.__class__.__name__,
self.in_channels,
self.out_channels, self.num_heads, self.residual)
register_layer('ExphormerLayer', ExphormerFullLayer)