ml.utils.device.base

Utilities for working with devices.

This module contains utilities for working with devices, such as moving tensors and modules to devices, and getting prefetchers for non-blocking host-to-device transfers.

The typical flow for using this module is:

from ml.utils.device.auto import detect_device

device = detect_device()
device.module_to(some_module)
device.tensor_to(some_tensor)
device.get_prefetcher(some_dataloader)
device.recursive_apply(some_container, some_func)
ml.utils.device.base.allow_nonblocking(device_a: device, device_b: device) bool[source]
class ml.utils.device.base.Prefetcher(to_device_func: Callable[[Any], Any], dataloader: DataLoader[T_co], raise_stop_iter: bool = False)[source]

Bases: Iterable[T_co], Generic[T_co]

Helper class for pre-loading samples into device memory.

property dataloader_iter: _BaseDataLoaderIter
prefetch() None[source]
recursive_chunk(item: Any, chunks: int) list[Any][source]

Applies a function recursively to tensors in an item.

Parameters:
  • item – The item to apply the function to

  • chunks – The number of output chunks

Returns:

The item, split into the requested number of chunks

classmethod recursive_apply(item: Any, func: Callable[[Tensor], Tensor]) Any[source]
class ml.utils.device.base.InfinitePrefetcher(prefetcher: Prefetcher[T_co])[source]

Bases: Iterable[T_co]

class ml.utils.device.base.base_device[source]

Bases: ABC

Base mixin for different trainer device types.

abstract classmethod has_device() bool[source]

Detects whether or not the device is available.

Returns:

If the device is available

abstract get_torch_compile_backend() str | Callable[source]

Returns the backend to use for Torch compile.

Returns:

The backend

sample_to_device(sample: T) T[source]
get_prefetcher(dataloader: DataLoader) Prefetcher[source]
module_to(module: Module, with_dtype: bool = False) None[source]
tensor_to(tensor: ndarray | Tensor) Tensor[source]
recursive_apply(item: T) T[source]
autocast_context(enabled: bool = True) ContextManager[source]
supports_grad_scaler() bool[source]