Source code for ml.tasks.environments.worker

"""Defines workers for reinforcement learning environments."""

import collections
import logging
import queue
import threading
from abc import ABC, abstractmethod
from multiprocessing.managers import SyncManager
from queue import Queue
from typing import TYPE_CHECKING, Deque, Generic, Literal, Sequence, cast, get_args, overload

import torch.multiprocessing as mp

from ml.core.common_types import RLAction, RLState
from ml.utils.logging import configure_logging

if TYPE_CHECKING:
    from ml.tasks.environments.base import Environment

logger = logging.getLogger(__name__)

Mode = Literal["thread", "process"]
SpecialAction = Literal["reset", "close"]
SpecialState = Literal["terminated"]


[docs]def clear_queue(q: Queue) -> None: while True: try: q.get_nowait() except queue.Empty: break
[docs]def cast_worker_mode(m: str) -> Mode: choices = get_args(Mode) if m not in choices: raise ValueError(f"`{m}` is not a valid mode; choices are {choices}") return cast(Mode, m)
[docs]class BaseEnvironmentWorker(ABC, Generic[RLState, RLAction]):
[docs] @abstractmethod def cleanup(self) -> None: """Cleanup the worker."""
def __del__(self) -> None: self.cleanup()
[docs] @abstractmethod def get_state(self) -> RLState | SpecialState: """Returns the current environment states. Returns: The current environment state. """
[docs] @abstractmethod def send_action(self, action: RLAction | SpecialAction) -> None: """Sends an action to the environment. Args: action: The action to send to the environment """
[docs] @classmethod @abstractmethod def from_environment( cls, env: "Environment[RLState, RLAction]", num_workers: int, ) -> Sequence["BaseEnvironmentWorker[RLState, RLAction]"]: """Creates a worker from an environment. Args: env: The environment to create the worker from. num_workers: The number of workers to create. Returns: The workers. """
[docs]class SyncEnvironmentWorker(BaseEnvironmentWorker[RLState, RLAction], Generic[RLState, RLAction]): def __init__(self, env: "Environment[RLState, RLAction]", seed: int = 1337) -> None: """Defines a synchronous environment worker. Args: env: The environment to wrap. seed: The random seed to use. """ super().__init__() self.env = env self.seed = seed self.state: RLState | SpecialState | None = None
[docs] @classmethod def from_environment( cls, env: "Environment[RLState, RLAction]", num_workers: int, ) -> Sequence["SyncEnvironmentWorker[RLState, RLAction]"]: return [cls(env) for _ in range(num_workers)]
[docs] def cleanup(self) -> None: pass
[docs] def get_state(self) -> RLState | SpecialState: if self.state is None: raise RuntimeError("Environment has not been reset") return self.state
[docs] def send_action(self, action: RLAction | SpecialAction) -> None: if action == "close": raise ValueError("Cannot close a synchronous environment") if action == "reset": self.state = self.env.reset(self.seed) else: self.state = self.env.step(action) if self.env.terminated(self.state): self.state = "terminated"
[docs]class AsyncEnvironmentWorker(BaseEnvironmentWorker[RLState, RLAction], Generic[RLState, RLAction]): def __init__( self, env: "Environment[RLState, RLAction]", manager: SyncManager, rank: int | None = None, world_size: int | None = None, seed: int = 1337, cleanup_time: float = 5.0, mode: Mode = "process", daemon: bool = True, ) -> None: """Defines an asynchronous environment worker. This worker either runs in a separate thread or process, and is used to asynchronously interact with an environment. This is useful for environments that are slow to interact with, such as a simulator. Args: env: The environment to wrap. manager: The manager to use for shared memory. rank: The rank of the worker. world_size: The number of workers. seed: The random seed to use. cleanup_time: The time to wait for the worker to finish before killing it. mode: The mode to use for the worker. daemon: Whether to run the worker as a daemon. Raises: ValueError: If the mode is invalid. """ super().__init__() self.cleanup_time = cleanup_time self.rank = 0 if rank is None else rank self.world_size = 1 if world_size is None else world_size self.action_queue: "Queue[RLAction | SpecialAction]" = manager.Queue(maxsize=1) self.state_queue: "Queue[RLState | SpecialState]" = manager.Queue(maxsize=1) args = env, seed, self.action_queue, self.state_queue, rank, world_size self._proc: threading.Thread | mp.Process if mode == "thread": self._proc = threading.Thread(target=self._thread, args=args, daemon=daemon) self._proc.start() elif mode == "process": self._proc = mp.Process(target=self._thread, args=args, daemon=daemon) self._proc.start() else: raise ValueError(f"Invalid mode: {mode}")
[docs] @classmethod def from_environment( cls, env: "Environment[RLState, RLAction]", num_workers: int, ) -> Sequence["AsyncEnvironmentWorker[RLState, RLAction]"]: manager = mp.Manager() return [cls(env, manager, rank=rank, world_size=num_workers) for rank in range(num_workers)]
[docs] def cleanup(self) -> None: logger.debug("Cleaning up task...") try: self.send_action("close") self._proc.join(timeout=self.cleanup_time) if self._proc.is_alive(): logger.warning("Process failed to finish after %.2f seconds; killing", self.cleanup_time) if isinstance(self._proc, threading.Thread): self._proc._stop() else: self._proc.kill() except Exception: pass
@classmethod def _thread( cls, env: "Environment[RLState, RLAction]", seed: int, action_queue: "Queue[RLAction | SpecialAction]", state_queue: "Queue[RLState | SpecialState]", rank: int | None, world_size: int | None, ) -> None: configure_logging(rank=rank, world_size=world_size) while True: action = action_queue.get() if action == "close": logger.debug("Got close action; exiting") break if action == "reset": logger.debug("Got reset action; resetting environment") state = env.reset(seed) else: state = env.step(action) if env.terminated(state): state_queue.put("terminated") else: state_queue.put(state)
[docs] def get_state(self) -> RLState | SpecialState: return self.state_queue.get()
[docs] def send_action(self, action: RLAction | SpecialAction) -> None: if action in ("reset", "close"): clear_queue(self.state_queue) clear_queue(self.action_queue) self.action_queue.put(action)
[docs]class WorkerPool(Generic[RLState, RLAction]):
[docs] @abstractmethod def reset(self) -> None: pass
@abstractmethod def __len__(self) -> int: pass @overload def get_state(self, timeout: float) -> tuple[RLState | SpecialState, int] | None: ... @overload def get_state(self, timeout: None = None) -> tuple[RLState | SpecialState, int]: ...
[docs] def get_state(self, timeout: float | None = None) -> tuple[RLState | SpecialState, int] | None: return self._get_state_impl(timeout)
@abstractmethod def _get_state_impl(self, timeout: float | None = None) -> tuple[RLState | SpecialState, int] | None: """Returns the current state for one of the workers. Args: timeout: The timeout for getting the worker state. Returns: The worker state, or None if we timed out. """
[docs] @abstractmethod def send_action(self, action: RLAction | SpecialAction, worker_id: int) -> None: """Sends an action to the given worker. Args: action: The action to send. worker_id: The ID of the worker to send the action to. """
[docs] @classmethod @abstractmethod def from_workers( cls, workers: Sequence[BaseEnvironmentWorker[RLState, RLAction]], ) -> "WorkerPool[RLState, RLAction]": """Creates a worker pool from a list of workers. Args: workers: The list of workers. Returns: The worker pool. """
[docs]class SyncWorkerPool(WorkerPool[RLState, RLAction], Generic[RLState, RLAction]): def __init__(self, workers: Sequence[BaseEnvironmentWorker[RLState, RLAction]]) -> None: super().__init__() self.workers = workers self.worker_queue: Deque[int] = collections.deque(range(len(self.workers)))
[docs] def reset(self) -> None: for worker in self.workers: worker.send_action("reset")
def __len__(self) -> int: return len(self.workers) def _get_state_impl(self, timeout: float | None = None) -> tuple[RLState | SpecialState, int] | None: i = self.worker_queue.popleft() self.worker_queue.append(i) return self.workers[i].get_state(), i
[docs] def send_action(self, action: RLAction | SpecialAction, worker_id: int) -> None: self.workers[worker_id].send_action(action)
[docs] @classmethod def from_workers( cls, workers: Sequence[BaseEnvironmentWorker[RLState, RLAction]], ) -> "SyncWorkerPool[RLState, RLAction]": return cls(workers)
[docs]class AsyncWorkerPool(WorkerPool[RLState, RLAction], Generic[RLState, RLAction]): def __init__(self, workers: Sequence[BaseEnvironmentWorker[RLState, RLAction]], daemon: bool = True) -> None: super().__init__() self.workers = workers self.manager = mp.Manager() self.state_queue: "Queue[tuple[RLState | SpecialState, int]]" = self.manager.Queue(maxsize=len(workers)) self.action_queues: list["Queue[RLAction | SpecialAction]"] = [ self.manager.Queue(maxsize=1) for _ in range(len(workers)) ] # Starts a thread for each worker. self._procs = [ threading.Thread( target=self._thread, args=(env_id, worker, self.state_queue, action_queue), daemon=daemon, ) for env_id, (worker, action_queue) in enumerate(zip(workers, self.action_queues)) ] for proc in self._procs: proc.start()
[docs] def cleanup(self) -> None: logger.debug("Cleaning up worker pool...") try: clear_queue(self.state_queue) for action_queue in self.action_queues: clear_queue(action_queue) action_queue.put("close") for proc in self._procs: proc.join() except Exception: pass
[docs] def reset(self) -> None: clear_queue(self.state_queue) for action_queue in self.action_queues: clear_queue(action_queue) action_queue.put("reset")
def __del__(self) -> None: self.cleanup() def __len__(self) -> int: return len(self.workers) @classmethod def _thread( cls, env_id: int, worker: BaseEnvironmentWorker[RLState, RLAction], state_queue: "Queue[tuple[RLState | SpecialState, int]]", action_queue: "Queue[RLAction | SpecialAction]", ) -> None: logger.debug("Starting worker pool thread") while True: action = action_queue.get() if action == "close": logger.debug("Got None action; exiting thread") worker.cleanup() break worker.send_action(action) state = worker.get_state() state_queue.put((state, env_id)) def _get_state_impl(self, timeout: float | None = None) -> tuple[RLState | SpecialState, int] | None: if timeout is None: return self.state_queue.get() try: return self.state_queue.get(timeout=timeout) except queue.Empty: return None
[docs] def send_action(self, action: RLAction | SpecialAction, worker_id: int) -> None: self.action_queues[worker_id].put(action)
[docs] @classmethod def from_workers( cls, workers: Sequence[BaseEnvironmentWorker[RLState, RLAction]], ) -> "AsyncWorkerPool[RLState, RLAction]": return cls(workers)
[docs]def get_worker_pool( workers: Sequence[BaseEnvironmentWorker[RLState, RLAction]], force_sync: bool = False, ) -> WorkerPool[RLState, RLAction]: if (len(workers) == 1 and isinstance(workers[0], SyncEnvironmentWorker)) or force_sync: return SyncWorkerPool(workers) return AsyncWorkerPool(workers)