Source code for ml.tasks.rl.base

"""The base task and config for reinforcement learning tasks.

This class expects you to implement the following functions:

.. code-block:: python

    class MyReinforcementLearningTask(ml.ReinforcementLearningTask[Config, Model, State, Action, Output, Loss]):
        def get_actions(self, model: Model, states: list[State], optimal: bool) -> list[Action]:
            ...

        def get_environment(self) -> Environment:
            ...

        def run_model(self, model: Model, batch: tuple[State, Action], state: ml.State) -> Output:
            ...

        def compute_loss(self, model: Model, batch: tuple[State, Action], state: ml.State, output: Output) -> Loss:
            ...

Additionally, you can implement :meth:`postprocess_trajectory` and :meth:`postprocess_trajectories` to apply some
postprocessing to collected batches, such as computing the discounted rewards.
"""

import functools
import logging
import multiprocessing as mp
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Generic, Iterable, Iterator, Literal, TypeVar, overload

import numpy as np
import torch
from omegaconf import MISSING
from torch import Tensor
from torch.utils.data.dataset import Dataset

from ml.core.common_types import Loss, Output, RLAction, RLState
from ml.core.config import conf_field
from ml.core.state import State
from ml.tasks.base import BaseTask, BaseTaskConfig, ModelT
from ml.tasks.environments.base import Environment
from ml.tasks.environments.worker import (
    AsyncEnvironmentWorker,
    BaseEnvironmentWorker,
    SpecialState,
    SyncEnvironmentWorker,
    WorkerPool,
    cast_worker_mode,
    get_worker_pool,
)
from ml.tasks.rl.replay import MultiReplaySamples, ReplayDataset, ReplaySamples
from ml.utils.timer import spinnerator
from ml.utils.video import Writer, standardize_image, write_video

logger = logging.getLogger(__name__)


