ml.tasks.losses.loss

Defines a general-purpose API for losses.

from ml.tasks.losses.loss import loss_fn

mse_loss = loss_fn("mse")
loss = mse_loss(pred, target)
assert loss.shape == pred.shape

The following loss functions are supported:

Modality-Agnostic

  • "bce-logits": Binary cross-entropy loss with logits

  • "bce": Binary cross-entropy loss

  • "huber": Huber loss, which is a smoothed version of the L1 loss

  • "l1": L1 loss

  • "log-cosh": Log cosh loss, which is a smoothed version of the L1 loss

  • "mse": Mean squared error loss

  • "xent": Cross-entropy loss

Audio

  • "log-stft-magnitude": Log STFT magnitude loss

  • "mel": Mel spectrogram loss

  • "mfcc": Mel-frequency cepstral coefficients loss

  • "multi-stft": Multi-resolution short-time Fourier transform loss

  • "spectral-convergence": Spectral convergence loss

  • "stft": Short-time Fourier transform loss

Image

  • "image-grad": Image gradient loss

  • "lpips": Learned perceptual image patch similarity loss

  • "ssim": Structural similarity index loss

KL Divergence

  • "kl-pair": KL divergence loss between two Gaussian distributions

  • "kl-single": KL divergence loss between a Gaussian distribution and a standard normal

ml.tasks.losses.loss.log_cosh_loss(pred: Tensor, target: Tensor) Tensor[source]
ml.tasks.losses.loss.loss_fn(loss: Literal['mse']) Callable[[Tensor, Tensor], Tensor][source]
ml.tasks.losses.loss.loss_fn(loss: Literal['l1']) Callable[[Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['huber'], *, huber_beta: float = 1.0) Callable[[Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['pseudo-huber'], *, pseudo_huber_factor: float = 0.00054, dim: int = -1, keepdim: bool = False) Callable[[Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['log-cosh']) Callable[[Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['xent']) Callable[[Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['bce']) Callable[[Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['bce-logits']) Callable[[Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['spectral-convergence']) Callable[[Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['log-stft-magnitude']) Callable[[Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['stft'], *, fft_size: int = 1024, shift_size: int = 120, win_length: int = 600, window_fn: Literal['hann', 'hamming', 'blackman'] = 'hann') Callable[[Tensor, Tensor], tuple[torch.Tensor, torch.Tensor]]
ml.tasks.losses.loss.loss_fn(loss: Literal['multi-stft'], *, fft_size: int = 1024, shift_size: int = 120, win_length: int = 600, window_fn: Literal['hann', 'hamming', 'blackman'] = 'hann', fft_size_multiples: list[float] = [0.5, 1.0, 2.0]) Callable[[Tensor, Tensor], tuple[torch.Tensor, torch.Tensor]]
ml.tasks.losses.loss.loss_fn(loss: Literal['mel'], *, sample_rate: int, f_min: float = 0.0, f_max: float | None = None, fft_size: int = 1024, shift_size: int = 120, win_length: int = 600, n_mels: int = 80, window_fn: Literal['hann', 'hamming', 'blackman'] = 'hann') Callable[[Tensor, Tensor], tuple[torch.Tensor, torch.Tensor]]
ml.tasks.losses.loss.loss_fn(loss: Literal['mfcc'], *, sample_rate: int, f_min: float = 0.0, f_max: float | None = None, fft_size: int = 1024, shift_size: int = 120, win_length: int = 600, n_mels: int = 80, n_mfcc: int = 40, log_mels: bool = False, window_fn: Literal['hann', 'hamming', 'blackman'] = 'hann') Callable[[Tensor, Tensor], tuple[torch.Tensor, torch.Tensor]]
ml.tasks.losses.loss.loss_fn(loss: Literal['ssim'], *, image_kernel_size: int = 3, ssim_stride: int = 1, ssim_channels: int = 3, ssim_mode: Literal['avg', 'std'] = 'avg', image_sigma: float = 1.0, ssim_dynamic_range: float = 1.0) Callable[[Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['image-grad'], *, image_kernel_size: int = 3, image_sigma: float = 1.0) Callable[[Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['lpips'], *, pretrained: bool = True, requires_grad: bool = False) Callable[[Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['kl-single'], *, clamp_min: float = -30.0, clamp_max: float = 20.0) Callable[[Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['kl-pair'], *, clamp_min: float = -30.0, clamp_max: float = 20.0) Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
ml.tasks.losses.loss.loss_fn(loss: Literal['bce-logits', 'bce', 'huber', 'pseudo-huber', 'l1', 'log-cosh', 'mse', 'xent', 'log-stft-magnitude', 'mel', 'mfcc', 'multi-stft', 'spectral-convergence', 'stft', 'image-grad', 'lpips', 'ssim', 'kl-pair', 'kl-single'], *, huber_beta: float = 1.0, fft_size: int = 1024, shift_size: int = 120, win_length: int = 600, window_fn: Literal['hann', 'hamming', 'blackman'] = 'hann', fft_size_multiples: list[float] = [0.5, 1.0, 2.0], image_kernel_size: int = 3, ssim_stride: int = 1, ssim_channels: int = 3, ssim_mode: Literal['avg', 'std'] = 'avg', image_sigma: float = 1.0, ssim_dynamic_range: float = 1.0) Callable[[Tensor], Tensor] | Callable[[Tensor, Tensor], Tensor] | Callable[[Tensor, Tensor], tuple[torch.Tensor, torch.Tensor]] | Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]

Returns a loss function.

Parameters:
  • loss – The loss function to use.

  • huber_beta – The beta parameter for the Huber loss.

  • pseudo_huber_factor – The factor parameter for the Pseudo-Huber loss. The default value is taken from the Consistency Model paper.

  • sample_rate – The sample rate of the audio.

  • f_min – The minimum frequency for the STFT loss.

  • f_max – The maximum frequency for the STFT loss.

  • fft_size – The size of the FFT.

  • shift_size – The size of the shift.

  • win_length – The size of the window.

  • n_mels – The number of mel bins for the mel spectrogram loss.

  • n_mfcc – The number of MFCCs for the MFCC loss.

  • log_mels – Whether to use log-mel spectrograms for the MFCC loss.

  • window_fn – The window function to use.

  • fft_size_multiples – The multiples of the FFT size to use for the multi-resolution STFT loss.

  • image_kernel_size – The size of the kernel for the SSIM loss.

  • ssim_stride – The stride of the kernel for the SSIM loss.

  • ssim_channels – The number of channels for the SSIM loss.

  • ssim_mode – The mode for the SSIM loss, either "avg" or "sum".

  • image_sigma – The sigma parameter for the SSIM loss.

  • ssim_dynamic_range – The dynamic range parameter for the SSIM loss.

  • pretrained – Whether to use pretrained weights for the loss function, if the loss function uses a pretrained model.

  • requires_grad – Whether to require gradients for parameters in the loss function, for parameters which can be disabled.

  • clamp_min – The minimum value to clamp the input to.

  • clamp_max – The maximum value to clamp the input to.

  • dim – If the loss is only applied over a single dimension, this is the dimension to apply the loss over.

  • keepdim – If the loss is only applied over a single dimension, whether to keep the dimension.

Returns:

The loss function, as a callable that takes in the input tensor or tensors and returns the loss tensor or tensors.