Source code for ml.utils.augmentation

"""Augmentation utilities for image data."""

import torch
from torch import Tensor


[docs]def get_image_mask(image: Tensor, *, min_prct: float = 0.2, max_prct: float = 0.5) -> Tensor: """Gets a random mask for the image. Args: image: The image to mask, with shape (..., H, W) min_prct: Minimum percent of image to mask in either dimension max_prct: Maximum percent of image to mask in either dimension Returns: An image mask, with shape (..., H, W) matching the original image shape """ device, bsz, (height, width) = image.device, image.shape[:-2], image.shape[-2:] mask_dims = bsz + (1, 1) min_height, max_height = int(height * min_prct), int(height * max_prct) min_width, max_width = int(width * min_prct), int(width * max_prct) mask_height = torch.floor(torch.rand(*mask_dims, device=device) * (max_height - min_height) + min_height).long() mask_bottom = torch.floor(torch.rand(*mask_dims, device=device) * (height - mask_height)).long() mask_top = mask_bottom + mask_height mask_width = torch.floor(torch.rand(*mask_dims, device=device) * (max_width - min_width) + min_width) mask_left = torch.floor(torch.rand(*mask_dims, device=device) * (width - mask_height)).long() mask_right = mask_left + mask_width ys, xs = torch.meshgrid(torch.arange(height, device=device), torch.arange(width, device=device), indexing="ij") ys, xs = ys.view((1,) * len(bsz) + (height, width)), xs.view([1] * len(bsz) + [height, width]) mask = (xs >= mask_left) & (xs <= mask_right) & (ys >= mask_bottom) & (ys <= mask_top) return mask