Source code for ml.utils.image

"""Defines utilities for dealing with images."""

import itertools
import math
from fractions import Fraction
from pathlib import Path
from typing import Iterator

import numpy as np
import torch
import torchvision.transforms.functional as V
from PIL import Image, ImageSequence
from torch import Tensor
from torchvision.transforms import InterpolationMode

VALID_CHANNEL_COUNTS = {1, 3}


[docs]def as_uint8(arr: np.ndarray) -> np.ndarray: if np.issubdtype(arr.dtype, np.integer): return arr.astype(np.uint8) if np.issubdtype(arr.dtype, np.floating): return (arr * 255).round().astype(np.uint8) raise NotImplementedError(f"Unsupported dtype: {arr.dtype}")
[docs]def make_human_viewable_resolution( image: Tensor, interpolation: InterpolationMode = InterpolationMode.BILINEAR, trg_res: tuple[int, int] = (250, 250), ) -> Tensor: """Resizes image to human-viewable resolution. Args: image: The image to resize, with shape (C, H, W) interpolation: Interpolation mode to use for image resizing trg_res: The target image resolution; the image will be reshaped to have approximately the same area as an image with this resolution Returns: The resized image """ width, height = V.get_image_size(image) trg_height, trg_width = trg_res factor = math.sqrt((trg_height * trg_width) / (height * width)) new_height, new_width = int(height * factor), int(width * factor) return V.resize(image, [new_height, new_width], interpolation)
def _aminmax(t: Tensor) -> tuple[Tensor, Tensor]: # `aminmax` isn't supported for MPS tensors, fall back to separate calls. minv, maxv = (t.min(), t.max()) if t.is_mps else tuple(t.aminmax()) return minv, maxv
[docs]def standardize_image( image: np.ndarray | Tensor, *, log_key: str | None = None, normalize: bool = True, keep_resolution: bool = False, ) -> np.ndarray: """Converts an arbitrary image to shape (C, H, W). Args: image: The image tensor to log log_key: An optional logging key to use in the exception message normalize: Normalize images to (0, 1) keep_resolution: If set, preserve original image resolution, otherwise change image resolution to human-viewable Returns: The normalized image, with shape (H, W, C) Raises: ValueError: If the image shape is invalid """ if isinstance(image, np.ndarray): image = torch.from_numpy(image) if normalize and image.is_floating_point(): minv, maxv = _aminmax(image) maxv.clamp_min_(1.0) minv.clamp_max_(0.0) image = torch.clamp((image.detach() - minv) / (maxv - minv), 0.0, 1.0) if image.ndim == 2: image = image.unsqueeze(0) elif image.ndim == 3: if image.shape[0] in VALID_CHANNEL_COUNTS: pass elif image.shape[2] in VALID_CHANNEL_COUNTS: image = image.permute(2, 0, 1) else: raise ValueError(f"Invalid channel count{'' if log_key is None else f' for {log_key}'}: {image.shape}") else: raise ValueError(f"Invalid image shape{'' if log_key is None else f' for {log_key}'}: {image.shape}") if not keep_resolution: image = make_human_viewable_resolution(image) return image.permute(1, 2, 0).detach().cpu().numpy()
[docs]def read_gif(in_file: str | Path, *, skip_first_frame: bool = True) -> Iterator[np.ndarray]: """Function that reads a GIF and returns a stream of Numpy arrays. Args: in_file: The path to the input file. skip_first_frame: If set, skip the first frame. Yields: A stream of Numpy arrays with shape (H, W, C). """ gif = Image.open(str(in_file)) iterator = ImageSequence.Iterator(gif) if skip_first_frame: next(iterator) for frame in iterator: yield np.array(frame)
[docs]def write_gif( itr: Iterator[np.ndarray | Tensor], out_file: str | Path, *, keep_resolution: bool = False, fps: int | Fraction = 10, loop: bool = False, first_frame_zeros: bool = True, ) -> None: """Function that writes an GIF from a stream of input tensors. Args: itr: The image iterator, yielding images with shape (H, W, C). out_file: The path to the output file. keep_resolution: If set, preserve original image resolution, otherwise change image resolution to human-viewable. fps: Frames per second for the GIF. loop: If set, loop the GIF. first_frame_zeros: If set, the first frame will be all zeros. """ def to_image(t: np.ndarray | Tensor) -> Image.Image: return Image.fromarray(standardize_image(t, keep_resolution=keep_resolution)) first_frame = standardize_image(next(itr), keep_resolution=keep_resolution) first_img = Image.fromarray(np.zeros_like(first_frame) if first_frame_zeros else first_frame) first_img.save( str(out_file), save_all=True, append_images=itertools.chain((Image.fromarray(i) for i in (first_frame,)), (to_image(t) for t in itr)), duration=int(1000 / fps), # Number of milliseconds per frame. loop=int(loop), )