Source code for ml.tasks.datasets.clippify

"""Defines a dataset for converting frame streams to clips."""

import logging
import random
from collections import deque
from typing import Deque, Generic, Iterator, TypeVar

import torch
from torch import Tensor
from torch.utils.data.dataset import IterableDataset

from ml.tasks.datasets.collate import collate

logger: logging.Logger = logging.getLogger(__name__)

Batch = TypeVar("Batch")


[docs]class ClippifyDataset(IterableDataset[tuple[Tensor, Batch]], Generic[Batch]): def __init__( self, image_dataset: IterableDataset[Batch], num_images: int, *, stride: int = 1, jump_size: int = 1, sample_first: bool = False, use_last: bool = False, ) -> None: """Defines a dataset which efficiently yields sequences of images. The underlying image dataset just needs to iterate through a sequence of images, in order. This wrapper dataset collates the images into clips with some striding between adjacent images. Images are inserted into a deque and routinely popped. The underlying dataset should do necessary error handling, since this dataset will simply throw an error on failure. Args: image_dataset: The child dataset which yields the images num_images: The number of images in each clip stride: The stride between adjacent images jump_size: How many frames to jump in the future sample_first: If set, don't always start on the first item; instead, sample the first item within `jump_size` use_last: If set, always use the last item in the dataset """ super().__init__() self.image_dataset = image_dataset self.num_images = num_images self.stride = stride self.jump_size = jump_size self.sample_first = sample_first self.use_last = use_last image_iter: Iterator[Batch] inds: list[int] image_queue: Deque[tuple[int, Batch]] image_ptr: int image_queue_ptr: int hit_last: bool def __iter__(self) -> Iterator[tuple[Tensor, Batch]]: self.image_iter = self.image_dataset.__iter__() self.inds = [i * self.stride for i in range(self.num_images)] self.image_queue = deque() self.image_ptr = random.randint(0, self.jump_size - 1) if self.sample_first else 0 self.image_queue_ptr = 0 self.hit_last = False return self def __next__(self) -> tuple[Tensor, Batch]: if self.hit_last: raise StopIteration inds = self.inds # Pushes images up to the last index. while self.image_queue_ptr <= inds[-1] + self.image_ptr: try: image = next(self.image_iter) self.image_queue.append((self.image_queue_ptr, image)) self.image_queue_ptr += 1 except StopIteration: if not self.use_last: raise # Points `image_ptr` at the first index in the last sequence. # If this index isn't in the queue, raise anyway. self.image_ptr = self.image_queue_ptr - inds[-1] - 1 if self.image_queue[0][0] > inds[0] + self.image_ptr: raise # Flag to avoid another iteration. self.hit_last = True break # Pops all of the images which are below the first index. while self.image_queue[0][0] < inds[0] + self.image_ptr: self.image_queue.popleft() # Collates the images to return. item = collate([self.image_queue[i][1] for i in inds]) assert item is not None, "Indices are empty!" self.image_ptr += self.jump_size return torch.IntTensor(inds), item