Source code for nfflr.train.trainer

from __future__ import annotations

__all__ = ()
import os
from pathlib import Path
from datetime import timedelta
from typing import TYPE_CHECKING

import torch
from torch.utils.data import SubsetRandomSampler

import typer
import ignite
import ignite.distributed as idist
from ignite.utils import manual_seed
from ignite.metrics import Loss, MeanAbsoluteError
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from ignite.handlers import Checkpoint, DiskSaver, FastaiLRFinder, TerminateOnNan
from ignite.engine import Events, create_supervised_trainer

import ray
import ray.tune
from ray.air import session

from py_config_runner import ConfigObject

import nfflr


if TYPE_CHECKING:
    from nfflr.train.config import TrainingConfig

from nfflr.train.utils import (
    group_decay,
    select_target,
    setup_evaluator_with_grad,
    setup_optimizer,
    setup_scheduler,
    transfer_outputs,
)
from nfflr.train.swag import SWAGHandler
from nfflr.models.utils import reset_initial_output_bias

cli = typer.Typer()

# set up multi-GPU training, if available
gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
num_gpus = len(gpus.split(","))

backend = None
nproc_per_node = None
if num_gpus > 1:
    backend == "nccl" if torch.distributed.is_nccl_available() else "gloo"
    nproc_per_node = num_gpus

spawn_kwargs = {
    "backend": backend,
    "nproc_per_node": nproc_per_node,
    "timeout": timedelta(seconds=60),
}


def log_console(engine: ignite.engine.Engine, name: str):
    """Log evaluation stats to console."""
    epoch = engine.state.training_epoch  # custom state field
    m = engine.state.metrics
    loss = m["loss"]

    print(f"{name} results - Epoch: {epoch}  Avg loss: {loss:.2f}")
    if "mae_forces" in m.keys():
        print(f"energy: {m['mae_energy']:.2f}  force: {m['mae_forces']:.4f}")


