ml.tasks.losses.loss
Defines a general-purpose API for losses.
from ml.tasks.losses.loss import loss_fn
mse_loss = loss_fn("mse")
loss = mse_loss(pred, target)
assert loss.shape == pred.shape
The following loss functions are supported:
Modality-Agnostic
"bce-logits"
: Binary cross-entropy loss with logits"bce"
: Binary cross-entropy loss"huber"
: Huber loss, which is a smoothed version of the L1 loss"l1"
: L1 loss"log-cosh"
: Log cosh loss, which is a smoothed version of the L1 loss"mse"
: Mean squared error loss"xent"
: Cross-entropy loss
Audio
"log-stft-magnitude"
: Log STFT magnitude loss"mel"
: Mel spectrogram loss"mfcc"
: Mel-frequency cepstral coefficients loss"multi-stft"
: Multi-resolution short-time Fourier transform loss"spectral-convergence"
: Spectral convergence loss"stft"
: Short-time Fourier transform loss
Image
"image-grad"
: Image gradient loss"lpips"
: Learned perceptual image patch similarity loss"ssim"
: Structural similarity index loss
KL Divergence
"kl-pair"
: KL divergence loss between two Gaussian distributions"kl-single"
: KL divergence loss between a Gaussian distribution and a standard normal
- ml.tasks.losses.loss.loss_fn(loss: Literal['mse']) Callable[[Tensor, Tensor], Tensor] [source]
- ml.tasks.losses.loss.loss_fn(loss: Literal['l1']) Callable[[Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['huber'], *, huber_beta: float = 1.0) Callable[[Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['pseudo-huber'], *, pseudo_huber_factor: float = 0.00054, dim: int = -1, keepdim: bool = False) Callable[[Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['log-cosh']) Callable[[Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['xent']) Callable[[Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['bce']) Callable[[Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['bce-logits']) Callable[[Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['spectral-convergence']) Callable[[Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['log-stft-magnitude']) Callable[[Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['stft'], *, fft_size: int = 1024, shift_size: int = 120, win_length: int = 600, window_fn: Literal['hann', 'hamming', 'blackman'] = 'hann') Callable[[Tensor, Tensor], tuple[torch.Tensor, torch.Tensor]]
- ml.tasks.losses.loss.loss_fn(loss: Literal['multi-stft'], *, fft_size: int = 1024, shift_size: int = 120, win_length: int = 600, window_fn: Literal['hann', 'hamming', 'blackman'] = 'hann', fft_size_multiples: list[float] = [0.5, 1.0, 2.0]) Callable[[Tensor, Tensor], tuple[torch.Tensor, torch.Tensor]]
- ml.tasks.losses.loss.loss_fn(loss: Literal['mel'], *, sample_rate: int, f_min: float = 0.0, f_max: float | None = None, fft_size: int = 1024, shift_size: int = 120, win_length: int = 600, n_mels: int = 80, window_fn: Literal['hann', 'hamming', 'blackman'] = 'hann') Callable[[Tensor, Tensor], tuple[torch.Tensor, torch.Tensor]]
- ml.tasks.losses.loss.loss_fn(loss: Literal['mfcc'], *, sample_rate: int, f_min: float = 0.0, f_max: float | None = None, fft_size: int = 1024, shift_size: int = 120, win_length: int = 600, n_mels: int = 80, n_mfcc: int = 40, log_mels: bool = False, window_fn: Literal['hann', 'hamming', 'blackman'] = 'hann') Callable[[Tensor, Tensor], tuple[torch.Tensor, torch.Tensor]]
- ml.tasks.losses.loss.loss_fn(loss: Literal['ssim'], *, image_kernel_size: int = 3, ssim_stride: int = 1, ssim_channels: int = 3, ssim_mode: Literal['avg', 'std'] = 'avg', image_sigma: float = 1.0, ssim_dynamic_range: float = 1.0) Callable[[Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['image-grad'], *, image_kernel_size: int = 3, image_sigma: float = 1.0) Callable[[Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['lpips'], *, pretrained: bool = True, requires_grad: bool = False) Callable[[Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['kl-single'], *, clamp_min: float = -30.0, clamp_max: float = 20.0) Callable[[Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['kl-pair'], *, clamp_min: float = -30.0, clamp_max: float = 20.0) Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
- ml.tasks.losses.loss.loss_fn(loss: Literal['bce-logits', 'bce', 'huber', 'pseudo-huber', 'l1', 'log-cosh', 'mse', 'xent', 'log-stft-magnitude', 'mel', 'mfcc', 'multi-stft', 'spectral-convergence', 'stft', 'image-grad', 'lpips', 'ssim', 'kl-pair', 'kl-single'], *, huber_beta: float = 1.0, fft_size: int = 1024, shift_size: int = 120, win_length: int = 600, window_fn: Literal['hann', 'hamming', 'blackman'] = 'hann', fft_size_multiples: list[float] = [0.5, 1.0, 2.0], image_kernel_size: int = 3, ssim_stride: int = 1, ssim_channels: int = 3, ssim_mode: Literal['avg', 'std'] = 'avg', image_sigma: float = 1.0, ssim_dynamic_range: float = 1.0) Callable[[Tensor], Tensor] | Callable[[Tensor, Tensor], Tensor] | Callable[[Tensor, Tensor], tuple[torch.Tensor, torch.Tensor]] | Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
Returns a loss function.
- Parameters:
loss – The loss function to use.
huber_beta – The beta parameter for the Huber loss.
pseudo_huber_factor – The factor parameter for the Pseudo-Huber loss. The default value is taken from the Consistency Model paper.
sample_rate – The sample rate of the audio.
f_min – The minimum frequency for the STFT loss.
f_max – The maximum frequency for the STFT loss.
fft_size – The size of the FFT.
shift_size – The size of the shift.
win_length – The size of the window.
n_mels – The number of mel bins for the mel spectrogram loss.
n_mfcc – The number of MFCCs for the MFCC loss.
log_mels – Whether to use log-mel spectrograms for the MFCC loss.
window_fn – The window function to use.
fft_size_multiples – The multiples of the FFT size to use for the multi-resolution STFT loss.
image_kernel_size – The size of the kernel for the SSIM loss.
ssim_stride – The stride of the kernel for the SSIM loss.
ssim_channels – The number of channels for the SSIM loss.
ssim_mode – The mode for the SSIM loss, either
"avg"
or"sum"
.image_sigma – The sigma parameter for the SSIM loss.
ssim_dynamic_range – The dynamic range parameter for the SSIM loss.
pretrained – Whether to use pretrained weights for the loss function, if the loss function uses a pretrained model.
requires_grad – Whether to require gradients for parameters in the loss function, for parameters which can be disabled.
clamp_min – The minimum value to clamp the input to.
clamp_max – The maximum value to clamp the input to.
dim – If the loss is only applied over a single dimension, this is the dimension to apply the loss over.
keepdim – If the loss is only applied over a single dimension, whether to keep the dimension.
- Returns:
The loss function, as a callable that takes in the input tensor or tensors and returns the loss tensor or tensors.