Source code for opengt.layer.grit_layer
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric as pyg
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_scatter import scatter, scatter_max, scatter_add
from opengt.utils import negate_edge_index
from torch_geometric.graphgym.register import *
import opt_einsum as oe
from yacs.config import CfgNode as CN
import warnings
[docs]
def pyg_softmax(src, index, num_nodes=None):
r"""Computes a sparsely evaluated softmax.
Given a value tensor :attr:`src`, this function first groups the values
along the first dimension based on the indices specified in :attr:`index`,
and then proceeds to compute the softmax individually for each group.
Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements for applying the softmax.
num_nodes (int, optional): The number of nodes, *i.e.*
:obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
:rtype: :class:`Tensor`
"""
num_nodes = maybe_num_nodes(index, num_nodes)
out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
out = out.exp()
out = out / (
scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16)
return out
[docs]
class MultiHeadAttentionLayerGritSparse(nn.Module):
"""
Proposed Attention Computation for GRIT
"""
def __init__(self, in_dim, out_dim, num_heads, use_bias,
clamp=5., dropout=0., act=None,
edge_enhance=True,
sqrt_relu=False,
signed_sqrt=True,
cfg=CN(),
**kwargs):
super().__init__()
self.out_dim = out_dim
self.num_heads = num_heads
self.dropout = nn.Dropout(dropout)
self.clamp = np.abs(clamp) if clamp is not None else None
self.edge_enhance = edge_enhance
self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=True)
self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
self.E = nn.Linear(in_dim, out_dim * num_heads * 2, bias=True)
self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias)
nn.init.xavier_normal_(self.Q.weight)
nn.init.xavier_normal_(self.K.weight)
nn.init.xavier_normal_(self.E.weight)
nn.init.xavier_normal_(self.V.weight)
self.Aw = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, 1), requires_grad=True)
nn.init.xavier_normal_(self.Aw)
if act is None:
self.act = nn.Identity()
else:
self.act = act_dict[act]()
if self.edge_enhance:
self.VeRow = nn.Parameter(torch.zeros(self.out_dim, self.num_heads, self.out_dim), requires_grad=True)
nn.init.xavier_normal_(self.VeRow)
def propagate_attention(self, batch):
src = batch.K_h[batch.edge_index[0]] # (num relative) x num_heads x out_dim
dest = batch.Q_h[batch.edge_index[1]] # (num relative) x num_heads x out_dim
score = src + dest # element-wise multiplication
if batch.get("E", None) is not None:
batch.E = batch.E.view(-1, self.num_heads, self.out_dim * 2)
E_w, E_b = batch.E[:, :, :self.out_dim], batch.E[:, :, self.out_dim:]
# (num relative) x num_heads x out_dim
score = score * E_w
score = torch.sqrt(torch.relu(score)) - torch.sqrt(torch.relu(-score))
score = score + E_b
score = self.act(score)
e_t = score
# output edge
if batch.get("E", None) is not None:
batch.wE = score.flatten(1)
# final attn
score = oe.contract("ehd, dhc->ehc", score, self.Aw, backend="torch")
if self.clamp is not None:
score = torch.clamp(score, min=-self.clamp, max=self.clamp)
raw_attn = score
score = pyg_softmax(score, batch.edge_index[1]) # (num relative) x num_heads x 1
score = self.dropout(score)
batch.attn = score
# Aggregate with Attn-Score
msg = batch.V_h[batch.edge_index[0]] * score # (num relative) x num_heads x out_dim
batch.wV = torch.zeros_like(batch.V_h) # (num nodes in batch) x num_heads x out_dim
scatter(msg, batch.edge_index[1], dim=0, out=batch.wV, reduce='add')
if self.edge_enhance and batch.E is not None:
rowV = scatter(e_t * score, batch.edge_index[1], dim=0, reduce="add")
rowV = oe.contract("nhd, dhc -> nhc", rowV, self.VeRow, backend="torch")
batch.wV = batch.wV + rowV
def forward(self, batch):
Q_h = self.Q(batch.x)
K_h = self.K(batch.x)
V_h = self.V(batch.x)
if batch.get("edge_attr", None) is not None:
batch.E = self.E(batch.edge_attr)
else:
batch.E = None
batch.Q_h = Q_h.view(-1, self.num_heads, self.out_dim)
batch.K_h = K_h.view(-1, self.num_heads, self.out_dim)
batch.V_h = V_h.view(-1, self.num_heads, self.out_dim)
self.propagate_attention(batch)
h_out = batch.wV
e_out = batch.get('wE', None)
return h_out, e_out
[docs]
@register_layer("GritTransformer")
class GritTransformerLayer(nn.Module):
"""
Proposed Transformer Layer for GRIT
Adapted from https://github.com/LiamMa/GRIT
Parameters:
in_dim (int): Number of input features.
out_dim (int): Number of output features.
num_heads (int): Number of attention heads.
dropout (float): Dropout rate.
attn_dropout (float): Attention dropout rate.
layer_norm (bool): Whether to use layer normalization.
batch_norm (bool): Whether to use batch normalization.
residual (bool): Whether to use residual connections.
act (str): Activation function ('relu', 'gelu', etc.).
norm_e (bool): Whether to normalize edge features.
O_e (bool): Whether to use edge features in the output.
Input:
batch.x (torch.Tensor): Input node features.
batch.edge_index (torch.Tensor): Edge indices of the graph.
batch.edge_attr (torch.Tensor): Edge attributes.
Output:
batch.x (torch.Tensor): Output node features after applying the GritTransformer layer.
batch.edge_attr (torch.Tensor): Updated edge attributes.
"""
def __init__(self, in_dim, out_dim, num_heads,
dropout=0.0,
attn_dropout=0.0,
layer_norm=False, batch_norm=True,
residual=True,
act='relu',
norm_e=True,
O_e=True,
cfg=dict(),
**kwargs):
super().__init__()
self.debug = False
self.in_channels = in_dim
self.out_channels = out_dim
self.in_dim = in_dim
self.out_dim = out_dim
self.num_heads = num_heads
self.dropout = dropout
self.residual = residual
self.layer_norm = layer_norm
self.batch_norm = batch_norm
# -------
self.update_e = cfg.get("update_e", True)
self.bn_momentum = cfg.bn_momentum
self.bn_no_runner = cfg.bn_no_runner
self.rezero = cfg.get("rezero", False)
self.act = act_dict[act]() if act is not None else nn.Identity()
if cfg.get("attn", None) is None:
cfg.attn = dict()
self.use_attn = cfg.attn.get("use", True)
# self.sigmoid_deg = cfg.attn.get("sigmoid_deg", False)
self.deg_scaler = cfg.attn.get("deg_scaler", True)
self.attention = MultiHeadAttentionLayerGritSparse(
in_dim=in_dim,
out_dim=out_dim // num_heads,
num_heads=num_heads,
use_bias=cfg.attn.get("use_bias", False),
dropout=attn_dropout,
clamp=cfg.attn.get("clamp", 5.),
act=cfg.attn.get("act", "relu"),
edge_enhance=cfg.attn.get("edge_enhance", True),
sqrt_relu=cfg.attn.get("sqrt_relu", False),
signed_sqrt=cfg.attn.get("signed_sqrt", False),
scaled_attn =cfg.attn.get("scaled_attn", False),
no_qk=cfg.attn.get("no_qk", False),
)
if cfg.attn.get('graphormer_attn', False): # not used in GRIT
self.attention = MultiHeadAttentionLayerGraphormerSparse(
in_dim=in_dim,
out_dim=out_dim // num_heads,
num_heads=num_heads,
use_bias=cfg.attn.get("use_bias", False),
dropout=attn_dropout,
clamp=cfg.attn.get("clamp", 5.),
act=cfg.attn.get("act", "relu"),
edge_enhance=True,
sqrt_relu=cfg.attn.get("sqrt_relu", False),
signed_sqrt=cfg.attn.get("signed_sqrt", False),
scaled_attn =cfg.attn.get("scaled_attn", False),
no_qk=cfg.attn.get("no_qk", False),
)
self.O_h = nn.Linear(out_dim//num_heads * num_heads, out_dim)
if O_e:
self.O_e = nn.Linear(out_dim//num_heads * num_heads, out_dim)
else:
self.O_e = nn.Identity()
# -------- Deg Scaler Option ------
if self.deg_scaler:
self.deg_coef = nn.Parameter(torch.zeros(1, out_dim//num_heads * num_heads, 2))
nn.init.xavier_normal_(self.deg_coef)
if self.layer_norm:
self.layer_norm1_h = nn.LayerNorm(out_dim)
self.layer_norm1_e = nn.LayerNorm(out_dim) if norm_e else nn.Identity()
if self.batch_norm:
# when the batch_size is really small, use smaller momentum to avoid bad mini-batch leading to extremely bad val/test loss (NaN)
self.batch_norm1_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum)
self.batch_norm1_e = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum) if norm_e else nn.Identity()
# FFN for h
self.FFN_h_layer1 = nn.Linear(out_dim, out_dim * 2)
self.FFN_h_layer2 = nn.Linear(out_dim * 2, out_dim)
if self.layer_norm:
self.layer_norm2_h = nn.LayerNorm(out_dim)
if self.batch_norm:
self.batch_norm2_h = nn.BatchNorm1d(out_dim, track_running_stats=not self.bn_no_runner, eps=1e-5, momentum=cfg.bn_momentum)
if self.rezero:
self.alpha1_h = nn.Parameter(torch.zeros(1,1))
self.alpha2_h = nn.Parameter(torch.zeros(1,1))
self.alpha1_e = nn.Parameter(torch.zeros(1,1))
def forward(self, batch):
h = batch.x
num_nodes = batch.num_nodes
log_deg = get_log_deg(batch)
h_in1 = h # for first residual connection
e_in1 = batch.get("edge_attr", None)
e = None
# multi-head attention out
h_attn_out, e_attn_out = self.attention(batch)
h = h_attn_out.view(num_nodes, -1)
h = F.dropout(h, self.dropout, training=self.training)
# degree scaler
if self.deg_scaler:
h = torch.stack([h, h * log_deg], dim=-1)
h = (h * self.deg_coef).sum(dim=-1)
h = self.O_h(h)
if e_attn_out is not None:
e = e_attn_out.flatten(1)
e = F.dropout(e, self.dropout, training=self.training)
e = self.O_e(e)
if self.residual:
if self.rezero: h = h * self.alpha1_h
h = h_in1 + h # residual connection
if e is not None:
if self.rezero: e = e * self.alpha1_e
e = e + e_in1
if self.layer_norm:
h = self.layer_norm1_h(h)
if e is not None: e = self.layer_norm1_e(e)
if self.batch_norm:
h = self.batch_norm1_h(h)
if e is not None: e = self.batch_norm1_e(e)
# FFN for h
h_in2 = h # for second residual connection
h = self.FFN_h_layer1(h)
h = self.act(h)
h = F.dropout(h, self.dropout, training=self.training)
h = self.FFN_h_layer2(h)
if self.residual:
if self.rezero: h = h * self.alpha2_h
h = h_in2 + h # residual connection
if self.layer_norm:
h = self.layer_norm2_h(h)
if self.batch_norm:
h = self.batch_norm2_h(h)
batch.x = h
if self.update_e:
batch.edge_attr = e
else:
batch.edge_attr = e_in1
return batch
def __repr__(self):
return '{}(in_channels={}, out_channels={}, heads={}, residual={})\n[{}]'.format(
self.__class__.__name__,
self.in_channels,
self.out_channels, self.num_heads, self.residual,
super().__repr__(),
)
@torch.no_grad()
def get_log_deg(batch):
if "log_deg" in batch:
log_deg = batch.log_deg
elif "deg" in batch:
deg = batch.deg
log_deg = torch.log(deg + 1).unsqueeze(-1)
else:
warnings.warn("Compute the degree on the fly; Might be problematric if have applied edge-padding to complete graphs")
deg = pyg.utils.degree(batch.edge_index[1],
num_nodes=batch.num_nodes,
dtype=torch.float
)
log_deg = torch.log(deg + 1)
log_deg = log_deg.view(batch.num_nodes, 1)
return log_deg