private def planBroadcastJoin()

in spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala [696:880]


  private def planBroadcastJoin(
      left: LogicalPlan,
      right: LogicalPlan,
      children: Seq[Expression],
      joinType: JoinType,
      spatialPredicate: SpatialPredicate,
      indexType: IndexType,
      broadcastLeft: Boolean,
      broadcastRight: Boolean,
      isGeography: Boolean,
      extraCondition: Option[Expression],
      distance: Option[Expression]): Seq[SparkPlan] = {

    val broadcastSide = joinType match {
      case Inner if broadcastLeft => Some(LeftSide)
      case Inner if broadcastRight => Some(RightSide)
      case LeftSemi if broadcastRight => Some(RightSide)
      case LeftAnti if broadcastRight => Some(RightSide)
      case LeftOuter if broadcastRight => Some(RightSide)
      case RightOuter if broadcastLeft => Some(LeftSide)
      case _ => None
    }

    if (broadcastSide.isEmpty) {
      return Nil
    }

    if (spatialPredicate == SpatialPredicate.KNN) {
      {
        // validate the k value for KNN join
        val kValue: Int = distance.get.eval().asInstanceOf[Int]
        require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.")

        val leftShape = children.head
        val rightShape = children.tail.head

        val querySide = matchExpressionsToPlans(leftShape, rightShape, left, right) match {
          case Some((_, _, false)) =>
            LeftSide
          case Some((_, _, true)) =>
            RightSide
          case None =>
            Nil
        }
        val objectSidePlan = if (querySide == LeftSide) right else left

        checkObjectPlanFilterPushdown(objectSidePlan)

        if (querySide == broadcastSide.get) {
          // broadcast is on query side
          return BroadcastQuerySideKNNJoinExec(
            planLater(left),
            planLater(right),
            leftShape,
            rightShape,
            broadcastSide.get,
            joinType,
            k = distance.get,
            useApproximate = false,
            spatialPredicate,
            isGeography,
            condition = null,
            extraCondition = None) :: Nil
        } else {
          // broadcast is on object side
          return BroadcastObjectSideKNNJoinExec(
            planLater(left),
            planLater(right),
            leftShape,
            rightShape,
            broadcastSide.get,
            joinType,
            k = distance.get,
            useApproximate = false,
            spatialPredicate,
            isGeography,
            condition = null,
            extraCondition = None) :: Nil
        }
      }
    }

    val a = children.head
    val b = children.tail.head
    val isRasterPredicate =
      a.dataType.isInstanceOf[RasterUDT] || b.dataType.isInstanceOf[RasterUDT]

    val relationship =
      (distance, spatialPredicate, isGeography, extraCondition, isRasterPredicate) match {
        case (Some(_), SpatialPredicate.INTERSECTS, false, Some(ST_DWithin(Seq(_*))), false) =>
          "ST_DWithin"
        case (Some(_), SpatialPredicate.INTERSECTS, false, _, false) => "ST_Distance <="
        case (Some(_), _, false, _, false) => "ST_Distance <"
        case (Some(_), SpatialPredicate.INTERSECTS, true, Some(ST_DWithin(Seq(_*))), false) =>
          "ST_DWithin(useSpheroid = true)"
        case (Some(_), SpatialPredicate.INTERSECTS, true, _, false) =>
          "ST_Distance (Geography) <="
        case (Some(_), _, true, _, false) => "ST_Distance (Geography) <"
        case (None, _, false, _, false) => s"ST_$spatialPredicate"
        case (None, _, false, _, true) => s"RS_$spatialPredicate"
      }
    val (distanceOnIndexSide, distanceOnStreamSide) = distance
      .map { distanceExpr =>
        matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
          case Some(side) =>
            if (broadcastSide.get == side) (Some(distanceExpr), None)
            else if (distanceExpr.references.isEmpty) (Some(distanceExpr), None)
            else (None, Some(distanceExpr))
          case _ =>
            throw new IllegalArgumentException(
              "Distance expression must be bound to one side of the join")
        }
      }
      .getOrElse((None, None))

    matchExpressionsToPlans(a, b, left, right) match {
      case Some((_, _, swapped)) =>
        logInfo(s"Planning spatial join for $relationship relationship")
        val (leftPlan, rightPlan, streamShape, windowSide) = (broadcastSide.get, swapped) match {
          case (LeftSide, false) => // Broadcast the left side, windows on the left
            (
              SpatialIndexExec(
                planLater(left),
                a,
                indexType,
                isRasterPredicate,
                isGeography,
                distanceOnIndexSide),
              planLater(right),
              b,
              LeftSide)
          case (LeftSide, true) => // Broadcast the left side, objects on the left
            (
              SpatialIndexExec(
                planLater(left),
                b,
                indexType,
                isRasterPredicate,
                isGeography,
                distanceOnIndexSide),
              planLater(right),
              a,
              RightSide)
          case (RightSide, false) => // Broadcast the right side, windows on the left
            (
              planLater(left),
              SpatialIndexExec(
                planLater(right),
                b,
                indexType,
                isRasterPredicate,
                isGeography,
                distanceOnIndexSide),
              a,
              LeftSide)
          case (RightSide, true) => // Broadcast the right side, objects on the left
            (
              planLater(left),
              SpatialIndexExec(
                planLater(right),
                a,
                indexType,
                isRasterPredicate,
                isGeography,
                distanceOnIndexSide),
              b,
              RightSide)
        }
        BroadcastIndexJoinExec(
          leftPlan,
          rightPlan,
          streamShape,
          broadcastSide.get,
          windowSide,
          joinType,
          spatialPredicate,
          extraCondition,
          distanceOnStreamSide) :: Nil
      case None =>
        logInfo(
          s"Spatial join for $relationship with arguments not aligned " +
            "with join relations is not supported")
        Nil
    }
  }