Source code for opengt.head.ogb_code_graph

import torch.nn as nn

import torch_geometric.graphgym.register as register
from torch_geometric.graphgym import cfg
from torch_geometric.graphgym.register import register_head


[docs] @register_head('ogb_code_graph') class OGBCodeGraphHead(nn.Module): """ Sequence prediction head for ogbg-code2 graph-level prediction tasks. Parameters: dim_in (int): Input dimension. dim_out (int): IGNORED, kept for GraphGym framework compatibility L (int): Number of hidden layers. """ def __init__(self, dim_in, dim_out, L=1): super().__init__() self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling] self.L = L num_vocab = 5002 self.max_seq_len = 5 if self.L != 1: raise ValueError(f"Multilayer prediction heads are not supported.") self.graph_pred_linear_list = nn.ModuleList() for i in range(self.max_seq_len): self.graph_pred_linear_list.append(nn.Linear(dim_in, num_vocab)) def _apply_index(self, batch): return batch.pred_list, {'y_arr': batch.y_arr, 'y': batch.y} def forward(self, batch): graph_emb = self.pooling_fun(batch.x, batch.batch) pred_list = [] for i in range(self.max_seq_len): pred_list.append(self.graph_pred_linear_list[i](graph_emb)) batch.pred_list = pred_list pred, true = self._apply_index(batch) return pred, true