ml.trainers.base

Defines the base trainer class and config.

The trainer is the thing that actually runs the training loop. There are separate trainers for supervised and reinforcement learning since the latter requires interacting with an environment, so you use the appropriate trainer for your task (they are defined in ml.trainers.sl and ml.trainers.rl respectively). The base trainer handles things like setting up the experiment directory, saving checkpoints, and logging.

ml.trainers.base.abs_path(path: str) str[source]
ml.trainers.base.cpu_count(default: int) int[source]
ml.trainers.base.date_str(_: str) str[source]
ml.trainers.base.add_lock_file(exp_dir: Path, lock_type: Literal['running', 'scheduled', 'ckpt'], *, exists_ok: bool = False) None[source]
ml.trainers.base.remove_lock_file(exp_dir: Path, lock_type: Literal['running', 'scheduled', 'ckpt'], *, missing_ok: bool = False) None[source]
ml.trainers.base.has_lock_file(exp_dir: Path, lock_type: Literal['running', 'scheduled', 'ckpt'] | None = None) bool[source]
ml.trainers.base.get_ckpt_path(exp_dir: Path, state: State | None = None) Path[source]

Defines the path to the checkpoint for a given state.

Parameters:
  • exp_dir – The experiment directory

  • state – The current trainer state

Returns:

The path to the PyTorch checkpoint to save or load

ml.trainers.base.get_exp_dir(run_dir: Path, run_id: int) Path[source]
ml.trainers.base.get_empty_exp_dir(run_dir: Path) Path[source]

Returns the path to the run directory, given a run ID.

Parameters:

run_dir – The base run directory for the experiment

Returns:

An experiment directory without a lockfile

ml.trainers.base.diff_configs(first: ListConfig | DictConfig, second: ListConfig | DictConfig, prefix: str | None = None) tuple[list[str], list[str]][source]

Returns the difference between two configs.

Parameters:
  • first – The first (original) config

  • second – The second (new) config

  • prefix – The prefix to check (used for recursion, not main call)

Returns:

Two lists of lines describing the diff between the two configs

ml.trainers.base.save_config(config_path: Path, raw_config: DictConfig) None[source]
class ml.trainers.base.CheckpointConfig(save_every_n_steps: int | None = None, save_every_n_seconds: float | None = 3600.0, only_save_most_recent: bool = True, load_from_ckpt_path: str | None = None, load_strict: bool = True)[source]

Bases: object

save_every_n_steps: int | None = None
save_every_n_seconds: float | None = 3600.0
only_save_most_recent: bool = True
load_from_ckpt_path: str | None = None
load_strict: bool = True
class ml.trainers.base.BaseTrainerConfig(name: str = '???', exp_name: str = '${ml.exp_name:null}', exp_dir: str = '???', log_dir_name: str = 'logs', use_double_weight_precision: bool = False, checkpoint: ~ml.trainers.base.CheckpointConfig = <factory>)[source]

Bases: BaseConfig

Defines the base config for all trainers.

exp_name: str = '${ml.exp_name:null}'
exp_dir: str = '???'
log_dir_name: str = 'logs'
use_double_weight_precision: bool = False
checkpoint: CheckpointConfig
classmethod resolve(config: BaseTrainerConfig) None[source]

Runs post-construction config resolution.

Parameters:

config – The config to resolve

class ml.trainers.base.BaseTrainer(config: TrainerConfigT)[source]

Bases: BaseObject[TrainerConfigT], Generic[TrainerConfigT, ModelT, TaskT]

Defines the base trainer type.

loggers: list[ml.loggers.base.BaseLogger]
logger: MultiLogger
add_logger(sublogger: BaseLogger) None[source]
add_loggers(subloggers: list[ml.loggers.base.BaseLogger]) None[source]
property config_path: Path
save_config() None[source]
log_run_config() None[source]
add_lock_file(lock_type: Literal['running', 'scheduled', 'ckpt'], *, exists_ok: bool = False) None[source]
remove_lock_file(lock_type: Literal['running', 'scheduled', 'ckpt'], *, missing_ok: bool = False) None[source]
get_ckpt_path(state: State | None = None) Path[source]
property ckpt_path: Path
should_checkpoint(state: State) bool[source]
load_checkpoint(ckpt: str | Path, task: TaskT, model: ModelT, optims: Optimizer | dict[str, torch.optim.optimizer.Optimizer] | None = None, lr_scheds: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter] | None = None, *, device: str | device | None = None, strict: bool | None = None) State[source]

Loads a given checkpoint, from a path or dictionary.

Parameters:
  • ckpt – The checkpoint to load.

  • task – The task to load the checkpoint into.

  • model – The model to load the checkpoint into.

  • optims – The optimizer to load the checkpoint into.

  • lr_scheds – The learning rate scheduler to load the checkpoint into.

  • weights_only – If set, only load the model weights.

  • device – The device to load the checkpoint onto.

  • strict – If set, raise an error if the checkpoint contains keys which don’t exist in the model.

Returns:

The state loaded from the checkpoint.

Raises:

UnpicklingError – If there is some issue unpickling the checkpoint.

save_checkpoint(state: State, task: TaskT, model: ModelT, optims: Optimizer | dict[str, torch.optim.optimizer.Optimizer] | None = None, lr_scheds: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter] | None = None) Path[source]
train(model: ModelT, task: TaskT, optimizer: BaseOptimizer, lr_scheduler: BaseLRScheduler) None[source]

Runs the training loop.

Parameters:
  • model – The current model

  • task – The current task

  • optimizer – The current optimizer

  • lr_scheduler – The current learning rate scheduler

Raises:

NotImplementedError – If the subclass does not implement this method

write_logs(task: TaskT, model: ModelT, state: State) None[source]
load_state_dict(state_dict: dict) None[source]

Function for loading state dict keys for different components.

Parameters:
  • state_dict – The state dictionary being saved (overriders should mutate inplace)

  • metadata – The metadata being saved (overriders should mutate inplace)

update_state_dict(state_dict: dict) None[source]

Function for getting the checkpoint to save.

Parameters:
  • state_dict – The state dictionary being saved (overriders should mutate inplace)

  • metadata – The metadata being saved (overriders should mutate inplace)

on_exit(sig: Signals, state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_scheduler: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]
add_signal_handler(sig: Signals, handler: Callable[[], None]) None[source]
on_step_start(state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]
on_step_end(state: State, loss_dict: dict[str, torch.Tensor], task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]
on_epoch_start(state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]
on_epoch_end(state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]
on_training_start(state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]
on_training_end(state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None[source]