import torch
from torch_geometric.graphgym.register import (register_node_encoder,
register_edge_encoder)
"""
=== Description of the ogbg-code2 dataset ===
* Node Encoder code based on OGB's:
https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/code2/utils.py
Node Encoder config parameters are set based on the OGB example:
https://github.com/snap-stanford/ogb/blob/master/examples/graphproppred/code2/main_pyg.py
where the following three node features are used:
1. node type
2. node attribute
3. node depth
nodetypes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'typeidx2type.csv.gz'))
nodeattributes_mapping = pd.read_csv(os.path.join(dataset.root, 'mapping', 'attridx2attr.csv.gz'))
num_nodetypes = len(nodetypes_mapping['type'])
num_nodeattributes = len(nodeattributes_mapping['attr'])
max_depth = 20
* Edge attributes are generated by `augment_edge` function dynamically:
edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1)
edge_attr[:,1]: whether it is original direction (0) or inverse direction (1)
"""
num_nodetypes = 98
num_nodeattributes = 10030
max_depth = 20
[docs]
@register_node_encoder('ASTNode')
class ASTNodeEncoder(torch.nn.Module):
"""
The Abstract Syntax Tree (AST) Node Encoder used for ogbg-code2 dataset.
Parameters:
emb_dim (int): Output node embedding dimension
Input:
batch.x (torch.Tensor): Default node feature. The first and second column represents node type and node attributes.
batch.node_depth (torch.Tensor): The depth of the node in the AST.
Output:
batch.x (torch.Tensor): emb_dim-dimensional vector
"""
def __init__(self, emb_dim):
super().__init__()
self.max_depth = max_depth
self.type_encoder = torch.nn.Embedding(num_nodetypes, emb_dim)
self.attribute_encoder = torch.nn.Embedding(num_nodeattributes, emb_dim)
self.depth_encoder = torch.nn.Embedding(self.max_depth + 1, emb_dim)
def forward(self, batch):
x = batch.x
depth = batch.node_depth.view(-1, )
depth[depth > self.max_depth] = self.max_depth
batch.x = self.type_encoder(x[:, 0]) + self.attribute_encoder(x[:, 1]) \
+ self.depth_encoder(depth)
return batch
[docs]
@register_edge_encoder('ASTEdge')
class ASTEdgeEncoder(torch.nn.Module):
"""
The Abstract Syntax Tree (AST) Edge Encoder used for ogbg-code2 dataset.
Edge attributes are generated by `augment_edge` function dynamically and
are expected to be:
edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1)
edge_attr[:,1]: whether it is original direction (0) or inverse direction (1)
Parameters:
emb_dim (int): Output edge embedding dimension
Input:
batch.edge_attr (torch.Tensor): Default edge feature.
Output:
batch.edge_attr (torch.Tensor): emb_dim-dimensional vector
"""
def __init__(self, emb_dim):
super().__init__()
self.embedding_type = torch.nn.Embedding(2, emb_dim)
self.embedding_direction = torch.nn.Embedding(2, emb_dim)
def forward(self, batch):
embedding = self.embedding_type(batch.edge_attr[:, 0]) + \
self.embedding_direction(batch.edge_attr[:, 1])
batch.edge_attr = embedding
return batch