"""Defines a trainer to use for reinforcement learning.
This trainer spawns a number of workers to collect experience from the
environment. The workers then send the experience to the model, which
learns from it. The model sends actions back to the workers, which
perform the actions in the environment and collect the next state.
"""
import contextlib
import logging
import signal
from dataclasses import dataclass
from types import FrameType
from typing import Generic, TypeVar
from omegaconf import MISSING
from ml.core.config import conf_field
from ml.core.registry import register_trainer
from ml.lr_schedulers.base import BaseLRScheduler
from ml.optimizers.base import BaseOptimizer
from ml.tasks.rl.base import ReinforcementLearningTask
from ml.trainers.base import ModelT
from ml.trainers.learning import BaseLearningTrainer, BaseLearningTrainerConfig
from ml.utils.exceptions import TrainingFinishedError
from ml.utils.timer import Timer
logger = logging.getLogger(__name__)
[docs]@dataclass
class SamplingConfig:
num_epoch_samples: int = conf_field(MISSING, help="Number of samples to collect each epoch")
min_batch_size: int = conf_field(1, help="Minimum batch size for doing inference on the model")
max_batch_size: int | None = conf_field(None, help="Maximum batch size to infer through model")
max_wait_time: float | None = conf_field(None, help="Maximum time to wait for inferring batches")
min_trajectory_length: int = conf_field(1, help="Minimum length of trajectories to collect")
max_trajectory_length: int | None = conf_field(None, help="Maximum length of trajectories to collect")
force_sync: bool = conf_field(False, help="Force workers to run in sync mode rather than async mode")
optimal: bool = conf_field(False, help="Whether to choose the optimal action or sample from the policy")
[docs]@dataclass
class ReinforcementLearningTrainerConfig(BaseLearningTrainerConfig):
sampling: SamplingConfig = conf_field(SamplingConfig())
ReinforcementLearningTrainerConfigT = TypeVar(
"ReinforcementLearningTrainerConfigT",
bound=ReinforcementLearningTrainerConfig,
)
ReinforcementLearningTaskT = TypeVar("ReinforcementLearningTaskT", bound=ReinforcementLearningTask)
[docs]@register_trainer("rl", ReinforcementLearningTrainerConfig)
class ReinforcementLearningTrainer(
BaseLearningTrainer[ReinforcementLearningTrainerConfigT, ModelT, ReinforcementLearningTaskT],
Generic[ReinforcementLearningTrainerConfigT, ModelT, ReinforcementLearningTaskT],
):
[docs] def train(
self,
model: ModelT,
task: ReinforcementLearningTaskT,
optimizer: BaseOptimizer,
lr_scheduler: BaseLRScheduler,
) -> None:
"""Runs the training loop.
Args:
model: The current model
task: The current task
optimizer: The current optimizer
lr_scheduler: The current learning rate scheduler
Raises:
ValueError: If the task is not a reinforcement learning task
"""
if not isinstance(task, ReinforcementLearningTask):
raise ValueError(f"Expected task to be a ReinforcementLearningTask, got {type(task)}")
self._init_environment()
with Timer("compiling model"):
model = self._compile_model(model)
with Timer("compiling training step"):
train_step = self._compile_func(self.train_step)
with Timer("building task model"):
task_model = self._get_task_model(task, model)
optim, lr_sched = self._get_optim_and_lr_sched(model, optimizer, lr_scheduler)
state = self._get_state(task, model, optim, lr_sched)
def on_exit(signum: int, _: FrameType | None) -> None:
sig = signal.Signals(signum)
self.on_exit(sig, state, task, model, optim, lr_sched)
# Handle user-defined interrupts.
signal.signal(signal.SIGUSR1, on_exit)
# Gets the environment workers.
worker_pool = task.get_worker_pool(force_sync=self.config.sampling.force_sync)
self.on_training_start(state, task, model, optim, lr_sched)
try:
with contextlib.ExitStack() as ctx:
profile = self.get_profile()
if profile is not None:
ctx.enter_context(profile)
while True:
with self.step_context("on_epoch_start"):
self.on_epoch_start(state, task, model, optim, lr_sched)
with self.step_context("collect_rl_samples"), self.autocast_context:
samples = task.collect_samples(
model=model,
worker_pool=worker_pool,
total_samples=self.config.sampling.num_epoch_samples,
min_trajectory_length=self.config.sampling.min_trajectory_length,
max_trajectory_length=self.config.sampling.max_trajectory_length,
min_batch_size=self.config.sampling.min_batch_size,
max_batch_size=self.config.sampling.max_batch_size,
max_wait_time=self.config.sampling.max_wait_time,
optimal=self.config.sampling.optimal,
)
with self.step_context("build_rl_dataset"):
with Timer("building dataset"):
train_ds = task.build_rl_dataset(samples)
with Timer("building dataloader"):
train_dl = task.get_dataloader(train_ds, "train")
with Timer("getting prefetcher"):
train_pf = self._device.get_prefetcher(train_dl)
for train_batch in train_pf:
self._log_prefetcher_stats(train_pf)
if task.is_training_over(state):
raise TrainingFinishedError
with self.step_context("on_step_start"):
self.on_step_start(state, task, model, optim, lr_sched)
loss_dict = train_step(
task_model=task_model,
batches=iter([train_batch]),
state=state,
task=task,
model=model,
optim=optim,
lr_sched=lr_sched,
)
if self.should_checkpoint(state):
self.save_checkpoint(state, task, model, optim, lr_sched)
if profile is not None:
profile.step()
with self.step_context("on_step_end"):
self.on_step_end(state, loss_dict, task, model, optim, lr_sched)
if task.epoch_is_over(state):
break
with self.step_context("on_epoch_end"):
self.on_epoch_end(state, task, model, optim, lr_sched)
except TrainingFinishedError:
self.save_checkpoint(state, task, model, optim, lr_sched)
logger.info(
"Finished training after %d epochs, %d steps, %d samples",
state.num_epochs,
state.num_steps,
state.num_samples,
)
except Exception:
logger.exception("Caught exception during training loop")
finally:
self.on_training_end(state, task, model, optim, lr_sched)