ml.trainers.mixins.step_wrapper

Defines a mixin to wrap some steps in a context manager.

This is used by other components which want to know when a step is being run, such as when doing profiling.

class ml.trainers.mixins.step_wrapper.StepContext(step: Literal['backward', 'build_rl_dataset', 'change_mode', 'clip_grads', 'collect_rl_samples', 'forward', 'get_single_loss', 'log_losses', 'on_epoch_end', 'on_epoch_start', 'on_step_end', 'on_step_start', 'step', 'update_state', 'write_logs', 'zero_grads'])[source]

Bases: ContextManager

Context manager to get the current step type.

CURRENT_STEP: Literal['backward', 'build_rl_dataset', 'change_mode', 'clip_grads', 'collect_rl_samples', 'forward', 'get_single_loss', 'log_losses', 'on_epoch_end', 'on_epoch_start', 'on_step_end', 'on_step_start', 'step', 'update_state', 'write_logs', 'zero_grads'] | None = None
class ml.trainers.mixins.step_wrapper.StepContextMixin(config: TrainerConfigT)[source]

Bases: BaseTrainer[BaseTrainerConfigT, ModelT, TaskT], ABC

step_context(step: Literal['backward', 'build_rl_dataset', 'change_mode', 'clip_grads', 'collect_rl_samples', 'forward', 'get_single_loss', 'log_losses', 'on_epoch_end', 'on_epoch_start', 'on_step_end', 'on_step_start', 'step', 'update_state', 'write_logs', 'zero_grads']) ContextManager[source]