"""Defines a meter for computing statistics over a stream of values."""
import functools
from typing import Any, cast
import torch
import torch.distributed as dist
from torch import Tensor
from ml.utils.device.auto import detect_device
[docs]@functools.lru_cache
def get_device() -> torch.device:
return detect_device()._get_device()
[docs]class Meter:
def __init__(self) -> None:
self._min_val: Tensor | None = None
self._max_val: Tensor | None = None
self._total_val: Tensor | None = None
self._num_seen = torch.zeros((0,), dtype=torch.int64)
self.has_reduce_been_called = False
[docs] def add(self, value: int | float) -> None:
device = get_device()
if self._min_val is None or self._max_val is None or self._total_val is None:
self._min_val = torch.tensor(value, dtype=torch.float64, device=device)
self._max_val = torch.tensor(value, dtype=torch.float64, device=device)
self._total_val = torch.tensor(value, dtype=torch.float64, device=device)
else:
self._min_val.clamp_max_(value)
self._max_val.clamp_min_(value)
self._total_val.add_(value)
self._num_seen.add_(1)
[docs] def reduce(self) -> list[Any]:
if self.has_reduce_been_called:
raise RuntimeError("`reduce` should only be called once, otherwise you will end up with incorrect values")
self.has_reduce_been_called = True
# These are actually the work handles, they just don't have proper
# type support yet.
works: list[Any] = []
if self._min_val is not None:
works.append(dist.all_reduce(self._min_val, dist.ReduceOp.MIN, async_op=True))
if self._max_val is not None:
works.append(dist.all_reduce(self._max_val, dist.ReduceOp.MAX, async_op=True))
if self._total_val is not None:
works.append(dist.all_reduce(self._total_val, dist.ReduceOp.SUM, async_op=True))
works.append(dist.all_reduce(self._num_seen, dist.ReduceOp.SUM, async_op=True))
return works
@property
def num_seen(self) -> int:
return cast(int, self._num_seen.item())
@property
def min_val(self) -> float | None:
return None if self._min_val is None else self._min_val.item()
@property
def max_val(self) -> float | None:
return None if self._max_val is None else self._max_val.item()
@property
def mean_val(self) -> float | None:
return None if self._total_val is None else (self._total_val / self._num_seen).item()