Source code for opengt.network.cobformer

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import new_layer_config, GeneralMultiLayer
from torch_geometric.graphgym.register import register_network

from opengt.network.bga_model import BGA

[docs] @register_network('CoBFormer') class CoBFormer(torch.nn.Module): """ CoBFormer model. Only supports transductive node level tasks. Adapted from https://github.com/null-xyj/CoBFormer Parameters: dim_in (int): Number of input features. dim_out (int): Number of output features. cfg (dict): Configuration dictionary containing model parameters from GraphGym. - cfg.gt.alpha (float): Balance factor for GNN and BGA loss. - cfg.gt.tau (float): Temperature parameter for softmax. - cfg.gt.layer_type (str): Type of GNN layer to use. e.g., 'GCN'. - cfg.gnn.layers (int): Number of GNN layers. Input: batch (torch_geometric.data.Batch): Input batch containing node features and graph structure. - batch.x (torch.Tensor): Input node features. - batch.edge_index (torch.Tensor): Edge indices of the graph. - batch.patch (torch.Tensor): Patch indices. - batch.y (torch.Tensor): Input labels. - batch.split (str): Split type (train, val, test). Output: pred (torch.Tensor): Predicted node features after applying the CoBFormer model. true (torch.Tensor): True labels. extra_loss (torch.Tensor): Extra loss term for GNN and BGA cotraining. """ def __init__(self, dim_in: int, dim_out: int): super(CoBFormer, self).__init__() self.alpha = cfg.gt.alpha self.tau = cfg.gt.tau self.gnn = GeneralMultiLayer(cfg.gt.layer_type.split('+')[0].lower()+'conv', new_layer_config(dim_in = dim_in, dim_out = dim_out, has_bias = True, has_act = False, num_layers = cfg.gnn.layers, cfg = cfg)) self.bga = BGA(dim_in, dim_out) self.attn = None def _apply_index(self, batch): x = batch.x y = batch.y if 'y' in batch else None if 'split' not in batch: return x, y mask = batch[f'{batch.split}_mask'] return x[mask], y[mask] if y is not None else None def forward(self, batch): tmpbatch = batch.clone() batch1 = self.gnn(tmpbatch) batch2 = self.bga(batch) z1 = batch1.x z2 = batch2.x extra_loss = (F.cross_entropy(z1*self.tau, F.softmax(z2*self.tau, dim=1)) + F.cross_entropy(z2*self.tau, F.softmax(z1*self.tau, dim=1)))*(1-self.alpha)/self.alpha pred, true = self._apply_index(batch2) return pred, true, extra_loss