in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java [317:375]
public void mapPartition(
Iterable<Tuple3<Double, Integer, Map<Double, Double>[]>> iterable,
Collector<NaiveBayesModelData> collector) {
ArrayList<Tuple3<Double, Integer, Map<Double, Double>[]>> list = new ArrayList<>();
iterable.iterator().forEachRemaining(list::add);
final int featureSize = list.get(0).f2.length;
for (Tuple3<Double, Integer, Map<Double, Double>[]> tup : list) {
Preconditions.checkArgument(
featureSize == tup.f2.length, "Feature vectors should be of equal length.");
}
double[] numDocs = new double[featureSize];
HashSet<Double>[] categoryNumbers = new HashSet[featureSize];
for (int i = 0; i < featureSize; i++) {
categoryNumbers[i] = new HashSet<>();
}
for (Tuple3<Double, Integer, Map<Double, Double>[]> tup : list) {
for (int i = 0; i < featureSize; i++) {
numDocs[i] += tup.f1;
categoryNumbers[i].addAll(tup.f2[i].keySet());
}
}
int[] categoryNumber = new int[featureSize];
double piLog = 0;
int numLabels = list.size();
for (int i = 0; i < featureSize; i++) {
categoryNumber[i] = categoryNumbers[i].size();
piLog += numDocs[i];
}
piLog = Math.log(piLog + numLabels * smoothing);
Map<Double, Double>[][] theta = new HashMap[numLabels][featureSize];
double[] piArray = new double[numLabels];
double[] labels = new double[numLabels];
// Consider smoothing.
for (int i = 0; i < numLabels; i++) {
Map<Double, Double>[] param = list.get(i).f2;
for (int j = 0; j < featureSize; j++) {
Map<Double, Double> squareData = new HashMap<>();
double thetaLog =
Math.log(list.get(i).f1 * 1.0 + smoothing * categoryNumber[j]);
for (Double cate : categoryNumbers[j]) {
double value = param[j].getOrDefault(cate, 0.0);
squareData.put(cate, Math.log(value + smoothing) - thetaLog);
}
theta[i][j] = squareData;
}
labels[i] = list.get(i).f0;
double weightSum = list.get(i).f1 * featureSize;
piArray[i] = Math.log(weightSum + smoothing) - piLog;
}
NaiveBayesModelData modelData =
new NaiveBayesModelData(theta, Vectors.dense(piArray), Vectors.dense(labels));
collector.collect(modelData);
}