in modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java [175:228]
private <K, V> CentroidStat getCentroidStat(DatasetBuilder<K, V> datasetBuilder,
Preprocessor<K, V> vectorizer,
List<Vector> centers) {
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<LabeledVector>> partDataBuilder =
new LabeledDatasetPartitionDataBuilderOnHeap<>(vectorizer);
try (Dataset<EmptyContext, LabeledVectorSet<LabeledVector>> dataset = datasetBuilder.build(
envBuilder,
(env, upstream, upstreamSize) -> new EmptyContext(),
partDataBuilder,
learningEnvironment()
)) {
return dataset.compute(data -> {
CentroidStat res = new CentroidStat();
for (int i = 0; i < data.rowSize(); i++) {
final IgniteBiTuple<Integer, Double> closestCentroid = findClosestCentroid(centers, data.getRow(i));
int centroidIdx = closestCentroid.get1();
double lb = data.label(i);
// add new label to label set
res.labels().add(lb);
ConcurrentHashMap<Double, Integer> centroidStat = res.centroidStat.get(centroidIdx);
if (centroidStat == null) {
centroidStat = new ConcurrentHashMap<>();
centroidStat.put(lb, 1);
res.centroidStat.put(centroidIdx, centroidStat);
}
else {
int cnt = centroidStat.getOrDefault(lb, 0);
centroidStat.put(lb, cnt + 1);
}
res.counts.merge(centroidIdx, 1,
(IgniteBiFunction<Integer, Integer, Integer>)(i1, i2) -> i1 + i2);
}
return res;
}, (a, b) -> {
if (a == null)
return b == null ? new CentroidStat() : b;
if (b == null)
return a;
return a.merge(b);
});
}
catch (Exception e) {
throw new RuntimeException(e);
}
}