"""Defines a timer context manager for timing code blocks.
This also provides a simple spinner for long-running tasks.
"""
import errno
import functools
import logging
import os
import signal
import sys
import threading
import time
import warnings
from threading import Thread
from types import TracebackType
from typing import Any, Callable, ContextManager, Generic, Iterable, Iterator, ParamSpec, Sequence, Sized, TypeVar
from ml.utils.colors import colorize
from ml.utils.distributed import is_master
timer_logger: logging.Logger = logging.getLogger(__name__)
T = TypeVar("T")
P = ParamSpec("P")
[docs]@functools.lru_cache
def allow_spinners() -> bool:
return (
"PYTEST_CURRENT_TEST" not in os.environ
and "pytest" not in sys.modules
and sys.stdout.isatty()
and os.environ.get("TERM") != "dumb"
and is_master()
)
[docs]class Spinner:
def __init__(self, text: str | None = None) -> None:
self._text = "" if text is None else text
self._max_line_len = 0
self._spinner_stop = False
self._spinner_close = False
self._flag = threading.Event()
self._thread = Thread(target=self._spinner, daemon=True)
self._thread.start()
# If we're in a breakpoint, we want to close the spinner when we exit
# the breakpoint.
self._original_breakpointhook = sys.breakpointhook
def _breakpointhook(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
warnings.warn("Breakpoint hit inside spinner; run `up 1` to see where it was hit")
self.stop()
sys.breakpointhook(*args, **kwargs)
[docs] def set_text(self, text: str) -> "Spinner":
sys.stderr.write(" " * (self._max_line_len + 1) + "\r")
sys.stderr.flush()
self._max_line_len = 0
self._text = colorize(text, "grey")
return self
[docs] def start(self) -> None:
self._spinner_stop = False
self._flag.set()
sys.breakpointhook = self._breakpointhook
[docs] def stop(self) -> None:
self._spinner_stop = True
sys.breakpointhook = self._original_breakpointhook
[docs] def close(self) -> None:
self.stop()
self._spinner_close = True
self._thread.join()
def _spinner(self) -> None:
chars = [colorize(c, "light-yellow") for c in ("|", "/", "-", "\\")]
while not self._spinner_close:
self._flag.wait()
start_time = time.time()
while not self._spinner_stop:
time.sleep(0.1)
char = chars[int((time.time() * 10) % len(chars))]
elapsed_secs = time.time() - start_time
line = f"[ {char} {elapsed_secs:.1f} ] {self._text}\r"
self._max_line_len = max(self._max_line_len, len(line))
sys.stderr.write(line)
sys.stderr.flush()
sys.stderr.write(" " * (self._max_line_len + 1) + "\r")
sys.stderr.flush()
self._flag.clear()
[docs]@functools.lru_cache
def spinner() -> Spinner:
return Spinner()
[docs]class Timer(ContextManager):
"""Defines a simple timer for logging an event."""
def __init__(
self,
description: str,
min_seconds_to_print: float = 5.0,
logger: logging.Logger | None = None,
spinner: bool = False,
) -> None:
self.description = description
self.min_seconds_to_print = min_seconds_to_print
self._start_time: float | None = None
self._elapsed_time: float | None = None
self._logger = timer_logger if logger is None else logger
self._use_spinner = spinner and allow_spinners()
@property
def elapsed_time(self) -> float:
assert (elapsed_time := self._elapsed_time) is not None
return elapsed_time
def __enter__(self) -> "Timer":
self._start_time = time.time()
if self._use_spinner:
spinner().set_text(self.description).start()
return self
def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
assert self._start_time is not None
self._elapsed_time = time.time() - self._start_time
if self._elapsed_time > self.min_seconds_to_print:
self._logger.warning("Finished %s in %.3g seconds", self.description, self._elapsed_time)
spinner().stop()
[docs]class spinnerator(Generic[T]): # noqa: N801
"""Defines a spinning iterator which uses the built-in spinner."""
def __init__(
self,
items: Sequence[T] | Iterable[T],
desc: str | None = None,
total: int | None = None,
logger: logging.Logger | None = None,
) -> None:
self._items = items
self._desc = "Processing..." if desc is None else desc
self._num_items: int = 0
self._total_items: int | None = total
self._iter: Iterator[T] | None = None
self._logger = timer_logger if logger is None else logger
self._use_spinner = allow_spinners()
@property
def desc(self) -> str:
n, t = self._num_items, self._total_items
processed_string = f"{n}" if t is None else f"{n}/{t} ({n/t:.0%})"
return " ".join((self._desc, processed_string))
[docs] @classmethod
def range(
cls,
start: int | None,
stop: int | None = None,
step: int = 1,
desc: str | None = None,
logger: logging.Logger | None = None,
) -> "spinnerator[int]":
if start is None:
return spinnerator([], desc=desc, logger=logger)
if stop is None:
return spinnerator(
range(0, start, step),
desc=desc,
logger=logger,
)
return spinnerator(
range(start, stop, step),
desc=desc,
logger=logger,
)
[docs] def update(self, n: int = 1) -> None:
self._num_items += n
if self._use_spinner:
spinner().set_text(self.desc)
def __enter__(self) -> "spinnerator[T]":
if self._use_spinner:
spinner().set_text(self.desc).start()
return self
def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
if self._use_spinner:
spinner().stop()
def __iter__(self) -> Iterator[T]:
assert self._items is not None, "Must provide items to iterate over"
self._num_items = 0
if self._total_items is None and isinstance(self._items, Sized):
self._total_items = len(self._items)
if self._use_spinner:
spinner().set_text(self.desc).start()
self._iter = iter(self._items)
return self
def __next__(self) -> T:
assert self._iter is not None, "Must call __iter__ before __next__"
try:
item = next(self._iter)
except Exception:
spinner().stop()
raise
self._num_items += 1
if self._use_spinner:
spinner().set_text(self.desc)
return item
[docs]def timeout(seconds: int, error_message: str = os.strerror(errno.ETIME)) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator for timing out long-running functions.
Note that this function won't work on Windows.
Args:
seconds: Timeout after this many seconds
error_message: Error message to pass to TimeoutError
Returns:
Decorator function
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
def _handle_timeout(*_: Any) -> None: # noqa: ANN401
raise TimeoutError(error_message)
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(seconds)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0)
return result
return wrapper
return decorator