Source code for opengt.network.trans_conv

import torch
import torch.nn as nn
from opengt.layer.trans_conv_layer import TransConvLayer

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.layer import MLP, new_layer_config
from torch_geometric.graphgym.register import register_network

[docs] @register_network("TransConv") class TransConv(nn.Module): """ TransConv module for SGFormer model. Adapted from https://github.com/qitianwu/SGFormer Parameters: dim_in (int): Number of input features. dim_out (int): Number of output features. cfg (dict): Configuration dictionary containing model parameters from GraphGym. - cfg.gt.layers: Number of TransConv layers. Input: batch (torch_geometric.data.Batch): Input batch containing node features and graph structure. - batch.x (torch.Tensor): Input node features. - batch.edge_index (torch.Tensor): Edge indices of the graph. Output: batch (torch_geometric.data.Batch): Output batch after processing. """ def __init__(self, dim_in, dim_out): super().__init__() self.mlp = MLP(new_layer_config(dim_in=dim_in, dim_out=dim_out, num_layers=1, has_act=True, has_bias=True, cfg=cfg)) layers = nn.ModuleList() for i in range(cfg.gt.layers): layers.append( TransConvLayer(dim_out, dim_out, config=cfg.gt)) self.layers=nn.Sequential(*layers) def forward(self, batch): for module in self.children(): batch = module(batch) return batch