ml.trainers.learning
Defines a vanilla trainer which doesn’t do any device or data manipulation.
This trainer expects the task to handle all the relevant movement of data and models to their associated devices.
- class ml.trainers.learning.BaseLearningTrainerConfig(name: str = '???', exp_name: str = '${ml.exp_name:null}', exp_dir: str = '???', log_dir_name: str = 'logs', use_double_weight_precision: bool = False, checkpoint: ml.trainers.base.CheckpointConfig = <factory>, parallel: ml.trainers.mixins.data_parallel.ParallelConfig = <factory>, compiler: ml.trainers.mixins.compile.TorchCompileConfig = <factory>, cpu_stats_ping_interval: int = 1, cpu_stats_only_log_once: bool = False, gpu_stats_ping_interval: int = 10, gpu_stats_only_log_once: bool = False, mixed_precision: ml.trainers.mixins.mixed_precision.MixedPrecisionConfig = <factory>, clip_grad_norm: float = 10.0, clip_grad_norm_type: Any = 2, balance_grad_norms: bool = False, profiler: ml.trainers.mixins.profiler.Profiler = <factory>, set_to_none: bool = True, deterministic: bool = False, use_tf32: bool = True, detect_anomaly: bool = False)[source]
Bases:
ProfilerTrainerConfig
,MixedPrecisionTrainerConfig
,GPUStatsConfig
,CPUStatsConfig
,CompileConfig
,TrainerParallelConfig
,BaseTrainerConfig
- set_to_none: bool = True
- deterministic: bool = False
- use_tf32: bool = True
- detect_anomaly: bool = False
- class ml.trainers.learning.BaseLearningTrainer(config: ProfilerTrainerConfigT)[source]
Bases:
ProfilerTrainerMixin
[BaseLearningTrainerConfigT
,ModelT
,TaskT
],MixedPrecisionTrainerMixin
[BaseLearningTrainerConfigT
,ModelT
,TaskT
],GPUStatsMixin
[BaseLearningTrainerConfigT
,ModelT
,TaskT
],CPUStatsMixin
[BaseLearningTrainerConfigT
,ModelT
,TaskT
],CompileMixin
[BaseLearningTrainerConfigT
,ModelT
,TaskT
],ParallelMixin
[BaseLearningTrainerConfigT
,ModelT
,TaskT
],BaseTrainer
[BaseLearningTrainerConfigT
,ModelT
,TaskT
],Generic
[BaseLearningTrainerConfigT
,ModelT
,TaskT
]- train_step(*, task_model: Module, batches: Iterator[Batch], state: State, task: TaskT, model: ModelT, optim: Optimizer | Collection[Optimizer], lr_sched: SchedulerAdapter | Collection[SchedulerAdapter]) dict[str, torch.Tensor] [source]