ml.models.kmeans
Defines a distributed K-Means module.
This is used to apply K-Means clusters to a tensor. This module can be used with cluster centers found via Scikit-Learn, Faiss, or other libraries.
- class ml.models.kmeans.KMeans(centers: Tensor | ndarray)[source]
Bases:
Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- centers: Tensor
- centers_norm: Tensor
- forward(x: Tensor) Tensor [source]
Applies K-Means to get cluster IDs.
We compute
(x - centers) ^ 2
by rewriting asx ^ 2 - 2 * x * centers + centers ^ 2
which avoids expanding the tensor when doing the norm.- Parameters:
x – The input tensor, with shape
(*, n_features)
- Returns:
The cluster IDs, with shape
(*)