def apply()

in spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala [179:506]


  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint))
        if optimizationEnabled(left, right, condition) =>
      var broadcastLeft = leftHint.exists(_.strategy.contains(BROADCAST))
      var broadcastRight = rightHint.exists(_.strategy.contains(BROADCAST))

      /*
       * If either side is small we can automatically broadcast just like Spark does.
       * This only applies to inner joins as there are no optimized fallback plan for other join types.
       * It's better that users are explicit about broadcasting for other join types than seeing wildly different behavior
       * depending on data size.
       */
      if (!broadcastLeft && !broadcastRight && joinType == Inner) {
        val canAutoBroadCastLeft = canAutoBroadcastBySize(left)
        val canAutoBroadCastRight = canAutoBroadcastBySize(right)
        if (canAutoBroadCastLeft && canAutoBroadCastRight) {
          // Both sides can be broadcast. Choose the smallest side.
          broadcastLeft = left.stats.sizeInBytes <= right.stats.sizeInBytes
          broadcastRight = !broadcastLeft
        } else {
          broadcastLeft = canAutoBroadCastLeft
          broadcastRight = canAutoBroadCastRight
        }
      }

      // Check if the filters in the plans are supported
      checkPlanFilters(left)
      checkPlanFilters(right)

      val joinConditionMatcher = OptimizableJoinCondition(left, right)
      val queryDetection: Option[JoinQueryDetection] = condition.flatMap {
        case joinConditionMatcher(predicate, extraCondition) =>
          predicate match {
            case pred: ST_Predicate =>
              getJoinDetection(left, right, pred, extraCondition)
            case pred: RS_Predicate =>
              getRasterJoinDetection(left, right, pred, extraCondition)
            case ST_DWithin(Seq(leftShape, rightShape, distance)) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  isGeography = false,
                  condition,
                  Some(distance)))
            case ST_DWithin(Seq(leftShape, rightShape, distance, useSpheroid)) =>
              val useSpheroidUnwrapped = useSpheroid.eval().asInstanceOf[Boolean]
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  isGeography = useSpheroidUnwrapped,
                  condition,
                  Some(distance)))

            // For distance joins we execute the actual predicate (condition) and not only extraConditions.
            // ST_Distance
            case LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), distance) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  false,
                  condition,
                  Some(distance)))
            case LessThan(ST_Distance(Seq(leftShape, rightShape)), distance) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  false,
                  condition,
                  Some(distance)))

            // ST_DistanceSphere
            case LessThanOrEqual(ST_DistanceSphere(Seq(leftShape, rightShape)), distance) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  true,
                  condition,
                  Some(distance)))
            case LessThan(ST_DistanceSphere(Seq(leftShape, rightShape)), distance) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  true,
                  condition,
                  Some(distance)))

            // ST_DistanceSpheroid
            case LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  true,
                  condition,
                  Some(distance)))
            case LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)), distance) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  true,
                  condition,
                  Some(distance)))

            // ST_HausdorffDistance
            case LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape, rightShape)), distance) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  false,
                  condition,
                  Some(distance)))
            case LessThan(ST_HausdorffDistance(Seq(leftShape, rightShape)), distance) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  false,
                  condition,
                  Some(distance)))
            case LessThanOrEqual(
                  ST_HausdorffDistance(Seq(leftShape, rightShape, densityFrac)),
                  distance) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  false,
                  condition,
                  Some(distance)))
            case LessThan(
                  ST_HausdorffDistance(Seq(leftShape, rightShape, densityFrac)),
                  distance) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  false,
                  condition,
                  Some(distance)))

            // ST_FrechetDistance
            case LessThanOrEqual(ST_FrechetDistance(Seq(leftShape, rightShape)), distance) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  false,
                  condition,
                  Some(distance)))
            case LessThan(ST_FrechetDistance(Seq(leftShape, rightShape)), distance) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  SpatialPredicate.INTERSECTS,
                  false,
                  condition,
                  Some(distance)))

            // ST_KNN
            case ST_KNN(Seq(leftShape, rightShape, k)) =>
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  spatialPredicate = SpatialPredicate.KNN,
                  isGeography = false,
                  condition,
                  Some(k)))

            case ST_KNN(Seq(leftShape, rightShape, k, useSpheroid)) =>
              val useSpheroidUnwrapped = useSpheroid.eval().asInstanceOf[Boolean]
              Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  spatialPredicate = SpatialPredicate.KNN,
                  isGeography = useSpheroidUnwrapped,
                  condition,
                  Some(k)))

            case _ => None
          }
        case _ => None
      }

      val sedonaConf = new SedonaConf(sparkSession.conf)

      if ((broadcastLeft || broadcastRight) && sedonaConf.getUseIndex) {
        queryDetection match {
          case Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  spatialPredicate,
                  isGeography,
                  extraCondition,
                  distance)) =>
            planBroadcastJoin(
              left,
              right,
              Seq(leftShape, rightShape),
              joinType,
              spatialPredicate,
              sedonaConf.getIndexType,
              broadcastLeft,
              broadcastRight,
              isGeography,
              extraCondition,
              distance)
          case _ =>
            Nil
        }
      } else {
        queryDetection match {
          case Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  spatialPredicate,
                  isGeography,
                  extraCondition,
                  None)) =>
            planSpatialJoin(
              left,
              right,
              Seq(leftShape, rightShape),
              joinType,
              spatialPredicate,
              extraCondition)
          case Some(
                JoinQueryDetection(
                  left,
                  right,
                  leftShape,
                  rightShape,
                  spatialPredicate,
                  isGeography,
                  extraCondition,
                  Some(distance))) =>
            Option(spatialPredicate) match {
              case Some(SpatialPredicate.KNN) =>
                planKNNJoin(
                  left,
                  right,
                  Seq(leftShape, rightShape),
                  joinType,
                  distance,
                  isGeography,
                  condition.get,
                  extraCondition)
              case Some(predicate) =>
                planDistanceJoin(
                  left,
                  right,
                  Seq(leftShape, rightShape),
                  joinType,
                  distance,
                  spatialPredicate,
                  isGeography,
                  extraCondition)
              case None =>
                Nil
            }
          case None =>
            Nil
        }
      }
    case _ =>
      Nil
  }