AtomsDataset#
- class nfflr.AtomsDataset(df: str | Path | DataFrame, target: str = 'formation_energy_peratom', transform: Callable | None = None, custom_collate_fn: Callable | None = None, custom_prepare_batch_fn: Callable | None = None, train_val_seed: int = 42, id_tag: str = 'jid', group_ids: bool = False, group_split_token: str = '_', n_train: float | int = 0.8, n_val: float | int = 0.1, energy_units: Literal['eV', 'eV/atom'] = 'eV', standardize: bool = False, diskcache: Path | str | bool | None = None)[source]#
Dataset of Atoms.
Methods
collate_default
(samples)Dataloader helper to batch graphs cross samples.
collate_forcefield
(samples)Dataloader helper to batch graphs cross samples.
prepare_batch_default
(batch[, device, ...])Send batched dgl crystal graph to device.
Get train/val/test split indices for SubsetRandomSampler.
split_dataset_by_id
(n_train, n_val)Get train/val/test split indices for SubsetRandomSampler.
estimate_reference_energies
get_energy_and_forces
setup_target_standardization
- split_dataset_by_id(n_train: float | int, n_val: float | int)[source]#
Get train/val/test split indices for SubsetRandomSampler.
Stratify by calculation / trajectory id “group_id”
- static prepare_batch_default(batch: Tuple[Any, Dict[str, Tensor]], device=None, non_blocking=False) Tuple[Any, Dict[str, Tensor]] [source]#
Send batched dgl crystal graph to device.
- static collate_default(samples: List[Tuple[DGLGraph, Tensor]])[source]#
Dataloader helper to batch graphs cross samples.
Forces get collated into a graph batch by concatenating along the atoms dimension
energy and stress are global targets (properties of the whole graph) total energy is a scalar, stess is a rank 2 tensor
- static collate_forcefield(samples: List[Tuple[DGLGraph, Tensor]])[source]#
Dataloader helper to batch graphs cross samples.
Forces get collated into a graph batch by concatenating along the atoms dimension
energy and stress are global targets (properties of the whole graph) total energy is a scalar, stess is a rank 2 tensor