override protected def doExecute()

in spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala [44:144]


  override protected def doExecute(): RDD[InternalRow] = {
    val boundLeftShape = BindReferences.bindReference(leftShape, left.output)
    val boundRightShape = BindReferences.bindReference(rightShape, right.output)

    val leftResultsRaw = left.execute().asInstanceOf[RDD[UnsafeRow]]
    val rightResultsRaw = right.execute().asInstanceOf[RDD[UnsafeRow]]

    val sedonaConf = SedonaConf.fromActiveSession

    val (leftShapes, rightShapes) =
      toSpatialRddPair(leftResultsRaw, boundLeftShape, rightResultsRaw, boundRightShape)

    // Only do SpatialRDD analyze when the user doesn't know approximate total count of the spatial partitioning
    // dominant side rdd
    if (sedonaConf.getJoinApproximateTotalCount == -1) {
      if (sedonaConf.getJoinSpartitionDominantSide == JoinSpartitionDominantSide.LEFT) {
        leftShapes.analyze()
      } else {
        rightShapes.analyze()
      }
    }
    log.info("[SedonaSQL] Number of partitions on the left: " + leftResultsRaw.partitions.size)
    log.info("[SedonaSQL] Number of partitions on the right: " + rightResultsRaw.partitions.size)

    var numPartitions = -1
    try {
      if (sedonaConf.getJoinSpartitionDominantSide == JoinSpartitionDominantSide.LEFT) {
        if (sedonaConf.getFallbackPartitionNum != -1) {
          numPartitions = sedonaConf.getFallbackPartitionNum
        } else {
          numPartitions = joinPartitionNumOptimizer(
            leftShapes.rawSpatialRDD.partitions.size(),
            rightShapes.rawSpatialRDD.partitions.size(),
            leftShapes.approximateTotalCount)
        }
        doSpatialPartitioning(leftShapes, rightShapes, numPartitions, sedonaConf)
      } else {
        if (sedonaConf.getFallbackPartitionNum != -1) {
          numPartitions = sedonaConf.getFallbackPartitionNum
        } else {
          numPartitions = joinPartitionNumOptimizer(
            rightShapes.rawSpatialRDD.partitions.size(),
            leftShapes.rawSpatialRDD.partitions.size(),
            rightShapes.approximateTotalCount)
        }
        doSpatialPartitioning(rightShapes, leftShapes, numPartitions, sedonaConf)
      }
    } catch {
      case e: IllegalArgumentException => {
        print(e.getMessage)
        // Partition number are not qualified
        // Use fallback num partitions specified in SedonaConf
        if (sedonaConf.getJoinSpartitionDominantSide == JoinSpartitionDominantSide.LEFT) {
          numPartitions = sedonaConf.getFallbackPartitionNum
          doSpatialPartitioning(leftShapes, rightShapes, numPartitions, sedonaConf)
        } else {
          numPartitions = sedonaConf.getFallbackPartitionNum
          doSpatialPartitioning(rightShapes, leftShapes, numPartitions, sedonaConf)
        }
      }
    }

    val joinParams = new JoinParams(
      sedonaConf.getUseIndex,
      spatialPredicate,
      sedonaConf.getIndexType,
      sedonaConf.getJoinBuildSide)

    // logInfo(s"leftShape count ${leftShapes.spatialPartitionedRDD.count()}")
    // logInfo(s"rightShape count ${rightShapes.spatialPartitionedRDD.count()}")

    val matchesRDD: RDD[(Geometry, Geometry)] =
      (leftShapes.spatialPartitionedRDD, rightShapes.spatialPartitionedRDD) match {
        case (null, null) =>
          // Dominant side is empty, skipped creating partitioned RDDs. Result of join should also be empty.
          sparkContext.parallelize(Seq[(Geometry, Geometry)]())
        case _ => JoinQuery.spatialJoin(leftShapes, rightShapes, joinParams).rdd
      }

    logDebug(s"Join result has ${matchesRDD.count()} rows")

    matchesRDD.mapPartitions { iter =>
      val joinRow = {
        val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
        (l: UnsafeRow, r: UnsafeRow) => joiner.join(l, r)
      }

      val joined = iter.map { case (l, r) =>
        val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
        val rightRow = r.getUserData.asInstanceOf[UnsafeRow]
        joinRow(leftRow, rightRow)
      }

      extraCondition match {
        case Some(condition) =>
          val boundCondition = Predicate.create(condition, output)
          joined.filter(row => boundCondition.eval(row))
        case None => joined
      }
    }
  }