AtomsDataset#

class nfflr.AtomsDataset(df: str | Path | DataFrame, target: str = 'formation_energy_peratom', transform: Callable | None = None, custom_collate_fn: Callable | None = None, custom_prepare_batch_fn: Callable | None = None, train_val_seed: int = 42, id_tag: str = 'jid', group_ids: bool = False, group_split_token: str = '_', n_train: float | int = 0.8, n_val: float | int = 0.1, energy_units: Literal['eV', 'eV/atom'] = 'eV', standardize: bool = False, diskcache: Path | str | bool | None = None)[source]#

Dataset of Atoms.

Methods

collate_default(samples)

Dataloader helper to batch graphs cross samples.

collate_forcefield(samples)

Dataloader helper to batch graphs cross samples.

prepare_batch_default(batch[, device, ...])

Send batched dgl crystal graph to device.

split_dataset()

Get train/val/test split indices for SubsetRandomSampler.

split_dataset_by_id(n_train, n_val)

Get train/val/test split indices for SubsetRandomSampler.

estimate_reference_energies

get_energy_and_forces

setup_target_standardization

split_dataset_by_id(n_train: float | int, n_val: float | int)[source]#

Get train/val/test split indices for SubsetRandomSampler.

Stratify by calculation / trajectory id “group_id”

split_dataset()[source]#

Get train/val/test split indices for SubsetRandomSampler.

static prepare_batch_default(batch: Tuple[Any, Dict[str, Tensor]], device=None, non_blocking=False) Tuple[Any, Dict[str, Tensor]][source]#

Send batched dgl crystal graph to device.

static collate_default(samples: List[Tuple[DGLGraph, Tensor]])[source]#

Dataloader helper to batch graphs cross samples.

Forces get collated into a graph batch by concatenating along the atoms dimension

energy and stress are global targets (properties of the whole graph) total energy is a scalar, stess is a rank 2 tensor

static collate_forcefield(samples: List[Tuple[DGLGraph, Tensor]])[source]#

Dataloader helper to batch graphs cross samples.

Forces get collated into a graph batch by concatenating along the atoms dimension

energy and stress are global targets (properties of the whole graph) total energy is a scalar, stess is a rank 2 tensor