ml.tasks.rl.replay

Defines the replay buffer classes for storing sessions from an environment.

class ml.tasks.rl.replay.ReplaySamples(samples: list[T])[source]

Bases: Generic[T]

sample(clip_size: int, stride: int = 1, only_last: bool = False) list[T][source]
class ml.tasks.rl.replay.MultiReplaySamples(samples: list[ml.tasks.rl.replay.ReplaySamples[T]])[source]

Bases: Generic[T]

partition(rank: int, world_size: int) MultiReplaySamples[source]
sample(clip_size: int, stride: int = 1, only_last: bool = False) list[list[T]][source]
class ml.tasks.rl.replay.ReplayDataset(buffer: ~ml.tasks.rl.replay.MultiReplaySamples[~ml.tasks.rl.replay.T], clip_size: int, stride: int = 1, collate_fn: ~typing.Callable[[list[~T]], ~ml.tasks.rl.replay.T] = <function collate>)[source]

Bases: IterableDataset[T], Generic[T]

property buffer_partitioned: MultiReplaySamples[T]