in spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java [769:863]
public static <U extends Geometry, T extends Geometry> JavaPairRDD<U, T> knnJoin(
SpatialRDD<U> queryRDD,
SpatialRDD<T> objectRDD,
JoinParams joinParams,
boolean includeTies,
boolean broadcastJoin)
throws Exception {
verifyCRSMatch(queryRDD, objectRDD);
if (!broadcastJoin) verifyPartitioningNumberMatch(queryRDD, objectRDD);
SparkContext sparkContext = queryRDD.rawSpatialRDD.context();
LongAccumulator buildCount = Metrics.createMetric(sparkContext, "buildCount");
LongAccumulator streamCount = Metrics.createMetric(sparkContext, "streamCount");
LongAccumulator resultCount = Metrics.createMetric(sparkContext, "resultCount");
LongAccumulator candidateCount = Metrics.createMetric(sparkContext, "candidateCount");
final Broadcast<STRtree> broadcastObjectsTreeIndex;
final Broadcast<List<UniqueGeometry<U>>> broadcastQueryObjects;
if (broadcastJoin && objectRDD.indexedRawRDD != null && objectRDD.indexedRDD == null) {
// If broadcastJoin is true and rawIndex is created on object side
// we will broadcast queryRDD to objectRDD
List<UniqueGeometry<U>> uniqueQueryObjects = new ArrayList<>();
for (U queryObject : queryRDD.rawSpatialRDD.collect()) {
// Wrap the query objects in a UniqueGeometry object to count for duplicate queries in the
// join
uniqueQueryObjects.add(new UniqueGeometry<>(queryObject));
}
broadcastQueryObjects =
JavaSparkContext.fromSparkContext(sparkContext).broadcast(uniqueQueryObjects);
broadcastObjectsTreeIndex = null;
} else if (broadcastJoin && objectRDD.indexedRawRDD == null && objectRDD.indexedRDD == null) {
// If broadcastJoin is true and index and rawIndex are NOT created on object side
// we will broadcast objectRDD to queryRDD
STRtree strTree = objectRDD.coalesceAndBuildRawIndex(IndexType.RTREE);
broadcastObjectsTreeIndex =
JavaSparkContext.fromSparkContext(sparkContext).broadcast(strTree);
broadcastQueryObjects = null;
} else {
// Regular join does not need to set broadcast inderx
broadcastQueryObjects = null;
broadcastObjectsTreeIndex = null;
}
// The reason for using objectRDD as the right side is that the partitions are built on the
// right side.
final JavaRDD<Pair<U, T>> joinResult;
if (broadcastObjectsTreeIndex == null && broadcastQueryObjects == null) {
// no broadcast join
final KnnJoinIndexJudgement<U, T> judgement =
new KnnJoinIndexJudgement<>(
joinParams.k,
joinParams.distanceMetric,
includeTies,
null,
null,
buildCount,
streamCount,
resultCount,
candidateCount);
joinResult =
queryRDD.spatialPartitionedRDD.zipPartitions(objectRDD.spatialPartitionedRDD, judgement);
} else if (broadcastObjectsTreeIndex != null) {
// broadcast join with objectRDD as broadcast side
final KnnJoinIndexJudgement<U, T> judgement =
new KnnJoinIndexJudgement<>(
joinParams.k,
joinParams.distanceMetric,
includeTies,
null,
broadcastObjectsTreeIndex,
buildCount,
streamCount,
resultCount,
candidateCount);
// won't need inputs from the shapes in the objectRDD
joinResult = queryRDD.rawSpatialRDD.mapPartitions(judgement::callUsingBroadcastObjectIndex);
} else {
// broadcast join with queryRDD as broadcast side
final KnnJoinIndexJudgement<UniqueGeometry<U>, T> judgement =
new KnnJoinIndexJudgement<>(
joinParams.k,
joinParams.distanceMetric,
includeTies,
broadcastQueryObjects,
null,
buildCount,
streamCount,
resultCount,
candidateCount);
joinResult = querySideBroadcastKNNJoin(objectRDD, joinParams, judgement, includeTies);
}
return joinResult.mapToPair(
(PairFunction<Pair<U, T>, U, T>) pair -> new Tuple2<>(pair.getKey(), pair.getValue()));
}