Source code for opengt.layer.Exphormer

import numpy as np
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_

from torch_geometric.graphgym.models.layer import LayerConfig
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_layer



[docs] class ExphormerAttention(nn.Module): def __init__(self, in_dim, out_dim, num_heads, use_bias, dim_edge=None, use_virt_nodes=False): super().__init__() if out_dim % num_heads != 0: raise ValueError('hidden dimension is not dividable by the number of heads') self.out_dim = out_dim // num_heads self.num_heads = num_heads self.use_virt_nodes = use_virt_nodes self.use_bias = use_bias if dim_edge is None: dim_edge = in_dim self.Q = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) self.K = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) self.E = nn.Linear(dim_edge, self.out_dim * num_heads, bias=use_bias) self.V = nn.Linear(in_dim, self.out_dim * num_heads, bias=use_bias) # self._reset_parameters() # def _reset_parameters(self): # xavier_uniform_(self.Q) # xavier_uniform_(self.K) # xavier_uniform_(self.V) # xavier_uniform_(self.E) def propagate_attention(self, batch, edge_index): src = batch.K_h[edge_index[0].to(torch.long)] # (num edges) x num_heads x out_dim dest = batch.Q_h[edge_index[1].to(torch.long)] # (num edges) x num_heads x out_dim score = torch.mul(src, dest) # element-wise multiplication # Scale scores by sqrt(d) score = score / np.sqrt(self.out_dim) # Use available edge features to modify the scores for edges score = torch.mul(score, batch.E) # (num real edges) x num_heads x out_dim score = torch.exp(score.sum(-1, keepdim=True).clamp(-5, 5)) # (num real edges) x num_heads x 1 # Apply attention score to each source node to create edge messages msg = batch.V_h[edge_index[0].to(torch.long)] * score # (num real edges) x num_heads x out_dim # Add-up real msgs in destination nodes as given by batch.edge_index[1] batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim scatter(msg, edge_index[1], dim=0, out=batch.wV, reduce='add') # Compute attention normalization coefficient batch.Z = score.new_zeros(batch.V_h.size(0), self.num_heads, 1) # (num nodes in batch) x num_heads x 1 scatter(score, edge_index[1], dim=0, out=batch.Z, reduce='add') def forward(self, batch): edge_attr = batch.expander_edge_attr edge_index = batch.expander_edge_index h = batch.x num_node = batch.batch.shape[0] if self.use_virt_nodes: h = torch.cat([h, batch.virt_h], dim=0) edge_index = torch.cat([edge_index, batch.virt_edge_index], dim=1) edge_attr = torch.cat([edge_attr, batch.virt_edge_attr], dim=0) Q_h = self.Q(h) K_h = self.K(h) E = self.E(edge_attr) V_h = self.V(h) # Reshaping into [num_nodes, num_heads, feat_dim] to # get projections for multi-head attention batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim) batch.K_h = K_h.view(-1, self.num_heads, self.out_dim) batch.E = E.view(-1, self.num_heads, self.out_dim) batch.V_h = V_h.view(-1, self.num_heads, self.out_dim) self.propagate_attention(batch, edge_index) h_out = batch.wV / (batch.Z + 1e-6) h_out = h_out.view(-1, self.out_dim * self.num_heads) batch.virt_h = h_out[num_node:] h_out = h_out[:num_node] return h_out
register_layer('Exphormer', ExphormerAttention) def get_activation(activation): if activation == 'relu': return 2, nn.ReLU() elif activation == 'gelu': return 2, nn.GELU() elif activation == 'silu': return 2, nn.SiLU() elif activation == 'glu': return 1, nn.GLU() else: raise ValueError(f'activation function {activation} is not valid!')
[docs] class ExphormerFullLayer(nn.Module): """Exphormer attention + FFN Adapted from https://github.com/hamed1375/Exphormer Parameters: in_dim (int): Number of input features. out_dim (int): Number of output features. num_heads (int): Number of attention heads. dropout (float): Dropout rate. dim_edge (int): Number of edge features. Default: None. layer_norm (bool): Whether to use layer normalization. Default: False. batch_norm (bool): Whether to use batch normalization. Default: True. activation (str): Activation function. Default: 'relu'. residual (bool): Whether to use residual connection. Default: True. use_bias (bool): Whether to use bias in linear layers. Default: False. use_virt_nodes (bool): Whether to use virtual nodes. Default: False. Input: batch.x (Tensor): Input node features. batch.edge_index (Tensor): Edge indices of the graph. batch.expander_edge_attr (Tensor): Edge features for attention. batch.expander_edge_index (Tensor): Edge indices for attention. batch.virt_h (Tensor): Virtual node features. Output: batch.x (Tensor): Output node features after applying the Exphormer layer. """ def __init__(self, in_dim, out_dim, num_heads, dropout=0.0, dim_edge=None, layer_norm=False, batch_norm=True, activation = 'relu', residual=True, use_bias=False, use_virt_nodes=False): super().__init__() self.in_channels = in_dim self.out_channels = out_dim self.num_heads = num_heads self.dropout = dropout self.residual = residual self.layer_norm = layer_norm self.batch_norm = batch_norm self.attention = ExphormerAttention(in_dim, out_dim, num_heads, use_bias=use_bias, dim_edge=dim_edge, use_virt_nodes=use_virt_nodes) self.O_h = nn.Linear(out_dim, out_dim) if self.layer_norm: self.layer_norm1_h = nn.LayerNorm(out_dim) if self.batch_norm: self.batch_norm1_h = nn.BatchNorm1d(out_dim) # FFN for h self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2) factor, self.activation_fn = get_activation(activation=activation) self.FFN_h_layer2 = nn.Linear(out_dim * factor, out_dim) if self.layer_norm: self.layer_norm2_h = nn.LayerNorm(out_dim) if self.batch_norm: self.batch_norm2_h = nn.BatchNorm1d(out_dim) # self.reset_parameters() # def reset_parameters(self): # xavier_uniform_(self.attention.Q.weight, gain=1 / math.sqrt(2)) # xavier_uniform_(self.attention.K.weight, gain=1 / math.sqrt(2)) # xavier_uniform_(self.attention.V.weight, gain=1 / math.sqrt(2)) # xavier_uniform_(self.attention.E.weight, gain=1 / math.sqrt(2)) # xavier_uniform_(self.O_h.weight, gain=1 / math.sqrt(2)) # constant_(self.O_h.bias, 0.0) def forward(self, batch): h = batch.x h_in1 = h # for first residual connection # multi-head attention out h_attn_out = self.attention(batch) # Concat multi-head outputs h = h_attn_out.view(-1, self.out_channels) h = F.dropout(h, self.dropout, training=self.training) # h = self.O_h(h) if self.residual: h = h_in1 + h # residual connection if self.layer_norm: h = self.layer_norm1_h(h) if self.batch_norm: h = self.batch_norm1_h(h) h_in2 = h # for second residual connection # FFN for h h = self.FFN_h_layer1(h) h = self.activation_fn(h) h = F.dropout(h, self.dropout, training=self.training) h = self.FFN_h_layer2(h) if self.residual: h = h_in2 + h # residual connection if self.layer_norm: h = self.layer_norm2_h(h) if self.batch_norm: h = self.batch_norm2_h(h) batch.x = h return batch def __repr__(self): return '{}(in_channels={}, out_channels={}, heads={}, residual={})'.format( self.__class__.__name__, self.in_channels, self.out_channels, self.num_heads, self.residual)
register_layer('ExphormerLayer', ExphormerFullLayer)