[docs]@dataclass class EnvironmentConfig: num_env_workers: int = conf_field(1, help="Number of environment workers (0 to run synchronously)") env_worker_mode: str = conf_field("process", help="Mode for running environment worker") env_seed: int = conf_field(1337, help="Environment seed") env_cleanup_time: float = conf_field(5.0, help="Cleanup time for the async environment worker") max_steps: int = conf_field(1_000, help="Maximum number of steps in a clip")
[docs]@dataclass class DatasetConfig: num_samples: int = conf_field(1, help="Number of training samples in replay") num_update_steps: int = conf_field(MISSING, help="How often to interact with the environment") stride: int = conf_field(1, help="Replay stride to use") replay_buffer_sample_size: int = conf_field(10000, help="Number of epochs of experience to keep in replay buffer")
[docs]@dataclass class ReinforcementLearningTaskConfig(BaseTaskConfig): environment: EnvironmentConfig = conf_field(EnvironmentConfig()) dataset: DatasetConfig = conf_field(DatasetConfig())
ReinforcementLearningTaskConfigT = TypeVar("ReinforcementLearningTaskConfigT", bound=ReinforcementLearningTaskConfig)
[docs]class ReinforcementLearningTask( BaseTask[ReinforcementLearningTaskConfigT, ModelT, tuple[RLState, RLAction], Output, Loss], Generic[ReinforcementLearningTaskConfigT, ModelT, RLState, RLAction, Output, Loss], ABC, ):
[docs] @abstractmethod def get_actions(self, model: ModelT, states: list[RLState], optimal: bool) -> list[RLAction]: """Samples an action from the policy, given the previous state. Args: model: The model to sample from. states: The previous states. optimal: Whether to get the optimal action or to sample from the policy. Returns: The next actions to take for each state. """
[docs] @abstractmethod def get_environment(self) -> Environment[RLState, RLAction]: """Returns the environment for the task. Returns: The environment for the task """
[docs] def build_rl_dataset( self, samples: MultiReplaySamples[tuple[RLState, RLAction]], ) -> Dataset[tuple[RLState, RLAction]]: return ReplayDataset( samples, clip_size=self.config.dataset.num_samples, stride=self.config.dataset.stride, )
[docs] @functools.lru_cache def get_environment_cached(self) -> Environment[RLState, RLAction]: return self.get_environment()
[docs] def get_environment_workers(self, force_sync: bool = False) -> list[BaseEnvironmentWorker[RLState, RLAction]]: env_cfg = self.config.environment if env_cfg.num_env_workers <= 0 or force_sync: return [ SyncEnvironmentWorker(self.get_environment(), seed=env_cfg.env_seed) for _ in range(max(1, env_cfg.num_env_workers)) ] manager = mp.Manager() return [ AsyncEnvironmentWorker( self.get_environment(), manager, rank=rank, world_size=env_cfg.num_env_workers, seed=env_cfg.env_seed, cleanup_time=env_cfg.env_cleanup_time, mode=cast_worker_mode(env_cfg.env_worker_mode), ) for rank in range(env_cfg.num_env_workers) ]
[docs] def get_worker_pool(self, force_sync: bool = False) -> WorkerPool[RLState, RLAction]: return get_worker_pool(self.get_environment_workers(force_sync=force_sync), force_sync=force_sync)
[docs] def postprocess_trajectory(self, samples: list[tuple[RLState, RLAction]]) -> list[tuple[RLState, RLAction]]: """Performs any global postprocessing on the trajectory. Args: samples: The trajectory to postprocess. Returns: The postprocessed trajectory. """ return samples
[docs] def postprocess_trajectories( self, trajectories: list[list[tuple[RLState, RLAction]]], ) -> list[list[tuple[RLState, RLAction]]]: """Performs any global postprocessing on all of the trajectories. Args: trajectories: The trajectories to postprocess. Returns: The postprocessed trajectories. """ return trajectories
[docs] def iter_samples( self, model: ModelT, worker_pool: WorkerPool[RLState, RLAction], *, total_samples: int | None = None, min_trajectory_length: int = 1, max_trajectory_length: int | None = None, min_batch_size: int = 1, max_batch_size: int | None = None, max_wait_time: float | None = None, optimal: bool = True, ) -> Iterable[list[tuple[RLState, RLAction]]]: """Collects samples from the environment. Args: model: The model to sample from. worker_pool: The pool of workers for the environment total_samples: The total number of samples to collect; if None, iterates forever min_trajectory_length: Minimum sequence length to consider a sequence as having contributed to `total_samples` max_trajectory_length: Maximum sequence length to consider a sequence as having contributed to `total_samples` min_batch_size: Minimum batch size for doing inference on model max_batch_size: Maximum batch size for doing inference on model max_wait_time: Maximum amount of time to wait to build batch optimal: Whether to get the optimal action or to sample from the policy. Yields: Lists of samples from the environment. Raises: ValueError: If `min_batch_size` is greater than `max_batch_size`. """ min_trajectory_length = max(min_trajectory_length, self.config.dataset.num_samples, 1) num_samples, num_trajectories = 0, 0 worker_pool.reset() trajectories: list[list[tuple[RLState, RLAction]]] = [[] for _ in range(len(worker_pool))] max_batch_size = len(worker_pool) if max_batch_size is None else min(max_batch_size, len(worker_pool)) if total_samples is not None and min_trajectory_length > total_samples: raise ValueError(f"{min_trajectory_length=} > {total_samples=}") if min_batch_size > max_batch_size: raise ValueError(f"{min_batch_size=} > {max_batch_size=}") with spinnerator.range(total_samples, desc="Sampling") as pbar, torch.no_grad(): while total_samples is None or num_samples < total_samples: start_time = time.time() # Wait for new samples to be ready. batch: list[tuple[RLState, int]] = [] batch_special: list[tuple[SpecialState, int]] = [] while len(batch) + len(batch_special) < max_batch_size: elapsed_time = time.time() - start_time if max_wait_time is not None and elapsed_time > max_wait_time and len(batch) >= min_batch_size: break state, worker_id = worker_pool.get_state() pbar.update() # Update every time we get a new state. if state == "terminated": if len(trajectories[worker_id]) >= min_trajectory_length: yield self.postprocess_trajectory(trajectories[worker_id]) num_samples += len(trajectories[worker_id]) else: logger.warning( "Discarding trajectory of length %d because it is less than %d", len(trajectories[worker_id]), min_trajectory_length, ) trajectories[worker_id] = [] batch_special.append((state, worker_id)) else: batch.append((state, worker_id)) # Sample actions for the new samples states, worker_ids = [state for state, _ in batch], [worker_id for _, worker_id in batch] actions = self.get_actions(model, states, optimal) if states else [] # Send the actions to the workers trajectory_lengths = 0 for state, action, worker_id in zip(states, actions, worker_ids): if max_trajectory_length is not None and len(trajectories[worker_id]) >= max_trajectory_length: yield self.postprocess_trajectory(trajectories[worker_id]) num_samples += len(trajectories[worker_id]) num_trajectories += 1 trajectories[worker_id] = [] worker_pool.send_action("reset", worker_id) else: trajectories[worker_id].append((state, action)) trajectory_len = len(trajectories[worker_id]) if trajectory_len >= min_trajectory_length: trajectory_lengths += trajectory_len worker_pool.send_action(action, worker_id) for state, worker_id in batch_special: if state == "terminated": worker_pool.send_action("reset", worker_id) else: raise ValueError(f"Unknown special state {state}") # If the current trajectories would finish the episode, then # add them to the list of all trajectories. if total_samples is not None and num_samples + trajectory_lengths >= total_samples: for t in trajectories: if len(t) < min_trajectory_length: continue yield self.postprocess_trajectory(t) num_trajectories += 1 num_samples += trajectory_lengths pbar.update(trajectory_lengths) break logger.info("Collected %d total samples and %d trajectories", num_samples, num_trajectories)
[docs] def collect_samples( self, model: ModelT, worker_pool: WorkerPool[RLState, RLAction], total_samples: int, *, min_trajectory_length: int = 1, max_trajectory_length: int | None = None, min_batch_size: int = 1, max_batch_size: int | None = None, max_wait_time: float | None = None, optimal: bool = True, ) -> MultiReplaySamples[tuple[RLState, RLAction]]: trajectories_iter = self.iter_samples( model=model, worker_pool=worker_pool, total_samples=total_samples, min_trajectory_length=min_trajectory_length, max_trajectory_length=max_trajectory_length, min_batch_size=min_batch_size, max_batch_size=max_batch_size, max_wait_time=max_wait_time, optimal=optimal, ) # Does global postprocessing on the sampled trajectories. all_trajectories = list(trajectories_iter) all_trajectories = self.postprocess_trajectories(all_trajectories) return MultiReplaySamples([ReplaySamples(t) for t in all_trajectories])
[docs] def epoch_is_over(self, state: State) -> bool: return state.num_epoch_steps >= self.config.dataset.num_update_steps
@overload def sample_clip( self, *, save_path: str | Path, return_images: Literal[True] = True, return_states: Literal[False] = False, model: ModelT | None = None, writer: Writer = "ffmpeg", standardize_images: bool = True, optimal: bool = True, ) -> None: ... @overload def sample_clip( self, *, return_images: Literal[True] = True, return_states: Literal[False] = False, model: ModelT | None = None, standardize_images: bool = True, optimal: bool = True, ) -> Tensor: ... @overload def sample_clip( self, *, return_images: Literal[True] = True, return_states: Literal[True], model: ModelT | None = None, standardize_images: bool = True, optimal: bool = True, ) -> tuple[Tensor, list[tuple[RLState, RLAction]]]: ... @overload def sample_clip( self, *, return_images: Literal[False], return_states: Literal[True], model: ModelT | None = None, optimal: bool = True, ) -> list[tuple[RLState, RLAction]]: ...
[docs] def sample_clip( self, *, save_path: str | Path | None = None, return_images: bool = True, return_states: bool = False, model: ModelT | None = None, writer: Writer = "ffmpeg", standardize_images: bool = True, optimal: bool = True, ) -> Tensor | list[tuple[RLState, RLAction]] | tuple[Tensor, list[tuple[RLState, RLAction]]] | None: """Samples a clip for a given model. Args: save_path: Where to save the sampled clips return_images: Whether to return the images return_states: Whether to return the states model: The model to sample from; if not provided, samples actions randomly from the model writer: The writer to use to save the clip standardize_images: Whether to standardize the images optimal: Whether to sample actions optimally Returns: The sampled clip, if `save_path` is not provided, otherwise `None` (the clip is saved to `save_path`). Raises: ValueError: If `save_path` is provided and `return_states` is `True` """ env_cfg = self.config.environment if not return_states and not return_images: raise ValueError("Must return states, images or both") environment = self.get_environment_cached() def iter_states() -> Iterator[tuple[RLState, RLAction]]: state = environment.reset() if environment.terminated(state): raise RuntimeError("Initial state is terminated") iterator = spinnerator.range(env_cfg.max_steps) for i in iterator: if environment.terminated(state): logger.info("Terminating environment early, after %d / %d steps", i, env_cfg.max_steps) break if model is None: action = environment.sample_action() else: (action,) = self.get_actions(model, [state], optimal) state = environment.step(action) yield (state, action) def iter_images() -> Iterator[np.ndarray | Tensor]: for state, _ in iter_states(): yield environment.render(state) def iter_images_and_states() -> Iterator[tuple[np.ndarray | Tensor, tuple[RLState, RLAction]]]: for state, action in iter_states(): image = environment.render(state) if standardize_images: image = standardize_image(image) yield image, (state, action) if save_path is None: if return_images and return_states: images, states = zip(*iter_images_and_states()) images_np = [standardize_image(image) for image in images] return torch.from_numpy(np.stack(images_np)), list(states) if return_states: return list(iter_states()) images_np = [standardize_image(image) for image in iter_images()] return torch.from_numpy(np.stack(images_np)) if return_states: raise ValueError("Cannot return states when saving to a file") write_video(iter_images(), save_path, fps=environment.fps, writer=writer) return None