Source code for ml.models.kmeans

"""Defines a distributed K-Means module.

This is used to apply K-Means clusters to a tensor. This module can be used
with cluster centers found via Scikit-Learn, Faiss, or other libraries.
"""

import logging
from typing import Callable

import numpy as np
import torch
from torch import Tensor, nn

from ml.utils.triton import supports_triton

logger = logging.getLogger(__name__)


def _vanilla_kmeans(x: Tensor, centers: Tensor, centers_norm: Tensor) -> Tensor:
    # Equivalent code:
    # dist = torch.norm(x[..., None, :] - centers, p=2, dim=-1)
    # return dist.argmin(dim=-1)
    x_norm = (x**2).sum(-1)
    dist = x_norm[..., None] - (2 * (x @ centers.transpose(0, 1))) + centers_norm
    # Absolute value is required here because sometimes the distance
    # can be negative due to numerical instability.
    return dist.abs().argmin(dim=-1)


[docs]def kmeans_fn(cpu: bool) -> Callable[[Tensor, Tensor, Tensor], Tensor]: if cpu or not supports_triton(): return _vanilla_kmeans from ml.utils.triton.kmeans import kmeans as triton_kmeans_fn return triton_kmeans_fn
[docs]class KMeans(nn.Module): __constants__ = ["n_clusters", "n_features"] centers: Tensor centers_norm: Tensor def __init__(self, centers: Tensor | np.ndarray) -> None: super().__init__() n_clusters, n_features = centers.shape self.n_clusters = n_clusters self.n_features = n_features self.register_buffer("centers", torch.empty(n_clusters, n_features), persistent=False) self.register_buffer("centers_norm", torch.empty(n_clusters), persistent=False) self.load_centers(centers) self.kmeans_fn = kmeans_fn(True) self.kmeans_fn_cuda = kmeans_fn(False)
[docs] def load_centers(self, centers: Tensor | np.ndarray) -> None: if isinstance(centers, np.ndarray): centers = torch.from_numpy(centers) assert centers.shape == self.centers.shape, f"Expected shape {self.centers.shape}, got {centers.shape}" self.centers.copy_(centers.to(self.centers)) self.centers_norm.copy_((self.centers**2).sum(-1))
[docs] def forward(self, x: Tensor) -> Tensor: """Applies K-Means to get cluster IDs. We compute ``(x - centers) ^ 2`` by rewriting as ``x ^ 2 - 2 * x * centers + centers ^ 2`` which avoids expanding the tensor when doing the norm. Args: x: The input tensor, with shape ``(*, n_features)`` Returns: The cluster IDs, with shape ``(*)`` """ kmeans_fn = self.kmeans_fn_cuda if x.is_cuda else self.kmeans_fn return kmeans_fn(x, self.centers, self.centers_norm)