Source code for opengt.network.mlp_mixer

import torch.nn as nn
from opengt.layer.mlp_mixer import MLPMixer
from opengt.layer.patch_encoder import PatchEncoder
import functorch.einops.rearrange as Rearrange
from torch_scatter import scatter

from torch_geometric.nn import Sequential
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.models.layer import new_layer_config, MLP, GeneralLayer, GCNConv
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_network
from opengt.encoder.feature_encoder import FeatureEncoder

[docs] @register_network('GraphMLPMixer') class GraphMLPMixer(nn.Module): """ GraphMLPMixer model. Only supports graph-level tasks. Adapted from https://github.com/XiaoxinHe/Graph-ViT-MLPMixer Parameters: dim_in (int): Number of input features. dim_out (int): Number of output features. cfg (dict): Configuration dictionary containing model parameters from GraphGym. - cfg.gt.dim_hidden (int): Hidden dimension for GNN layers and GraphMLPMixer layers. - cfg.gt.mlpmixer_layers (int): Number of Mixer blocks. - cfg.gt.mlpmixer_dropout (float): Dropout rate for the Mixer blocks. - cfg.metis.patches (int): Number of patches. - cfg.metis.patch_rw_dim (int): Dimension of the patch random walk encoder. Input: batch (torch_geometric.data.Batch): Input batch containing node features and graph structure. - batch.x (torch.Tensor): Input node features. - batch.edge_index (torch.Tensor): Edge indices of the graph. - batch.subgraphs_nodes_mapper (torch.Tensor): Mapping of nodes to subgraphs. - batch.subgraphs_edges_mapper (torch.Tensor): Mapping of edges to subgraphs. - batch.combined_subgraphs (torch.Tensor): Combined subgraphs. - batch.subgraphs_batch (torch.Tensor): Batch indices for subgraphs. - batch.mask (torch.Tensor): Mask for global pooling. Output: pred (torch.Tensor): Graph logits after model processing. true (torch.Tensor): Labels for the input batch. """ def __init__(self, dim_in, dim_out): super().__init__() self.encoder = FeatureEncoder(dim_in) dim_in = self.encoder.dim_in self.dropout = cfg.gt.dropout self.pooling = cfg.gt.pooling self.pre_mp = GeneralLayer('linear', new_layer_config(dim_in = dim_in, dim_out = cfg.gt.dim_hidden, num_layers = 1, has_act = True, has_bias = True, cfg = cfg)) # Patch Encoder # x = x[data.subgraphs_nodes_mapper] # e = edge_attr[data.subgraphs_edges_mapper] # edge_index = data.combined_subgraphs # batch_x = data.subgraphs_batch # for i, gnn in enumerate(self.gnns): # if i > 0: # subgraph = scatter(x, batch_x, dim=0, # reduce=self.pooling)[batch_x] # x = x + self.U[i-1](subgraph) # x = scatter(x, data.subgraphs_nodes_mapper, # dim=0, reduce='mean')[data.subgraphs_nodes_mapper] # x = gnn(x, edge_index, e) # subgraph_x = scatter(x, batch_x, dim=0, reduce=self.pooling) self.patch_encoder = PatchEncoder(dim_in=cfg.gt.dim_hidden, dim_out=cfg.gt.dim_hidden) # needs scatter & patch rw encoder residual # patch PE if cfg.metis.patch_rw_dim > 0: self.patch_rw_encoder = Sequential('x', [(lambda x: self.copy_patch_pe(x), 'x -> x1'), (MLP(new_layer_config(dim_in=cfg.metis.patch_rw_dim, dim_out=cfg.gt.dim_hidden, num_layers=2, has_act=True, has_bias=True, cfg=cfg)), 'x1 -> x1'), (lambda x, x1: self.aggregate_batches_add(x, x1), 'x, x1 -> x')]) self.reshape_layer = Sequential('x', [(lambda x: self.reshape(x), 'x -> x')]) self.transformer_encoder = MLPMixer(dim_hidden=cfg.gt.dim_hidden, dropout=cfg.gt.mlpmixer_dropout, layers=cfg.gt.mlpmixer_layers, patches=cfg.metis.patches) # global pooling self.pooling_layer = Sequential('x',[(lambda x: self.global_pooling(x), 'x -> x')]) GNNHead = register.head_dict['mlp_mixer_graph'] self.post_mp = GNNHead(dim_in=cfg.gt.dim_hidden, dim_out=dim_out) def copy_patch_pe(self, batch): # Copy the patch PE to batch.x new_batch = batch.clone() new_batch.x = batch.patch_pe return new_batch def aggregate_batches_add(self, x, x1): new_batch = x.clone() new_batch.x = x.x + x1.x return new_batch def reshape(self, batch): # Reshape the input tensor to (B, P, D) format ret = batch.clone() ret.x = Rearrange(batch.x,'(B P) D -> B P D', P=cfg.metis.patches) return ret def global_pooling(self, batch): ret = batch.clone() ret.x = (batch.x * batch.mask.unsqueeze(-1)).sum(1) / batch.mask.sum(1, keepdim=True) return ret def forward(self, batch): for module in self.children(): batch = module(batch) return batch