ml.models.kmeans

Defines a distributed K-Means module.

This is used to apply K-Means clusters to a tensor. This module can be used with cluster centers found via Scikit-Learn, Faiss, or other libraries.

ml.models.kmeans.kmeans_fn(cpu: bool) Callable[[Tensor, Tensor, Tensor], Tensor][source]
class ml.models.kmeans.KMeans(centers: Tensor | ndarray)[source]

Bases: Module

Initializes internal Module state, shared by both nn.Module and ScriptModule.

centers: Tensor
centers_norm: Tensor
load_centers(centers: Tensor | ndarray) None[source]
forward(x: Tensor) Tensor[source]

Applies K-Means to get cluster IDs.

We compute (x - centers) ^ 2 by rewriting as x ^ 2 - 2 * x * centers + centers ^ 2 which avoids expanding the tensor when doing the norm.

Parameters:

x – The input tensor, with shape (*, n_features)

Returns:

The cluster IDs, with shape (*)