in flink-ml-lib/src/main/java/org/apache/flink/ml/stats/chisqtest/ChiSqTest.java [543:603]
public Tuple3<Integer, Double, Integer> map(Tuple4<Integer, Double, Double, Long> v) {
if (categoricalMargins.isEmpty()) {
List<Tuple3<Integer, Double, Long>> categoricalMarginList =
getRuntimeContext().getBroadcastVariable(bcCategoricalMarginsKey);
List<Tuple3<Integer, Double, Long>> labelMarginList =
getRuntimeContext().getBroadcastVariable(bcLabelMarginsKey);
for (Tuple3<Integer, Double, Long> indexAndFeatureAndCount :
categoricalMarginList) {
index2NumCategories.merge(indexAndFeatureAndCount.f0, 1, Integer::sum);
}
numLabels = (int) labelMarginList.stream().map(x -> x.f1).distinct().count();
for (Tuple3<Integer, Double, Long> indexAndFeatureAndCount :
categoricalMarginList) {
categoricalMargins.put(
new Tuple2<>(indexAndFeatureAndCount.f0, indexAndFeatureAndCount.f1),
indexAndFeatureAndCount.f2);
}
Map<Integer, Double> sampleSizeCount = new HashMap<>();
Integer tmpKey = null;
for (Tuple3<Integer, Double, Long> indexAndLabelAndCount : labelMarginList) {
Integer index = indexAndLabelAndCount.f0;
if (tmpKey == null) {
tmpKey = index;
sampleSizeCount.put(index, 0D);
}
sampleSizeCount.computeIfPresent(
index, (k, count) -> count + indexAndLabelAndCount.f2);
labelMargins.put(
new Tuple2<>(index, indexAndLabelAndCount.f1),
indexAndLabelAndCount.f2);
}
Optional<Double> sampleSizeOpt =
sampleSizeCount.values().stream().reduce(Double::sum);
Preconditions.checkArgument(sampleSizeOpt.isPresent());
sampleSize = sampleSizeOpt.get();
}
Integer index = v.f0;
// Degrees of freedom
int dof = (index2NumCategories.get(index) - 1) * (numLabels - 1);
Tuple2<Integer, Double> category = new Tuple2<>(v.f0, v.f1);
Tuple2<Integer, Double> indexAndLabelKey = new Tuple2<>(v.f0, v.f2);
Long theCategoricalMargin = categoricalMargins.get(category);
Long theLabelMargin = labelMargins.get(indexAndLabelKey);
Long observed = v.f3;
double expected = (double) (theLabelMargin * theCategoricalMargin) / sampleSize;
double categoricalStatistic = pearsonFunc(observed, expected);
return new Tuple3<>(index, categoricalStatistic, dof);
}