Source code for opengt.layer.degta_layer

import inspect
import torch
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing, GATConv, GCNConv
from torch_geometric.utils import to_dense_adj

from torch.nn import Sequential
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_layer

[docs] @register_layer('DeGTAConv') class DeGTAConv(torch.nn.Module): """ Decouplized Graph Triple Attention Layer. Adapted from https://github.com/wangxiaotang0906/DeGTA Parameters: dim_in (int): Number of input features. Input: batch.x (Tensor): Input node features. Should be concatenated from three different encoders. batch.edge_index (Tensor): Edge indices of the graph. Output: ret.x (Tensor): Output node features after applying the DeGTA layer. """ def __init__( self, dim_in ): super().__init__() enc_name = cfg.dataset.node_encoder_name.split('+') if len(enc_name) != 3: raise ValueError(f"Expected 3 encoders, got {len(enc_name)}: {enc_name}") self.pe_channels = getattr(cfg, f"posenc_{enc_name[1]}").dim_pe self.se_channels = getattr(cfg, f"posenc_{enc_name[2]}").dim_pe self.ae_channels = dim_in - self.pe_channels - self.se_channels self.heads = cfg.gt.n_heads self.dropout = cfg.gt.dropout # multi-view encoder self.ae_encoder = Sequential( torch.nn.Linear(self.ae_channels, self.ae_channels * 2), torch.nn.ReLU(), torch.nn.Dropout(self.dropout), torch.nn.Linear(self.ae_channels * 2, self.ae_channels), torch.nn.Dropout(self.dropout), ) self.pe_encoder = Sequential( torch.nn.Linear(self.pe_channels, self.pe_channels * 2), torch.nn.ReLU(), torch.nn.Linear(self.pe_channels * 2, self.pe_channels), ) self.se_encoder = Sequential( torch.nn.Linear(self.se_channels, self.se_channels * 2), torch.nn.ReLU(), torch.nn.Linear(self.se_channels * 2, self.se_channels), ) # local_channel self.ae_attn_l = GATConv(self.ae_channels, self.ae_channels, add_self_loops = False) self.pe_attn_l = GATConv(self.pe_channels, self.pe_channels, add_self_loops = False) self.se_attn_l = GATConv(self.se_channels, self.se_channels, add_self_loops = False) self.weightedconv = GCNConv(self.ae_channels, self.ae_channels) # global_channel self.ae_attn_g = torch.nn.MultiheadAttention(self.ae_channels, self.heads) self.pe_attn_g = torch.nn.MultiheadAttention(self.pe_channels, self.heads) self.se_attn_g = torch.nn.MultiheadAttention(self.se_channels, self.heads) # multi-view adaptive integration self.a_l = torch.nn.Parameter(torch.tensor(0.33)) self.b_l = torch.nn.Parameter(torch.tensor(0.33)) self.c_l = torch.nn.Parameter(torch.tensor(0.33)) self.a_g = torch.nn.Parameter(torch.tensor(0.33)) self.b_g = torch.nn.Parameter(torch.tensor(0.33)) self.c_g = torch.nn.Parameter(torch.tensor(0.33)) # local-global adaptive integration self.localchannel = torch.nn.Linear(self.ae_channels, self.ae_channels, bias=False) self.globalchannel = torch.nn.Linear(self.ae_channels, self.ae_channels, bias=False) self.norm1 = torch.nn.LayerNorm(self.ae_channels) self.norm2 = torch.nn.LayerNorm(self.ae_channels) self.norm3 = torch.nn.LayerNorm(self.ae_channels) self.norm_with_batch = False if self.norm1 is not None: signature = inspect.signature(self.norm1.forward) self.norm_with_batch = 'batch' in signature.parameters def forward( self, batch ) -> torch.Tensor: # multi_view encoder ae, pe, se = torch.split(batch.x, [self.ae_channels, self.pe_channels, self.se_channels], dim=1) pe_0 = pe.clone() se_0 = se.clone() ae = self.ae_encoder(ae) pe = self.pe_encoder(pe) se = self.se_encoder(se) edge_index = batch.edge_index K = cfg.gt.K # print("after encoder", ae, pe, se) # local_channel _, peattn = self.pe_attn_l(pe, edge_index, return_attention_weights=True) _, seattn = self.se_attn_l(se, edge_index, return_attention_weights=True) _, aeattn = self.ae_attn_l(ae, edge_index, return_attention_weights=True) local_attn = self.a_l * peattn[1] + self.b_l * seattn[1] + self.c_l * aeattn[1] out_local = self.weightedconv(ae, edge_index, local_attn) out_local = F.dropout(out_local, p=self.dropout, training=self.training) # global_channel _, peattn = self.pe_attn_g(pe, pe, pe, need_weights=True) _, seattn = self.se_attn_g(se, se, se, need_weights=True) _, aeattn = self.ae_attn_g(ae, ae, ae, need_weights=True) sample_attn = self.a_g * peattn + self.b_g * seattn adj = to_dense_adj(edge_index, max_num_nodes=ae.size(0)) zero_vec = torch.zeros_like(adj) sample_attn = torch.where(adj > 0, zero_vec, sample_attn) # hard sampling values, indices = sample_attn.topk(K, dim=1, largest=True, sorted=True) mask = torch.zeros_like(sample_attn).scatter_(1, indices, torch.ones_like(values)) sample_attn_masked = sample_attn * mask aeattn = aeattn * mask global_attn = (0.5 * self.a_g + 0.5 * self.b_g) * sample_attn_masked + self.c_g * aeattn column_means = torch.sum(global_attn, dim=1) global_attn = global_attn / column_means # y_soft = sample_attn # index = y_soft.max(dim=-1, keepdim=True)[1] # y_hard = torch.zeros_like(sample_attn, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0) # ret = y_hard - y_soft.detach() + y_soft # active_edge = ret[:,0] out_global = torch.matmul(global_attn, ae) out_global = F.dropout(out_global, p=self.dropout, training=self.training) # local_global integration out = self.localchannel(out_local) + self.globalchannel(out_global) if self.norm3 is not None: out = self.norm3(out) ret = batch.clone() ret.x = torch.cat([out.squeeze(0), pe_0, se_0], dim=1) return ret