"""Defines a wrapper for caching function calls to a file location."""
import functools
import json
import logging
import pickle
from typing import Any, Callable, Generic, Literal, Mapping, Sequence, TypeVar, get_args
import numpy as np
from ml.core.env import get_cache_dir
logger: logging.Logger = logging.getLogger(__name__)
Object = TypeVar("Object", bound=Any)
Tk = TypeVar("Tk")
Tv = TypeVar("Tv")
CacheType = Literal["pkl", "json"]
[docs]class cached_object: # noqa: N801
def __init__(self, cache_key: str, ext: CacheType = "pkl", ignore: bool = False, cache_obj: bool = True) -> None:
"""Defines a wrapper for caching function calls to a file location.
This is just a convenient way of caching heavy operations to disk,
using a specific key.
Args:
cache_key: The key to use for caching the file
ext: The caching type to use (JSON or pickling)
ignore: Should the cache be ignored?
cache_obj: If set, keep the object around to avoid deserializing it
when it is accessed again
"""
self.cache_key = cache_key
self.ext = ext
self.obj = None
self.ignore = ignore
self.cache_obj = cache_obj
assert ext in get_args(CacheType), f"Unexpected extension: {ext}"
def __call__(self, func: Callable[..., Object]) -> Callable[..., Object]:
"""Returns a wrapped function that caches the return value.
Args:
func: The function to cache, which returns the object to load
Returns:
A cached version of the same function
"""
@functools.wraps(func)
def call_function_cached(*args: Any, **kwargs: Any) -> Object: # noqa: ANN401
if self.obj is not None:
return self.obj
keys: list[str] = []
for arg in args:
keys += [str(arg)]
for key, val in sorted(kwargs.items()):
keys += [f"{key}_{val}"]
key = ".".join(keys)
fpath = get_cache_dir() / self.cache_key / f"{key}.{self.ext}"
if fpath.is_file() and not self.ignore:
logger.debug("Loading cached object from %s", fpath)
if self.ext == "json":
with open(fpath, "r", encoding="utf-8") as f:
return json.load(f)
if self.ext == "pkl":
with open(fpath, "rb") as fb:
return pickle.load(fb)
raise NotImplementedError(f"Can't load extension {self.ext}")
obj = func(*args, **kwargs)
if self.cache_obj:
self.obj = obj
logger.debug("Saving cached object to %s", fpath)
fpath.parent.mkdir(exist_ok=True, parents=True)
if self.ext == "json":
with open(fpath, "w", encoding="utf-8") as f:
json.dump(obj, f)
return obj
if self.ext == "pkl":
with open(fpath, "wb") as fb:
pickle.dump(obj, fb)
return obj
raise NotImplementedError(f"Can't save extension {self.ext}")
return call_function_cached
[docs]class DictIndex(Generic[Tk, Tv]):
def __init__(self, items: Mapping[Tk, Sequence[Tv]]) -> None:
"""Indexes a dictionary with values that are lists.
This lazily indexes all the values in the provided dictionary, flattens
them out and allows them to be looked up by a specific index. This is
analogous to PyTorch's ConcatDataset.
Args:
items: The dictionary to index
"""
self.items = items
@functools.cached_property
def _item_list(self) -> list[tuple[Tk, list[Tv]]]:
return [(k, list(v)) for k, v in self.items.items()]
@functools.cached_property
def _indices(self) -> np.ndarray:
return np.concatenate([np.array([0]), np.cumsum(np.array([len(i) for _, i in self._item_list]))])
def __getitem__(self, index: int) -> tuple[Tk, Tv]:
a = np.searchsorted(self._indices, index, side="right") - 1
b = index - self._indices[a]
key, values = self._item_list[a]
value = values[b]
return key, value
def __len__(self) -> int:
return self._indices[-1].item()