in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java [155:195]
private double predictLabel(DenseVector feature) {
double normSquare = Math.pow(BLAS.norm2(feature), 2);
BLAS.gemv(-2.0, knnModelData.packedFeatures, true, feature, 0.0, distanceVector);
for (int i = 0; i < distanceVector.size(); i++) {
distanceVector.values[i] =
Math.sqrt(
Math.abs(
distanceVector.values[i]
+ normSquare
+ knnModelData.featureNormSquares.values[i]));
}
PriorityQueue<Tuple2<Double, Double>> nearestKNeighbors =
new PriorityQueue<>(
Comparator.comparingDouble(distanceAndLabel -> -distanceAndLabel.f0));
double[] labelValues = knnModelData.labels.values;
for (int i = 0; i < labelValues.length; ++i) {
if (nearestKNeighbors.size() < k) {
nearestKNeighbors.add(Tuple2.of(distanceVector.get(i), labelValues[i]));
} else {
Tuple2<Double, Double> currentFarthestNeighbor = nearestKNeighbors.peek();
if (currentFarthestNeighbor.f0 > distanceVector.get(i)) {
nearestKNeighbors.poll();
nearestKNeighbors.add(Tuple2.of(distanceVector.get(i), labelValues[i]));
}
}
}
Map<Double, Double> labelWeights = new HashMap<>(nearestKNeighbors.size());
while (!nearestKNeighbors.isEmpty()) {
Tuple2<Double, Double> distanceAndLabel = nearestKNeighbors.poll();
labelWeights.merge(distanceAndLabel.f1, 1.0, Double::sum);
}
double maxWeight = 0.0;
double predictedLabel = -1.0;
for (Map.Entry<Double, Double> entry : labelWeights.entrySet()) {
if (entry.getValue() > maxWeight) {
maxWeight = entry.getValue();
predictedLabel = entry.getKey();
}
}
return predictedLabel;
}