import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.graphgym.register as register
import torch_geometric.nn as pyg_nn
from torch_geometric.graphgym.models.layer import LayerConfig
from torch_geometric.graphgym.register import register_layer
from torch_scatter import scatter
[docs]
class GatedGCNLayer(pyg_nn.conv.MessagePassing):
"""
GatedGCN layer
Residual Gated Graph ConvNets
https://arxiv.org/pdf/1711.07553.pdf
"""
def __init__(self, in_dim, out_dim, dropout, residual, act='relu',
equivstable_pe=False, **kwargs):
super().__init__(**kwargs)
self.activation = register.act_dict[act]
self.A = pyg_nn.Linear(in_dim, out_dim, bias=True)
self.B = pyg_nn.Linear(in_dim, out_dim, bias=True)
self.C = pyg_nn.Linear(in_dim, out_dim, bias=True)
self.D = pyg_nn.Linear(in_dim, out_dim, bias=True)
self.E = pyg_nn.Linear(in_dim, out_dim, bias=True)
# Handling for Equivariant and Stable PE using LapPE
# ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
self.EquivStablePE = equivstable_pe
if self.EquivStablePE:
self.mlp_r_ij = nn.Sequential(
nn.Linear(1, out_dim),
self.activation(),
nn.Linear(out_dim, 1),
nn.Sigmoid())
self.bn_node_x = nn.BatchNorm1d(out_dim)
self.bn_edge_e = nn.BatchNorm1d(out_dim)
self.act_fn_x = self.activation()
self.act_fn_e = self.activation()
self.dropout = dropout
self.residual = residual
self.e = None
def forward(self, batch):
x, e, edge_index = batch.x, batch.edge_attr, batch.edge_index
"""
x : [n_nodes, in_dim]
e : [n_edges, in_dim]
edge_index : [2, n_edges]
"""
if self.residual:
x_in = x
e_in = e
Ax = self.A(x)
Bx = self.B(x)
Ce = self.C(e)
Dx = self.D(x)
Ex = self.E(x)
# Handling for Equivariant and Stable PE using LapPE
# ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
pe_LapPE = batch.pe_EquivStableLapPE if self.EquivStablePE else None
x, e = self.propagate(edge_index,
Bx=Bx, Dx=Dx, Ex=Ex, Ce=Ce,
e=e, Ax=Ax,
PE=pe_LapPE)
x = self.bn_node_x(x)
e = self.bn_edge_e(e)
x = self.act_fn_x(x)
e = self.act_fn_e(e)
x = F.dropout(x, self.dropout, training=self.training)
e = F.dropout(e, self.dropout, training=self.training)
if self.residual:
x = x_in + x
e = e_in + e
batch.x = x
batch.edge_attr = e
return batch
[docs]
def message(self, Dx_i, Ex_j, PE_i, PE_j, Ce):
"""
{}x_i : [n_edges, out_dim]
{}x_j : [n_edges, out_dim]
{}e : [n_edges, out_dim]
"""
e_ij = Dx_i + Ex_j + Ce
sigma_ij = torch.sigmoid(e_ij)
# Handling for Equivariant and Stable PE using LapPE
# ICLR 2022 https://openreview.net/pdf?id=e95i1IHcWj
if self.EquivStablePE:
r_ij = ((PE_i - PE_j) ** 2).sum(dim=-1, keepdim=True)
r_ij = self.mlp_r_ij(r_ij) # the MLP is 1 dim --> hidden_dim --> 1 dim
sigma_ij = sigma_ij * r_ij
self.e = e_ij
return sigma_ij
[docs]
def aggregate(self, sigma_ij, index, Bx_j, Bx):
"""
sigma_ij : [n_edges, out_dim] ; is the output from message() function
index : [n_edges]
{}x_j : [n_edges, out_dim]
"""
dim_size = Bx.shape[0] # or None ?? <--- Double check this
sum_sigma_x = sigma_ij * Bx_j
numerator_eta_xj = scatter(sum_sigma_x, index, 0, None, dim_size,
reduce='sum')
sum_sigma = sigma_ij
denominator_eta_xj = scatter(sum_sigma, index, 0, None, dim_size,
reduce='sum')
out = numerator_eta_xj / (denominator_eta_xj + 1e-6)
return out
[docs]
def update(self, aggr_out, Ax):
"""
aggr_out : [n_nodes, out_dim] ; is the output from aggregate() function after the aggregation
{}x : [n_nodes, out_dim]
"""
x = Ax + aggr_out
e_out = self.e
del self.e
return x, e_out
[docs]
@register_layer('gatedgcnconv')
class GatedGCNGraphGymLayer(nn.Module):
"""GatedGCN layer.
Residual Gated Graph ConvNets
https://arxiv.org/pdf/1711.07553.pdf
Parameters:
in_dim (int): Number of input features. Handled by GraphGym.
out_dim (int): Number of output features. Handled by GraphGym.
"""
def __init__(self, layer_config: LayerConfig, **kwargs):
super().__init__()
self.model = GatedGCNLayer(in_dim=layer_config.dim_in,
out_dim=layer_config.dim_out,
dropout=0., # Dropout is handled by GraphGym's `GeneralLayer` wrapper
residual=False, # Residual connections are handled by GraphGym's `GNNStackStage` wrapper
act=layer_config.act,
**kwargs)
def forward(self, batch):
return self.model(batch)