"""Defines the base logger.
New loggers should implement whichever `log_` functions they need to.
Unimplemented functions are simply ignored. The framework will handle logging
rate limiting and munging the logged values to a common format.
The internal ergonomics for logging are a bit confusing to follow. Each
component has access to a ``MultiLogger`` which it can use to log values.
After each step, each ``MultiLogger`` sends its values to any implemented
loggers which have ``should_write`` return ``True``. This lets the implemented
loggers aggregate values over multiple ``MultiLogger``s. New loggers should
follow the implementation of one of the existing loggers.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Generic, TypeVar, Union
from omegaconf import DictConfig
from torch import Tensor
from ml.core.config import BaseConfig, BaseObject, conf_field
from ml.core.state import Phase, State
Number = Union[int, float, Tensor]
[docs]@dataclass
class BaseLoggerConfig(BaseConfig):
write_every_n_seconds: float | None = conf_field(None, help="Only write a log line every N seconds")
write_train_every_n_seconds: float | None = conf_field(None, help="Only write a train log line every N seconds")
write_val_every_n_seconds: float | None = conf_field(None, help="Only write a val log line every N seconds")
LoggerConfigT = TypeVar("LoggerConfigT", bound=BaseLoggerConfig)
[docs]class BaseLogger(BaseObject[LoggerConfigT], Generic[LoggerConfigT], ABC):
"""Defines the base logger."""
log_directory: Path
def __init__(self, config: LoggerConfigT) -> None:
super().__init__(config)
self.last_write_time: dict[Phase, float] = {}
[docs] def initialize(self, log_directory: Path) -> None:
self.log_directory = log_directory
[docs] def log_scalar(self, key: str, value: Callable[[], Number], state: State, namespace: str) -> None:
"""Logs a scalar value.
Args:
key: The key to log
value: The value to log
state: The current log state
namespace: The namespace to be logged
"""
[docs] def log_string(self, key: str, value: Callable[[], str], state: State, namespace: str) -> None:
"""Logs a string value.
Args:
key: The key to log
value: The value to log
state: The current log state
namespace: The namespace to be logged
"""
[docs] def log_image(self, key: str, value: Callable[[], Tensor], state: State, namespace: str) -> None:
"""Logs a normalized image, with shape (C, H, W).
Args:
key: The key to log
value: The value to log
state: The current log state
namespace: The namespace to be logged
"""
[docs] def log_audio(self, key: str, value: Callable[[], tuple[Tensor, int]], state: State, namespace: str) -> None:
"""Logs a normalized audio, with shape (T,).
Args:
key: The key to log
value: The value to log
state: The current log state
namespace: The namespace to be logged
"""
[docs] def log_video(self, key: str, value: Callable[[], Tensor], state: State, namespace: str) -> None:
"""Logs a normalized video, with shape (T, C, H, W).
Args:
key: The key to log
value: The value to log
state: The current log state
namespace: The namespace to be logged
"""
[docs] def log_histogram(self, key: str, value: Callable[[], Tensor], state: State, namespace: str) -> None:
"""Logs a histogram, with any shape.
Args:
key: The key to log
value: The value to log
state: The current log state
namespace: The namespace to be logged
"""
[docs] def log_point_cloud(self, key: str, value: Callable[[], Tensor], state: State, namespace: str) -> None:
"""Logs a normalized point cloud, with shape (B, N, 3).
Args:
key: The key to log
value: The value to log
state: The current log state
namespace: The namespace to be logged
"""
[docs] def log_config(self, config: DictConfig) -> None:
"""Logs a set of metrics and configuration.
This is only called once, when metrics are computed for a whole dataset.
Args:
config: The run config
"""
[docs] def should_write(self, state: State) -> bool:
"""Returns whether or not the current state should be written.
This function checks that the last time the current phase was written
was greater than some interval in the past, to avoid writing tons of
values when the iteration time is extremely small.
Args:
state: The state to check
Returns:
If the logger should write values for the current state
"""
if state.phase not in self.last_write_time:
self.last_write_time[state.phase] = state.elapsed_time_s
return True
elif state.elapsed_time_s - self.last_write_time[state.phase] < self.write_every_n_seconds(state):
return False
else:
self.last_write_time[state.phase] = state.elapsed_time_s
return True
[docs] @abstractmethod
def write(self, state: State) -> None:
"""Writes the logs.
Args:
state: The current log state
"""
[docs] @abstractmethod
def default_write_every_n_seconds(self, state: State) -> float:
"""Returns the default write interval in seconds.
Args:
state: The state to get the default write interval for
Returns:
The default write interval, in seconds
"""
[docs] def write_every_n_seconds(self, state: State) -> float:
"""Returns the write interval in seconds.
Args:
state: The state to get the write interval for
Returns:
The write interval, in seconds
"""
if state.phase == "train":
if self.config.write_train_every_n_seconds is not None:
return self.config.write_train_every_n_seconds
elif self.config.write_val_every_n_seconds is not None:
return self.config.write_val_every_n_seconds
if self.config.write_every_n_seconds is not None:
return self.config.write_every_n_seconds
return self.default_write_every_n_seconds(state)