ml.trainers.base
Defines the base trainer class and config.
The trainer is the thing that actually runs the training loop. There are
separate trainers for supervised and reinforcement learning since the latter
requires interacting with an environment, so you use the appropriate trainer
for your task (they are defined in ml.trainers.sl
and
ml.trainers.rl
respectively). The base trainer handles things like
setting up the experiment directory, saving checkpoints, and logging.
- ml.trainers.base.add_lock_file(exp_dir: Path, lock_type: Literal['running', 'scheduled', 'ckpt'], *, exists_ok: bool = False) None [source]
- ml.trainers.base.remove_lock_file(exp_dir: Path, lock_type: Literal['running', 'scheduled', 'ckpt'], *, missing_ok: bool = False) None [source]
- ml.trainers.base.has_lock_file(exp_dir: Path, lock_type: Literal['running', 'scheduled', 'ckpt'] | None = None) bool [source]
- ml.trainers.base.get_ckpt_path(exp_dir: Path, state: State | None = None) Path [source]
Defines the path to the checkpoint for a given state.
- Parameters:
exp_dir – The experiment directory
state – The current trainer state
- Returns:
The path to the PyTorch checkpoint to save or load
- ml.trainers.base.get_empty_exp_dir(run_dir: Path) Path [source]
Returns the path to the run directory, given a run ID.
- Parameters:
run_dir – The base run directory for the experiment
- Returns:
An experiment directory without a lockfile
- ml.trainers.base.diff_configs(first: ListConfig | DictConfig, second: ListConfig | DictConfig, prefix: str | None = None) tuple[list[str], list[str]] [source]
Returns the difference between two configs.
- Parameters:
first – The first (original) config
second – The second (new) config
prefix – The prefix to check (used for recursion, not main call)
- Returns:
Two lists of lines describing the diff between the two configs
- class ml.trainers.base.CheckpointConfig(save_every_n_steps: int | None = None, save_every_n_seconds: float | None = 3600.0, only_save_most_recent: bool = True, load_from_ckpt_path: str | None = None, load_strict: bool = True)[source]
Bases:
object
- save_every_n_steps: int | None = None
- save_every_n_seconds: float | None = 3600.0
- only_save_most_recent: bool = True
- load_from_ckpt_path: str | None = None
- load_strict: bool = True
- class ml.trainers.base.BaseTrainerConfig(name: str = '???', exp_name: str = '${ml.exp_name:null}', exp_dir: str = '???', log_dir_name: str = 'logs', use_double_weight_precision: bool = False, checkpoint: ~ml.trainers.base.CheckpointConfig = <factory>)[source]
Bases:
BaseConfig
Defines the base config for all trainers.
- exp_name: str = '${ml.exp_name:null}'
- exp_dir: str = '???'
- log_dir_name: str = 'logs'
- use_double_weight_precision: bool = False
- checkpoint: CheckpointConfig
- classmethod resolve(config: BaseTrainerConfig) None [source]
Runs post-construction config resolution.
- Parameters:
config – The config to resolve
- class ml.trainers.base.BaseTrainer(config: TrainerConfigT)[source]
Bases:
BaseObject
[TrainerConfigT
],Generic
[TrainerConfigT
,ModelT
,TaskT
]Defines the base trainer type.
- loggers: list[ml.loggers.base.BaseLogger]
- logger: MultiLogger
- add_logger(sublogger: BaseLogger) None [source]
- add_loggers(subloggers: list[ml.loggers.base.BaseLogger]) None [source]
- property config_path: Path
- add_lock_file(lock_type: Literal['running', 'scheduled', 'ckpt'], *, exists_ok: bool = False) None [source]
- remove_lock_file(lock_type: Literal['running', 'scheduled', 'ckpt'], *, missing_ok: bool = False) None [source]
- property ckpt_path: Path
- load_checkpoint(ckpt: str | Path, task: TaskT, model: ModelT, optims: Optimizer | dict[str, torch.optim.optimizer.Optimizer] | None = None, lr_scheds: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter] | None = None, *, device: str | device | None = None, strict: bool | None = None) State [source]
Loads a given checkpoint, from a path or dictionary.
- Parameters:
ckpt – The checkpoint to load.
task – The task to load the checkpoint into.
model – The model to load the checkpoint into.
optims – The optimizer to load the checkpoint into.
lr_scheds – The learning rate scheduler to load the checkpoint into.
weights_only – If set, only load the model weights.
device – The device to load the checkpoint onto.
strict – If set, raise an error if the checkpoint contains keys which don’t exist in the model.
- Returns:
The state loaded from the checkpoint.
- Raises:
UnpicklingError – If there is some issue unpickling the checkpoint.
- save_checkpoint(state: State, task: TaskT, model: ModelT, optims: Optimizer | dict[str, torch.optim.optimizer.Optimizer] | None = None, lr_scheds: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter] | None = None) Path [source]
- train(model: ModelT, task: TaskT, optimizer: BaseOptimizer, lr_scheduler: BaseLRScheduler) None [source]
Runs the training loop.
- Parameters:
model – The current model
task – The current task
optimizer – The current optimizer
lr_scheduler – The current learning rate scheduler
- Raises:
NotImplementedError – If the subclass does not implement this method
- load_state_dict(state_dict: dict) None [source]
Function for loading state dict keys for different components.
- Parameters:
state_dict – The state dictionary being saved (overriders should mutate inplace)
metadata – The metadata being saved (overriders should mutate inplace)
- update_state_dict(state_dict: dict) None [source]
Function for getting the checkpoint to save.
- Parameters:
state_dict – The state dictionary being saved (overriders should mutate inplace)
metadata – The metadata being saved (overriders should mutate inplace)
- on_exit(sig: Signals, state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_scheduler: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]
- on_step_start(state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]
- on_step_end(state: State, loss_dict: dict[str, torch.Tensor], task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]
- on_epoch_start(state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]
- on_epoch_end(state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]
- on_training_start(state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]
- on_training_end(state: State, task: TaskT, model: ModelT, optim: Optimizer | dict[str, torch.optim.optimizer.Optimizer], lr_sched: SchedulerAdapter | dict[str, ml.lr_schedulers.base.SchedulerAdapter]) None [source]