ml.tasks.rl.replay
Defines the replay buffer classes for storing sessions from an environment.
- class ml.tasks.rl.replay.MultiReplaySamples(samples: list[ml.tasks.rl.replay.ReplaySamples[T]])[source]
Bases:
Generic
[T
]- partition(rank: int, world_size: int) MultiReplaySamples [source]
- class ml.tasks.rl.replay.ReplayDataset(buffer: ~ml.tasks.rl.replay.MultiReplaySamples[~ml.tasks.rl.replay.T], clip_size: int, stride: int = 1, collate_fn: ~typing.Callable[[list[~T]], ~ml.tasks.rl.replay.T] = <function collate>)[source]
Bases:
IterableDataset
[T
],Generic
[T
]- property buffer_partitioned: MultiReplaySamples[T]