Source code for ml.loggers.multi

"""Defines a general logger for munging logged values to an expected format.

This logger handles munging, rate limiting, and multiplexing logged values
to each of the implemented child loggers. It is the logging interface that
is exposed to the task and model.
"""

import functools
import logging
import math
import re
from collections import defaultdict
from types import TracebackType
from typing import Callable, Iterator, Literal, Sequence, TypeVar

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as V
from PIL import Image, ImageDraw, ImageFont
from torch import Tensor
from torchvision.transforms import InterpolationMode

from ml.core.state import State
from ml.loggers.base import BaseLogger
from ml.utils.logging import IntervalTicker

logger = logging.getLogger(__name__)

T = TypeVar("T")
LogT = TypeVar("LogT")
Number = int | float | Tensor

ChannelSelectMode = Literal["first", "last", "mean"]

VALID_VIDEO_CHANNEL_COUNTS = {1, 3}
VALID_AUDIO_CHANNEL_COUNTS = {1, 2}
TARGET_FPS = 12
DEFAULT_NAMESPACE = "value"


def _aminmax(t: Tensor) -> tuple[Tensor, Tensor]:
    # `aminmax` isn't supported for MPS tensors, fall back to separate calls.
    minv, maxv = (t.min(), t.max()) if t.is_mps else tuple(t.aminmax())
    return minv, maxv


def _chunk_lines(text: str, max_length: int) -> Iterator[str]:
    for i in range(0, len(text), max_length):
        yield text[i : i + max_length]


