in flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java [271:325]
private void alignAndComputeModelData() throws Exception {
if (!modelDataState.get().iterator().hasNext()
|| !localBatchDataState.get().iterator().hasNext()) {
return;
}
KMeansModelData modelData =
OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
DenseVector[] centroids = modelData.centroids;
VectorWithNorm[] centroidsWithNorm = new VectorWithNorm[modelData.centroids.length];
for (int i = 0; i < centroidsWithNorm.length; i++) {
centroidsWithNorm[i] = new VectorWithNorm(modelData.centroids[i]);
}
DenseVector weights = modelData.weights;
modelDataState.clear();
List<DenseVector[]> pointsList =
IteratorUtils.toList(localBatchDataState.get().iterator());
DenseVector[] points = pointsList.remove(0);
localBatchDataState.update(pointsList);
int dim = centroids[0].size();
int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();
// Computes new centroids.
DenseVector[] sums = new DenseVector[k];
int[] counts = new int[k];
for (int i = 0; i < k; i++) {
sums[i] = new DenseVector(dim);
}
for (DenseVector point : points) {
int closestCentroidId =
distanceMeasure.findClosest(centroidsWithNorm, new VectorWithNorm(point));
counts[closestCentroidId]++;
BLAS.axpy(1.0, point, sums[closestCentroidId]);
}
// Considers weight and decay factor when updating centroids.
BLAS.scal(decayFactor / parallelism, weights);
for (int i = 0; i < k; i++) {
if (counts[i] == 0) {
continue;
}
DenseVector centroid = centroids[i];
weights.values[i] = weights.values[i] + counts[i];
double lambda = counts[i] / weights.values[i];
BLAS.scal(1.0 - lambda, centroid);
BLAS.axpy(lambda / counts[i], sums[i], centroid);
}
output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
}