Source code for ml.models.quantization.vq

"""Defines modules for doing vector quantization (or codebook learning).

VQ is a general category of techniques concerned with mapping a continuous
distribution to a discrete one. This module learns a set of codebooks which
are used to discretize the input.
"""

import copy
from typing import cast

import torch
import torch.distributed
import torch.nn.functional as F
from torch import Tensor, nn

from ml.models.modules import swap_grads


def _ema_inplace(moving_avg: Tensor, new: Tensor, decay: float) -> None:
    moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))


def _laplace_smoothing(x: Tensor, n_categories: int, epsilon: float = 1e-5) -> Tensor:
    return (x + epsilon) / (x.sum() + n_categories * epsilon)


def _sample_vectors(samples: Tensor, num: int) -> Tensor:
    num_samples, device = samples.shape[0], samples.device
    if num_samples >= num:
        indices = torch.randperm(num_samples, device=device)[:num]
    else:
        indices = torch.randint(0, num_samples, (num,), device=device)
    return samples[indices]


def _kmeans(samples: Tensor, num_clusters: int, num_iters: int = 10) -> tuple[Tensor, Tensor]:
    dim, dtype = samples.shape[-1], samples.dtype

    means = _sample_vectors(samples, num_clusters)

    for _ in range(num_iters):
        diffs = samples.unsqueeze(1) - means.unsqueeze(0)
        dists = -(diffs**2).sum(dim=-1)

        buckets = dists.max(dim=-1).indices
        bins = torch.bincount(buckets, minlength=num_clusters)
        zero_mask = bins == 0
        bins_min_clamped = bins.masked_fill(zero_mask, 1)

        new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
        new_means.scatter_add_(0, buckets.unsqueeze(1).repeat(1, dim), samples)
        new_means = new_means / bins_min_clamped[..., None]

        means = torch.where(zero_mask[..., None], means, new_means)

    return means, bins


