"""Defines a general logger for munging logged values to an expected format.
This logger handles munging, rate limiting, and multiplexing logged values
to each of the implemented child loggers. It is the logging interface that
is exposed to the task and model.
"""
import functools
import logging
import math
import re
from collections import defaultdict
from types import TracebackType
from typing import Callable, Iterator, Literal, Sequence, TypeVar
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as V
from PIL import Image, ImageDraw, ImageFont
from torch import Tensor
from torchvision.transforms import InterpolationMode
from ml.core.state import State
from ml.loggers.base import BaseLogger
from ml.utils.logging import IntervalTicker
logger = logging.getLogger(__name__)
T = TypeVar("T")
LogT = TypeVar("LogT")
Number = int | float | Tensor
ChannelSelectMode = Literal["first", "last", "mean"]
VALID_VIDEO_CHANNEL_COUNTS = {1, 3}
VALID_AUDIO_CHANNEL_COUNTS = {1, 2}
TARGET_FPS = 12
DEFAULT_NAMESPACE = "value"
def _aminmax(t: Tensor) -> tuple[Tensor, Tensor]:
# `aminmax` isn't supported for MPS tensors, fall back to separate calls.
minv, maxv = (t.min(), t.max()) if t.is_mps else tuple(t.aminmax())
return minv, maxv
def _chunk_lines(text: str, max_length: int) -> Iterator[str]:
for i in range(0, len(text), max_length):
yield text[i : i + max_length]
[docs]def standardize_text(text: str, max_line_length: int | None = None, remove_non_ascii: bool = False) -> list[str]:
"""Standardizes a text string to a list of lines.
Args:
text: The text to standardize
max_line_length: If set, truncate lines to this length
remove_non_ascii: Remove non-ASCII characters if present
Returns:
The standardized text lines
"""
if remove_non_ascii:
text = "".join(char for char in text if ord(char) < 128)
lines = [re.sub(r"\s+", " ", line) for line in re.split(r"[\n\r]+", text.strip())]
if max_line_length is not None:
lines = [subline for line in lines for subline in _chunk_lines(line, max_line_length)]
return lines
[docs]def get_audio_channel(audio: Tensor, channel_select_mode: ChannelSelectMode) -> Tensor:
"""For stereo audio, selects a single channel.
Args:
audio: The audio tensor to select a channel from, with shape (C, L)
channel_select_mode: The channel selection mode
Returns:
The selected audio channel
Raises:
ValueError: If the audio shape is invalid
"""
if audio.shape[-2] not in VALID_AUDIO_CHANNEL_COUNTS:
raise ValueError(f"Invalid audio channel count: {audio.shape[0]}")
if channel_select_mode == "first":
return audio[..., 0, :]
if channel_select_mode == "last":
return audio[..., -1, :]
if channel_select_mode == "mean":
return audio.mean(dim=-2)
raise ValueError(f"Invalid channel select mode: {channel_select_mode}")
[docs]def make_human_viewable_resolution(
image: Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
trg_res: tuple[int, int] = (250, 250),
) -> Tensor:
"""Resizes image to human-viewable resolution.
Args:
image: The image to resize, with shape (C, H, W)
interpolation: Interpolation mode to use for image resizing
trg_res: The target image resolution; the image will be reshaped to
have approximately the same area as an image with this resolution
Returns:
The resized image
"""
width, height = V.get_image_size(image)
trg_height, trg_width = trg_res
factor = math.sqrt((trg_height * trg_width) / (height * width))
new_height, new_width = int(height * factor), int(width * factor)
return V.resize(image, [new_height, new_width], interpolation)
[docs]def standardize_image(
image: Tensor,
*,
log_key: str | None = None,
normalize: bool = True,
keep_resolution: bool = False,
) -> Tensor:
"""Converts an arbitrary image to shape (C, H, W).
Args:
image: The image tensor to log
log_key: An optional logging key to use in the exception message
normalize: Normalize images to (0, 1)
keep_resolution: If set, preserve original image resolution, otherwise
change image resolution to human-viewable
Returns:
The normalized image, with shape (C, H, W)
Raises:
ValueError: If the image shape is invalid
"""
if normalize and image.is_floating_point():
minv, maxv = _aminmax(image)
maxv.clamp_min_(1.0)
minv.clamp_max_(0.0)
image = torch.clamp((image.detach() - minv) / (maxv - minv), 0.0, 1.0)
if image.ndim == 2:
image = image.unsqueeze(0)
elif image.ndim == 3:
if image.shape[0] in VALID_VIDEO_CHANNEL_COUNTS:
pass
elif image.shape[2] in VALID_VIDEO_CHANNEL_COUNTS:
image = image.permute(2, 0, 1)
else:
raise ValueError(f"Invalid channel count{'' if log_key is None else f' for {log_key}'}: {image.shape}")
else:
raise ValueError(f"Invalid image shape{'' if log_key is None else f' for {log_key}'}: {image.shape}")
if not keep_resolution:
image = make_human_viewable_resolution(image)
return image
LabelT = TypeVar("LabelT", Sequence[str], None)
def _get_n_samples(t: Tensor, labels: LabelT, n: int, dim: int = 0) -> tuple[Tensor, LabelT]:
if t.shape[dim] <= n:
return t, labels
idxs = torch.linspace(0, t.shape[dim] - 1, n, device=t.device, dtype=t.dtype)
idxs = idxs.round().long().clamp(0, t.shape[dim] - 1)
t = torch.index_select(t, dim, idxs)
if labels is None:
return t, None
return t, [labels[i] for i in idxs.cpu().tolist()]
[docs]def standardize_images(
images: Tensor,
labels: LabelT,
*,
max_images: int | None = None,
log_key: str | None = None,
normalize: bool = True,
keep_resolution: bool = False,
) -> tuple[Tensor, LabelT]:
"""Converts an arbitrary set of images to shape (B, C, H, W).
Args:
images: The image tensor to log
labels: The labels for the images
max_images: Maximum number of images to select
log_key: An optional logging key to use in the exception message
normalize: Normalize images to (0, 1)
keep_resolution: If set, preserve original image resolution, otherwise
change image resolution to human-viewable
Returns:
The normalized image, with shape (B, C, H, W)
Raises:
ValueError: If the image shape is invalid
"""
if normalize and images.is_floating_point():
minv, maxv = _aminmax(images)
maxv.clamp_min_(1.0)
minv.clamp_max_(0.0)
images = torch.clamp((images.detach() - minv) / (maxv - minv), 0.0, 1.0)
if images.ndim == 3:
images = images.unsqueeze(1)
elif images.ndim == 4:
if images.shape[1] in VALID_VIDEO_CHANNEL_COUNTS:
pass
elif images.shape[3] in VALID_VIDEO_CHANNEL_COUNTS:
images = images.permute(0, 3, 1, 2)
else:
raise ValueError(f"Invalid channel count{'' if log_key is None else f' for {log_key}'}: {images.shape}")
else:
raise ValueError(f"Invalid image shape{'' if log_key is None else f' for {log_key}'}: {images.shape}")
if max_images is not None:
images, labels = _get_n_samples(images, labels, max_images, dim=0)
if not keep_resolution:
images = torch.stack([make_human_viewable_resolution(image) for image in images.unbind(0)], 0)
return images, labels
[docs]@functools.lru_cache()
def audio_warning_ticker() -> IntervalTicker:
return IntervalTicker(5.0)
[docs]def standardize_audio(audio: Tensor, *, log_key: str | None = None) -> Tensor:
"""Converts an arbitrary audio tensor to shape (C, T).
Args:
audio: The audio tensor to log
log_key: An optional logging key to use in the exception message
Returns:
The standardized audio tensor, with shape (C, T)
Raises:
ValueError: If the audio shape is invalid
"""
if audio.ndim == 1:
audio = audio.unsqueeze(0)
elif audio.ndim == 2:
if audio.shape[0] in VALID_AUDIO_CHANNEL_COUNTS:
pass
elif audio.shape[1] in VALID_AUDIO_CHANNEL_COUNTS:
audio = audio.permute(1, 0)
else:
raise ValueError(f"Invalid channel count{'' if log_key is None else f' for {log_key}'}: {audio.shape}")
else:
raise ValueError(f"Invalid audio shape{'' if log_key is None else f' for {log_key}'}: {audio.shape}")
max_abs = audio.abs().max()
if max_abs > 1.0:
if audio_warning_ticker().tick():
logger.warning("Audio is outside the range [-1, 1]; clipping")
audio = audio.clamp(-5e3, 5e3) / max_abs
return audio
[docs]def standardize_audios(audios: Tensor, *, log_key: str | None = None, max_audios: int | None = None) -> Tensor:
"""Converts an arbitrary audio tensor to shape (B, C, T).
Args:
audios: The audio tensor to log
log_key: An optional logging key to use in the exception message
max_audios: Maximum number of audios to select
Returns:
The standardized audio tensor, with shape (B, C, T)
Raises:
ValueError: If the audio shape is invalid
"""
if audios.ndim == 2:
audios = audios.unsqueeze(1)
elif audios.ndim == 3:
if audios.shape[1] in VALID_AUDIO_CHANNEL_COUNTS:
pass
elif audios.shape[2] in VALID_AUDIO_CHANNEL_COUNTS:
audios = audios.permute(2, 1)
else:
raise ValueError(f"Invalid channel count{'' if log_key is None else f' for {log_key}'}: {audios.shape}")
else:
raise ValueError(f"Invalid audio shape{'' if log_key is None else f' for {log_key}'}: {audios.shape}")
if max_audios is not None:
audios, _ = _get_n_samples(audios, None, max_audios, dim=0)
max_abs = audios.abs().max()
if max_abs > 1.0:
if audio_warning_ticker().tick():
logger.warning("Audio is outside the range [-1, 1]; clipping")
audios = audios.clamp(-5e3, 5e3) / max_abs
return audios
[docs]def separate_with_padding(audio: Tensor, sep_frames: int) -> Tensor:
"""Converts a (B, C, T) waveform to (C, B * (T + sep_frames) - sep_frames).
Args:
audio: The audio tensor to separate
sep_frames: Number of frames to insert between each audio tensor
Returns:
The separated audio tensor
Raises:
ValueError: If the audio shape is invalid
"""
if sep_frames == 0:
return audio.transpose(0, 1).flatten(1)
if audio.ndim != 3:
raise ValueError(f"Invalid audio shape: {audio.shape}")
bsz, chans, tsz = audio.shape
audio_samples = audio.unbind(0) # B * (C, T)
output_tensor = audio.new_zeros(chans, bsz * (tsz + sep_frames) - sep_frames)
for i, audio_sample in enumerate(audio_samples):
output_tensor[:, i * (tsz + sep_frames) : i * (tsz + sep_frames) + tsz] = audio_sample
return output_tensor
[docs]def standardize_video(video: Tensor, *, log_key: str | None = None, normalize: bool = True) -> Tensor:
"""Converts an arbitrary video to shape (T, C, H, W).
Args:
video: The video tensor to log
log_key: An optional logging key to use in the exception message
normalize: Normalize images to (0, 1)
Returns:
The normalized video, with shape (T, C, H, W)
Raises:
ValueError: If the video shape is invalid
"""
if normalize and video.is_floating_point():
minv, maxv = _aminmax(video[-1])
maxv.clamp_min_(1.0)
minv.clamp_max_(0.0)
video = torch.clamp((video.detach() - minv) / (maxv - minv), 0.0, 1.0)
if video.ndim == 3:
return video.unsqueeze(1)
if video.ndim == 4:
if video.shape[1] in VALID_VIDEO_CHANNEL_COUNTS:
return video
if video.shape[3] in VALID_VIDEO_CHANNEL_COUNTS:
return video.permute(0, 3, 1, 2)
raise ValueError(f"Invalid video shape{'' if log_key is None else f' for {log_key}'}: {video.shape}")
[docs]def standardize_videos(
videos: Tensor,
*,
max_videos: int | None = None,
log_key: str | None = None,
normalize: bool = True,
) -> Tensor:
"""Converts an arbitrary video to shape (B, T, C, H, W).
Args:
videos: The video tensor to log
max_videos: Maximum number of images to select
log_key: An optional logging key to use in the exception message
normalize: Normalize images to (0, 1)
Returns:
The normalized video, with shape (B, T, C, H, W)
Raises:
ValueError: If the video shape is invalid
"""
if normalize and videos.is_floating_point():
minv, maxv = _aminmax(videos[:, -1])
maxv.clamp_min_(1.0)
minv.clamp_max_(0.0)
videos = torch.clamp((videos.detach() - minv) / (maxv - minv), 0.0, 1.0)
if videos.ndim == 4:
return videos.unsqueeze(2)
if videos.ndim == 5:
if videos.shape[2] in VALID_VIDEO_CHANNEL_COUNTS:
return videos if max_videos is None else videos[:max_videos]
if videos.shape[4] in VALID_VIDEO_CHANNEL_COUNTS:
videos = videos.permute(0, 3, 1, 2)
return videos if max_videos is None else videos[:max_videos]
raise ValueError(f"Invalid video shape{'' if log_key is None else f' for {log_key}'}: {videos.shape}")
[docs]def image_with_text(
image: Tensor,
text: list[str],
max_num_lines: int | None = None,
line_spacing: int = 4,
centered: bool = True,
) -> Tensor:
"""Adds a text label to an image.
Args:
image: The image to label, with shape (C, H, W)
text: The text label for the image
max_num_lines: The number of lines of spacing to add to the bottom
of the image
line_spacing: The spacing between adjacent lines
centered: If set, center the text labels, otherwise align to the left
Returns:
The image with a text label
"""
if not text:
return image
if max_num_lines is None:
max_num_lines = len(text)
else:
text = text[:max_num_lines]
pil_image = V.to_pil_image(image)
width, height = pil_image.size
font: ImageFont.ImageFont = ImageFont.load_default()
_, _, _, line_height = font.getbbox(text[0])
new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing)
padded_image = Image.new(pil_image.mode, (new_width, new_height), 255)
padded_image.paste(pil_image, (0, 0))
drawer = ImageDraw.Draw(padded_image)
for i, text_line in enumerate(text):
text_line_top = height + line_spacing + i * (line_height + line_spacing)
if centered:
_, _, line_width, _ = font.getbbox(text_line)
text_line_left = (width - line_width) / 2
drawer.text((text_line_left, text_line_top), text_line, font=font, fill=0)
else:
drawer.text((line_spacing, text_line_top), text_line, font=font, fill=0)
return V.pil_to_tensor(padded_image)
[docs]def normalize_video_fps(
video: Tensor | list[Tensor],
fps: int | None,
length: float | None,
stack_dim: int = 0,
target_fps: int = TARGET_FPS,
) -> Tensor:
"""Normalizes a video to have a particular FPS.
Args:
video: The video to normalize, with shape (T, C, H, W)
fps: The desired frames per second
length: The desired video length, in seconds, at the target FPS
target_fps: The target frames per second for the logger
stack_dim: Which dimension to stack along, for lists
Returns:
The normalized video
"""
if fps is None and length is None:
return torch.stack(video, dim=stack_dim) if isinstance(video, list) else video
pre_frames = len(video) if isinstance(video, list) else video.size(0)
if fps is None:
assert length is not None # Not used, just for type checker
fps = int(pre_frames / length)
post_frames = int(pre_frames * (target_fps / fps))
if isinstance(video, list):
frame_ids = torch.linspace(0, pre_frames - 1, post_frames).long()
return torch.stack([video[i] for i in frame_ids], dim=stack_dim)
frame_ids = torch.linspace(0, pre_frames - 1, post_frames, device=video.device).long()
return video[frame_ids]
[docs]def standardize_point_cloud(value: Tensor, max_points: int, *, log_key: str | None) -> Tensor:
for i in range(0, value.ndim - 1):
if value.shape[i] == 3:
value = value.transpose(i, -1)
break
if value.shape[-1] != 3:
raise ValueError(f"Invalid point cloud shape{'' if log_key is None else f' for {log_key}'}: {value.shape}")
if value.ndim == 2:
value = value.unsqueeze(0)
elif value.ndim > 3:
value = value.flatten(1, -2)
if value.shape[1] > max_points:
indices = torch.multinomial(torch.ones(value.shape[1], device=value.device), max_points)
value = value[:, indices]
return value
[docs]def make_square_image_or_video(
images_or_videos: Tensor,
*,
sep: int = 0,
squareness_weight: float = 1.0,
emptiness_weight: float = 1.0,
) -> Tensor:
"""Makes a square image by concatenating all the child images.
This does a simple ternary search to minimize a squareness penalty and an
emptiness penalty (i.e., the resulting image should be mostly filled in
and also approximately square).
Args:
images_or_videos: The images tensor, with shape (B, C, H, W) or
(B, T, C, H, W)
sep: Some optional padding around the images
squareness_weight: Weight for number of non-square pixels in penalty
emptiness_weight: Weight for number of empty pixels in penalty
Returns:
The square image, with shape (C, H', W') or (T, C, H', W')
"""
assert images_or_videos.dim() in (4, 5)
def ternary_search_optimal_side_counts(height: int, width: int, count: int) -> tuple[int, int]:
lo, hi = 1, count
def squareness_penalty(val: int) -> float:
h, w = val * height, ((count + val - 1) // val) * width
return (h * w) - min(h, w) ** 2
def emptiness_penalty(val: int) -> float:
h, w = val * height, ((count + val - 1) // val) * width
return (h * w) - (height * width * count)
def penalty(val: int) -> float:
return squareness_penalty(val) * squareness_weight + emptiness_penalty(val) * emptiness_weight
# Runs ternary search to minimize penalty.
while lo < hi - 2:
lmid, rmid = (lo * 2 + hi) // 3, (lo + hi * 2) // 3
if penalty(lmid) > penalty(rmid):
lo = lmid
else:
hi = rmid
# Returns the lowest-penalty configuration.
mid = (lo + hi) // 2
plo, pmid, phi = penalty(lo), penalty(mid), penalty(hi)
if pmid <= plo and pmid <= phi:
return mid, (count + mid - 1) // mid
elif plo <= phi:
return lo, (count + lo - 1) // lo
else:
return hi, (count + hi - 1) // hi
height, width = images_or_videos.shape[-2:]
image_list = list(torch.unbind(images_or_videos, dim=0))
hside, wside = ternary_search_optimal_side_counts(height, width, len(image_list))
image_list = image_list + [torch.zeros_like(images_or_videos[0])] * (hside * wside - len(image_list))
a, b = sep // 2, (sep + 1) // 2
image_list = [F.pad(image, (a, b, a, b)) for image in image_list]
wconcat = [torch.cat(image_list[i : i + wside], dim=-1) for i in range(0, len(image_list), wside)]
new_image = torch.cat(wconcat, dim=-2)
return new_image[..., a : new_image.shape[-2] - b, a : new_image.shape[-1] - b]
[docs]def cast_fp32(value: T) -> T:
if isinstance(value, Tensor) and value.is_floating_point():
return value.detach().float().cpu() # type: ignore[return-value]
return value
NAMESPACE_STACK: list[str] = []
[docs]class namespace_context: # noqa: N801
def __init__(self, name: str | None) -> None:
self._name = name
self._prev_stack: list[str] | None = None
def __enter__(self) -> None:
if self._name is None:
self._prev_stack = NAMESPACE_STACK[:]
NAMESPACE_STACK.clear()
else:
NAMESPACE_STACK.append(self._name)
def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
if self._prev_stack is not None:
NAMESPACE_STACK[:] = self._prev_stack
else:
NAMESPACE_STACK.pop()
[docs]class MultiLogger:
"""Defines an intermediate container which holds values to log somewhere else."""
def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
self.images: dict[str, dict[str, Callable[[], Tensor]]] = defaultdict(dict)
self.audio: dict[str, dict[str, Callable[[], tuple[Tensor, int]]]] = defaultdict(dict)
self.videos: dict[str, dict[str, Callable[[], Tensor]]] = defaultdict(dict)
self.histograms: dict[str, dict[str, Callable[[], Tensor]]] = defaultdict(dict)
self.point_clouds: dict[str, dict[str, Callable[[], Tensor]]] = defaultdict(dict)
self.poses: dict[str, dict[str, Callable[[], Tensor]]] = defaultdict(dict)
self.default_namespace = default_namespace
[docs] def resolve_namespace(self, namespace: str | None = None) -> str:
return "_".join([self.default_namespace if namespace is None else namespace] + NAMESPACE_STACK)
[docs] def log_scalar(self, key: str, value: Callable[[], Number] | Number, *, namespace: str | None = None) -> None:
"""Logs a scalar value.
Args:
key: The key being logged
value: The scalar value being logged
namespace: An optional logging namespace
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def scalar_future() -> Number:
value_concrete = value() if callable(value) else value
assert isinstance(value_concrete, (int, float, Tensor))
value_concrete = cast_fp32(value_concrete)
return value_concrete
self.scalars[namespace][key] = scalar_future
[docs] def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
"""Logs a string value.
Args:
key: The key being logged
value: The string value being logged
namespace: An optional logging namespace
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def value_future() -> str:
value_concrete = value() if callable(value) else value
assert isinstance(value_concrete, str)
return value_concrete
self.strings[namespace][key] = value_future
[docs] def log_image(
self,
key: str,
value: Callable[[], Tensor] | Tensor,
*,
namespace: str | None = None,
keep_resolution: bool = False,
) -> None:
"""Logs an image.
Args:
key: The key being logged
value: The image being logged; can be (C, H, W), (H, W, C) or (H, W)
as an RGB (3 channel) or grayscale (1 channel) image
namespace: An optional logging namespace
keep_resolution: If set, keep the image resolution the same,
otherwise upscale or downscale the image to a standard
resolution
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def image_future() -> Tensor:
value_concrete = value() if callable(value) else value
assert isinstance(value_concrete, Tensor)
value_concrete = cast_fp32(value_concrete)
return standardize_image(value_concrete, log_key=f"{namespace}/{key}", keep_resolution=keep_resolution)
self.images[namespace][key] = image_future
[docs] def log_labeled_image(
self,
key: str,
value: Callable[[], tuple[Tensor, str]] | tuple[Tensor, str],
*,
namespace: str | None = None,
max_line_length: int | None = None,
keep_resolution: bool = False,
centered: bool = True,
) -> None:
"""Logs an image with a label.
Args:
key: The key being logged
value: The image and label being logged; the image can be (C, H, W),
(H, W, C) or (H, W) as an RGB (3 channel) or grayscale
(1 channel) image
namespace: An optional logging namespace
max_line_length: Labels longer than this length are wrapped around
keep_resolution: If set, keep the image resolution the same,
otherwise upscale or downscale the image to a standard
resolution
centered: If set, center the text labels, otherwise align to the
left
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def labeled_image_future() -> Tensor:
image, text = value() if callable(value) else value
assert isinstance(image, Tensor)
assert isinstance(text, str)
image = standardize_image(image, log_key=f"{namespace}/{key}", keep_resolution=keep_resolution)
text_list = standardize_text(text, max_line_length=max_line_length, remove_non_ascii=True)
image = cast_fp32(image)
return image_with_text(image, text_list, centered=centered)
self.images[namespace][key] = labeled_image_future
[docs] def log_images(
self,
key: str,
value: Callable[[], Tensor] | Tensor,
*,
namespace: str | None = None,
keep_resolution: bool = False,
max_images: int | None = None,
sep: int = 0,
) -> None:
"""Logs a set of images.
The images are tiled to be nearly-square.
Args:
key: The key being logged
value: The images being logged; can be (B, C, H, W), (B, H, W, C)
or (B H, W) as an RGB (3 channel) or grayscale (1 channel) image
namespace: An optional logging namespace
keep_resolution: If set, keep the image resolution the same,
otherwise upscale or downscale the image to a standard
resolution
max_images: The maximum number of images to show; extra images
are clipped
sep: An optional separation amount between adjacent images
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def images_future() -> Tensor:
value_concrete = value() if callable(value) else value
assert isinstance(value_concrete, Tensor)
value_concrete, _ = standardize_images(
value_concrete,
None,
max_images=max_images,
log_key=f"{namespace}/{key}",
keep_resolution=keep_resolution,
)
value_concrete = cast_fp32(value_concrete)
return make_square_image_or_video(value_concrete, sep=sep)
self.images[namespace][key] = images_future
[docs] def log_labeled_images(
self,
key: str,
value: Callable[[], tuple[Tensor, Sequence[str]]] | tuple[Tensor, Sequence[str]],
*,
namespace: str | None = None,
max_line_length: int | None = None,
keep_resolution: bool = False,
max_images: int | None = None,
sep: int = 0,
centered: bool = True,
) -> None:
"""Logs a set of images with labels.
The images are tiled to be nearly-square.
Args:
key: The key being logged
value: The images and labels being logged; images can be
(B, C, H, W), (B, H, W, C) or (B, H, W) as an RGB (3 channel)
or grayscale (1 channel) image, with exactly B labels
namespace: An optional logging namespace
max_line_length: Labels longer than this length are wrapped around
keep_resolution: If set, keep the image resolution the same,
otherwise upscale or downscale the image to a standard
resolution
max_images: The maximum number of images to show; extra images
are clipped
sep: An optional separation amount between adjacent images
centered: If set, center the text labels, otherwise align to the
left
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def labeled_images_future() -> Tensor:
images, texts = value() if callable(value) else value
assert isinstance(images, Tensor)
assert images.shape[0] == len(texts)
images, texts = standardize_images(
images,
texts,
max_images=max_images,
log_key=f"{namespace}/{key}",
keep_resolution=keep_resolution,
)
num_images = len(images)
text_lists = [standardize_text(text, max_line_length, remove_non_ascii=True) for text in texts]
max_num_lines = max(len(text_list) for text_list in text_lists)
labeled_images = torch.stack(
[
image_with_text(images[i], text_lists[i], max_num_lines=max_num_lines, centered=centered)
for i in range(num_images)
],
dim=0,
)
return make_square_image_or_video(labeled_images, sep=sep)
self.images[namespace][key] = labeled_images_future
[docs] def log_audio(
self,
key: str,
value: Callable[[], Tensor] | Tensor,
*,
namespace: str | None = None,
sample_rate: int = 44100,
log_spec: bool = True,
n_fft_ms: float = 32.0,
hop_length_ms: float | None = None,
channel_select_mode: ChannelSelectMode = "first",
keep_resolution: bool = False,
) -> None:
"""Logs an audio clip.
Args:
key: The key being logged
value: The audio clip being logged; can be (C, T) or (T) as
a mono (1 channel) or stereo (2 channel) audio clip
namespace: An optional logging namespace
sample_rate: The sample rate of the audio clip
log_spec: If set, also log the spectrogram
n_fft_ms: FFT size, in milliseconds
hop_length_ms: The FFT hop length, in milliseconds
channel_select_mode: How to select the channel if the audio is
stereo; can be "first", "last", or "mean"; this is only used
for the spectrogram
keep_resolution: If set, keep the resolution of the
spectrogram; otherwise, make human-viewable
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def raw_audio_future() -> Tensor:
value_concrete = value() if callable(value) else value
assert isinstance(value_concrete, Tensor)
return value_concrete
@functools.lru_cache
def audio_future() -> tuple[Tensor, int]:
value_concrete = raw_audio_future()
audio = standardize_audio(value_concrete, log_key=f"{namespace}/{key}")
audio = cast_fp32(audio)
return audio, sample_rate
self.audio[namespace][key] = audio_future
if log_spec:
# Using a unique key for the spectrogram is very important because
# otherwise Tensorboard will have some issues.
self.log_spectrogram(
key=f"{key}_spec",
value=raw_audio_future,
namespace=namespace,
sample_rate=sample_rate,
n_fft_ms=n_fft_ms,
hop_length_ms=hop_length_ms,
channel_select_mode=channel_select_mode,
keep_resolution=keep_resolution,
)
[docs] def log_audios(
self,
key: str,
value: Callable[[], Tensor] | Tensor,
*,
namespace: str | None = None,
sep_ms: float = 0.0,
max_audios: int | None = None,
sample_rate: int = 44100,
log_spec: bool = True,
n_fft_ms: float = 32.0,
hop_length_ms: float | None = None,
channel_select_mode: ChannelSelectMode = "first",
spec_sep: int = 0,
keep_resolution: bool = False,
) -> None:
"""Logs multiple audio clips.
Args:
key: The key being logged
value: The audio clip being logged; can be (B, C, T) or (B, T) as
a mono (1 channel) or stereo (2 channel) audio clip, with
exactly B clips
namespace: An optional logging namespace
sep_ms: An optional separation amount between adjacent audio clips
max_audios: An optional maximum number of audio clips to log
sample_rate: The sample rate of the audio clip
log_spec: If set, also log the spectrogram
n_fft_ms: FFT size, in milliseconds
hop_length_ms: The FFT hop length, in milliseconds
channel_select_mode: How to select the channel if the audio is
stereo; can be "first", "last", or "mean"; this is only used
for the spectrogram
spec_sep: An optional separation amount between adjacent
spectrograms
keep_resolution: If set, keep the resolution of the
spectrogram; otherwise, make human-viewable
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def raw_audio_future() -> Tensor:
value_concrete = value() if callable(value) else value
assert isinstance(value_concrete, Tensor)
return value_concrete
@functools.lru_cache
def audio_future() -> tuple[Tensor, int]:
value_concrete = raw_audio_future()
audio = standardize_audios(value_concrete, log_key=f"{namespace}/{key}", max_audios=max_audios)
audio = cast_fp32(audio)
def to_frames(ms: float) -> int:
return 0 if ms == 0.0 else 2 ** round(math.log2(ms * sample_rate / 1000))
audio = separate_with_padding(audio, to_frames(sep_ms))
return audio, sample_rate
self.audio[namespace][key] = audio_future
if log_spec:
# Using a unique key for the spectrogram is very important because
# otherwise Tensorboard will have some issues.
self.log_spectrograms(
key=f"{key}_spec",
value=raw_audio_future,
namespace=namespace,
max_audios=max_audios,
sample_rate=sample_rate,
n_fft_ms=n_fft_ms,
hop_length_ms=hop_length_ms,
channel_select_mode=channel_select_mode,
spec_sep=spec_sep,
keep_resolution=keep_resolution,
)
[docs] def log_spectrogram(
self,
key: str,
value: Callable[[], Tensor] | Tensor,
*,
namespace: str | None = None,
sample_rate: int = 44100,
n_fft_ms: float = 32.0,
hop_length_ms: float | None = None,
channel_select_mode: ChannelSelectMode = "first",
keep_resolution: bool = False,
) -> None:
"""Logs spectrograms of an audio clip.
Args:
key: The key being logged
value: The audio clip being logged; can be (C, T) or (T) as
a mono (1 channel) or stereo (2 channel) audio clip
namespace: An optional logging namespace
sample_rate: The sample rate of the audio clip
n_fft_ms: FFT size, in milliseconds
hop_length_ms: The FFT hop length, in milliseconds
channel_select_mode: How to select the channel if the audio is
stereo; can be "first", "last", or "mean"; this is only used
for the spectrogram
keep_resolution: If set, keep the resolution of the
spectrogram; otherwise, make human-viewable
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def spec_future() -> Tensor:
audio = value() if callable(value) else value
audio = standardize_audio(audio, log_key=f"{namespace}/{key}")
audio = get_audio_channel(audio, channel_select_mode)
def to_frames(ms: float) -> int:
return 2 ** round(math.log2(ms * sample_rate / 1000))
n_fft = to_frames(n_fft_ms)
hop_length = None if hop_length_ms is None else to_frames(hop_length_ms)
audio = audio.to(torch.float32)
audio_spec = torch.stft(audio, n_fft, hop_length=hop_length, normalized=True, return_complex=True)
audio_spec = torch.log10(torch.abs(audio_spec) + 1e-6)
return standardize_image(
audio_spec,
log_key=f"{namespace}/{key}",
keep_resolution=keep_resolution,
)
self.images[namespace][key] = spec_future
[docs] def log_spectrograms(
self,
key: str,
value: Callable[[], Tensor] | Tensor,
*,
namespace: str | None = None,
max_audios: int | None = None,
sample_rate: int = 44100,
n_fft_ms: float = 32.0,
hop_length_ms: float | None = None,
channel_select_mode: ChannelSelectMode = "first",
spec_sep: int = 0,
keep_resolution: bool = False,
) -> None:
"""Logs spectrograms of audio clips.
Args:
key: The key being logged
value: The audio clip being logged; can be (B, C, T) or (B, T) as
a mono (1 channel) or stereo (2 channel) audio clip, with
exactly B clips
namespace: An optional logging namespace
max_audios: An optional maximum number of audio clips to log
sample_rate: The sample rate of the audio clip
n_fft_ms: FFT size, in milliseconds
hop_length_ms: The FFT hop length, in milliseconds
channel_select_mode: How to select the channel if the audio is
stereo; can be "first", "last", or "mean"; this is only used
for the spectrogram
spec_sep: An optional separation amount between adjacent
spectrograms
keep_resolution: If set, keep the resolution of the
spectrogram; otherwise, make human-viewable
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def spec_future() -> Tensor:
audio = value() if callable(value) else value
audio = standardize_audios(audio, log_key=f"{namespace}/{key}", max_audios=max_audios)
audio = get_audio_channel(audio, channel_select_mode)
def to_frames(ms: float) -> int:
return 2 ** round(math.log2(ms * sample_rate / 1000))
n_fft = to_frames(n_fft_ms)
hop_length = None if hop_length_ms is None else to_frames(hop_length_ms)
audio = audio.to(torch.float32)
audio_spec = torch.stft(audio, n_fft, hop_length=hop_length, normalized=True, return_complex=True)
audio_spec = torch.log10(torch.abs(audio_spec) + 1e-6)
audio_spec, _ = standardize_images(
audio_spec,
None,
log_key=f"{namespace}/{key}",
keep_resolution=keep_resolution,
)
audio_spec = make_square_image_or_video(audio_spec, sep=spec_sep)
return audio_spec
self.images[namespace][key] = spec_future
[docs] def log_video(
self,
key: str,
value: Callable[[], Tensor] | Tensor,
*,
namespace: str | None = None,
fps: int | None = None,
length: float | None = None,
) -> None:
"""Logs a video.
Args:
key: The key being logged
value: The video being logged; the video can be (T, C, H, W),
(T, H, W, C) or (T, H, W) as an RGB (3 channel) or grayscale
(1 channel) video
namespace: An optional logging namespace
fps: The video frames per second
length: The desired video length, in seconds, at the target FPS
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def video_future() -> Tensor:
value_concrete = value() if callable(value) else value
assert isinstance(value_concrete, Tensor)
video = standardize_video(value_concrete, log_key=f"{namespace}/{key}")
value_concrete = cast_fp32(value_concrete)
return normalize_video_fps(video, fps, length)
self.videos[namespace][key] = video_future
[docs] def log_videos(
self,
key: str,
value: Callable[[], Tensor | list[Tensor]] | Tensor | list[Tensor],
*,
namespace: str | None = None,
max_videos: int | None = None,
sep: int = 0,
fps: int | None = None,
length: int | None = None,
) -> None:
"""Logs a set of video.
Args:
key: The key being logged
value: The videos being logged; the video can be (B, T, C, H, W),
(B, T, H, W, C) or (B T, H, W) as an RGB (3 channel) or
grayscale (1 channel) video
namespace: An optional logging namespace
max_videos: The maximum number of videos to show; extra images
are clipped
sep: An optional separation amount between adjacent videos
fps: The video frames per second
length: The desired video length, in seconds, at the target FPS
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def videos_future() -> Tensor:
value_concrete = value() if callable(value) else value
assert isinstance(value_concrete, (Tensor, list))
video = normalize_video_fps(value_concrete, fps, length, stack_dim=1)
video = standardize_videos(video, max_videos=max_videos, log_key=f"{namespace}/{key}")
value_concrete = cast_fp32(value_concrete)
return make_square_image_or_video(video, sep=sep)
self.videos[namespace][key] = videos_future
[docs] def log_histogram(self, key: str, value: Callable[[], Tensor] | Tensor, *, namespace: str | None = None) -> None:
"""Logs a histogram.
Args:
key: The key being logged
value: The values to create a histogram from, with arbitrary shape
namespace: An optional logging namespace
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def histogram_future() -> Tensor:
value_concrete = value() if callable(value) else value
assert isinstance(value_concrete, Tensor)
value_concrete = cast_fp32(value_concrete)
return value_concrete
self.histograms[namespace][key] = histogram_future
[docs] def log_point_cloud(
self,
key: str,
value: Callable[[], Tensor] | Tensor,
*,
namespace: str | None = None,
max_points: int = 1000,
) -> None:
"""Logs a point cloud.
Args:
key: The key being logged
value: The point cloud values, with shape (N, 3) or (B, ..., 3);
can pass multiple batches in order to show multiple point
clouds
namespace: An optional logging namespace
max_points: An optional maximum number of points in the point cloud
"""
namespace = self.resolve_namespace(namespace)
@functools.lru_cache
def point_cloud_future() -> Tensor:
value_concrete = value() if callable(value) else value
assert isinstance(value_concrete, Tensor)
value_concrete = cast_fp32(value_concrete)
return standardize_point_cloud(value_concrete, max_points, log_key=f"{namespace}/{key}")
self.point_clouds[namespace][key] = point_cloud_future
[docs] def write_dict(
self,
loggers: list[BaseLogger],
values: dict[str, dict[str, Callable[[], LogT]]],
state: State,
func: Callable[[BaseLogger], Callable[[str, Callable[[], LogT], State, str], None]],
) -> None:
for logger in loggers:
for namespace, value in values.items():
for key, log_value in value.items():
func(logger)(key, log_value, state, namespace)
values.clear()
[docs] def write(self, loggers: list[BaseLogger], state: State) -> None:
self.write_dict(loggers, self.scalars, state, lambda logger: logger.log_scalar)
self.write_dict(loggers, self.strings, state, lambda logger: logger.log_string)
self.write_dict(loggers, self.images, state, lambda logger: logger.log_image)
self.write_dict(loggers, self.audio, state, lambda logger: logger.log_audio)
self.write_dict(loggers, self.videos, state, lambda logger: logger.log_video)
self.write_dict(loggers, self.histograms, state, lambda logger: logger.log_histogram)
self.write_dict(loggers, self.point_clouds, state, lambda logger: logger.log_point_cloud)