Source code for ml.tasks.datasets.async_iterable

"""Defines a dataset for asynchronous iteration.

This dataclass is useful when you are trying to use Python's ``async`` /
``await`` syntax to iterate over a dataset. This just starts a separate thread
that runs the async iterator and puts the results into a queue, which is then
used to iterate over the dataset.

Example::

    class MyDataset(AsyncIterableDataset):
        async def __aiter__(self) -> AsyncIterator[T]:
            for i in range(10):
                yield i

    for i in MyDataset():
        print(i)
"""

import asyncio
import logging
import queue
import threading
from typing import AsyncIterator, Iterator, TypeVar

from torch.utils.data.dataset import IterableDataset

logger: logging.Logger = logging.getLogger(__name__)

T = TypeVar("T")


[docs]async def add_to_queue(async_iter: AsyncIterator[T], q: "queue.Queue[T | None]") -> None: try: async for item in async_iter: assert item is not None, "Item should not be None" q.put(item) finally: q.put(None)
[docs]def thread_worker(async_iter: AsyncIterator[T], q: "queue.Queue[T | None]") -> None: asyncio.run(add_to_queue(async_iter, q))
[docs]def thread_async_iter(async_iter: AsyncIterator[T], max_queue_size: int) -> Iterator[T]: q: "queue.Queue[T | None]" = queue.Queue(maxsize=max_queue_size) thread = threading.Thread(target=thread_worker, args=(async_iter, q), daemon=True) thread.start() while True: item = q.get(block=True) if item is None: break yield item thread.join()
[docs]class AsyncIterableDataset(IterableDataset[T]): def __init__(self, max_async_queue_size: int = 2) -> None: super().__init__() # The async iterator blocks on the queue if it has more than this many # elements, in order to avoid having the queue get too large. self.max_async_queue_size = max_async_queue_size def __aiter__(self) -> AsyncIterator[T]: raise NotImplementedError def __iter__(self) -> Iterator[T]: return thread_async_iter(self.__aiter__(), self.max_async_queue_size)