"""Defines a metered logger.
This logger keeps track of statistics of logged values. It is useful for
getting global statistics during evaluation.
"""
from dataclasses import dataclass
from typing import Any, Callable, Iterable
from torch import Tensor
from ml.core.registry import register_logger
from ml.core.state import Phase, State
from ml.loggers.base import BaseLogger, BaseLoggerConfig
from ml.utils.meter import Meter
[docs]def get_value(value: int | float | Tensor) -> int | float:
if isinstance(value, (int, float)):
return value
if isinstance(value, Tensor):
return value.detach().float().cpu().item()
raise TypeError(f"Unexpected log type: {type(value)}")
[docs]@dataclass
class MeterLoggerConfig(BaseLoggerConfig):
pass
[docs]@register_logger("meter", MeterLoggerConfig)
class MeterLogger(BaseLogger[MeterLoggerConfig]):
def __init__(self, config: MeterLoggerConfig) -> None:
super().__init__(config)
self.meters: dict[Phase, dict[str, dict[str, Meter]]] = {}
[docs] def get_meter(self, state: State, key: str, namespace: str | None) -> Meter:
if namespace is None:
namespace = "default"
if state.phase not in self.meters:
self.meters[state.phase] = {}
if namespace not in self.meters[state.phase]:
self.meters[state.phase][namespace] = {}
return self.meters[state.phase][namespace][key]
[docs] def log_scalar(self, key: str, value: Callable[[], int | float | Tensor], state: State, namespace: str) -> None:
self.get_meter(state, key, namespace).add(get_value(value()))
[docs] def iter_meters(self) -> Iterable[Meter]:
for v in self.meters.values():
for vv in v.values():
for vvv in vv.values():
yield vvv
[docs] def get_value_dict(self) -> dict[str, int | float]:
# First, reduces the meters.
works: list[Any] = []
for meter in self.iter_meters():
works.extend(meter.reduce())
for work in works:
work.wait()
# Next, builds the output dictionaries.
out_dict: dict[str, int | float] = {}
for phase, phase_meters in self.meters.items():
for namespace, namespace_meters in phase_meters.items():
for key, meter in namespace_meters.items():
abs_key = f"{phase}/{namespace}/{key}"
if meter.min_val is not None:
out_dict[f"{abs_key}/min"] = meter.min_val
if meter.max_val is not None:
out_dict[f"{abs_key}/max"] = meter.max_val
if meter.mean_val is not None:
out_dict[f"{abs_key}/mean"] = meter.mean_val
return out_dict
[docs] def write(self, state: State) -> None:
pass
[docs] def default_write_every_n_seconds(self, state: State) -> float:
return 0.0