[docs] def get_dataflow(dataset: nfflr.AtomsDataset, config: TrainingConfig): """Configure training and validation datasets. Wraps `train` and `val` splits of `dataset` in ignite's :py:func:`auto_dataloader <ignite.distributed.auto.auto_dataloader>`. Parameters ---------- dataset : nfflr.AtomsDataset config : nfflr.train.TrainingConfig """ train_loader = idist.auto_dataloader( dataset, collate_fn=dataset.collate, batch_size=config.batch_size, sampler=SubsetRandomSampler(dataset.split["train"]), drop_last=True, num_workers=config.dataloader_workers, pin_memory=config.pin_memory, ) val_loader = idist.auto_dataloader( dataset, collate_fn=dataset.collate, batch_size=config.batch_size, sampler=SubsetRandomSampler(dataset.split["val"]), drop_last=False, # True -> possible issue crashing with MP dataset num_workers=config.dataloader_workers, pin_memory=config.pin_memory, ) return train_loader, val_loader
[docs] def setup_model_and_optimizer( model: torch.nn.Module, dataset: nfflr.AtomsDataset, config: TrainingConfig, ): """Initialize model, criterion, and optimizer.""" model = idist.auto_model(model) criterion = idist.auto_model(config.criterion) if isinstance(criterion, nfflr.nn.MultitaskLoss): # auto_model won't transfer buffers...? criterion = criterion.to(idist.device()) params = group_decay(model) if isinstance(criterion, torch.nn.Module) and len(list(criterion.parameters())) > 0: params.append({"params": criterion.parameters(), "weight_decay": 0}) optimizer = setup_optimizer(params, config) optimizer = idist.auto_optim(optimizer) if config.initialize_estimated_reference_energies: model.reset_atomic_reference_energies(dataset.estimate_reference_energies()) if config.initialize_bias: train_loader, _ = get_dataflow(dataset, config) reset_initial_output_bias( model, train_loader, max_samples=500 / config.batch_size ) return model, criterion, optimizer
[docs] def setup_trainer( model, criterion, optimizer, scheduler, prepare_batch, config: TrainingConfig ): """Create ignite trainer and attach common event handlers.""" device = idist.device() trainer = create_supervised_trainer( model, optimizer, criterion, gradient_accumulation_steps=config.gradient_accumulation_steps, prepare_batch=prepare_batch, device=device, ) trainer.add_event_handler(Events.EPOCH_COMPLETED, TerminateOnNan()) if scheduler is not None: trainer.add_event_handler( Events.ITERATION_COMPLETED, lambda engine: scheduler.step() ) return trainer
[docs] def setup_checkpointing(state: dict, config: TrainingConfig): """Configure model and trainer checkpointing. `state` should contain at least `model`, `optimizer`, and `trainer`. """ checkpoint_handler = Checkpoint( state, DiskSaver(config.output_dir, create_dir=True, require_empty=False), n_saved=1, global_step_transform=lambda *_: state["trainer"].state.epoch, ) state["trainer"].add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler) if config.resume_checkpoint is not None: checkpoint = torch.load(config.resume_checkpoint, map_location=idist.device()) Checkpoint.load_objects(to_load=state, checkpoint=checkpoint) return state
[docs] def setup_evaluators(model, prepare_batch, metrics, transfer_outputs): """Configure train and validation evaluators.""" device = idist.device() # create_supervised_evaluator train_evaluator = setup_evaluator_with_grad( model, metrics=metrics, prepare_batch=prepare_batch, device=device, output_transform=transfer_outputs, ) val_evaluator = setup_evaluator_with_grad( model, metrics=metrics, prepare_batch=prepare_batch, device=device, output_transform=transfer_outputs, ) return train_evaluator, val_evaluator
[docs] def train( model: torch.nn.Module, dataset: nfflr.AtomsDataset, config: nfflr.train.TrainingConfig, local_rank: int = 0, ): """NFFLr trainer entry point. Parameters ---------- model : torch.nn.Module dataset : nfflr.AtomsDataset config : nfflr.train.TrainingConfig local_rank : int, optional """ rank = idist.get_rank() manual_seed(config.random_seed + local_rank) train_loader, val_loader = get_dataflow(dataset, config) model, criterion, optimizer = setup_model_and_optimizer(model, dataset, config) scheduler = setup_scheduler(config, optimizer, len(train_loader)) if config.swag: swag_handler = SWAGHandler(model) trainer = setup_trainer( model, criterion, optimizer, scheduler, dataset.prepare_batch, config ) if config.progress and rank == 0: pbar = ProgressBar() pbar.attach(trainer, output_transform=lambda x: {"loss": x}) if config.checkpoint: state = dict(model=model, optimizer=optimizer, trainer=trainer) if scheduler is not None: state["scheduler"] = scheduler if isinstance(criterion, torch.nn.Module): state["criterion"] = criterion if config.swag: state["swagmodel"] = swag_handler.swagmodel setup_checkpointing(state, config) if config.swag: swag_handler.attach(trainer) # evaluation metrics = {"loss": Loss(criterion)} if isinstance(criterion, nfflr.nn.MultitaskLoss): # NOTE: unscaling currently uses a global scale # shared across all tasks (intended to scale energy units) unscale = None if dataset.standardize: unscale = dataset.scaler.unscale eval_metrics = { f"mae_{task}": MeanAbsoluteError(select_target(task, unscale_fn=unscale)) for task in criterion.tasks } metrics.update(eval_metrics) train_evaluator, val_evaluator = setup_evaluators( model, dataset.prepare_batch, metrics, transfer_outputs ) if config.progress and rank == 0: vpbar = ProgressBar() vpbar.attach(train_evaluator) vpbar.attach(val_evaluator) history = { "train": {m: [] for m in metrics.keys()}, "validation": {m: [] for m in metrics.keys()}, } @trainer.on(Events.EPOCH_COMPLETED) def _eval(engine): n_train_eval = int(config.train_eval_fraction * len(train_loader)) n_train_eval = max(n_train_eval, 1) # at least one batch train_evaluator.state.training_epoch = engine.state.epoch val_evaluator.state.training_epoch = engine.state.epoch train_evaluator.run(train_loader, epoch_length=n_train_eval, max_epochs=1) val_evaluator.run(val_loader) def log_metric_history(engine, output_dir: Path): phases = {"train": train_evaluator, "validation": val_evaluator} for phase, evaluator in phases.items(): for key, value in evaluator.state.metrics.items(): history[phase][key].append(value) torch.save(history, output_dir / "metric_history.pkl") if rank == 0: train_evaluator.add_event_handler(Events.COMPLETED, log_console, name="train") val_evaluator.add_event_handler(Events.COMPLETED, log_console, name="val") val_evaluator.add_event_handler( Events.COMPLETED, log_metric_history, config.output_dir ) if ray.tune.is_session_enabled(): val_evaluator.add_event_handler( Events.COMPLETED, lambda engine: session.report(engine.state.metrics) ) trainer.run(train_loader, max_epochs=config.epochs) return val_evaluator.state.metrics["loss"]
[docs] def lr( model: torch.nn.Module, dataset: nfflr.AtomsDataset, config: nfflr.train.TrainingConfig, local_rank: int = 0, ): """NFFLr learning rate finder entry point. Runs the Fast.ai learning rate finder :py:class:`ignite.handlers.lr_finder.FastaiLRFinder` for `model`, `dataset`, and `config`. Parameters ---------- model : torch.nn.Module dataset : nfflr.AtomsDataset config : nfflr.train.TrainingConfig local_rank : int, optional """ rank = idist.get_rank() manual_seed(config.random_seed + local_rank) config.checkpoint = False scheduler = None train_loader, val_loader = get_dataflow(dataset, config) model, criterion, optimizer = setup_model_and_optimizer(model, dataset, config) trainer = setup_trainer( model, criterion, optimizer, scheduler, dataset.prepare_batch, config ) if config.progress and rank == 0: pbar = ProgressBar() pbar.attach(trainer, output_transform=lambda x: {"loss": x}) lr_finder = FastaiLRFinder() to_save = {"model": model, "optimizer": optimizer} with lr_finder.attach( trainer, to_save, start_lr=1e-6, end_lr=0.1, num_iter=400, diverge_th=1e9 ) as finder: finder.run(train_loader) if rank == 0: # print("Suggested LR", lr_finder.lr_suggestion()) ax = lr_finder.plot(display_suggestion=False) ax.loglog() ax.set_ylim(None, 5.0) ax.figure.savefig("lr.png")
@cli.command("train") def cli_train(config_path: Path, verbose: bool = False): """NFF training entry point.""" with idist.Parallel(**spawn_kwargs) as parallel: config = ConfigObject(config_path) if verbose: print(config) # wrap train entry point for idist.Parallel def train_wrapper(local_rank, model, dataset, args): return train(model, dataset, args, local_rank=local_rank) parallel.run(train_wrapper, config.model, config.dataset, config.args) @cli.command("lr") def cli_lr(config_path: Path, verbose: bool = False): """NFF Learning rate finder entry point.""" with idist.Parallel(**spawn_kwargs) as parallel: if verbose: print(spawn_kwargs) print("loading config") config = ConfigObject(config_path) if verbose: print(config) # wrap lr entry point for idist.Parallel def lr_wrapper(local_rank, model, dataset, args): return lr(model, dataset, args, local_rank=local_rank) parallel.run(lr_wrapper, config.model, config.dataset, config.args) if __name__ == "__main__": cli()