private void alignAndComputeModelData()

in flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java [271:325]


        private void alignAndComputeModelData() throws Exception {
            if (!modelDataState.get().iterator().hasNext()
                    || !localBatchDataState.get().iterator().hasNext()) {
                return;
            }

            KMeansModelData modelData =
                    OperatorStateUtils.getUniqueElement(modelDataState, "modelData").get();
            DenseVector[] centroids = modelData.centroids;
            VectorWithNorm[] centroidsWithNorm = new VectorWithNorm[modelData.centroids.length];
            for (int i = 0; i < centroidsWithNorm.length; i++) {
                centroidsWithNorm[i] = new VectorWithNorm(modelData.centroids[i]);
            }
            DenseVector weights = modelData.weights;
            modelDataState.clear();

            List<DenseVector[]> pointsList =
                    IteratorUtils.toList(localBatchDataState.get().iterator());
            DenseVector[] points = pointsList.remove(0);
            localBatchDataState.update(pointsList);

            int dim = centroids[0].size();
            int parallelism = getRuntimeContext().getNumberOfParallelSubtasks();

            // Computes new centroids.
            DenseVector[] sums = new DenseVector[k];
            int[] counts = new int[k];

            for (int i = 0; i < k; i++) {
                sums[i] = new DenseVector(dim);
            }
            for (DenseVector point : points) {
                int closestCentroidId =
                        distanceMeasure.findClosest(centroidsWithNorm, new VectorWithNorm(point));
                counts[closestCentroidId]++;
                BLAS.axpy(1.0, point, sums[closestCentroidId]);
            }

            // Considers weight and decay factor when updating centroids.
            BLAS.scal(decayFactor / parallelism, weights);
            for (int i = 0; i < k; i++) {
                if (counts[i] == 0) {
                    continue;
                }

                DenseVector centroid = centroids[i];
                weights.values[i] = weights.values[i] + counts[i];
                double lambda = counts[i] / weights.values[i];

                BLAS.scal(1.0 - lambda, centroid);
                BLAS.axpy(lambda / counts[i], sums[i], centroid);
            }

            output.collect(new StreamRecord<>(new KMeansModelData(centroids, weights)));
        }