Source code for ml.utils.device.base

"""Utilities for working with devices.

This module contains utilities for working with devices, such as moving tensors
and modules to devices, and getting prefetchers for non-blocking host-to-device
transfers.

The typical flow for using this module is:

.. code-block:: python

    from ml.utils.device.auto import detect_device

    device = detect_device()
    device.module_to(some_module)
    device.tensor_to(some_tensor)
    device.get_prefetcher(some_dataloader)
    device.recursive_apply(some_container, some_func)
"""

import contextlib
import functools
from abc import ABC, abstractmethod
from dataclasses import is_dataclass
from typing import Any, Callable, ContextManager, Generic, Iterable, Iterator, Mapping, Sequence, TypeVar

import numpy as np
import torch
from torch import Tensor, nn
from torch.utils.data.dataloader import DataLoader, _BaseDataLoaderIter

from ml.utils.containers import recursive_apply
from ml.utils.timer import Timer

T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)


[docs]def allow_nonblocking(device_a: torch.device, device_b: torch.device) -> bool: return device_a.type in ("cpu", "cuda") and device_b.type in ("cpu", "cuda")
[docs]class Prefetcher(Iterable[T_co], Generic[T_co]): """Helper class for pre-loading samples into device memory.""" def __init__( self, to_device_func: Callable[[Any], Any], dataloader: DataLoader[T_co], raise_stop_iter: bool = False, ) -> None: super().__init__() self.to_device_func = to_device_func self.dataloader = dataloader self.raise_stop_iter = raise_stop_iter self.next_sample = None self.get_batch_time = 0.0 self.to_device_time = 0.0 self._dataloader_iter: _BaseDataLoaderIter | None = None @property def dataloader_iter(self) -> _BaseDataLoaderIter: if self._dataloader_iter is None: with Timer("starting dataloader"): self._dataloader_iter = iter(self.dataloader) return self._dataloader_iter
[docs] def prefetch(self) -> None: try: with Timer("getting sample from dataloader") as timer: next_sample = next(self.dataloader_iter) self.get_batch_time = timer.elapsed_time with Timer("moving sample to device") as timer: self.next_sample = self.to_device_func(next_sample) self.to_device_time = timer.elapsed_time except StopIteration: self.next_sample = None
[docs] def recursive_chunk(self, item: Any, chunks: int) -> list[Any]: # noqa: ANN401 """Applies a function recursively to tensors in an item. Args: item: The item to apply the function to chunks: The number of output chunks Returns: The item, split into the requested number of chunks """ if isinstance(item, (str, int, float)): return [item] * chunks if isinstance(item, np.ndarray): item = torch.from_numpy(item) if isinstance(item, Tensor): item_chunk_list = list(item.chunk(chunks, dim=0)) assert len(item_chunk_list) == chunks, f"{len(item_chunk_list)=} != {chunks=}" return item_chunk_list if is_dataclass(item): item_chunk_dict = {k: self.recursive_chunk(v, chunks) for k, v in item.__dict__.items()} return [item.__class__(**{k: v[i] for k, v in item_chunk_dict.items()}) for i in range(chunks)] if isinstance(item, Mapping): item_chunk_dict = {k: self.recursive_chunk(v, chunks) for k, v in item.items()} return [{k: v[i] for k, v in item_chunk_dict.items()} for i in range(chunks)] if isinstance(item, Sequence): item_chunk_lists = [self.recursive_chunk(i, chunks) for i in item] return [[k[i] for k in item_chunk_lists] for i in range(chunks)] return item
[docs] @classmethod def recursive_apply(cls, item: Any, func: Callable[[Tensor], Tensor]) -> Any: # noqa: ANN401 return recursive_apply(item, func)
def __iter__(self) -> Iterator[T_co]: # Yields one sample quickly. next_sample = next(self.dataloader_iter) yield self.to_device_func(next_sample) try: self.prefetch() while True: if self.next_sample is None: raise StopIteration sample = self.next_sample self.prefetch() yield sample except StopIteration: # Resets the dataloader if the iteration has completed. self._dataloader_iter = iter(self.dataloader) if self.raise_stop_iter: raise
[docs]class InfinitePrefetcher(Iterable[T_co]): def __init__(self, prefetcher: Prefetcher[T_co]) -> None: self.prefetcher = prefetcher def __iter__(self) -> Iterator[T_co]: while True: for batch in self.prefetcher: yield batch
[docs]class base_device(ABC): # noqa: N801 """Base mixin for different trainer device types.""" def __init__(self) -> None: super().__init__() self._device = self._get_device() self._dtype_fp = self._get_floating_point_type() def __str__(self) -> str: return f"device({self._device.type}, {self._device.index}, {self._dtype_fp})" def __repr__(self) -> str: return str(self)
[docs] @classmethod @abstractmethod def has_device(cls) -> bool: """Detects whether or not the device is available. Returns: If the device is available """
@abstractmethod def _get_device(self) -> torch.device: """Returns the device, for instantiating new tensors. Returns: The device """ @abstractmethod def _get_floating_point_type(self) -> torch.dtype: """Returns the default floating point type to use. Returns: The dtype """
[docs] @abstractmethod def get_torch_compile_backend(self) -> str | Callable: """Returns the backend to use for Torch compile. Returns: The backend """
[docs] def sample_to_device(self, sample: T) -> T: return Prefetcher.recursive_apply( sample, lambda t: t.to( self._device, self._dtype_fp if t.is_floating_point() else t.dtype, non_blocking=allow_nonblocking(t.device, self._device), ), )
[docs] def get_prefetcher(self, dataloader: DataLoader) -> Prefetcher: return Prefetcher(functools.partial(self.sample_to_device), dataloader)
[docs] def module_to(self, module: nn.Module, with_dtype: bool = False) -> None: if with_dtype: module.to(self._device, self._dtype_fp) else: module.to(self._device)
[docs] def tensor_to(self, tensor: np.ndarray | Tensor) -> Tensor: if isinstance(tensor, np.ndarray): tensor = torch.from_numpy(tensor) if tensor.is_floating_point(): return tensor.to(self._device, self._dtype_fp) return tensor.to(self._device)
[docs] def recursive_apply(self, item: T) -> T: def func(i: Tensor) -> Tensor: if isinstance(i, Tensor): return self.tensor_to(i) return i return recursive_apply(item, func)
[docs] def autocast_context(self, enabled: bool = True) -> ContextManager: device_type = self._device.type if device_type not in ("cpu", "cuda"): return contextlib.nullcontext() if device_type == "cpu" and self._dtype_fp != torch.bfloat16: return contextlib.nullcontext() return torch.autocast( device_type=device_type, dtype=self._dtype_fp, enabled=enabled, )
[docs] def supports_grad_scaler(self) -> bool: return False