ml.trainers.rl

Defines a trainer to use for reinforcement learning.

This trainer spawns a number of workers to collect experience from the environment. The workers then send the experience to the model, which learns from it. The model sends actions back to the workers, which perform the actions in the environment and collect the next state.

class ml.trainers.rl.SamplingConfig(num_epoch_samples: int = '???', min_batch_size: int = 1, max_batch_size: int | None = None, max_wait_time: float | None = None, min_trajectory_length: int = 1, max_trajectory_length: int | None = None, force_sync: bool = False, optimal: bool = False)[source]

Bases: object

num_epoch_samples: int = '???'
min_batch_size: int = 1
max_batch_size: int | None = None
max_wait_time: float | None = None
min_trajectory_length: int = 1
max_trajectory_length: int | None = None
force_sync: bool = False
optimal: bool = False
class ml.trainers.rl.ReinforcementLearningTrainerConfig(name: str = '???', exp_name: str = '${ml.exp_name:null}', exp_dir: str = '???', log_dir_name: str = 'logs', use_double_weight_precision: bool = False, checkpoint: ml.trainers.base.CheckpointConfig = <factory>, parallel: ml.trainers.mixins.data_parallel.ParallelConfig = <factory>, compiler: ml.trainers.mixins.compile.TorchCompileConfig = <factory>, cpu_stats_ping_interval: int = 1, cpu_stats_only_log_once: bool = False, gpu_stats_ping_interval: int = 10, gpu_stats_only_log_once: bool = False, mixed_precision: ml.trainers.mixins.mixed_precision.MixedPrecisionConfig = <factory>, clip_grad_norm: float = 10.0, clip_grad_norm_type: Any = 2, balance_grad_norms: bool = False, profiler: ml.trainers.mixins.profiler.Profiler = <factory>, set_to_none: bool = True, deterministic: bool = False, use_tf32: bool = True, detect_anomaly: bool = False, sampling: ml.trainers.rl.SamplingConfig = <factory>)[source]

Bases: BaseLearningTrainerConfig

sampling: SamplingConfig
class ml.trainers.rl.ReinforcementLearningTrainer(config: ProfilerTrainerConfigT)[source]

Bases: BaseLearningTrainer[ReinforcementLearningTrainerConfigT, ModelT, ReinforcementLearningTaskT], Generic[ReinforcementLearningTrainerConfigT, ModelT, ReinforcementLearningTaskT]

train(model: ModelT, task: ReinforcementLearningTaskT, optimizer: BaseOptimizer, lr_scheduler: BaseLRScheduler) None[source]

Runs the training loop.

Parameters:
  • model – The current model

  • task – The current task

  • optimizer – The current optimizer

  • lr_scheduler – The current learning rate scheduler

Raises:

ValueError – If the task is not a reinforcement learning task