class _EuclideanCodebook(nn.Module):
    """Codebook with Euclidean distance.

    Parameters:
        dim: Dimension.
        codebook_size: Codebook size (i.e., the number of codes).
        kmeans_init: Whether to use k-means to initialize the codebooks. If set
            to true, run the k-means algorithm on the first training batch and
            use the learned centroids as initialization.
        kmeans_iters: Number of iterations used for k-means algorithm at
            initialization.
        decay: Decay for exponential moving average over the codebooks.
        epsilon: Epsilon value for numerical stability.
        threshold_ema_dead_code: Threshold for dead code expiration. Replace
            any codes that have an exponential moving average cluster size less
            than the specified threshold with randomly selected vector from the
            current batch.
    """

    def __init__(
        self,
        dim: int,
        codebook_size: int,
        kmeans_init: bool = False,
        kmeans_iters: int = 10,
        decay: float = 0.99,
        epsilon: float = 1e-5,
        threshold_ema_dead_code: int = 2,
    ) -> None:
        super().__init__()

        self.decay = decay

        embed = torch.empty(codebook_size, dim)
        if not kmeans_init:
            nn.init.kaiming_uniform_(embed)

        self.codebook_size = codebook_size

        self.kmeans_iters = kmeans_iters
        self.epsilon = epsilon
        self.threshold_ema_dead_code = threshold_ema_dead_code

        self.all_reduce_fn = torch.distributed.all_reduce if torch.distributed.is_initialized() else lambda x: x

        self.register_buffer("inited", Tensor([not kmeans_init]))
        self.register_buffer("cluster_size", torch.zeros(codebook_size))
        self.register_buffer("embed", embed)
        self.register_buffer("embed_avg", embed.clone())

    inited: Tensor
    cluster_size: Tensor
    embed: Tensor
    embed_avg: Tensor

    @torch.jit.ignore
    def init_embed_(self, data: Tensor) -> None:
        if self.inited:
            return

        embed, cluster_size = _kmeans(data, self.codebook_size, self.kmeans_iters)
        self.embed.data.copy_(embed)
        self.embed_avg.data.copy_(embed.clone())
        self.cluster_size.data.copy_(cluster_size)
        self.inited.data.copy_(Tensor([True]))

    def replace_(self, samples: Tensor, mask: Tensor) -> None:
        modified_codebook = torch.where(mask[..., None], _sample_vectors(samples, self.codebook_size), self.embed)
        self.embed.data.copy_(modified_codebook)

    def expire_codes_(self, batch_samples: Tensor) -> None:
        if self.threshold_ema_dead_code == 0:
            return

        expired_codes = self.cluster_size < self.threshold_ema_dead_code
        if not torch.any(expired_codes):
            return

        batch_samples = batch_samples.flatten(0, -2)
        self.replace_(batch_samples, mask=expired_codes)

    def preprocess(self, x: Tensor) -> Tensor:
        return x.flatten(0, -2)

    def quantize(self, x: Tensor) -> Tensor:
        embed = self.embed.t()
        dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
        embed_ind = dist.max(dim=-1).indices
        return embed_ind

    def postprocess_emb(self, embed_ind: Tensor, shape: torch.Size) -> Tensor:
        return embed_ind.view(*shape[:-1])

    def dequantize(self, embed_ind: Tensor) -> Tensor:
        quantize = F.embedding(embed_ind, self.embed)
        return quantize

    def encode(self, x: Tensor) -> Tensor:
        shape = x.shape
        x = self.preprocess(x)
        embed_ind = self.quantize(x)
        embed_ind = self.postprocess_emb(embed_ind, shape)
        return embed_ind

    def decode(self, embed_ind: Tensor) -> Tensor:
        quantize = self.dequantize(embed_ind)
        return quantize

    def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
        shape, dtype = x.shape, x.dtype
        x = self.preprocess(x)

        self.init_embed_(x)

        embed_ind = self.quantize(x)
        embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
        embed_ind = self.postprocess_emb(embed_ind, shape)
        quantize = self.dequantize(embed_ind)

        if self.training:
            # We do the expiry of code at that point as buffers are in sync
            # and all the workers will take the same decision.
            cluster_size = embed_onehot.sum(0)
            self.all_reduce_fn(cluster_size)
            _ema_inplace(self.cluster_size, cluster_size, self.decay)
            embed_sum = x.t() @ embed_onehot
            self.all_reduce_fn(embed_sum)
            _ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
            smoothed_sizes = _laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
            cluster_size = smoothed_sizes * self.cluster_size.sum()
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
            self.embed.data.copy_(embed_normalized)
            self.expire_codes_(x)

        return quantize, embed_ind


