ml.tasks.losses.audio

Defines some loss functions which are suitable for audio.

ml.tasks.losses.audio.get_window(window: Literal['hann', 'hamming', 'blackman'], win_length: int) Tensor[source]

Gets a window tensor from a function name.

Parameters:
  • window – The window function name.

  • win_length – The window length.

Returns:

The window tensor, with shape (win_length).

ml.tasks.losses.audio.stft(x: Tensor, fft_size: int, hop_size: int, win_length: int, window: Tensor) Tensor[source]

Perform STFT and convert to magnitude spectrogram.

Parameters:
  • x – Input signal tensor with shape (B, T).

  • fft_size – FFT size.

  • hop_size – Hop size.

  • win_length – Window length.

  • window – The window function.

Returns:

Magnitude spectrogram with shape (B, num_frames, fft_size // 2 + 1).

ml.tasks.losses.audio.spectral_convergence_loss(x_mag: Tensor, y_mag: Tensor, eps: float = 1e-08) Tensor[source]

Spectral convergence loss module.

Parameters:
  • x_mag – Magnitude spectrogram of predicted signal, with shape (B, num_frames, #=num_freq_bins).

  • y_mag – Magnitude spectrogram of groundtruth signal, with shape (B, num_frames, num_freq_bins).

  • eps – A small value to avoid division by zero.

Returns:

Spectral convergence loss value.

ml.tasks.losses.audio.log_stft_magnitude_loss(x_mag: Tensor, y_mag: Tensor, eps: float = 1e-08) Tensor[source]

Log STFT magnitude loss module.

Parameters:
  • x_mag – Magnitude spectrogram of predicted signal (B, num_frames, num_freq_bins).

  • y_mag – Magnitude spectrogram of groundtruth signal (B, num_frames, num_freq_bins).

  • eps – A small value to avoid log(0).

Returns:

Log STFT magnitude loss value.

Return type:

Tensor

ml.tasks.losses.audio.stft_magnitude_loss(x_mag: Tensor, y_mag: Tensor) Tensor[source]

STFT magnitude loss module.

Parameters:
  • x_mag – Magnitude spectrogram of predicted signal (B, num_frames, num_freq_bins).

  • y_mag – Magnitude spectrogram of groundtruth signal (B, num_frames, num_freq_bins).

Returns:

STFT magnitude loss value.

Return type:

Tensor

class ml.tasks.losses.audio.STFTLoss(fft_size: int = 1024, shift_size: int = 120, win_length: int = 600, window: Literal['hann', 'hamming', 'blackman'] = 'hann')[source]

Bases: Module

Defines a STFT loss function.

This function returns two losses which are roughly equivalent, one for minimizing the spectral distance and one for minimizing the log STFT magnitude distance. The spectral convergence loss is defined as:

\[\begin{split}L_{spec} = \\frac{\\|Y - X\\|_F}{\\|Y\\|_F}\end{split}\]

where \(X\) and \(Y\) are the predicted and groundtruth STFT spectrograms, respectively. The log STFT magnitude loss is defined as:

\[\begin{split}L_{mag} = \\frac{\\|\\log Y - \\log X\\|_1}{N}\end{split}\]
Parameters:
  • fft_size – FFT size, meaning the number of Fourier bins.

  • shift_size – Shift size in sample.

  • win_length – Window length in sample.

  • window – Window function type. Choices are hann, hamming and blackman.

Inputs:

x: Predicted signal (B, T). y: Groundtruth signal (B, T).

Outputs:

Spectral convergence loss value and log STFT magnitude loss value.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

window: Tensor
forward(x: Tensor, y: Tensor) tuple[torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ml.tasks.losses.audio.MultiResolutionSTFTLoss(fft_sizes: list[int] = [1024, 2048, 512], hop_sizes: list[int] = [120, 240, 60], win_lengths: list[int] = [600, 1200, 300], window: Literal['hann', 'hamming', 'blackman'] = 'hann', factor_sc: float = 1.0, factor_mag: float = 1.0)[source]

Bases: Module

Multi resolution STFT loss module.

Parameters:
  • fft_sizes – List of FFT sizes.

  • hop_sizes – List of hop sizes.

  • win_lengths – List of window lengths.

  • window – Window function type. Choices are hann, hamming and blackman.

  • factor_sc – A balancing factor across different losses.

  • factor_mag – A balancing factor across different losses.

Inputs:

x: Predicted signal (B, T). y: Groundtruth signal (B, T).

Outputs:

Multi resolution spectral convergence loss value, and multi resolution log STFT magnitude loss value.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, y: Tensor) tuple[torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ml.tasks.losses.audio.MelLoss(sample_rate: int, n_fft: int = 400, win_length: int | None = None, hop_length: int | None = None, f_min: float = 0.0, f_max: float | None = None, n_mels: int = 80, window: Literal['hann', 'hamming', 'blackman'] = 'hann', power: float = 1.0, normalized: bool = False, eps: float = 1e-07)[source]

Bases: Module

Defines a Mel loss function.

This module is similar to STFTLoss, but it uses mel spectrogram instead of the regular STFT, which may be more suitable for speech.

Parameters:
  • sample_rate – Sample rate of the input signal.

  • n_fft – FFT size.

  • win_length – Window length.

  • hop_length – Hop size.

  • f_min – Minimum frequency.

  • f_max – Maximum frequency.

  • n_mels – Number of mel bins.

  • window – Window function name.

  • power – Exponent for the magnitude spectrogram.

  • normalized – Whether to normalize by number of frames.

Inputs:

x: Predicted signal (B, T). y: Groundtruth signal (B, T).

Outputs:

Spectral convergence loss value and log mel spectrogram loss value.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, y: Tensor) tuple[torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ml.tasks.losses.audio.MFCCLoss(sample_rate: int, n_mfcc: int = 40, dct_type: int = 2, norm: str | None = 'ortho', log_mels: bool = False, n_fft: int = 400, win_length: int | None = None, hop_length: int | None = None, f_min: float = 0.0, f_max: float | None = None, n_mels: int = 80, window: Literal['hann', 'hamming', 'blackman'] = 'hann')[source]

Bases: Module

Defines an MFCC loss function.

This is similar to MelLoss, but it uses MFCC instead of mel spectrogram. MFCCs are like the “spectrum of a spectrum” which are usually just used to compress the representation. In the context of a loss function it should be largely equivalent to the mel spectrogram, although it may be more robust to noise.

Parameters:
  • sample_rate – Sample rate of the input signal.

  • n_mfcc – Number of MFCCs.

  • dct_type – DCT type.

  • norm – Norm type.

  • log_mels – Whether to use log-mel spectrograms instead of mel.

  • n_fft – FFT size, for Mel spectrogram.

  • win_length – Window length, for Mel spectrogram.

  • hop_length – Hop size, for Mel spectrogram.

  • f_min – Minimum frequency, for Mel spectrogram.

  • f_max – Maximum frequency, for Mel spectrogram.

  • n_mels – Number of mel bins, for Mel spectrogram.

  • window – Window function name, for Mel spectrogram.

Inputs:

x: Predicted signal (B, T). y: Groundtruth signal (B, T).

Outputs:

Spectral convergence loss value and log mel spectrogram loss value.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, y: Tensor) tuple[torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.