ml.trainers.learning

Defines a vanilla trainer which doesn’t do any device or data manipulation.

This trainer expects the task to handle all the relevant movement of data and models to their associated devices.

class ml.trainers.learning.BaseLearningTrainerConfig(name: str = '???', exp_name: str = '${ml.exp_name:null}', exp_dir: str = '???', log_dir_name: str = 'logs', use_double_weight_precision: bool = False, checkpoint: ml.trainers.base.CheckpointConfig = <factory>, parallel: ml.trainers.mixins.data_parallel.ParallelConfig = <factory>, compiler: ml.trainers.mixins.compile.TorchCompileConfig = <factory>, cpu_stats_ping_interval: int = 1, cpu_stats_only_log_once: bool = False, gpu_stats_ping_interval: int = 10, gpu_stats_only_log_once: bool = False, mixed_precision: ml.trainers.mixins.mixed_precision.MixedPrecisionConfig = <factory>, clip_grad_norm: float = 10.0, clip_grad_norm_type: Any = 2, balance_grad_norms: bool = False, profiler: ml.trainers.mixins.profiler.Profiler = <factory>, set_to_none: bool = True, deterministic: bool = False, use_tf32: bool = True, detect_anomaly: bool = False)[source]

Bases: ProfilerTrainerConfig, MixedPrecisionTrainerConfig, GPUStatsConfig, CPUStatsConfig, CompileConfig, TrainerParallelConfig, BaseTrainerConfig

set_to_none: bool = True
deterministic: bool = False
use_tf32: bool = True
detect_anomaly: bool = False
class ml.trainers.learning.BaseLearningTrainer(config: ProfilerTrainerConfigT)[source]

Bases: ProfilerTrainerMixin[BaseLearningTrainerConfigT, ModelT, TaskT], MixedPrecisionTrainerMixin[BaseLearningTrainerConfigT, ModelT, TaskT], GPUStatsMixin[BaseLearningTrainerConfigT, ModelT, TaskT], CPUStatsMixin[BaseLearningTrainerConfigT, ModelT, TaskT], CompileMixin[BaseLearningTrainerConfigT, ModelT, TaskT], ParallelMixin[BaseLearningTrainerConfigT, ModelT, TaskT], BaseTrainer[BaseLearningTrainerConfigT, ModelT, TaskT], Generic[BaseLearningTrainerConfigT, ModelT, TaskT]

train_step(*, task_model: Module, batches: Iterator[Batch], state: State, task: TaskT, model: ModelT, optim: Optimizer | Collection[Optimizer], lr_sched: SchedulerAdapter | Collection[SchedulerAdapter]) dict[str, torch.Tensor][source]
val_step(*, task_model: Module, batch: Batch, state: State, task: TaskT, model: ModelT) None[source]
test_step(*, task_model: Module, batch: Batch, state: State, task: TaskT, model: ModelT) None[source]