[docs]def standardize_text(text: str, max_line_length: int | None = None, remove_non_ascii: bool = False) -> list[str]: """Standardizes a text string to a list of lines. Args: text: The text to standardize max_line_length: If set, truncate lines to this length remove_non_ascii: Remove non-ASCII characters if present Returns: The standardized text lines """ if remove_non_ascii: text = "".join(char for char in text if ord(char) < 128) lines = [re.sub(r"\s+", " ", line) for line in re.split(r"[\n\r]+", text.strip())] if max_line_length is not None: lines = [subline for line in lines for subline in _chunk_lines(line, max_line_length)] return lines
[docs]def get_audio_channel(audio: Tensor, channel_select_mode: ChannelSelectMode) -> Tensor: """For stereo audio, selects a single channel. Args: audio: The audio tensor to select a channel from, with shape (C, L) channel_select_mode: The channel selection mode Returns: The selected audio channel Raises: ValueError: If the audio shape is invalid """ if audio.shape[-2] not in VALID_AUDIO_CHANNEL_COUNTS: raise ValueError(f"Invalid audio channel count: {audio.shape[0]}") if channel_select_mode == "first": return audio[..., 0, :] if channel_select_mode == "last": return audio[..., -1, :] if channel_select_mode == "mean": return audio.mean(dim=-2) raise ValueError(f"Invalid channel select mode: {channel_select_mode}")
[docs]def make_human_viewable_resolution( image: Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, trg_res: tuple[int, int] = (250, 250), ) -> Tensor: """Resizes image to human-viewable resolution. Args: image: The image to resize, with shape (C, H, W) interpolation: Interpolation mode to use for image resizing trg_res: The target image resolution; the image will be reshaped to have approximately the same area as an image with this resolution Returns: The resized image """ width, height = V.get_image_size(image) trg_height, trg_width = trg_res factor = math.sqrt((trg_height * trg_width) / (height * width)) new_height, new_width = int(height * factor), int(width * factor) return V.resize(image, [new_height, new_width], interpolation)
[docs]def standardize_image( image: Tensor, *, log_key: str | None = None, normalize: bool = True, keep_resolution: bool = False, ) -> Tensor: """Converts an arbitrary image to shape (C, H, W). Args: image: The image tensor to log log_key: An optional logging key to use in the exception message normalize: Normalize images to (0, 1) keep_resolution: If set, preserve original image resolution, otherwise change image resolution to human-viewable Returns: The normalized image, with shape (C, H, W) Raises: ValueError: If the image shape is invalid """ if normalize and image.is_floating_point(): minv, maxv = _aminmax(image) maxv.clamp_min_(1.0) minv.clamp_max_(0.0) image = torch.clamp((image.detach() - minv) / (maxv - minv), 0.0, 1.0) if image.ndim == 2: image = image.unsqueeze(0) elif image.ndim == 3: if image.shape[0] in VALID_VIDEO_CHANNEL_COUNTS: pass elif image.shape[2] in VALID_VIDEO_CHANNEL_COUNTS: image = image.permute(2, 0, 1) else: raise ValueError(f"Invalid channel count{'' if log_key is None else f' for {log_key}'}: {image.shape}") else: raise ValueError(f"Invalid image shape{'' if log_key is None else f' for {log_key}'}: {image.shape}") if not keep_resolution: image = make_human_viewable_resolution(image) return image
LabelT = TypeVar("LabelT", Sequence[str], None) def _get_n_samples(t: Tensor, labels: LabelT, n: int, dim: int = 0) -> tuple[Tensor, LabelT]: if t.shape[dim] <= n: return t, labels idxs = torch.linspace(0, t.shape[dim] - 1, n, device=t.device, dtype=t.dtype) idxs = idxs.round().long().clamp(0, t.shape[dim] - 1) t = torch.index_select(t, dim, idxs) if labels is None: return t, None return t, [labels[i] for i in idxs.cpu().tolist()]
[docs]def standardize_images( images: Tensor, labels: LabelT, *, max_images: int | None = None, log_key: str | None = None, normalize: bool = True, keep_resolution: bool = False, ) -> tuple[Tensor, LabelT]: """Converts an arbitrary set of images to shape (B, C, H, W). Args: images: The image tensor to log labels: The labels for the images max_images: Maximum number of images to select log_key: An optional logging key to use in the exception message normalize: Normalize images to (0, 1) keep_resolution: If set, preserve original image resolution, otherwise change image resolution to human-viewable Returns: The normalized image, with shape (B, C, H, W) Raises: ValueError: If the image shape is invalid """ if normalize and images.is_floating_point(): minv, maxv = _aminmax(images) maxv.clamp_min_(1.0) minv.clamp_max_(0.0) images = torch.clamp((images.detach() - minv) / (maxv - minv), 0.0, 1.0) if images.ndim == 3: images = images.unsqueeze(1) elif images.ndim == 4: if images.shape[1] in VALID_VIDEO_CHANNEL_COUNTS: pass elif images.shape[3] in VALID_VIDEO_CHANNEL_COUNTS: images = images.permute(0, 3, 1, 2) else: raise ValueError(f"Invalid channel count{'' if log_key is None else f' for {log_key}'}: {images.shape}") else: raise ValueError(f"Invalid image shape{'' if log_key is None else f' for {log_key}'}: {images.shape}") if max_images is not None: images, labels = _get_n_samples(images, labels, max_images, dim=0) if not keep_resolution: images = torch.stack([make_human_viewable_resolution(image) for image in images.unbind(0)], 0) return images, labels
[docs]@functools.lru_cache() def audio_warning_ticker() -> IntervalTicker: return IntervalTicker(5.0)
[docs]def standardize_audio(audio: Tensor, *, log_key: str | None = None) -> Tensor: """Converts an arbitrary audio tensor to shape (C, T). Args: audio: The audio tensor to log log_key: An optional logging key to use in the exception message Returns: The standardized audio tensor, with shape (C, T) Raises: ValueError: If the audio shape is invalid """ if audio.ndim == 1: audio = audio.unsqueeze(0) elif audio.ndim == 2: if audio.shape[0] in VALID_AUDIO_CHANNEL_COUNTS: pass elif audio.shape[1] in VALID_AUDIO_CHANNEL_COUNTS: audio = audio.permute(1, 0) else: raise ValueError(f"Invalid channel count{'' if log_key is None else f' for {log_key}'}: {audio.shape}") else: raise ValueError(f"Invalid audio shape{'' if log_key is None else f' for {log_key}'}: {audio.shape}") max_abs = audio.abs().max() if max_abs > 1.0: if audio_warning_ticker().tick(): logger.warning("Audio is outside the range [-1, 1]; clipping") audio = audio.clamp(-5e3, 5e3) / max_abs return audio
[docs]def standardize_audios(audios: Tensor, *, log_key: str | None = None, max_audios: int | None = None) -> Tensor: """Converts an arbitrary audio tensor to shape (B, C, T). Args: audios: The audio tensor to log log_key: An optional logging key to use in the exception message max_audios: Maximum number of audios to select Returns: The standardized audio tensor, with shape (B, C, T) Raises: ValueError: If the audio shape is invalid """ if audios.ndim == 2: audios = audios.unsqueeze(1) elif audios.ndim == 3: if audios.shape[1] in VALID_AUDIO_CHANNEL_COUNTS: pass elif audios.shape[2] in VALID_AUDIO_CHANNEL_COUNTS: audios = audios.permute(2, 1) else: raise ValueError(f"Invalid channel count{'' if log_key is None else f' for {log_key}'}: {audios.shape}") else: raise ValueError(f"Invalid audio shape{'' if log_key is None else f' for {log_key}'}: {audios.shape}") if max_audios is not None: audios, _ = _get_n_samples(audios, None, max_audios, dim=0) max_abs = audios.abs().max() if max_abs > 1.0: if audio_warning_ticker().tick(): logger.warning("Audio is outside the range [-1, 1]; clipping") audios = audios.clamp(-5e3, 5e3) / max_abs return audios
[docs]def separate_with_padding(audio: Tensor, sep_frames: int) -> Tensor: """Converts a (B, C, T) waveform to (C, B * (T + sep_frames) - sep_frames). Args: audio: The audio tensor to separate sep_frames: Number of frames to insert between each audio tensor Returns: The separated audio tensor Raises: ValueError: If the audio shape is invalid """ if sep_frames == 0: return audio.transpose(0, 1).flatten(1) if audio.ndim != 3: raise ValueError(f"Invalid audio shape: {audio.shape}") bsz, chans, tsz = audio.shape audio_samples = audio.unbind(0) # B * (C, T) output_tensor = audio.new_zeros(chans, bsz * (tsz + sep_frames) - sep_frames) for i, audio_sample in enumerate(audio_samples): output_tensor[:, i * (tsz + sep_frames) : i * (tsz + sep_frames) + tsz] = audio_sample return output_tensor
[docs]def standardize_video(video: Tensor, *, log_key: str | None = None, normalize: bool = True) -> Tensor: """Converts an arbitrary video to shape (T, C, H, W). Args: video: The video tensor to log log_key: An optional logging key to use in the exception message normalize: Normalize images to (0, 1) Returns: The normalized video, with shape (T, C, H, W) Raises: ValueError: If the video shape is invalid """ if normalize and video.is_floating_point(): minv, maxv = _aminmax(video[-1]) maxv.clamp_min_(1.0) minv.clamp_max_(0.0) video = torch.clamp((video.detach() - minv) / (maxv - minv), 0.0, 1.0) if video.ndim == 3: return video.unsqueeze(1) if video.ndim == 4: if video.shape[1] in VALID_VIDEO_CHANNEL_COUNTS: return video if video.shape[3] in VALID_VIDEO_CHANNEL_COUNTS: return video.permute(0, 3, 1, 2) raise ValueError(f"Invalid video shape{'' if log_key is None else f' for {log_key}'}: {video.shape}")
[docs]def standardize_videos( videos: Tensor, *, max_videos: int | None = None, log_key: str | None = None, normalize: bool = True, ) -> Tensor: """Converts an arbitrary video to shape (B, T, C, H, W). Args: videos: The video tensor to log max_videos: Maximum number of images to select log_key: An optional logging key to use in the exception message normalize: Normalize images to (0, 1) Returns: The normalized video, with shape (B, T, C, H, W) Raises: ValueError: If the video shape is invalid """ if normalize and videos.is_floating_point(): minv, maxv = _aminmax(videos[:, -1]) maxv.clamp_min_(1.0) minv.clamp_max_(0.0) videos = torch.clamp((videos.detach() - minv) / (maxv - minv), 0.0, 1.0) if videos.ndim == 4: return videos.unsqueeze(2) if videos.ndim == 5: if videos.shape[2] in VALID_VIDEO_CHANNEL_COUNTS: return videos if max_videos is None else videos[:max_videos] if videos.shape[4] in VALID_VIDEO_CHANNEL_COUNTS: videos = videos.permute(0, 3, 1, 2) return videos if max_videos is None else videos[:max_videos] raise ValueError(f"Invalid video shape{'' if log_key is None else f' for {log_key}'}: {videos.shape}")
[docs]def image_with_text( image: Tensor, text: list[str], max_num_lines: int | None = None, line_spacing: int = 4, centered: bool = True, ) -> Tensor: """Adds a text label to an image. Args: image: The image to label, with shape (C, H, W) text: The text label for the image max_num_lines: The number of lines of spacing to add to the bottom of the image line_spacing: The spacing between adjacent lines centered: If set, center the text labels, otherwise align to the left Returns: The image with a text label """ if not text: return image if max_num_lines is None: max_num_lines = len(text) else: text = text[:max_num_lines] pil_image = V.to_pil_image(image) width, height = pil_image.size font: ImageFont.ImageFont = ImageFont.load_default() _, _, _, line_height = font.getbbox(text[0]) new_width, new_height = width, height + line_spacing + max_num_lines * (line_height + line_spacing) padded_image = Image.new(pil_image.mode, (new_width, new_height), 255) padded_image.paste(pil_image, (0, 0)) drawer = ImageDraw.Draw(padded_image) for i, text_line in enumerate(text): text_line_top = height + line_spacing + i * (line_height + line_spacing) if centered: _, _, line_width, _ = font.getbbox(text_line) text_line_left = (width - line_width) / 2 drawer.text((text_line_left, text_line_top), text_line, font=font, fill=0) else: drawer.text((line_spacing, text_line_top), text_line, font=font, fill=0) return V.pil_to_tensor(padded_image)
[docs]def normalize_video_fps( video: Tensor | list[Tensor], fps: int | None, length: float | None, stack_dim: int = 0, target_fps: int = TARGET_FPS, ) -> Tensor: """Normalizes a video to have a particular FPS. Args: video: The video to normalize, with shape (T, C, H, W) fps: The desired frames per second length: The desired video length, in seconds, at the target FPS target_fps: The target frames per second for the logger stack_dim: Which dimension to stack along, for lists Returns: The normalized video """ if fps is None and length is None: return torch.stack(video, dim=stack_dim) if isinstance(video, list) else video pre_frames = len(video) if isinstance(video, list) else video.size(0) if fps is None: assert length is not None # Not used, just for type checker fps = int(pre_frames / length) post_frames = int(pre_frames * (target_fps / fps)) if isinstance(video, list): frame_ids = torch.linspace(0, pre_frames - 1, post_frames).long() return torch.stack([video[i] for i in frame_ids], dim=stack_dim) frame_ids = torch.linspace(0, pre_frames - 1, post_frames, device=video.device).long() return video[frame_ids]
[docs]def standardize_point_cloud(value: Tensor, max_points: int, *, log_key: str | None) -> Tensor: for i in range(0, value.ndim - 1): if value.shape[i] == 3: value = value.transpose(i, -1) break if value.shape[-1] != 3: raise ValueError(f"Invalid point cloud shape{'' if log_key is None else f' for {log_key}'}: {value.shape}") if value.ndim == 2: value = value.unsqueeze(0) elif value.ndim > 3: value = value.flatten(1, -2) if value.shape[1] > max_points: indices = torch.multinomial(torch.ones(value.shape[1], device=value.device), max_points) value = value[:, indices] return value
[docs]def make_square_image_or_video( images_or_videos: Tensor, *, sep: int = 0, squareness_weight: float = 1.0, emptiness_weight: float = 1.0, ) -> Tensor: """Makes a square image by concatenating all the child images. This does a simple ternary search to minimize a squareness penalty and an emptiness penalty (i.e., the resulting image should be mostly filled in and also approximately square). Args: images_or_videos: The images tensor, with shape (B, C, H, W) or (B, T, C, H, W) sep: Some optional padding around the images squareness_weight: Weight for number of non-square pixels in penalty emptiness_weight: Weight for number of empty pixels in penalty Returns: The square image, with shape (C, H', W') or (T, C, H', W') """ assert images_or_videos.dim() in (4, 5) def ternary_search_optimal_side_counts(height: int, width: int, count: int) -> tuple[int, int]: lo, hi = 1, count def squareness_penalty(val: int) -> float: h, w = val * height, ((count + val - 1) // val) * width return (h * w) - min(h, w) ** 2 def emptiness_penalty(val: int) -> float: h, w = val * height, ((count + val - 1) // val) * width return (h * w) - (height * width * count) def penalty(val: int) -> float: return squareness_penalty(val) * squareness_weight + emptiness_penalty(val) * emptiness_weight # Runs ternary search to minimize penalty. while lo < hi - 2: lmid, rmid = (lo * 2 + hi) // 3, (lo + hi * 2) // 3 if penalty(lmid) > penalty(rmid): lo = lmid else: hi = rmid # Returns the lowest-penalty configuration. mid = (lo + hi) // 2 plo, pmid, phi = penalty(lo), penalty(mid), penalty(hi) if pmid <= plo and pmid <= phi: return mid, (count + mid - 1) // mid elif plo <= phi: return lo, (count + lo - 1) // lo else: return hi, (count + hi - 1) // hi height, width = images_or_videos.shape[-2:] image_list = list(torch.unbind(images_or_videos, dim=0)) hside, wside = ternary_search_optimal_side_counts(height, width, len(image_list)) image_list = image_list + [torch.zeros_like(images_or_videos[0])] * (hside * wside - len(image_list)) a, b = sep // 2, (sep + 1) // 2 image_list = [F.pad(image, (a, b, a, b)) for image in image_list] wconcat = [torch.cat(image_list[i : i + wside], dim=-1) for i in range(0, len(image_list), wside)] new_image = torch.cat(wconcat, dim=-2) return new_image[..., a : new_image.shape[-2] - b, a : new_image.shape[-1] - b]
[docs]def cast_fp32(value: T) -> T: if isinstance(value, Tensor) and value.is_floating_point(): return value.detach().float().cpu() # type: ignore[return-value] return value
NAMESPACE_STACK: list[str] = []
[docs]class namespace_context: # noqa: N801 def __init__(self, name: str | None) -> None: self._name = name self._prev_stack: list[str] | None = None def __enter__(self) -> None: if self._name is None: self._prev_stack = NAMESPACE_STACK[:] NAMESPACE_STACK.clear() else: NAMESPACE_STACK.append(self._name) def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None: if self._prev_stack is not None: NAMESPACE_STACK[:] = self._prev_stack else: NAMESPACE_STACK.pop()
[docs]class MultiLogger: """Defines an intermediate container which holds values to log somewhere else.""" def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None: self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict) self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict) self.images: dict[str, dict[str, Callable[[], Tensor]]] = defaultdict(dict) self.audio: dict[str, dict[str, Callable[[], tuple[Tensor, int]]]] = defaultdict(dict) self.videos: dict[str, dict[str, Callable[[], Tensor]]] = defaultdict(dict) self.histograms: dict[str, dict[str, Callable[[], Tensor]]] = defaultdict(dict) self.point_clouds: dict[str, dict[str, Callable[[], Tensor]]] = defaultdict(dict) self.poses: dict[str, dict[str, Callable[[], Tensor]]] = defaultdict(dict) self.default_namespace = default_namespace
[docs] def resolve_namespace(self, namespace: str | None = None) -> str: return "_".join([self.default_namespace if namespace is None else namespace] + NAMESPACE_STACK)
[docs] def log_scalar(self, key: str, value: Callable[[], Number] | Number, *, namespace: str | None = None) -> None: """Logs a scalar value. Args: key: The key being logged value: The scalar value being logged namespace: An optional logging namespace """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def scalar_future() -> Number: value_concrete = value() if callable(value) else value assert isinstance(value_concrete, (int, float, Tensor)) value_concrete = cast_fp32(value_concrete) return value_concrete self.scalars[namespace][key] = scalar_future
[docs] def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None: """Logs a string value. Args: key: The key being logged value: The string value being logged namespace: An optional logging namespace """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def value_future() -> str: value_concrete = value() if callable(value) else value assert isinstance(value_concrete, str) return value_concrete self.strings[namespace][key] = value_future
[docs] def log_image( self, key: str, value: Callable[[], Tensor] | Tensor, *, namespace: str | None = None, keep_resolution: bool = False, ) -> None: """Logs an image. Args: key: The key being logged value: The image being logged; can be (C, H, W), (H, W, C) or (H, W) as an RGB (3 channel) or grayscale (1 channel) image namespace: An optional logging namespace keep_resolution: If set, keep the image resolution the same, otherwise upscale or downscale the image to a standard resolution """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def image_future() -> Tensor: value_concrete = value() if callable(value) else value assert isinstance(value_concrete, Tensor) value_concrete = cast_fp32(value_concrete) return standardize_image(value_concrete, log_key=f"{namespace}/{key}", keep_resolution=keep_resolution) self.images[namespace][key] = image_future
[docs] def log_labeled_image( self, key: str, value: Callable[[], tuple[Tensor, str]] | tuple[Tensor, str], *, namespace: str | None = None, max_line_length: int | None = None, keep_resolution: bool = False, centered: bool = True, ) -> None: """Logs an image with a label. Args: key: The key being logged value: The image and label being logged; the image can be (C, H, W), (H, W, C) or (H, W) as an RGB (3 channel) or grayscale (1 channel) image namespace: An optional logging namespace max_line_length: Labels longer than this length are wrapped around keep_resolution: If set, keep the image resolution the same, otherwise upscale or downscale the image to a standard resolution centered: If set, center the text labels, otherwise align to the left """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def labeled_image_future() -> Tensor: image, text = value() if callable(value) else value assert isinstance(image, Tensor) assert isinstance(text, str) image = standardize_image(image, log_key=f"{namespace}/{key}", keep_resolution=keep_resolution) text_list = standardize_text(text, max_line_length=max_line_length, remove_non_ascii=True) image = cast_fp32(image) return image_with_text(image, text_list, centered=centered) self.images[namespace][key] = labeled_image_future
[docs] def log_images( self, key: str, value: Callable[[], Tensor] | Tensor, *, namespace: str | None = None, keep_resolution: bool = False, max_images: int | None = None, sep: int = 0, ) -> None: """Logs a set of images. The images are tiled to be nearly-square. Args: key: The key being logged value: The images being logged; can be (B, C, H, W), (B, H, W, C) or (B H, W) as an RGB (3 channel) or grayscale (1 channel) image namespace: An optional logging namespace keep_resolution: If set, keep the image resolution the same, otherwise upscale or downscale the image to a standard resolution max_images: The maximum number of images to show; extra images are clipped sep: An optional separation amount between adjacent images """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def images_future() -> Tensor: value_concrete = value() if callable(value) else value assert isinstance(value_concrete, Tensor) value_concrete, _ = standardize_images( value_concrete, None, max_images=max_images, log_key=f"{namespace}/{key}", keep_resolution=keep_resolution, ) value_concrete = cast_fp32(value_concrete) return make_square_image_or_video(value_concrete, sep=sep) self.images[namespace][key] = images_future
[docs] def log_labeled_images( self, key: str, value: Callable[[], tuple[Tensor, Sequence[str]]] | tuple[Tensor, Sequence[str]], *, namespace: str | None = None, max_line_length: int | None = None, keep_resolution: bool = False, max_images: int | None = None, sep: int = 0, centered: bool = True, ) -> None: """Logs a set of images with labels. The images are tiled to be nearly-square. Args: key: The key being logged value: The images and labels being logged; images can be (B, C, H, W), (B, H, W, C) or (B, H, W) as an RGB (3 channel) or grayscale (1 channel) image, with exactly B labels namespace: An optional logging namespace max_line_length: Labels longer than this length are wrapped around keep_resolution: If set, keep the image resolution the same, otherwise upscale or downscale the image to a standard resolution max_images: The maximum number of images to show; extra images are clipped sep: An optional separation amount between adjacent images centered: If set, center the text labels, otherwise align to the left """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def labeled_images_future() -> Tensor: images, texts = value() if callable(value) else value assert isinstance(images, Tensor) assert images.shape[0] == len(texts) images, texts = standardize_images( images, texts, max_images=max_images, log_key=f"{namespace}/{key}", keep_resolution=keep_resolution, ) num_images = len(images) text_lists = [standardize_text(text, max_line_length, remove_non_ascii=True) for text in texts] max_num_lines = max(len(text_list) for text_list in text_lists) labeled_images = torch.stack( [ image_with_text(images[i], text_lists[i], max_num_lines=max_num_lines, centered=centered) for i in range(num_images) ], dim=0, ) return make_square_image_or_video(labeled_images, sep=sep) self.images[namespace][key] = labeled_images_future
[docs] def log_audio( self, key: str, value: Callable[[], Tensor] | Tensor, *, namespace: str | None = None, sample_rate: int = 44100, log_spec: bool = True, n_fft_ms: float = 32.0, hop_length_ms: float | None = None, channel_select_mode: ChannelSelectMode = "first", keep_resolution: bool = False, ) -> None: """Logs an audio clip. Args: key: The key being logged value: The audio clip being logged; can be (C, T) or (T) as a mono (1 channel) or stereo (2 channel) audio clip namespace: An optional logging namespace sample_rate: The sample rate of the audio clip log_spec: If set, also log the spectrogram n_fft_ms: FFT size, in milliseconds hop_length_ms: The FFT hop length, in milliseconds channel_select_mode: How to select the channel if the audio is stereo; can be "first", "last", or "mean"; this is only used for the spectrogram keep_resolution: If set, keep the resolution of the spectrogram; otherwise, make human-viewable """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def raw_audio_future() -> Tensor: value_concrete = value() if callable(value) else value assert isinstance(value_concrete, Tensor) return value_concrete @functools.lru_cache def audio_future() -> tuple[Tensor, int]: value_concrete = raw_audio_future() audio = standardize_audio(value_concrete, log_key=f"{namespace}/{key}") audio = cast_fp32(audio) return audio, sample_rate self.audio[namespace][key] = audio_future if log_spec: # Using a unique key for the spectrogram is very important because # otherwise Tensorboard will have some issues. self.log_spectrogram( key=f"{key}_spec", value=raw_audio_future, namespace=namespace, sample_rate=sample_rate, n_fft_ms=n_fft_ms, hop_length_ms=hop_length_ms, channel_select_mode=channel_select_mode, keep_resolution=keep_resolution, )
[docs] def log_audios( self, key: str, value: Callable[[], Tensor] | Tensor, *, namespace: str | None = None, sep_ms: float = 0.0, max_audios: int | None = None, sample_rate: int = 44100, log_spec: bool = True, n_fft_ms: float = 32.0, hop_length_ms: float | None = None, channel_select_mode: ChannelSelectMode = "first", spec_sep: int = 0, keep_resolution: bool = False, ) -> None: """Logs multiple audio clips. Args: key: The key being logged value: The audio clip being logged; can be (B, C, T) or (B, T) as a mono (1 channel) or stereo (2 channel) audio clip, with exactly B clips namespace: An optional logging namespace sep_ms: An optional separation amount between adjacent audio clips max_audios: An optional maximum number of audio clips to log sample_rate: The sample rate of the audio clip log_spec: If set, also log the spectrogram n_fft_ms: FFT size, in milliseconds hop_length_ms: The FFT hop length, in milliseconds channel_select_mode: How to select the channel if the audio is stereo; can be "first", "last", or "mean"; this is only used for the spectrogram spec_sep: An optional separation amount between adjacent spectrograms keep_resolution: If set, keep the resolution of the spectrogram; otherwise, make human-viewable """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def raw_audio_future() -> Tensor: value_concrete = value() if callable(value) else value assert isinstance(value_concrete, Tensor) return value_concrete @functools.lru_cache def audio_future() -> tuple[Tensor, int]: value_concrete = raw_audio_future() audio = standardize_audios(value_concrete, log_key=f"{namespace}/{key}", max_audios=max_audios) audio = cast_fp32(audio) def to_frames(ms: float) -> int: return 0 if ms == 0.0 else 2 ** round(math.log2(ms * sample_rate / 1000)) audio = separate_with_padding(audio, to_frames(sep_ms)) return audio, sample_rate self.audio[namespace][key] = audio_future if log_spec: # Using a unique key for the spectrogram is very important because # otherwise Tensorboard will have some issues. self.log_spectrograms( key=f"{key}_spec", value=raw_audio_future, namespace=namespace, max_audios=max_audios, sample_rate=sample_rate, n_fft_ms=n_fft_ms, hop_length_ms=hop_length_ms, channel_select_mode=channel_select_mode, spec_sep=spec_sep, keep_resolution=keep_resolution, )
[docs] def log_spectrogram( self, key: str, value: Callable[[], Tensor] | Tensor, *, namespace: str | None = None, sample_rate: int = 44100, n_fft_ms: float = 32.0, hop_length_ms: float | None = None, channel_select_mode: ChannelSelectMode = "first", keep_resolution: bool = False, ) -> None: """Logs spectrograms of an audio clip. Args: key: The key being logged value: The audio clip being logged; can be (C, T) or (T) as a mono (1 channel) or stereo (2 channel) audio clip namespace: An optional logging namespace sample_rate: The sample rate of the audio clip n_fft_ms: FFT size, in milliseconds hop_length_ms: The FFT hop length, in milliseconds channel_select_mode: How to select the channel if the audio is stereo; can be "first", "last", or "mean"; this is only used for the spectrogram keep_resolution: If set, keep the resolution of the spectrogram; otherwise, make human-viewable """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def spec_future() -> Tensor: audio = value() if callable(value) else value audio = standardize_audio(audio, log_key=f"{namespace}/{key}") audio = get_audio_channel(audio, channel_select_mode) def to_frames(ms: float) -> int: return 2 ** round(math.log2(ms * sample_rate / 1000)) n_fft = to_frames(n_fft_ms) hop_length = None if hop_length_ms is None else to_frames(hop_length_ms) audio = audio.to(torch.float32) audio_spec = torch.stft(audio, n_fft, hop_length=hop_length, normalized=True, return_complex=True) audio_spec = torch.log10(torch.abs(audio_spec) + 1e-6) return standardize_image( audio_spec, log_key=f"{namespace}/{key}", keep_resolution=keep_resolution, ) self.images[namespace][key] = spec_future
[docs] def log_spectrograms( self, key: str, value: Callable[[], Tensor] | Tensor, *, namespace: str | None = None, max_audios: int | None = None, sample_rate: int = 44100, n_fft_ms: float = 32.0, hop_length_ms: float | None = None, channel_select_mode: ChannelSelectMode = "first", spec_sep: int = 0, keep_resolution: bool = False, ) -> None: """Logs spectrograms of audio clips. Args: key: The key being logged value: The audio clip being logged; can be (B, C, T) or (B, T) as a mono (1 channel) or stereo (2 channel) audio clip, with exactly B clips namespace: An optional logging namespace max_audios: An optional maximum number of audio clips to log sample_rate: The sample rate of the audio clip n_fft_ms: FFT size, in milliseconds hop_length_ms: The FFT hop length, in milliseconds channel_select_mode: How to select the channel if the audio is stereo; can be "first", "last", or "mean"; this is only used for the spectrogram spec_sep: An optional separation amount between adjacent spectrograms keep_resolution: If set, keep the resolution of the spectrogram; otherwise, make human-viewable """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def spec_future() -> Tensor: audio = value() if callable(value) else value audio = standardize_audios(audio, log_key=f"{namespace}/{key}", max_audios=max_audios) audio = get_audio_channel(audio, channel_select_mode) def to_frames(ms: float) -> int: return 2 ** round(math.log2(ms * sample_rate / 1000)) n_fft = to_frames(n_fft_ms) hop_length = None if hop_length_ms is None else to_frames(hop_length_ms) audio = audio.to(torch.float32) audio_spec = torch.stft(audio, n_fft, hop_length=hop_length, normalized=True, return_complex=True) audio_spec = torch.log10(torch.abs(audio_spec) + 1e-6) audio_spec, _ = standardize_images( audio_spec, None, log_key=f"{namespace}/{key}", keep_resolution=keep_resolution, ) audio_spec = make_square_image_or_video(audio_spec, sep=spec_sep) return audio_spec self.images[namespace][key] = spec_future
[docs] def log_video( self, key: str, value: Callable[[], Tensor] | Tensor, *, namespace: str | None = None, fps: int | None = None, length: float | None = None, ) -> None: """Logs a video. Args: key: The key being logged value: The video being logged; the video can be (T, C, H, W), (T, H, W, C) or (T, H, W) as an RGB (3 channel) or grayscale (1 channel) video namespace: An optional logging namespace fps: The video frames per second length: The desired video length, in seconds, at the target FPS """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def video_future() -> Tensor: value_concrete = value() if callable(value) else value assert isinstance(value_concrete, Tensor) video = standardize_video(value_concrete, log_key=f"{namespace}/{key}") value_concrete = cast_fp32(value_concrete) return normalize_video_fps(video, fps, length) self.videos[namespace][key] = video_future
[docs] def log_videos( self, key: str, value: Callable[[], Tensor | list[Tensor]] | Tensor | list[Tensor], *, namespace: str | None = None, max_videos: int | None = None, sep: int = 0, fps: int | None = None, length: int | None = None, ) -> None: """Logs a set of video. Args: key: The key being logged value: The videos being logged; the video can be (B, T, C, H, W), (B, T, H, W, C) or (B T, H, W) as an RGB (3 channel) or grayscale (1 channel) video namespace: An optional logging namespace max_videos: The maximum number of videos to show; extra images are clipped sep: An optional separation amount between adjacent videos fps: The video frames per second length: The desired video length, in seconds, at the target FPS """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def videos_future() -> Tensor: value_concrete = value() if callable(value) else value assert isinstance(value_concrete, (Tensor, list)) video = normalize_video_fps(value_concrete, fps, length, stack_dim=1) video = standardize_videos(video, max_videos=max_videos, log_key=f"{namespace}/{key}") value_concrete = cast_fp32(value_concrete) return make_square_image_or_video(video, sep=sep) self.videos[namespace][key] = videos_future
[docs] def log_histogram(self, key: str, value: Callable[[], Tensor] | Tensor, *, namespace: str | None = None) -> None: """Logs a histogram. Args: key: The key being logged value: The values to create a histogram from, with arbitrary shape namespace: An optional logging namespace """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def histogram_future() -> Tensor: value_concrete = value() if callable(value) else value assert isinstance(value_concrete, Tensor) value_concrete = cast_fp32(value_concrete) return value_concrete self.histograms[namespace][key] = histogram_future
[docs] def log_point_cloud( self, key: str, value: Callable[[], Tensor] | Tensor, *, namespace: str | None = None, max_points: int = 1000, ) -> None: """Logs a point cloud. Args: key: The key being logged value: The point cloud values, with shape (N, 3) or (B, ..., 3); can pass multiple batches in order to show multiple point clouds namespace: An optional logging namespace max_points: An optional maximum number of points in the point cloud """ namespace = self.resolve_namespace(namespace) @functools.lru_cache def point_cloud_future() -> Tensor: value_concrete = value() if callable(value) else value assert isinstance(value_concrete, Tensor) value_concrete = cast_fp32(value_concrete) return standardize_point_cloud(value_concrete, max_points, log_key=f"{namespace}/{key}") self.point_clouds[namespace][key] = point_cloud_future
[docs] def write_dict( self, loggers: list[BaseLogger], values: dict[str, dict[str, Callable[[], LogT]]], state: State, func: Callable[[BaseLogger], Callable[[str, Callable[[], LogT], State, str], None]], ) -> None: for logger in loggers: for namespace, value in values.items(): for key, log_value in value.items(): func(logger)(key, log_value, state, namespace) values.clear()
[docs] def write(self, loggers: list[BaseLogger], state: State) -> None: self.write_dict(loggers, self.scalars, state, lambda logger: logger.log_scalar) self.write_dict(loggers, self.strings, state, lambda logger: logger.log_string) self.write_dict(loggers, self.images, state, lambda logger: logger.log_image) self.write_dict(loggers, self.audio, state, lambda logger: logger.log_audio) self.write_dict(loggers, self.videos, state, lambda logger: logger.log_video) self.write_dict(loggers, self.histograms, state, lambda logger: logger.log_histogram) self.write_dict(loggers, self.point_clouds, state, lambda logger: logger.log_point_cloud)