ml.tasks.losses.image
Defines some loss functions which are suitable for images.
- class ml.tasks.losses.image.SSIMLoss(kernel_size: int = 3, stride: int = 1, channels: int = 3, mode: Literal['avg', 'std'] = 'avg', sigma: float = 1.0, dynamic_range: float = 1.0)[source]
Bases:
Module
Computes structural similarity loss (SSIM).
The dynamic_range is the difference between the maximum and minimum possible values for the image. This value is the actually the negative SSIM, so that minimizing it maximizes the SSIM score.
- Parameters:
kernel_size – Size of the Gaussian kernel.
stride – Stride of the Gaussian kernel.
channels – Number of channels in the image.
mode – Mode of the SSIM function, either
avg
orstd
. Theavg
mode uses unweighted(K, K)
regions, while thestd
mode uses Gaussian weighted(K, K)
regions, which allows for larger regions without worrying about blurring.sigma – Standard deviation of the Gaussian kernel.
dynamic_range – Difference between the maximum and minimum possible values for the image.
- Inputs:
x: float tensor with shape
(B, C, H, W)
y: float tensor with shape(B, C, H, W)
- Outputs:
float tensor with shape
(B, C, H - K + 1, W - K + 1)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, y: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class ml.tasks.losses.image.ImageGradLoss(kernel_size: int = 3, sigma: float = 1.0)[source]
Bases:
Module
Computes image gradients, for smoothing.
This function convolves the image with a special Gaussian kernel that contrasts the current pixel with the surrounding pixels, such that the output is zero if the current pixel is the same as the surrounding pixels, and is larger if the current pixel is different from the surrounding pixels.
- Parameters:
kernel_size – Size of the Gaussian kernel.
sigma – Standard deviation of the Gaussian kernel.
- Inputs:
x: float tensor with shape
(B, C, H, W)
- Outputs:
float tensor with shape
(B, C, H - ksz + 1, W - ksz + 1)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- kernel: Tensor
- forward(x: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class ml.tasks.losses.image.LPIPS(pretrained: bool = True, requires_grad: bool = False, dropout: float = 0.5)[source]
Bases:
Module
Computes the learned perceptual image patch similarity (LPIPS) loss.
This function extracts the VGG-16 features from each input image, projects them once, then computes the L2 distance between the projected features.
The input images should be in the range
[0, 1]
. The height and width of the input images should be at least 64 pixels but can otherwise be arbitrary.- Parameters:
pretrained – Whether to use the pretrained VGG-16 model. This should usually only be disabled for testing.
requires_grad – Whether to require gradients for the VGG-16 model. This should usually be disabled, unless you want to fine-tune the model.
dropout – Dropout probability for the input projections.
- Inputs:
image_a: float tensor with shape
(B, C, H, W)
image_b: float tensor with shape(B, C, H, W)
- Outputs:
float tensor with shape
(B,)
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(image_a: Tensor, image_b: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.