private CentroidStat getCentroidStat()

in modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java [175:228]


    private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder,
                                                Preprocessor<K, V> vectorizer,
                                                List<Vector> centers) {

        PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder =
            new LabeledDatasetPartitionDataBuilderOnHeap<>(vectorizer);

        try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(
            envBuilder,
            (env, upstream, upstreamSize) -> new EmptyContext(),
            partDataBuilder,
            learningEnvironment()
        )) {
            return dataset.compute(data -> {
                CentroidStat res = new CentroidStat();

                for (int i = 0; i < data.rowSize(); i++) {
                    final IgniteBiTuple<Integer, Double> closestCentroid = findClosestCentroid(centers, data.getRow(i));

                    int centroidIdx = closestCentroid.get1();

                    double lb = data.label(i);

                    // add new label to label set
                    res.labels().add(lb);

                    ConcurrentHashMap<Double, Integer> centroidStat = res.centroidStat.get(centroidIdx);

                    if (centroidStat == null) {
                        centroidStat = new ConcurrentHashMap<>();
                        centroidStat.put(lb, 1);
                        res.centroidStat.put(centroidIdx, centroidStat);
                    }
                    else {
                        int cnt = centroidStat.getOrDefault(lb, 0);
                        centroidStat.put(lb, cnt + 1);
                    }

                    res.counts.merge(centroidIdx, 1,
                        (IgniteBiFunction<Integer, Integer, Integer>)(i1, i2) -> i1 + i2);
                }
                return res;
            }, (a, b) -> {
                if (a == null)
                    return b == null ? new CentroidStat() : b;
                if (b == null)
                    return a;
                return a.merge(b);
            });
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }