Source code for nfflr.nn.layers.conv

from typing import Literal

import torch
from torch.nn import functional as F

import dgl
import dgl.function as fn

# import nfflr
from nfflr.nn import Norm


[docs] class EdgeGatedGraphConv(torch.nn.Module): """Edge gated graph convolution See `Bresson and Laurent <https://arxiv.org/abs/1711.07553>`_ :footcite:p:`bresson2018residual` for reference, and refer to `Dwivedi et al. <https://www.jmlr.org/papers/v24/22-0567.html>`_ :footcite:p:`dwivedi2022` for detailed discussion. .. math :: x_i^{l+1} = SiLU ( U x_i^l + \sum_{j \in \mathcal{N}(i)} \eta_{ij} ⊙ V x_j^l) This is similar to the interaction from `CGCNN <https://dx.doi.org/10.1103/physrevlett.120.145301>`_ :footcite:p:`cgcnn`, but edge features only go into the soft attention / edge gating function, and the primary node update function is W cat(u, v) + b .. footbibliography:: Parameters ---------- input_features : int output_features : int residual : bool, default=True add skip connection for both node and edge features norm : {"layernorm", "batchnorm", "instancenorm"}, optional skip_edgenorm : bool default=False omit normalization of edge features """ def __init__( self, input_features: int, output_features: int, residual: bool = True, norm: Literal["layernorm", "batchnorm", "instancenorm"] = "layernorm", skip_edgenorm: bool = False, ): """Initialize parameters for ALIGNN update.""" super().__init__() self.skip_edgenorm = skip_edgenorm self.residual = residual # CGCNN-Conv operates on augmented edge features # z_ij = cat(v_i, v_j, u_ij) # m_ij = σ(z_ij W_f + b_f) ⊙ g_s(z_ij W_s + b_s) # coalesce parameters for W_f and W_s # but -- split them up along feature dimension self.src_gate = torch.nn.Linear(input_features, output_features) self.dst_gate = torch.nn.Linear(input_features, output_features) self.edge_gate = torch.nn.Linear(input_features, output_features) if self.skip_edgenorm: self.norm_edges = torch.nn.Identity() else: self.norm_edges = Norm(output_features, norm, mode="edge") self.src_update = torch.nn.Linear(input_features, output_features) self.dst_update = torch.nn.Linear(input_features, output_features) self.norm_nodes = Norm(output_features, norm)
[docs] def forward( self, g: dgl.DGLGraph, node_feats: torch.Tensor, edge_feats: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """Edge-gated graph convolution. Parameters ---------- g : dgl.DGLGraph input graph node_feats : torch.Tensor input node features edge_feats : torch.Tensor input edge features Returns ------- node_features : torch.Tensor edge_features : torch.Tensor """ g = g.local_var() # instead of concatenating (u || v || e) and applying one weight matrix # split the weight matrix into three, apply, then sum # see https://docs.dgl.ai/guide/message-efficient.html # but split them on feature dimensions to update u, v, e separately # m = BatchNorm(Linear(cat(u, v, e))) # compute edge updates, equivalent to: # Softplus(Linear(u || v || e)) g.ndata["e_src"] = self.src_gate(node_feats) g.ndata["e_dst"] = self.dst_gate(node_feats) g.apply_edges(fn.u_add_v("e_src", "e_dst", "e_nodes")) m = g.edata.pop("e_nodes") + self.edge_gate(edge_feats) # if edge attributes have a cutoff function value # multiply the edge gate values with the cutoff value cutoff_value = g.edata.get("cutoff_value") if cutoff_value is not None: g.edata["sigma"] = torch.sigmoid(m) * cutoff_value.unsqueeze(1) else: g.edata["sigma"] = torch.sigmoid(m) g.ndata["Bh"] = self.dst_update(node_feats) g.update_all(fn.u_mul_e("Bh", "sigma", "m"), fn.sum("m", "sum_sigma_h")) g.update_all(fn.copy_e("sigma", "m"), fn.sum("m", "sum_sigma")) g.ndata["h"] = g.ndata["sum_sigma_h"] / (g.ndata["sum_sigma"] + 1e-6) x = self.src_update(node_feats) + g.ndata.pop("h") # softmax version seems to perform slightly worse # that the sigmoid-gated version # compute node updates # Linear(u) + edge_gates ⊙ Linear(v) # g.edata["gate"] = edge_softmax(g, y) # g.ndata["h_dst"] = self.dst_update(node_feats) # g.update_all(fn.u_mul_e("h_dst", "gate", "m"), fn.sum("m", "h")) # x = self.src_update(node_feats) + g.ndata.pop("h") # node and edge updates x = F.silu(self.norm_nodes(x)) y = F.silu(self.norm_edges(m)) if self.residual: x = node_feats + x y = edge_feats + y return x, y