ml.models.quantization.vq

Defines modules for doing vector quantization (or codebook learning).

VQ is a general category of techniques concerned with mapping a continuous distribution to a discrete one. This module learns a set of codebooks which are used to discretize the input.

class ml.models.quantization.vq.VectorQuantization(dim: int, codebook_size: int, codebook_dim: int | None = None, decay: float = 0.99, epsilon: float = 1e-05, kmeans_init: bool = False, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, commitment_weight: float = 1.0)[source]

Bases: Module

Vector quantization implementation.

The codebook itself doesn’t learn any parameters using backpropagation. Instead, it uses an exponential moving average to update the codebooks. For a given batch, we assign each vector to the nearest codebook item. We then get the mean of each cluster and update the codebook towards that vector using an exponential moving average.

Parameters:
  • dim – Dimension.

  • codebook_size – Codebook size (i.e., the number of codes).

  • codebook_dim – Codebook dimension. If not defined, uses the specified dimension in dim.

  • decay – Decay for exponential moving average over the codebooks.

  • epsilon – Epsilon value for numerical stability.

  • kmeans_init – Whether to use kmeans to initialize the codebooks.

  • kmeans_iters – Number of iterations used for kmeans initialization.

  • threshold_ema_dead_code – Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch.

  • commitment_weight – Weight for commitment loss.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

property codebook: Tensor
encode(x: Tensor) Tensor[source]
decode(embed_ind: Tensor) Tensor[source]
forward(x: Tensor) tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class ml.models.quantization.vq.ResidualVectorQuantization(vq_module: VectorQuantization, num_quantizers: int)[source]

Bases: Module

Residual vector quantization impementation.

This module is a wrapper around multiple vector quantization modules. It applies the quantization sequentially and adds the quantized output to the residual.

Parameters:
  • vq_module – Vector quantization module to wrap.

  • num_quantizers – Number of quantizers to use.

Example:

vq_module = VectorQuantization(128, 512)
rvq_module = ResidualVectorQuantization(vq_module, 4)
x = torch.randn(1, 128, 32)
quantized, indices, loss = rvq_module(x)
Input:

x: Tensor of shape (batch_size, seq_len, dim).

Output:

quantized: Tensor of shape (batch_size, seq_len, dim). indices: Tensor of shape (batch_size, seq_len). loss: Tensor of shape (codebook_size).

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, n_q: int | None = None) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor][source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

encode(x: Tensor, n_q: int | None = None) Tensor[source]
decode(q_indices: Tensor) Tensor[source]