"""Atomistic LIne Graph Neural Network.
A prototype crystal line graph network dgl implementation.
"""
from plum import dispatch
import logging
from dataclasses import dataclass
from typing import Tuple, Union, Optional, Literal, Callable
import dgl
import torch
from torch import nn
from dgl.nn import AvgPooling, SumPooling
import nfflr
from nfflr.nn import (
RBFExpansion,
MLPLayer,
ALIGNNConv,
EdgeGatedGraphConv,
AttributeEmbedding,
PeriodicRadiusGraph,
XPLOR,
)
from nfflr.data.graph import compute_bond_cosines
[docs]
@dataclass
class ALIGNNConfig:
"""Hyperparameter schema for nfflr.models.gnn.alignn."""
transform: Callable = PeriodicRadiusGraph(cutoff=8.0)
# cutoff: Optional[tuple[float]] = (7.5, 8.0)
cutoff: torch.nn.Module = XPLOR(7.5, 8.0)
alignn_layers: int = 4
gcn_layers: int = 4
norm: Literal["batchnorm", "layernorm"] = "batchnorm"
atom_features: str = "cgcnn"
edge_input_features: int = 80
triplet_input_features: int = 40
embedding_features: int = 64
hidden_features: int = 256
output_features: int = 1
compute_forces: bool = False
energy_units: Literal["eV", "eV/atom"] = "eV/atom"
reference_energies: Optional[Literal["fixed", "trainable"]] = None
[docs]
class ALIGNN(torch.nn.Module):
"""Atomistic Line graph network.
Chain alternating gated graph convolution updates on crystal graph
and atomistic line graph.
"""
def __init__(self, config: ALIGNNConfig = ALIGNNConfig()):
"""Initialize class with number of input features, conv layers."""
super().__init__()
self.config = config
self.transform = config.transform
logging.debug(f"{config=}")
if config.atom_features == "embedding":
self.atom_embedding = torch.nn.Embedding(108, config.hidden_features)
else:
self.atom_embedding = AttributeEmbedding(
config.atom_features, d_model=config.hidden_features
)
if config.reference_energies is not None:
self.reference_energy = nfflr.nn.AtomicReferenceEnergy(
requires_grad=config.reference_energies == "trainable"
)
self.edge_embedding = torch.nn.Sequential(
RBFExpansion(
vmin=0,
vmax=8.0,
bins=config.edge_input_features,
),
MLPLayer(
config.edge_input_features, config.embedding_features, norm=config.norm
),
MLPLayer(
config.embedding_features, config.hidden_features, norm=config.norm
),
)
self.angle_embedding = torch.nn.Sequential(
RBFExpansion(
vmin=-1,
vmax=1.0,
bins=config.triplet_input_features,
),
MLPLayer(
config.triplet_input_features,
config.embedding_features,
norm=config.norm,
),
MLPLayer(
config.embedding_features, config.hidden_features, norm=config.norm
),
)
self.alignn_layers = torch.nn.ModuleList(
[
ALIGNNConv(
config.hidden_features, config.hidden_features, norm=config.norm
)
for idx in range(config.alignn_layers)
]
)
self.gcn_layers = torch.nn.ModuleList(
[
EdgeGatedGraphConv(
config.hidden_features, config.hidden_features, norm=config.norm
)
for idx in range(config.gcn_layers)
]
)
if config.energy_units == "eV/atom":
self.readout = AvgPooling()
else:
self.readout = SumPooling()
self.fc = nn.Linear(config.hidden_features, config.output_features)
self.reset_atomic_reference_energies()
def reset_atomic_reference_energies(self, values: Optional[torch.Tensor] = None):
if hasattr(self, "reference_energy"):
self.reference_energy.reset_parameters(values=values)
@dispatch
def forward(self, x):
print("convert")
return self.forward(nfflr.Atoms(x))
@dispatch
def forward(self, x: nfflr.Atoms):
device = next(self.parameters()).device
return self.forward(self.transform(x).to(device))
[docs]
@dispatch
def forward(self, g: Union[Tuple[dgl.DGLGraph, dgl.DGLGraph], dgl.DGLGraph]):
"""ALIGNN : start with `atom_features`.
x: atom features (g.ndata)
y: bond features (g.edata and lg.ndata)
z: angle features (lg.edata)
"""
config = self.config
if isinstance(g, dgl.DGLGraph):
lg = None
else:
g, lg = g
g = g.local_var()
# to compute forces, take gradient wrt g.edata["r"]
# need to add bond vectors to autograd graph
if config.compute_forces:
g.edata["r"].requires_grad_(True)
# initial node features: atom feature network...
atomic_number = g.ndata.pop("atomic_number").int()
x = self.atom_embedding(atomic_number)
# initial bond features
bondlength = torch.norm(g.edata["r"], dim=1)
y = self.edge_embedding(bondlength)
if config.cutoff is not None:
# save cutoff function value for application in EdgeGatedGraphconv
g.edata["cutoff_value"] = self.config.cutoff(bondlength)
# initial triplet features
if len(self.alignn_layers) > 0:
if lg is None:
lg = g.line_graph(shared=True)
lg.apply_edges(compute_bond_cosines)
z = self.angle_embedding(lg.edata.pop("h"))
# ALIGNN updates: update node, edge, triplet features
for alignn_layer in self.alignn_layers:
x, y, z = alignn_layer(g, lg, x, y, z)
# gated GCN updates: update node, edge features
for gcn_layer in self.gcn_layers:
x, y = gcn_layer(g, x, y)
# norm-activation-pool-classify
h = self.readout(g, x)
atomwise_output = self.fc(h)
if hasattr(self, "reference_energy"):
atomwise_output += self.reference_energy(atomic_number)
output = torch.squeeze(atomwise_output)
if config.compute_forces:
forces, stress = nfflr.autograd_forces(
output,
g.edata["r"],
g,
energy_units=config.energy_units,
compute_stress=True,
)
return dict(energy=output, forces=forces, stress=stress)
return output