Source code for nfflr.nn.layers.common
import torch
from typing import Literal, Optional
from nfflr.nn.layers.norm import Norm
[docs]
class FeedForward(torch.nn.Module):
"""Two-layer feedforward network."""
def __init__(
self,
d_in: int,
d_hidden: Optional[int] = None,
d_out: Optional[int] = None,
norm: bool = False,
):
"""Doc for init"""
super().__init__()
if d_hidden is None:
d_hidden = 4 * d_in
if d_out is None:
d_out = d_in
self.project_hidden = torch.nn.Linear(d_in, d_hidden)
if norm:
self.norm = torch.nn.LayerNorm(d_hidden)
else:
self.norm = None
self.project_out = torch.nn.Linear(d_hidden, d_out)
self.reset_parameters()
[docs]
def forward(self, x: torch.Tensor):
"""Doc for forward."""
x = torch.nn.functional.silu(self.project_hidden(x))
if self.norm:
x = self.norm(x)
return self.project_out(x)
return self.layers(x)
def reset_parameters(self):
torch.nn.init.kaiming_normal_(self.project_hidden.weight, nonlinearity="relu")
torch.nn.init.zeros_(self.project_hidden.bias)
torch.nn.init.kaiming_normal_(self.project_out.weight, nonlinearity="relu")
torch.nn.init.zeros_(self.project_out.bias)
[docs]
class MLPLayer(torch.nn.Module):
"""Multilayer perceptron layer helper."""
def __init__(
self,
in_features: int,
out_features: int,
norm: Literal["layernorm", "batchnorm"] = "layernorm",
):
"""Linear, Batchnorm, SiLU layer."""
super().__init__()
self.layer = torch.nn.Sequential(
torch.nn.Linear(in_features, out_features),
Norm(out_features, norm_type=norm),
torch.nn.SiLU(),
)
[docs]
def forward(self, x):
"""Linear, norm, silu layer."""
return self.layer(x)