private double predictLabel()

in flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java [155:195]


        private double predictLabel(DenseVector feature) {
            double normSquare = Math.pow(BLAS.norm2(feature), 2);
            BLAS.gemv(-2.0, knnModelData.packedFeatures, true, feature, 0.0, distanceVector);
            for (int i = 0; i < distanceVector.size(); i++) {
                distanceVector.values[i] =
                        Math.sqrt(
                                Math.abs(
                                        distanceVector.values[i]
                                                + normSquare
                                                + knnModelData.featureNormSquares.values[i]));
            }
            PriorityQueue<Tuple2<Double, Double>> nearestKNeighbors =
                    new PriorityQueue<>(
                            Comparator.comparingDouble(distanceAndLabel -> -distanceAndLabel.f0));
            double[] labelValues = knnModelData.labels.values;
            for (int i = 0; i < labelValues.length; ++i) {
                if (nearestKNeighbors.size() < k) {
                    nearestKNeighbors.add(Tuple2.of(distanceVector.get(i), labelValues[i]));
                } else {
                    Tuple2<Double, Double> currentFarthestNeighbor = nearestKNeighbors.peek();
                    if (currentFarthestNeighbor.f0 > distanceVector.get(i)) {
                        nearestKNeighbors.poll();
                        nearestKNeighbors.add(Tuple2.of(distanceVector.get(i), labelValues[i]));
                    }
                }
            }
            Map<Double, Double> labelWeights = new HashMap<>(nearestKNeighbors.size());
            while (!nearestKNeighbors.isEmpty()) {
                Tuple2<Double, Double> distanceAndLabel = nearestKNeighbors.poll();
                labelWeights.merge(distanceAndLabel.f1, 1.0, Double::sum);
            }
            double maxWeight = 0.0;
            double predictedLabel = -1.0;
            for (Map.Entry<Double, Double> entry : labelWeights.entrySet()) {
                if (entry.getValue() > maxWeight) {
                    maxWeight = entry.getValue();
                    predictedLabel = entry.getKey();
                }
            }
            return predictedLabel;
        }