ml.tasks.rl.base
The base task and config for reinforcement learning tasks.
This class expects you to implement the following functions:
class MyReinforcementLearningTask(ml.ReinforcementLearningTask[Config, Model, State, Action, Output, Loss]):
def get_actions(self, model: Model, states: list[State], optimal: bool) -> list[Action]:
...
def get_environment(self) -> Environment:
...
def run_model(self, model: Model, batch: tuple[State, Action], state: ml.State) -> Output:
...
def compute_loss(self, model: Model, batch: tuple[State, Action], state: ml.State, output: Output) -> Loss:
...
Additionally, you can implement postprocess_trajectory()
and postprocess_trajectories()
to apply some
postprocessing to collected batches, such as computing the discounted rewards.
- class ml.tasks.rl.base.EnvironmentConfig(num_env_workers: int = 1, env_worker_mode: str = 'process', env_seed: int = 1337, env_cleanup_time: float = 5.0, max_steps: int = 1000)[source]
Bases:
object
- num_env_workers: int = 1
- env_worker_mode: str = 'process'
- env_seed: int = 1337
- env_cleanup_time: float = 5.0
- max_steps: int = 1000
- class ml.tasks.rl.base.DatasetConfig(num_samples: int = 1, num_update_steps: int = '???', stride: int = 1, replay_buffer_sample_size: int = 10000)[source]
Bases:
object
- num_samples: int = 1
- num_update_steps: int = '???'
- stride: int = 1
- replay_buffer_sample_size: int = 10000
- class ml.tasks.rl.base.ReinforcementLearningTaskConfig(max_epochs: int | None = None, max_steps: int | None = None, max_samples: int | None = None, max_seconds: float | None = None, train_dl: ml.tasks.base.DataLoaderConfig = <factory>, valid_dl: ml.tasks.base.DataLoaderConfig = <factory>, test_dl: ml.tasks.base.DataLoaderConfig = <factory>, name: str = '???', errors: ml.tasks.datasets.error_handling.ErrorHandlingConfig = <factory>, environment: ml.tasks.rl.base.EnvironmentConfig = <factory>, dataset: ml.tasks.rl.base.DatasetConfig = <factory>)[source]
Bases:
BaseTaskConfig
- environment: EnvironmentConfig
- dataset: DatasetConfig
- class ml.tasks.rl.base.ReinforcementLearningTask(config: BaseTaskConfigT)[source]
Bases:
BaseTask
[ReinforcementLearningTaskConfigT
,ModelT
,tuple
[RLState
,RLAction
],Output
,Loss
],Generic
[ReinforcementLearningTaskConfigT
,ModelT
,RLState
,RLAction
,Output
,Loss
],ABC
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- abstract get_actions(model: ModelT, states: list[RLState], optimal: bool) list[RLAction] [source]
Samples an action from the policy, given the previous state.
- Parameters:
model – The model to sample from.
states – The previous states.
optimal – Whether to get the optimal action or to sample from the policy.
- Returns:
The next actions to take for each state.
- abstract get_environment() Environment[RLState, RLAction] [source]
Returns the environment for the task.
- Returns:
The environment for the task
- build_rl_dataset(samples: MultiReplaySamples[tuple[RLState, RLAction]]) Dataset[tuple[RLState, RLAction]] [source]
- get_environment_cached() Environment[RLState, RLAction] [source]
- get_environment_workers(force_sync: bool = False) list[ml.tasks.environments.worker.BaseEnvironmentWorker[RLState, RLAction]] [source]
- get_worker_pool(force_sync: bool = False) WorkerPool[RLState, RLAction] [source]
- postprocess_trajectory(samples: list[tuple[RLState, RLAction]]) list[tuple[RLState, RLAction]] [source]
Performs any global postprocessing on the trajectory.
- Parameters:
samples – The trajectory to postprocess.
- Returns:
The postprocessed trajectory.
- postprocess_trajectories(trajectories: list[list[tuple[RLState, RLAction]]]) list[list[tuple[RLState, RLAction]]] [source]
Performs any global postprocessing on all of the trajectories.
- Parameters:
trajectories – The trajectories to postprocess.
- Returns:
The postprocessed trajectories.
- iter_samples(model: ModelT, worker_pool: WorkerPool[RLState, RLAction], *, total_samples: int | None = None, min_trajectory_length: int = 1, max_trajectory_length: int | None = None, min_batch_size: int = 1, max_batch_size: int | None = None, max_wait_time: float | None = None, optimal: bool = True) Iterable[list[tuple[RLState, RLAction]]] [source]
Collects samples from the environment.
- Parameters:
model – The model to sample from.
worker_pool – The pool of workers for the environment
total_samples – The total number of samples to collect; if None, iterates forever
min_trajectory_length – Minimum sequence length to consider a sequence as having contributed to total_samples
max_trajectory_length – Maximum sequence length to consider a sequence as having contributed to total_samples
min_batch_size – Minimum batch size for doing inference on model
max_batch_size – Maximum batch size for doing inference on model
max_wait_time – Maximum amount of time to wait to build batch
optimal – Whether to get the optimal action or to sample from the policy.
- Yields:
Lists of samples from the environment.
- Raises:
ValueError – If min_batch_size is greater than max_batch_size.
- collect_samples(model: ModelT, worker_pool: WorkerPool[RLState, RLAction], total_samples: int, *, min_trajectory_length: int = 1, max_trajectory_length: int | None = None, min_batch_size: int = 1, max_batch_size: int | None = None, max_wait_time: float | None = None, optimal: bool = True) MultiReplaySamples[tuple[RLState, RLAction]] [source]
- sample_clip(*, save_path: str | Path, return_images: Literal[True] = True, return_states: Literal[False] = False, model: ModelT | None = None, writer: Literal['ffmpeg', 'matplotlib', 'av', 'opencv'] = 'ffmpeg', standardize_images: bool = True, optimal: bool = True) None [source]
- sample_clip(*, return_images: Literal[True] = True, return_states: Literal[False] = False, model: ModelT | None = None, standardize_images: bool = True, optimal: bool = True) Tensor
- sample_clip(*, return_images: Literal[True] = True, return_states: Literal[True], model: ModelT | None = None, standardize_images: bool = True, optimal: bool = True) tuple[torch.Tensor, list[tuple[RLState, RLAction]]]
- sample_clip(*, return_images: Literal[False], return_states: Literal[True], model: ModelT | None = None, optimal: bool = True) list[tuple[RLState, RLAction]]
Samples a clip for a given model.
- Parameters:
save_path – Where to save the sampled clips
return_images – Whether to return the images
return_states – Whether to return the states
model – The model to sample from; if not provided, samples actions randomly from the model
writer – The writer to use to save the clip
standardize_images – Whether to standardize the images
optimal – Whether to sample actions optimally
- Returns:
The sampled clip, if save_path is not provided, otherwise None (the clip is saved to save_path).
- Raises:
ValueError – If save_path is provided and return_states is True