Source code for opengt.encoder.signnet_pos_encoder

"""
SignNet https://arxiv.org/abs/2202.13013
based on https://github.com/cptq/SignNet-BasisNet
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_node_encoder
from torch_geometric.nn import GINConv
from torch_scatter import scatter


class MLP(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 use_bn=False, use_ln=False, dropout=0.5, activation='relu',
                 residual=False):
        super().__init__()
        self.lins = nn.ModuleList()
        if use_bn: self.bns = nn.ModuleList()
        if use_ln: self.lns = nn.ModuleList()

        if num_layers == 1:
            # linear mapping
            self.lins.append(nn.Linear(in_channels, out_channels))
        else:
            self.lins.append(nn.Linear(in_channels, hidden_channels))
            if use_bn: self.bns.append(nn.BatchNorm1d(hidden_channels))
            if use_ln: self.lns.append(nn.LayerNorm(hidden_channels))
            for layer in range(num_layers - 2):
                self.lins.append(nn.Linear(hidden_channels, hidden_channels))
                if use_bn: self.bns.append(nn.BatchNorm1d(hidden_channels))
                if use_ln: self.lns.append(nn.LayerNorm(hidden_channels))
            self.lins.append(nn.Linear(hidden_channels, out_channels))
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'elu':
            self.activation = nn.ELU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        else:
            raise ValueError('Invalid activation')
        self.use_bn = use_bn
        self.use_ln = use_ln
        self.dropout = dropout
        self.residual = residual

    def forward(self, x):
        x_prev = x
        for i, lin in enumerate(self.lins[:-1]):
            x = lin(x)
            x = self.activation(x)
            if self.use_bn:
                if x.ndim == 2:
                    x = self.bns[i](x)
                elif x.ndim == 3:
                    x = self.bns[i](x.transpose(2, 1)).transpose(2, 1)
                else:
                    raise ValueError('invalid dimension of x')
            if self.use_ln: x = self.lns[i](x)
            if self.residual and x_prev.shape == x.shape: x = x + x_prev
            x = F.dropout(x, p=self.dropout, training=self.training)
            x_prev = x
        x = self.lins[-1](x)
        if self.residual and x_prev.shape == x.shape:
            x = x + x_prev
        return x


class GIN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, n_layers,
                 use_bn=True, dropout=0.5, activation='relu'):
        super().__init__()
        self.layers = nn.ModuleList()
        if use_bn: self.bns = nn.ModuleList()
        self.use_bn = use_bn
        # input layer
        update_net = MLP(in_channels, hidden_channels, hidden_channels, 2,
                         use_bn=use_bn, dropout=dropout, activation=activation)
        self.layers.append(GINConv(update_net))
        # hidden layers
        for i in range(n_layers - 2):
            update_net = MLP(hidden_channels, hidden_channels, hidden_channels,
                             2, use_bn=use_bn, dropout=dropout,
                             activation=activation)
            self.layers.append(GINConv(update_net))
            if use_bn: self.bns.append(nn.BatchNorm1d(hidden_channels))
        # output layer
        update_net = MLP(hidden_channels, hidden_channels, out_channels, 2,
                         use_bn=use_bn, dropout=dropout, activation=activation)
        self.layers.append(GINConv(update_net))
        if use_bn: self.bns.append(nn.BatchNorm1d(hidden_channels))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, edge_index):
        for i, layer in enumerate(self.layers):
            if i != 0:
                x = self.dropout(x)
                if self.use_bn:
                    if x.ndim == 2:
                        x = self.bns[i - 1](x)
                    elif x.ndim == 3:
                        x = self.bns[i - 1](x.transpose(2, 1)).transpose(2, 1)
                    else:
                        raise ValueError('invalid x dim')
            x = layer(x, edge_index)
        return x


[docs] class GINDeepSigns(nn.Module): """ Sign invariant neural network with MLP aggregation. f(v1, ..., vk) = rho(enc(v1) + enc(-v1), ..., enc(vk) + enc(-vk)) """ def __init__(self, in_channels, hidden_channels, out_channels, num_layers, k, dim_pe, rho_num_layers, use_bn=False, use_ln=False, dropout=0.5, activation='relu'): super().__init__() self.enc = GIN(in_channels, hidden_channels, out_channels, num_layers, use_bn=use_bn, dropout=dropout, activation=activation) rho_dim = out_channels * k self.rho = MLP(rho_dim, hidden_channels, dim_pe, rho_num_layers, use_bn=use_bn, dropout=dropout, activation=activation) def forward(self, x, edge_index, batch_index): N = x.shape[0] # Total number of nodes in the batch. x = x.transpose(0, 1) # N x K x In -> K x N x In x = self.enc(x, edge_index) + self.enc(-x, edge_index) x = x.transpose(0, 1).reshape(N, -1) # K x N x Out -> N x (K * Out) x = self.rho(x) # N x dim_pe (Note: in the original codebase dim_pe is always K) return x
[docs] class MaskedGINDeepSigns(nn.Module): """ Sign invariant neural network with sum pooling and DeepSet. f(v1, ..., vk) = rho(enc(v1) + enc(-v1), ..., enc(vk) + enc(-vk)) """ def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dim_pe, rho_num_layers, use_bn=False, use_ln=False, dropout=0.5, activation='relu'): super().__init__() self.enc = GIN(in_channels, hidden_channels, out_channels, num_layers, use_bn=use_bn, dropout=dropout, activation=activation) self.rho = MLP(out_channels, hidden_channels, dim_pe, rho_num_layers, use_bn=use_bn, dropout=dropout, activation=activation) def batched_n_nodes(self, batch_index): batch_size = batch_index.max().item() + 1 one = batch_index.new_ones(batch_index.size(0)) n_nodes = scatter(one, batch_index, dim=0, dim_size=batch_size, reduce='add') # Number of nodes in each graph. n_nodes = n_nodes.unsqueeze(1) return torch.cat([size * n_nodes.new_ones(size) for size in n_nodes]) def forward(self, x, edge_index, batch_index): N = x.shape[0] # Total number of nodes in the batch. K = x.shape[1] # Max. number of eigen vectors / frequencies. x = x.transpose(0, 1) # N x K x In -> K x N x In x = self.enc(x, edge_index) + self.enc(-x, edge_index) # K x N x Out x = x.transpose(0, 1) # K x N x Out -> N x K x Out batched_num_nodes = self.batched_n_nodes(batch_index) mask = torch.cat([torch.arange(K).unsqueeze(0) for _ in range(N)]) mask = (mask.to(batch_index.device) < batched_num_nodes.unsqueeze(1)).bool() # print(f" - mask: {mask.shape} {mask}") # print(f" - num_nodes: {num_nodes}") # print(f" - batched_num_nodes: {batched_num_nodes.shape} {batched_num_nodes}") x[~mask] = 0 x = x.sum(dim=1) # (sum over K) -> N x Out x = self.rho(x) # N x Out -> N x dim_pe (Note: in the original codebase dim_pe is always K) return x
[docs] @register_node_encoder('SignNet') class SignNetNodeEncoder(torch.nn.Module): """ SignNet Positional Embedding node encoder. https://arxiv.org/abs/2202.13013 https://github.com/cptq/SignNet-BasisNet Uses precomputated Laplacian eigen-decomposition, but instead of eigen-vector sign flipping + DeepSet/Transformer, computes the PE as: SignNetPE(v_1, ... , v_k) = rho ( [phi(v_i) + phi(-v_i)]^k_i=1 ) where \phi is GIN network applied to k first non-trivial eigenvectors, and rho is an MLP if k is a constant, but if all eigenvectors are used then rho is DeepSet with sum-pooling. SignNetPE of size dim_pe will get appended to each node feature vector. If `expand_x` set True, original node features will be first linearly projected to (dim_emb - dim_pe) size and the concatenated with SignNetPE. Parameters: dim_emb (int): Size of final node embedding. expand_x (bool): Expand node features `x` from dim_in to (dim_emb - dim_pe). """ def __init__(self, dim_emb, expand_x=True): super().__init__() dim_in = cfg.share.dim_in # Expected original input node features dim pecfg = cfg.posenc_SignNet dim_pe = pecfg.dim_pe # Size of PE embedding model_type = pecfg.model # Encoder NN model type for SignNet if model_type not in ['MLP', 'DeepSet']: raise ValueError(f"Unexpected SignNet model {model_type}") self.model_type = model_type sign_inv_layers = pecfg.layers # Num. layers in \phi GNN part rho_layers = pecfg.post_layers # Num. layers in \rho MLP/DeepSet if rho_layers < 1: raise ValueError(f"Num layers in rho model has to be positive.") max_freqs = pecfg.eigen.max_freqs # Num. eigenvectors (frequencies) self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable if dim_emb - dim_pe < 1: raise ValueError(f"SignNet PE size {dim_pe} is too large for " f"desired embedding size of {dim_emb}.") if expand_x: self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe) self.expand_x = expand_x # Sign invariant neural network. if self.model_type == 'MLP': self.sign_inv_net = GINDeepSigns( in_channels=1, hidden_channels=pecfg.phi_hidden_dim, out_channels=pecfg.phi_out_dim, num_layers=sign_inv_layers, k=max_freqs, dim_pe=dim_pe, rho_num_layers=rho_layers, use_bn=True, dropout=0.0, activation='relu' ) elif self.model_type == 'DeepSet': self.sign_inv_net = MaskedGINDeepSigns( in_channels=1, hidden_channels=pecfg.phi_hidden_dim, out_channels=pecfg.phi_out_dim, num_layers=sign_inv_layers, dim_pe=dim_pe, rho_num_layers=rho_layers, use_bn=True, dropout=0.0, activation='relu' ) else: raise ValueError(f"Unexpected model {self.model_type}") def forward(self, batch): if not (hasattr(batch, 'eigvals_sn') and hasattr(batch, 'eigvecs_sn')): raise ValueError("Precomputed eigen values and vectors are " f"required for {self.__class__.__name__}; " "set config 'posenc_SignNet.enable' to True") # eigvals = batch.eigvals_sn eigvecs = batch.eigvecs_sn # pos_enc = torch.cat((eigvecs.unsqueeze(2), eigvals), dim=2) # (Num nodes) x (Num Eigenvectors) x 2 pos_enc = eigvecs.unsqueeze(-1) # (Num nodes) x (Num Eigenvectors) x 1 empty_mask = torch.isnan(pos_enc) pos_enc[empty_mask] = 0 # (Num nodes) x (Num Eigenvectors) x 1 # SignNet pos_enc = self.sign_inv_net(pos_enc, batch.edge_index, batch.batch) # (Num nodes) x (pos_enc_dim) # Expand node features if needed if self.expand_x: h = self.linear_x(batch.x) else: h = batch.x # Concatenate final PEs to input embedding batch.x = torch.cat((h, pos_enc), 1) # Keep PE also separate in a variable (e.g. for skip connections to input) if self.pass_as_var: batch.pe_SignNet = pos_enc return batch