Source code for opengt.head.mlp_mixer_graph

import torch.nn as nn

import torch_geometric.graphgym.register as register
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import register_head


[docs] @register_head('mlp_mixer_graph') class MLPMixerGraphHead(nn.Module): """ Graph MLP Mixer prediction head for graph prediction tasks. Note that this head is specially designed for Graph MLP Mixer (without pooling layer). Cannot work on other models. Parameters: dim_in (int): Input dimension. dim_out (int): Output dimension. For binary prediction, dim_out=1. L (int): Number of hidden layers. Input: batch.x (torch.Tensor): Graph embedding. batch.y (torch.Tensor): Graph labels. Output: pred (torch.Tensor): Predicted graph labels. true (torch.Tensor): True graph labels. """ def __init__(self, dim_in, dim_out, L=2): super().__init__() list_FC_layers = [ nn.Linear(dim_in // 2 ** l, dim_in // 2 ** (l + 1), bias=True) for l in range(L)] list_FC_layers.append( nn.Linear(dim_in // 2 ** L, dim_out, bias=True)) self.FC_layers = nn.ModuleList(list_FC_layers) self.L = L self.activation = register.act_dict[cfg.gnn.act]() def _apply_index(self, batch): return batch.graph_feature, batch.y def forward(self, batch): graph_emb = batch.x for l in range(self.L): graph_emb = self.FC_layers[l](graph_emb) graph_emb = self.activation(graph_emb) graph_emb = self.FC_layers[self.L](graph_emb) batch.graph_feature = graph_emb pred, true = self._apply_index(batch) return pred, true