"""Defines a bunch of dataset transforms."""
import random
from typing import TypeVar
import torch
import torchvision.transforms.functional as V
from PIL.Image import Image as PILImage
from torch import Tensor, nn
from torchvision.transforms import InterpolationMode
Image = TypeVar("Image", Tensor, PILImage)
NormParams = tuple[float, float, float]
# Default image normalization parameters.
MEAN: NormParams = 0.48145466, 0.4578275, 0.40821073
STD: NormParams = 0.26862954, 0.26130258, 0.27577711
[docs]def square_crop(img: Image) -> Image:
"""Crops an image to a square.
Args:
img: The input image
Returns:
The cropped image, with height and width equal.
"""
img_width, img_height = V.get_image_size(img)
height = width = min(img_height, img_width)
top, left = (img_height - height) // 2, (img_width - width) // 2
return V.crop(img, top, left, height, width)
[docs]def square_resize_crop(img: Image, size: int, interpolation: InterpolationMode = InterpolationMode.NEAREST) -> Image:
"""Resizes an image to a square and then crops it.
Args:
img: The input image
size: The size of the square
interpolation: The interpolation mode to use
Returns:
The cropped image
"""
img_width, img_height = V.get_image_size(img)
min_dim = min(img_width, img_height)
height, width = int((img_width / min_dim) * size), int((img_height / min_dim) * size)
img = V.resize(img, [height, width], interpolation)
top, left = (height - size) // 2, (width - size) // 2
return V.crop(img, top, left, size, size)
[docs]def upper_left_crop(img: Image, height: int, width: int) -> Image:
"""Crops an image from the upper left corner.
This is useful because it preserves camera intrinsics for an image.
Args:
img: The input image
height: The height of the crop
width: The width of the crop
Returns:
The cropped image
"""
return V.crop(img, 0, 0, height, width)
[docs]def normalize(t: Tensor, *, mean: NormParams = MEAN, std: NormParams = STD) -> Tensor:
"""Normalizes an image tensor (by default, using ImageNet parameters).
This can be paired with :func:`denormalize` to convert an image tensor
to a normalized tensor for processing by a model.
Args:
t: The input tensor
mean: The mean to subtract
std: The standard deviation to divide by
Returns:
The normalized tensor
"""
return V.normalize(t, mean, std)
[docs]def denormalize(t: Tensor, *, mean: NormParams = MEAN, std: NormParams = STD) -> Tensor:
"""Denormalizes a tensor.
This can be paired with :func:`normalize` to convert a normalized tensor
back to the original image for viewing by humans.
Args:
t: The input tensor
mean: The mean to subtract
std: The standard deviation to divide by
Returns:
The denormalized tensor
"""
mean_tensor = torch.tensor(mean, device=t.device, dtype=t.dtype)
std_tensor = torch.tensor(std, device=t.device, dtype=t.dtype)
return (t * std_tensor[None, :, None, None]) + mean_tensor[None, :, None, None]
[docs]def random_square_crop(img: Image) -> Image:
"""Randomly crops an image to a square.
Args:
img: The input image
Returns:
The cropped image
"""
img_width, img_height = V.get_image_size(img)
height = width = min(img_height, img_width)
top, left = random.randint(0, img_height - height), random.randint(0, img_width - width)
return V.crop(img, top, left, height, width)
[docs]def random_square_crop_multi(imgs: list[Image]) -> list[Image]:
"""Randomly crops a list of images to the same size.
Args:
imgs: The list of images to crop
Returns:
The cropped images
"""
img_dims = V.get_image_size(imgs[0])
assert all(V.get_image_size(i) == img_dims for i in imgs[1:])
img_width, img_height = img_dims
height = width = min(img_width, img_height)
top, left = random.randint(0, img_height - height), random.randint(0, img_width - width)
return [V.crop(i, top, left, height, width) for i in imgs]
[docs]def make_size(img: Image, ref_size: tuple[int, int]) -> Image:
"""Converts an image to a specific size, zero-padding smaller dimension.
Args:
img: The input image
ref_size: The reference size, as (width, height)
Returns:
The resized image
"""
img_c, (img_w, img_h), (ref_w, ref_h) = V.get_image_num_channels(img), V.get_image_size(img), ref_size
if img_h / img_w < ref_h / ref_w: # Pad width
new_h, new_w = (img_h * ref_w) // img_w, ref_w
else:
new_h, new_w = ref_h, (img_w * ref_h) // img_h
img = V.resize(img, [new_h, new_w], InterpolationMode.BILINEAR)
new_img = img.new_zeros(img_c, ref_h, ref_w)
start_h, start_w = (ref_h - new_h) // 2, (ref_w - new_w) // 2
new_img[:, start_h : start_h + new_h, start_w : start_w + new_w] = img
return new_img
[docs]def make_same_size(img: Image, ref_img: Image) -> Image:
"""Converts an image to the same size as a reference image.
Args:
img: The input image
ref_img: The reference image
Returns:
The input image resized to the same size as the reference image,
zero-padding dimensions which are too small
"""
ref_w, ref_h = V.get_image_size(ref_img)
return make_size(img, (ref_w, ref_h))
[docs]class SquareResizeCrop(nn.Module):
"""Resizes and crops an image to a square with the target shape.
Generally SquareCrop followed by a resize should be preferred when using
bilinear resize, as it is faster to do the interpolation on the smaller
image. However, nearest neighbor resize on the larger image followed by
a crop on the smaller image can sometimes be faster.
"""
__constants__ = ["size", "interpolation"]
def __init__(self, size: int, interpolation: InterpolationMode = InterpolationMode.NEAREST) -> None:
"""Initializes the square resize crop.
Args:
size: The square height and width to resize to
interpolation: The interpolation type to use when resizing
"""
super().__init__()
self.size = int(size)
self.interpolation = InterpolationMode(interpolation)
[docs] def forward(self, img: Image) -> Image:
return square_resize_crop(img, self.size, self.interpolation)
[docs]class UpperLeftCrop(nn.Module):
"""Crops image from upper left corner, to preserve image intrinsics."""
__constants__ = ["height", "width"]
def __init__(self, height: int, width: int) -> None:
"""Initializes the upper left crop.
Args:
height: The max height of the cropped image
width: The max width of the cropped image
"""
super().__init__()
self.height, self.width = height, width
[docs] def forward(self, img: Image) -> Image:
return upper_left_crop(img, self.height, self.width)