"""Provides an implementation of Lookup-Free Quantization (LFQ).
LFQ is from the paper `Language Model Beats Diffusion - Tokenizer is Key to
Visual Generation <https://arxiv.org/abs/2310.05737>`_, which purports to
beat image generation using language models simply by using a high-quality
tokenizer.
"""
import math
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch import Tensor, nn
[docs]@dataclass
class Losses:
per_sample_entropy: Tensor
batch_entropy: Tensor
commitment: Tensor
[docs] def single(self) -> Tensor:
# Returns a single loss as the mean of the three losses.
return self.per_sample_entropy.mean() + self.batch_entropy.mean() + self.commitment.mean()
[docs]def euclidean_distance_squared(x: Tensor, y: Tensor) -> Tensor:
x2, y2 = (x**2).sum(-1), (y**2).sum(-1)
xy = torch.einsum("... i d, j d -> ... i j", x, y) * -2
return x2.unsqueeze(-1) + y2 + xy
[docs]def entropy(prob: Tensor, eps: float = 1e-20) -> Tensor:
return -prob * prob.clamp_min(eps).log()
[docs]class LookupFreeQuantization(nn.Module):
def __init__(
self,
*,
dim: int | None = None,
codebook_size: int | None = None,
entropy_loss_weight: float = 0.1,
commitment_loss_weight: float = 1.0,
diversity_gamma: float = 2.5,
num_codebooks: int = 1,
codebook_scale: float = 1.0,
) -> None:
super().__init__()
if dim is None and codebook_size is None:
raise ValueError("Either `dim` or `codebook_size` must be specified for LFQ.")
# Gets the default number of codebooks, or validates if provided.
if codebook_size is None:
assert dim is not None
codebook_size = 2**dim
elif not math.log2(codebook_size).is_integer():
suggested = 2 ** math.ceil(math.log2(codebook_size))
raise ValueError(f"Your codebook size must be a power of 2 (suggested {suggested})")
# Gets the default input dimension, or validates if provided.
codebook_dim = round(math.log2(codebook_size))
codebook_dims = codebook_dim * num_codebooks
if dim is None:
assert codebook_size is not None
dim = codebook_dims
# Projects the input to the codebook dimension.
self.project_in = nn.Linear(dim, codebook_dims) if dim != codebook_dims else nn.Identity()
self.project_out = nn.Linear(codebook_dims, dim) if dim != codebook_dims else nn.Identity()
self.dim = dim
self.codebook_dim = codebook_dim
self.num_codebooks = num_codebooks
self.codebook_scale = codebook_scale
# Stores loss weights.
self.diversity_gamma = diversity_gamma
self.entropy_loss_weight = entropy_loss_weight
self.commitment_loss_weight = commitment_loss_weight
# Default loss values to use during inference.
self.register_buffer("mask", 2 ** torch.arange(codebook_dim - 1, -1, -1), persistent=False)
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
# For converting indices to codes.
all_codes = torch.arange(codebook_size)
bits = ((all_codes[..., None].int() & self.mask) != 0).float()
codebook = self.bits_to_codes(bits)
self.register_buffer("codebook", codebook, persistent=False)
mask: Tensor
zero: Tensor
codebook: Tensor
[docs] def bits_to_codes(self, bits: Tensor) -> Tensor:
return bits * self.codebook_scale * 2 - self.codebook_scale
@property
def dtype(self) -> torch.dtype:
return self.codebook.dtype
[docs] def forward(self, x: Tensor, inv_temperature: float = 1.0) -> tuple[Tensor, Tensor, Losses]:
if x.shape[-1] != self.dim:
raise ValueError(f"Expected final dimension to be {self.dim}, but got input shape {x.shape}")
# Projects to codebook dimensions.
x = self.project_in(x)
x = x.unflatten(-1, (self.num_codebooks, self.codebook_dim))
# Does binary quantization of the codebook.
original_input = x
codebook_value = torch.ones_like(x) * self.codebook_scale
x_pos = x > 0
quantized = torch.where(x_pos, codebook_value, -codebook_value)
# If training, apply straight-through estimator.
if self.training:
x = x - x.detach() + quantized
else:
x = quantized
# Compute the unique codebook indices.
indices = (x_pos.int() * self.mask.int()).sum(-1)
# Computes entropy losses if training (otherwise, just return zeros).
if self.training:
distance = euclidean_distance_squared(original_input, self.codebook)
prob = (-distance * inv_temperature).softmax(dim=-1)
per_sample_entropy = entropy(prob).mean()
avg_prob = prob.flatten(0, -2).mean(dim=0)
codebook_entropy = entropy(avg_prob).mean()
else:
per_sample_entropy = codebook_entropy = self.zero
# Computes codebook commitment losses.
if self.training:
commit_loss = F.mse_loss(original_input, quantized.detach())
else:
commit_loss = self.zero
# Projects the quantized codebook back to the original dimension.
x = x.flatten(-2)
x = self.project_out(x)
# Gets the losses dataclass.
losses = Losses(
per_sample_entropy=per_sample_entropy * self.entropy_loss_weight,
batch_entropy=codebook_entropy * self.entropy_loss_weight * -self.diversity_gamma,
commitment=commit_loss * self.commitment_loss_weight,
)
return x, indices, losses