public static JavaPairRDD knnJoin()

in spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java [769:863]


  public static <U extends Geometry, T extends Geometry> JavaPairRDD<U, T> knnJoin(
      SpatialRDD<U> queryRDD,
      SpatialRDD<T> objectRDD,
      JoinParams joinParams,
      boolean includeTies,
      boolean broadcastJoin)
      throws Exception {
    verifyCRSMatch(queryRDD, objectRDD);
    if (!broadcastJoin) verifyPartitioningNumberMatch(queryRDD, objectRDD);

    SparkContext sparkContext = queryRDD.rawSpatialRDD.context();
    LongAccumulator buildCount = Metrics.createMetric(sparkContext, "buildCount");
    LongAccumulator streamCount = Metrics.createMetric(sparkContext, "streamCount");
    LongAccumulator resultCount = Metrics.createMetric(sparkContext, "resultCount");
    LongAccumulator candidateCount = Metrics.createMetric(sparkContext, "candidateCount");

    final Broadcast<STRtree> broadcastObjectsTreeIndex;
    final Broadcast<List<UniqueGeometry<U>>> broadcastQueryObjects;
    if (broadcastJoin && objectRDD.indexedRawRDD != null && objectRDD.indexedRDD == null) {
      // If broadcastJoin is true and rawIndex is created on object side
      // we will broadcast queryRDD to objectRDD
      List<UniqueGeometry<U>> uniqueQueryObjects = new ArrayList<>();
      for (U queryObject : queryRDD.rawSpatialRDD.collect()) {
        // Wrap the query objects in a UniqueGeometry object to count for duplicate queries in the
        // join
        uniqueQueryObjects.add(new UniqueGeometry<>(queryObject));
      }
      broadcastQueryObjects =
          JavaSparkContext.fromSparkContext(sparkContext).broadcast(uniqueQueryObjects);
      broadcastObjectsTreeIndex = null;
    } else if (broadcastJoin && objectRDD.indexedRawRDD == null && objectRDD.indexedRDD == null) {
      // If broadcastJoin is true and index and rawIndex are NOT created on object side
      // we will broadcast objectRDD to queryRDD
      STRtree strTree = objectRDD.coalesceAndBuildRawIndex(IndexType.RTREE);
      broadcastObjectsTreeIndex =
          JavaSparkContext.fromSparkContext(sparkContext).broadcast(strTree);
      broadcastQueryObjects = null;
    } else {
      // Regular join does not need to set broadcast inderx
      broadcastQueryObjects = null;
      broadcastObjectsTreeIndex = null;
    }

    // The reason for using objectRDD as the right side is that the partitions are built on the
    // right side.
    final JavaRDD<Pair<U, T>> joinResult;
    if (broadcastObjectsTreeIndex == null && broadcastQueryObjects == null) {
      // no broadcast join
      final KnnJoinIndexJudgement<U, T> judgement =
          new KnnJoinIndexJudgement<>(
              joinParams.k,
              joinParams.distanceMetric,
              includeTies,
              null,
              null,
              buildCount,
              streamCount,
              resultCount,
              candidateCount);
      joinResult =
          queryRDD.spatialPartitionedRDD.zipPartitions(objectRDD.spatialPartitionedRDD, judgement);
    } else if (broadcastObjectsTreeIndex != null) {
      // broadcast join with objectRDD as broadcast side
      final KnnJoinIndexJudgement<U, T> judgement =
          new KnnJoinIndexJudgement<>(
              joinParams.k,
              joinParams.distanceMetric,
              includeTies,
              null,
              broadcastObjectsTreeIndex,
              buildCount,
              streamCount,
              resultCount,
              candidateCount);
      // won't need inputs from the shapes in the objectRDD
      joinResult = queryRDD.rawSpatialRDD.mapPartitions(judgement::callUsingBroadcastObjectIndex);
    } else {
      // broadcast join with queryRDD as broadcast side
      final KnnJoinIndexJudgement<UniqueGeometry<U>, T> judgement =
          new KnnJoinIndexJudgement<>(
              joinParams.k,
              joinParams.distanceMetric,
              includeTies,
              broadcastQueryObjects,
              null,
              buildCount,
              streamCount,
              resultCount,
              candidateCount);
      joinResult = querySideBroadcastKNNJoin(objectRDD, joinParams, judgement, includeTies);
    }

    return joinResult.mapToPair(
        (PairFunction<Pair<U, T>, U, T>) pair -> new Tuple2<>(pair.getKey(), pair.getValue()));
  }