Source code for ml.tasks.datasets.video_file

"""Defines a dataset which iterates through frames in a video file."""

from pathlib import Path
from typing import Callable, Iterator

import numpy as np
from torch import Tensor
from torch.utils.data.dataset import IterableDataset

from ml.utils.numpy import as_cpu_tensor
from ml.utils.video import Reader, VideoProps, read_video


[docs]class VideoFileDataset(IterableDataset[Tensor]): def __init__( self, file_path: str | Path, reader: Reader = "ffmpeg", transform: None | Callable[[Tensor], Tensor] = None, ) -> None: """Defines a dataset which iterates through frames in a video file. Args: file_path: The path to the video file to iterate through reader: The video reader to use transform: An optional transform to apply to each frame """ super().__init__() self.file_path = str(file_path) self.reader = reader self.transform = transform video_props: VideoProps video_stream: Iterator[np.ndarray | Tensor] def __iter__(self) -> Iterator[Tensor]: self.video_props = VideoProps.from_file_ffmpeg(self.file_path) self.video_stream = read_video(self.file_path, reader=self.reader) return self def __next__(self) -> Tensor: buffer = next(self.video_stream) image = as_cpu_tensor(buffer).permute(2, 0, 1) # HWC -> CHW if self.transform is not None: image = self.transform(image) return image