in modules/ml-ext/ml/src/main/java/org/apache/ignite/ml/preprocessing/standardscaling/StandardScalerTrainer.java [56:104]
private StandardScalerData computeSum(LearningEnvironmentBuilder envBuilder,
DatasetBuilder<K, V> datasetBuilder,
Preprocessor<K, V> basePreprocessor) {
try (Dataset<EmptyContext, StandardScalerData> dataset = datasetBuilder.build(
envBuilder,
(env, upstream, upstreamSize) -> new EmptyContext(),
(env, upstream, upstreamSize, ctx) -> {
double[] sum = null;
double[] squaredSum = null;
long cnt = 0;
while (upstream.hasNext()) {
UpstreamEntry<K, V> entity = upstream.next();
Vector row = basePreprocessor.apply(entity.getKey(), entity.getValue()).features();
if (sum == null) {
sum = new double[row.size()];
squaredSum = new double[row.size()];
}
else {
assert sum.length == row.size() : "Base preprocessor must return exactly " + sum.length
+ " features";
}
++cnt;
for (int i = 0; i < row.size(); i++) {
double x = row.get(i);
sum[i] += x;
squaredSum[i] += x * x;
}
}
return new StandardScalerData(sum, squaredSum, cnt);
}, learningEnvironment(basePreprocessor)
)) {
return dataset.compute(data -> data,
(a, b) -> {
if (a == null)
return b;
if (b == null)
return a;
return a.merge(b);
});
}
catch (Exception e) {
throw new RuntimeException(e);
}
}