ml.tasks.rl.base

The base task and config for reinforcement learning tasks.

This class expects you to implement the following functions:

class MyReinforcementLearningTask(ml.ReinforcementLearningTask[Config, Model, State, Action, Output, Loss]):
    def get_actions(self, model: Model, states: list[State], optimal: bool) -> list[Action]:
        ...

    def get_environment(self) -> Environment:
        ...

    def run_model(self, model: Model, batch: tuple[State, Action], state: ml.State) -> Output:
        ...

    def compute_loss(self, model: Model, batch: tuple[State, Action], state: ml.State, output: Output) -> Loss:
        ...

Additionally, you can implement postprocess_trajectory() and postprocess_trajectories() to apply some postprocessing to collected batches, such as computing the discounted rewards.

class ml.tasks.rl.base.EnvironmentConfig(num_env_workers: int = 1, env_worker_mode: str = 'process', env_seed: int = 1337, env_cleanup_time: float = 5.0, max_steps: int = 1000)[source]

Bases: object

num_env_workers: int = 1
env_worker_mode: str = 'process'
env_seed: int = 1337
env_cleanup_time: float = 5.0
max_steps: int = 1000
class ml.tasks.rl.base.DatasetConfig(num_samples: int = 1, num_update_steps: int = '???', stride: int = 1, replay_buffer_sample_size: int = 10000)[source]

Bases: object

num_samples: int = 1
num_update_steps: int = '???'
stride: int = 1
replay_buffer_sample_size: int = 10000
class ml.tasks.rl.base.ReinforcementLearningTaskConfig(max_epochs: int | None = None, max_steps: int | None = None, max_samples: int | None = None, max_seconds: float | None = None, train_dl: ml.tasks.base.DataLoaderConfig = <factory>, valid_dl: ml.tasks.base.DataLoaderConfig = <factory>, test_dl: ml.tasks.base.DataLoaderConfig = <factory>, name: str = '???', errors: ml.tasks.datasets.error_handling.ErrorHandlingConfig = <factory>, environment: ml.tasks.rl.base.EnvironmentConfig = <factory>, dataset: ml.tasks.rl.base.DatasetConfig = <factory>)[source]

Bases: BaseTaskConfig

environment: EnvironmentConfig
dataset: DatasetConfig
class ml.tasks.rl.base.ReinforcementLearningTask(config: BaseTaskConfigT)[source]

Bases: BaseTask[ReinforcementLearningTaskConfigT, ModelT, tuple[RLState, RLAction], Output, Loss], Generic[ReinforcementLearningTaskConfigT, ModelT, RLState, RLAction, Output, Loss], ABC

Initializes internal Module state, shared by both nn.Module and ScriptModule.

abstract get_actions(model: ModelT, states: list[RLState], optimal: bool) list[RLAction][source]

Samples an action from the policy, given the previous state.

Parameters:
  • model – The model to sample from.

  • states – The previous states.

  • optimal – Whether to get the optimal action or to sample from the policy.

Returns:

The next actions to take for each state.

abstract get_environment() Environment[RLState, RLAction][source]

Returns the environment for the task.

Returns:

The environment for the task

build_rl_dataset(samples: MultiReplaySamples[tuple[RLState, RLAction]]) Dataset[tuple[RLState, RLAction]][source]
get_environment_cached() Environment[RLState, RLAction][source]
get_environment_workers(force_sync: bool = False) list[ml.tasks.environments.worker.BaseEnvironmentWorker[RLState, RLAction]][source]
get_worker_pool(force_sync: bool = False) WorkerPool[RLState, RLAction][source]
postprocess_trajectory(samples: list[tuple[RLState, RLAction]]) list[tuple[RLState, RLAction]][source]

Performs any global postprocessing on the trajectory.

Parameters:

samples – The trajectory to postprocess.

Returns:

The postprocessed trajectory.

postprocess_trajectories(trajectories: list[list[tuple[RLState, RLAction]]]) list[list[tuple[RLState, RLAction]]][source]

Performs any global postprocessing on all of the trajectories.

Parameters:

trajectories – The trajectories to postprocess.

Returns:

The postprocessed trajectories.

iter_samples(model: ModelT, worker_pool: WorkerPool[RLState, RLAction], *, total_samples: int | None = None, min_trajectory_length: int = 1, max_trajectory_length: int | None = None, min_batch_size: int = 1, max_batch_size: int | None = None, max_wait_time: float | None = None, optimal: bool = True) Iterable[list[tuple[RLState, RLAction]]][source]

Collects samples from the environment.

Parameters:
  • model – The model to sample from.

  • worker_pool – The pool of workers for the environment

  • total_samples – The total number of samples to collect; if None, iterates forever

  • min_trajectory_length – Minimum sequence length to consider a sequence as having contributed to total_samples

  • max_trajectory_length – Maximum sequence length to consider a sequence as having contributed to total_samples

  • min_batch_size – Minimum batch size for doing inference on model

  • max_batch_size – Maximum batch size for doing inference on model

  • max_wait_time – Maximum amount of time to wait to build batch

  • optimal – Whether to get the optimal action or to sample from the policy.

Yields:

Lists of samples from the environment.

Raises:

ValueError – If min_batch_size is greater than max_batch_size.

collect_samples(model: ModelT, worker_pool: WorkerPool[RLState, RLAction], total_samples: int, *, min_trajectory_length: int = 1, max_trajectory_length: int | None = None, min_batch_size: int = 1, max_batch_size: int | None = None, max_wait_time: float | None = None, optimal: bool = True) MultiReplaySamples[tuple[RLState, RLAction]][source]
epoch_is_over(state: State) bool[source]
sample_clip(*, save_path: str | Path, return_images: Literal[True] = True, return_states: Literal[False] = False, model: ModelT | None = None, writer: Literal['ffmpeg', 'matplotlib', 'av', 'opencv'] = 'ffmpeg', standardize_images: bool = True, optimal: bool = True) None[source]
sample_clip(*, return_images: Literal[True] = True, return_states: Literal[False] = False, model: ModelT | None = None, standardize_images: bool = True, optimal: bool = True) Tensor
sample_clip(*, return_images: Literal[True] = True, return_states: Literal[True], model: ModelT | None = None, standardize_images: bool = True, optimal: bool = True) tuple[torch.Tensor, list[tuple[RLState, RLAction]]]
sample_clip(*, return_images: Literal[False], return_states: Literal[True], model: ModelT | None = None, optimal: bool = True) list[tuple[RLState, RLAction]]

Samples a clip for a given model.

Parameters:
  • save_path – Where to save the sampled clips

  • return_images – Whether to return the images

  • return_states – Whether to return the states

  • model – The model to sample from; if not provided, samples actions randomly from the model

  • writer – The writer to use to save the clip

  • standardize_images – Whether to standardize the images

  • optimal – Whether to sample actions optimally

Returns:

The sampled clip, if save_path is not provided, otherwise None (the clip is saved to save_path).

Raises:

ValueError – If save_path is provided and return_states is True