Source code for ml.tasks.datasets.multi_iter

"""Defines a dataset for iterating from multiple sub-datasets.

It's often the case that you want to write a dataset for iterating from a
single sample, then combine all those datasets into one mega-dataset for
iterating from all the samples. This dataset serves that purpose by, at each
iteration, randomly choosing a dataset and getting it's next sample, until
all samples in all datasets have been exhausted.
"""

import random
from dataclasses import dataclass
from typing import Generic, Iterable, Iterator, TypeVar

import numpy as np
from torch.utils.data.dataset import IterableDataset

T = TypeVar("T")


[docs]@dataclass class DatasetInfo(Generic[T]): dataset: IterableDataset[T] sampling_rate: float = 1.0
[docs]class MultiIterDataset(IterableDataset[T]): def __init__( self, datasets: Iterable[DatasetInfo[T]], *, until_all_empty: bool = False, iterate_forever: bool = False, ) -> None: """Defines a dataset for iterating from multiple iterable datasets. Args: datasets: The information about the datasets to iterate from and how to iterate them; specifically, the sampling rate of each dataset. until_all_empty: If set, iterates until all datasets are empty, otherwise only iterate until any dataset is empty iterate_forever: If set, iterate child dataset forever """ super().__init__() self.datasets = list(datasets) assert all(i.sampling_rate > 0 for i in self.datasets) self.until_all_empty = until_all_empty self.iterate_forever = iterate_forever iterators: list[Iterator[T]] rate_cumsum: np.ndarray def __iter__(self) -> Iterator[T]: self.rate_cumsum = np.concatenate([np.array([0]), np.cumsum([i.sampling_rate for i in self.datasets])]) self.iterators = [i.dataset.__iter__() for i in self.datasets] return self def __next__(self) -> T: while True: val = random.random() * self.rate_cumsum[-1] idx = np.searchsorted(self.rate_cumsum, val, side="right") - 1 iterator = self.iterators[idx] try: return iterator.__next__() except StopIteration: if not (self.until_all_empty or self.iterate_forever) or len(self.iterators) == 1: raise if self.iterate_forever: # Restart iterator. self.iterators[idx] = self.datasets[idx].dataset.__iter__() else: self.iterators.pop(idx) lhs, rhs = self.rate_cumsum[:idx], self.rate_cumsum[idx + 1 :] - self.rate_cumsum[idx] self.rate_cumsum = np.concatenate([lhs, rhs])