[docs]class VectorQuantization(nn.Module): """Vector quantization implementation. The codebook itself doesn't learn any parameters using backpropagation. Instead, it uses an exponential moving average to update the codebooks. For a given batch, we assign each vector to the nearest codebook item. We then get the mean of each cluster and update the codebook towards that vector using an exponential moving average. Parameters: dim: Dimension. codebook_size: Codebook size (i.e., the number of codes). codebook_dim: Codebook dimension. If not defined, uses the specified dimension in dim. decay: Decay for exponential moving average over the codebooks. epsilon: Epsilon value for numerical stability. kmeans_init: Whether to use kmeans to initialize the codebooks. kmeans_iters: Number of iterations used for kmeans initialization. threshold_ema_dead_code: Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. commitment_weight: Weight for commitment loss. """ def __init__( self, dim: int, codebook_size: int, codebook_dim: int | None = None, decay: float = 0.99, epsilon: float = 1e-5, kmeans_init: bool = False, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, commitment_weight: float = 1.0, ) -> None: super().__init__() _codebook_dim = dim if codebook_dim is None else codebook_dim requires_projection = _codebook_dim != dim self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() self.epsilon = epsilon self.commitment_weight = commitment_weight self._codebook = _EuclideanCodebook( dim=_codebook_dim, codebook_size=codebook_size, kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, decay=decay, epsilon=epsilon, threshold_ema_dead_code=threshold_ema_dead_code, ) self.codebook_size = codebook_size @property def codebook(self) -> Tensor: return self._codebook.embed
[docs] def encode(self, x: Tensor) -> Tensor: x = self.project_in(x) embed_in = self._codebook.encode(x) return embed_in
[docs] def decode(self, embed_ind: Tensor) -> Tensor: quantize = self._codebook.decode(embed_ind) quantize = self.project_out(quantize) return quantize
[docs] def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: device = x.device x = self.project_in(x) quantize, embed_ind = self._codebook(x) loss = torch.tensor([0.0], device=device, requires_grad=self.training) if self.training and self.commitment_weight > 0: commit_loss = F.mse_loss(quantize.detach(), x) loss = loss + commit_loss * self.commitment_weight # Backpropagates gradients from `quantize` to `x`. Same as # ``quantize = x + (quantize - x).detach()`` but uses an autograd # function instead. Note that the original `quantize` doesn't require # gradients since it is the EMA of the codebook. quantize, _ = swap_grads(quantize, x) quantize = self.project_out(quantize) return quantize, embed_ind, loss
[docs]class ResidualVectorQuantization(nn.Module): """Residual vector quantization impementation. This module is a wrapper around multiple vector quantization modules. It applies the quantization sequentially and adds the quantized output to the residual. Parameters: vq_module: Vector quantization module to wrap. num_quantizers: Number of quantizers to use. Example:: vq_module = VectorQuantization(128, 512) rvq_module = ResidualVectorQuantization(vq_module, 4) x = torch.randn(1, 128, 32) quantized, indices, loss = rvq_module(x) Input: x: Tensor of shape ``(batch_size, seq_len, dim)``. Output: quantized: Tensor of shape ``(batch_size, seq_len, dim)``. indices: Tensor of shape ``(batch_size, seq_len)``. loss: Tensor of shape ``(codebook_size)``. """ __constants__ = ["codebook_size", "num_quantizers"] def __init__(self, vq_module: VectorQuantization, num_quantizers: int) -> None: super().__init__() self.codebook_size = vq_module.codebook_size self.num_quantizers = num_quantizers self.layers = cast( list[VectorQuantization], nn.ModuleList([self._get_copy(vq_module) for _ in range(num_quantizers)]), ) def _get_copy(self, vq_module: VectorQuantization) -> VectorQuantization: return copy.deepcopy(vq_module)
[docs] def forward(self, x: Tensor, n_q: int | None = None) -> tuple[Tensor, Tensor, Tensor, Tensor]: quantized_out: Tensor | None = None residual = x all_losses = [] all_indices = [] all_quantized = [] n_q = n_q or len(self.layers) for layer in self.layers[:n_q]: quantized, indices, loss = layer(residual) residual = residual - quantized.detach() quantized_out = quantized if quantized_out is None else quantized_out + quantized all_indices.append(indices) all_losses.append(loss) all_quantized.append(quantized) out_losses = torch.stack(all_losses) out_indices = torch.stack(all_indices) out_quant = torch.stack(all_quantized) assert quantized_out is not None return quantized_out, out_indices, out_losses, out_quant
[docs] def encode(self, x: Tensor, n_q: int | None = None) -> Tensor: residual = x all_indices = [] n_q = n_q or len(self.layers) for layer in self.layers[:n_q]: indices = layer.encode(residual) quantized = layer.decode(indices) residual = residual - quantized.detach() all_indices.append(indices) out_indices = torch.stack(all_indices, dim=-1) return out_indices
[docs] def decode(self, q_indices: Tensor) -> Tensor: quantized_out = torch.tensor(0.0, device=q_indices.device) for i, indices in enumerate(q_indices.unbind(-1)): layer = self.layers[i] quantized = layer.decode(indices) quantized_out = quantized_out + quantized return quantized_out