Source code for ml.utils.containers

"""Helper functions for applying functions to containers."""

from dataclasses import is_dataclass
from typing import Any, Callable, Iterable, Mapping, Sequence

from torch import Tensor


[docs]def recursive_apply(item: Any, func: Callable[[Tensor], Tensor]) -> Any: # noqa: ANN401 """Applies a function recursively to tensors in an item. Args: item: The item to apply the function to func: The function to apply (for the tensor) Returns: The same item, with the function applied """ if isinstance(item, (str, int, float)): return item if isinstance(item, Tensor): return func(item) if is_dataclass(item): return item.__class__(**{k: recursive_apply(v, func) for k, v in item.__dict__.items()}) if isinstance(item, Mapping): return {k: recursive_apply(v, func) for k, v in item.items()} if isinstance(item, Sequence): return [recursive_apply(i, func) for i in item] return item
[docs]def recursive_chunk(item: Any, num_chunks: int, dim: int = 0) -> Iterable[Any]: # noqa: ANN401 """Recursively chunk tensors N times. Args: item: The item to recursively chunk num_chunks: The number of splits to make dim: The split dimension Yields: N chunks of items """ if isinstance(item, (str, int, float)): yield from (item for _ in range(num_chunks)) elif isinstance(item, Tensor): yield from item.chunk(num_chunks, dim=dim) elif is_dataclass(item): yield from ( item.__class__(**{k: i for k, i in zip(item.__dict__, ii)}) for ii in zip(*(recursive_chunk(v, num_chunks, dim) for v in item.__dict__.values())) ) elif isinstance(item, Mapping): yield from (dict(zip(item, ii)) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item.values()))) elif isinstance(item, Sequence): yield from (list(ii) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item))) else: yield from (item for _ in range(num_chunks))