Source code for nfflr.nn.layers.basis
import torch
import numpy as np
from typing import Optional, Literal
[docs]
class ChebyshevExpansion(torch.nn.Module):
"""Expand features in (-1, 1) interval with Chebyshev basis."""
def __init__(self, basis_size: int):
super().__init__()
self.n = torch.arange(1, 1 + basis_size)
[docs]
def forward(self, x):
"""Trigonometric definition of Chebyshev polynomial basis for |x| \lq 1.
Tn(cos(theta)) = cos(n theta)
"""
return torch.cos(self.n * torch.acos(x).unsqueeze(1))
[docs]
class RBFExpansion(torch.nn.Module):
"""Expand interatomic distances with radial basis functions."""
def __init__(
self,
vmin: float = 0,
vmax: float = 8,
bins: int = 40,
lengthscale: Optional[float] = None,
):
"""Register torch parameters for RBF expansion."""
super().__init__()
self.vmin = vmin
self.vmax = vmax
self.bins = bins
self.register_buffer("centers", torch.linspace(self.vmin, self.vmax, self.bins))
if lengthscale is None:
# SchNet-style
# set lengthscales relative to granularity of RBF expansion
self.lengthscale = np.diff(self.centers).mean()
self.gamma = 1 / self.lengthscale
else:
self.lengthscale = lengthscale
self.gamma = 1 / (lengthscale**2)
[docs]
def forward(self, distance: torch.Tensor) -> torch.Tensor:
"""Apply RBF expansion to interatomic distance tensor."""
return torch.exp(-self.gamma * (distance.unsqueeze(1) - self.centers) ** 2)