Source code for opengt.network.performer

import torch
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNPreMP
from torch_geometric.graphgym.register import register_network

from opengt.layer.performer_layer import Performer as BackbonePerformer


[docs] @register_network('Performer') class Performer(torch.nn.Module): """Performer without edge features. This model disregards edge features and runs a linear transformer over a set of node features only. https://arxiv.org/abs/2009.14794 Adapted from https://github.com/rampasek/GraphGPS """ def __init__(self, dim_in, dim_out): super().__init__() self.encoder = FeatureEncoder(dim_in) dim_in = self.encoder.dim_in if cfg.gnn.layers_pre_mp > 0: self.pre_mp = GNNPreMP( dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp) dim_in = cfg.gnn.dim_inner assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \ "The inner and hidden dims must match." self.trf = BackbonePerformer( dim=cfg.gt.dim_hidden, depth=cfg.gt.layers, heads=cfg.gt.n_heads, dim_head=cfg.gt.dim_hidden // cfg.gt.n_heads ) GNNHead = register.head_dict[cfg.gnn.head] self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out) def forward(self, batch): for module in self.children(): batch = module(batch) return batch