in flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java [205:286]
public void process(
ProcessAllWindowFunction<Row, Row, W>.Context context,
Iterable<Row> values,
Collector<Row> output) {
List<Row> inputList = IteratorUtils.toList(values.iterator());
int numDataPoints = inputList.size();
if (numDataPoints == 0) {
return;
}
DistanceMatrix distanceMatrix = new DistanceMatrix(numDataPoints * 2 - 1);
VectorWithNorm v1, v2;
for (int i = 0; i < numDataPoints; i++) {
v1 = new VectorWithNorm(inputList.get(i).getFieldAs(featuresCol));
for (int j = i + 1; j < numDataPoints; j++) {
v2 = new VectorWithNorm(inputList.get(j).getFieldAs(featuresCol));
distanceMatrix.set(i, j, distanceMeasure.distance(v1, v2));
}
}
HashSet<Integer> nodeLabels = new HashSet<>(numDataPoints);
for (int i = 0; i < numDataPoints; i++) {
nodeLabels.add(i);
}
Tuple2<List<Tuple4<Integer, Integer, Integer, Double>>, int[]> nnChainAndSize =
nnChainCore(nodeLabels, distanceMatrix, linkage);
List<Tuple4<Integer, Integer, Integer, Double>> nnChain = nnChainAndSize.f0;
nnChain.sort(Comparator.comparingDouble(o -> o.f3));
reOrderNnChain(nnChain);
int stoppedIdx = 0;
if (distanceThreshold != null) {
for (Tuple4<Integer, Integer, Integer, Double> mergeItem : nnChain) {
if (mergeItem.f3 <= distanceThreshold) {
stoppedIdx++;
}
}
} else {
stoppedIdx = numDataPoints - numCluster;
}
List<Tuple4<Integer, Integer, Integer, Double>> earlyStoppedNnChain =
nnChain.subList(0, stoppedIdx);
int[] clusterIds = label(earlyStoppedNnChain, nnChain.size() + 1);
// Remaps the cluster Ids and output clustering results.
HashMap<Integer, Integer> remappedClusterIds = new HashMap<>();
int cnt = 0;
for (int i = 0; i < clusterIds.length; i++) {
int clusterId = clusterIds[i];
if (remappedClusterIds.containsKey(clusterId)) {
clusterIds[i] = remappedClusterIds.get(clusterId);
} else {
clusterIds[i] = cnt;
remappedClusterIds.put(clusterId, cnt++);
}
}
for (int i = 0; i < numDataPoints; i++) {
output.collect(RowUtils.append(inputList.get(i), clusterIds[i]));
}
// Outputs the merge info.
if (computeFullTree) {
stoppedIdx = nnChain.size();
}
for (int i = 0; i < stoppedIdx; i++) {
Tuple4<Integer, Integer, Integer, Double> mergeItem = nnChain.get(i);
int cid1 = Math.min(mergeItem.f0, mergeItem.f1);
int cid2 = Math.max(mergeItem.f0, mergeItem.f1);
context.output(
mergeInfoOutputTag,
Tuple4.of(
cid1,
cid2,
mergeItem.f3,
nnChainAndSize.f1[cid1] + nnChainAndSize.f1[cid2]));
}
}