ml.trainers.rl
Defines a trainer to use for reinforcement learning.
This trainer spawns a number of workers to collect experience from the environment. The workers then send the experience to the model, which learns from it. The model sends actions back to the workers, which perform the actions in the environment and collect the next state.
- class ml.trainers.rl.SamplingConfig(num_epoch_samples: int = '???', min_batch_size: int = 1, max_batch_size: int | None = None, max_wait_time: float | None = None, min_trajectory_length: int = 1, max_trajectory_length: int | None = None, force_sync: bool = False, optimal: bool = False)[source]
Bases:
object
- num_epoch_samples: int = '???'
- min_batch_size: int = 1
- max_batch_size: int | None = None
- max_wait_time: float | None = None
- min_trajectory_length: int = 1
- max_trajectory_length: int | None = None
- force_sync: bool = False
- optimal: bool = False
- class ml.trainers.rl.ReinforcementLearningTrainerConfig(name: str = '???', exp_name: str = '${ml.exp_name:null}', exp_dir: str = '???', log_dir_name: str = 'logs', use_double_weight_precision: bool = False, checkpoint: ml.trainers.base.CheckpointConfig = <factory>, parallel: ml.trainers.mixins.data_parallel.ParallelConfig = <factory>, compiler: ml.trainers.mixins.compile.TorchCompileConfig = <factory>, cpu_stats_ping_interval: int = 1, cpu_stats_only_log_once: bool = False, gpu_stats_ping_interval: int = 10, gpu_stats_only_log_once: bool = False, mixed_precision: ml.trainers.mixins.mixed_precision.MixedPrecisionConfig = <factory>, clip_grad_norm: float = 10.0, clip_grad_norm_type: Any = 2, balance_grad_norms: bool = False, profiler: ml.trainers.mixins.profiler.Profiler = <factory>, set_to_none: bool = True, deterministic: bool = False, use_tf32: bool = True, detect_anomaly: bool = False, sampling: ml.trainers.rl.SamplingConfig = <factory>)[source]
Bases:
BaseLearningTrainerConfig
- sampling: SamplingConfig
- class ml.trainers.rl.ReinforcementLearningTrainer(config: ProfilerTrainerConfigT)[source]
Bases:
BaseLearningTrainer
[ReinforcementLearningTrainerConfigT
,ModelT
,ReinforcementLearningTaskT
],Generic
[ReinforcementLearningTrainerConfigT
,ModelT
,ReinforcementLearningTaskT
]- train(model: ModelT, task: ReinforcementLearningTaskT, optimizer: BaseOptimizer, lr_scheduler: BaseLRScheduler) None [source]
Runs the training loop.
- Parameters:
model – The current model
task – The current task
optimizer – The current optimizer
lr_scheduler – The current learning rate scheduler
- Raises:
ValueError – If the task is not a reinforcement learning task