Source code for opengt.layer.bga_layer

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_layer

[docs] class ScaledDotProductAttention(nn.Module): ''' Scaled Dot-Product Attention ''' def __init__(self, temperature, attn_dropout=0.1): super(ScaledDotProductAttention, self).__init__() self.temperature = temperature self.dropout = nn.Dropout(attn_dropout) # self.label_same_matrix = torch.load('analysis/label_same_matrix_citeseer.pt').float() def forward(self, q, k, v, mask=None): attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) if mask is not None: attn = attn.masked_fill(mask == 0, -1e9) # self.label_same_matrix = self.label_same_matrix.to(attn.device) # attn = attn * self.label_same_matrix * 2 + attn * (1-self.label_same_matrix) attn = self.dropout(F.softmax(attn, dim=-1)) # attn = self.dropout(attn) output = torch.matmul(attn, v) return output, attn
[docs] class MultiHeadAttention(nn.Module): ''' Multi-Head Attention module ''' def __init__(self, n_head, channels, dropout=0.1): super(MultiHeadAttention, self).__init__() self.n_head = n_head self.channels = channels d_q = d_k = d_v = channels // n_head self.w_qs = nn.Linear(channels, channels, bias=False) self.w_ks = nn.Linear(channels, channels, bias=False) self.w_vs = nn.Linear(channels, channels, bias=False) self.fc = nn.Linear(channels, channels, bias=False) self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) self.dropout = nn.Dropout(dropout) def forward(self, q, k, v, mask=None): n_head = self.n_head d_q = d_k = d_v = self.channels // n_head B_q = q.size(0) N_q = q.size(1) B_k = k.size(0) N_k = k.size(1) B_v = v.size(0) N_v = v.size(1) residual = q # x = self.dropout(q) # Pass through the pre-attention projection: B * N x (h*dv) # Separate different heads: B * N x h x dv q = self.w_qs(q).view(B_q, N_q, n_head, d_q) k = self.w_ks(k).view(B_k, N_k, n_head, d_k) v = self.w_vs(v).view(B_v, N_v, n_head, d_v) # Transpose for attention dot product: B * h x N x dv q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) # For head axis broadcasting. if mask is not None: mask = mask.unsqueeze(1) q, attn = self.attention(q, k, v, mask=mask) # Transpose to move the head dimension back: B x N x h x dv # Combine the last two dimensions to concatenate all the heads together: B x N x (h*dv) q = q.transpose(1, 2).contiguous().view(B_q, N_q, -1) q = self.fc(q) q = q + residual return q, attn
[docs] class FFN(nn.Module): ''' A two-feed-forward-layer module ''' def __init__(self, channels, dropout=0.1): super(FFN, self).__init__() self.lin1 = nn.Linear(channels, channels) # position-wise self.lin2 = nn.Linear(channels, channels) # position-wise self.layer_norm = nn.LayerNorm(channels, eps=1e-6) self.Dropout = nn.Dropout(dropout) def forward(self, x): residual = x x = self.layer_norm(x) x = self.Dropout(x) x = F.relu(self.lin1(x)) x = self.lin2(x) + residual return x
[docs] @register_layer('BGALayer') class BGALayer(nn.Module): """ Bilevel Graph Attention layer. Used in CoBFormer model. Adapted from https://github.com/null-xyj/CoBFormer Parameters: n_head (int): Number of attention heads. Handled by GraphGym. channels (int): Number of input channels. Handled by GraphGym. dropout (float): Dropout rate. Input: x (Tensor): Input node features. patch (Tensor): Patch indices. attn_mask (Tensor): Attention mask. need_attn (bool): Whether to return attention weights. Output: x (Tensor): Output node features after applying the BGA layer. """ def __init__(self, n_head, channels, dropout=0.1): super(BGALayer, self).__init__() dropout = dropout channels = cfg.gt.dim_hidden n_head = cfg.gt.n_heads self.node_norm = nn.LayerNorm(channels) self.node_transformer = MultiHeadAttention(n_head, channels, dropout) self.patch_norm = nn.LayerNorm(channels) self.patch_transformer = MultiHeadAttention(n_head, channels, dropout) self.node_ffn = FFN(channels, dropout) self.patch_ffn = FFN(channels, dropout) self.fuse_lin = nn.Linear(2 * channels, channels) self.use_patch_attn = cfg.gt.use_patch_attn self.attn = None def forward(self, x, patch, attn_mask=None, need_attn=False): x = self.node_norm(x) patch_x = x[patch] patch_x, attn = self.node_transformer(patch_x, patch_x, patch_x, attn_mask) patch_x = self.node_ffn(patch_x) if need_attn: self.attn = torch.zeros((x.shape[0], x.shape[0])) for i in tqdm(range(patch.shape[0])): p = patch[i].tolist() row = torch.tensor([p] * len(p)).T.flatten() col = torch.tensor(p * len(p)) a = attn[i].mean(0).flatten().cpu() self.attn = self.attn.index_put((row, col), a) self.attn = self.attn[:-1][:, :-1].detach().cpu() if self.use_patch_attn: p = self.patch_norm(patch_x.mean(dim=1, keepdim=False)).unsqueeze(0) p, _ = self.patch_transformer(p, p, p) p = self.patch_ffn(p).permute(1, 0, 2) # p = p.repeat(1, patch.shape[1], 1) z = torch.cat([patch_x, p], dim=2) patch_x = F.relu(self.fuse_lin(z)) + patch_x x[patch] = patch_x return x