EdgeGatedGraphConv

EdgeGatedGraphConv#

class nfflr.nn.EdgeGatedGraphConv(input_features: int, output_features: int, residual: bool = True, norm: Literal['layernorm', 'batchnorm', 'instancenorm'] = 'layernorm', skip_edgenorm: bool = False)[source]#
1Xavier Bresson and Thomas Laurent. Residual gated graph convnets. 2018. arXiv:1711.07553.2Vijay Prakash Dwivedi, Chaitanya K. Joshi, Anh Tuan Luu, Thomas Laurent, Yoshua Bengio, and Xavier Bresson. Benchmarking graph neural networks. Journal of Machine Learning Research, 24(43):1–48, 2023. URL: http://jmlr.org/papers/v24/22-0567.html.3Tian Xie and Jeffrey C. Grossman. Crystal graph convolutional neural networks for an accurate and interpretable prediction of material properties. Physical Review Letters, April 2018. doi:10.1103/physrevlett.120.145301.

Edge gated graph convolution

See Bresson and Laurent 1Xavier Bresson and Thomas Laurent. Residual gated graph convnets. 2018. arXiv:1711.07553. for reference, and refer to Dwivedi et al. 2Vijay Prakash Dwivedi, Chaitanya K. Joshi, Anh Tuan Luu, Thomas Laurent, Yoshua Bengio, and Xavier Bresson. Benchmarking graph neural networks. Journal of Machine Learning Research, 24(43):1–48, 2023. URL: http://jmlr.org/papers/v24/22-0567.html. for detailed discussion.

\[x_i^{l+1} = SiLU ( U x_i^l + \sum_{j \in \mathcal{N}(i)} \eta_{ij} ⊙ V x_j^l)\]

This is similar to the interaction from CGCNN 3Tian Xie and Jeffrey C. Grossman. Crystal graph convolutional neural networks for an accurate and interpretable prediction of material properties. Physical Review Letters, April 2018. doi:10.1103/physrevlett.120.145301., but edge features only go into the soft attention / edge gating function, and the primary node update function is W cat(u, v) + b

Parameters:
input_featuresint
output_featuresint
residualbool, default=True

add skip connection for both node and edge features

norm{“layernorm”, “batchnorm”, “instancenorm”}, optional
skip_edgenormbool default=False

omit normalization of edge features

Methods

forward(g, node_feats, edge_feats)

Edge-gated graph convolution.

forward(g: DGLGraph, node_feats: Tensor, edge_feats: Tensor) tuple[Tensor, Tensor][source]#

Edge-gated graph convolution.

Parameters:
gdgl.DGLGraph

input graph

node_featstorch.Tensor

input node features

edge_featstorch.Tensor

input edge features

Returns:
node_featurestorch.Tensor
edge_featurestorch.Tensor