Source code for opengt.layer.mlp_mixer


import torch.nn as nn
from opengt.layer.mixer_block import MixerBlock
from torch_geometric.graphgym.register import register_layer

[docs] @register_layer("mlp_mixer") class MLPMixer(nn.Module): """ GraphMLPMixer layer. Adapted from https://github.com/XiaoxinHe/Graph-ViT-MLPMixer Parameters: layers (int): Number of Mixer blocks. dim_hidden (int): Number of input features. patches (int): Number of patches. with_final_norm (bool): Whether to apply final normalization. Default: True. dropout (float): Dropout rate. Default: 0.0. Input: batch.x (torch.Tensor): Input node features. Output: batch.x (torch.Tensor): Output node features after applying the Mixer blocks. """ def __init__(self, layers, dim_hidden, patches, with_final_norm=True, dropout=0): super().__init__() self.patches = patches self.with_final_norm = with_final_norm self.mixer_blocks = nn.ModuleList( [MixerBlock(dim_hidden, self.patches, dim_hidden*4, dim_hidden//2, dropout=dropout) for _ in range(layers)]) if self.with_final_norm: self.layer_norm = nn.LayerNorm(dim_hidden) def forward(self, batch): ret = batch.clone() for mixer_block in self.mixer_blocks: ret.x = mixer_block(ret.x) if self.with_final_norm: ret.x = self.layer_norm(ret.x) return ret