Source code for nfflr.nn.layers.norm
import dgl
import torch
from torch import nn
from typing import Literal, Optional
[docs]
class Norm(nn.Module):
def __init__(
self,
num_features: int,
norm_type: Optional[
Literal["batchnorm", "layernorm", "instancenorm"]
] = "layernorm",
mode: Literal["node", "edge"] = "node",
):
super().__init__()
self.norm_type = norm_type
self.mode = mode
if norm_type == "batchnorm":
self.norm = nn.BatchNorm1d(num_features)
elif norm_type == "layernorm":
self.norm = nn.LayerNorm(num_features)
elif norm_type == "instancenorm":
self.norm = InstanceNorm(mode=mode)
self.forward = self._forward_graph_instancenorm
elif norm_type is None:
self.norm = nn.Identity()
def forward(self, x: torch.Tensor):
return self.norm(x)
def _forward_graph_instancenorm(self, g: dgl.DGLGraph, x: torch.Tensor):
return self.norm(g, x)
[docs]
class InstanceNorm(nn.Module):
def __init__(self, mode: Literal["node", "edge"] = "node", eps: float = 1e-6):
super().__init__()
self.mode = mode
self.eps = eps
if self.mode == "node":
self.readout = dgl.readout.mean_nodes
self.broadcast = dgl.broadcast_nodes
self.data = lambda g: g.ndata
elif self.mode == "edge":
self.readout = dgl.readout.mean_edges
self.broadcast = dgl.broadcast_edges
self.data = lambda g: g.edata
else:
raise NotImplementedError(f"InstanceNorm(mode='{mode}') not supported")
def forward(self, g: dgl.DGLGraph, x: torch.Tensor):
g = g.local_var()
# compute per-instance channel-wise mean
self.data(g)["_x"] = x
mu = self.broadcast(g, self.readout(g, "_x"))
# compute per-instance channel-wise variance
self.data(g)["_sqdev"] = (x - mu) ** 2
var = self.readout(g, "_sqdev")
std = torch.sqrt(var + self.eps)
# apply normalization
return (x - mu) / self.broadcast(g, std)