in spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala [696:880]
private def planBroadcastJoin(
left: LogicalPlan,
right: LogicalPlan,
children: Seq[Expression],
joinType: JoinType,
spatialPredicate: SpatialPredicate,
indexType: IndexType,
broadcastLeft: Boolean,
broadcastRight: Boolean,
isGeography: Boolean,
extraCondition: Option[Expression],
distance: Option[Expression]): Seq[SparkPlan] = {
val broadcastSide = joinType match {
case Inner if broadcastLeft => Some(LeftSide)
case Inner if broadcastRight => Some(RightSide)
case LeftSemi if broadcastRight => Some(RightSide)
case LeftAnti if broadcastRight => Some(RightSide)
case LeftOuter if broadcastRight => Some(RightSide)
case RightOuter if broadcastLeft => Some(LeftSide)
case _ => None
}
if (broadcastSide.isEmpty) {
return Nil
}
if (spatialPredicate == SpatialPredicate.KNN) {
{
// validate the k value for KNN join
val kValue: Int = distance.get.eval().asInstanceOf[Int]
require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.")
val leftShape = children.head
val rightShape = children.tail.head
val querySide = matchExpressionsToPlans(leftShape, rightShape, left, right) match {
case Some((_, _, false)) =>
LeftSide
case Some((_, _, true)) =>
RightSide
case None =>
Nil
}
val objectSidePlan = if (querySide == LeftSide) right else left
checkObjectPlanFilterPushdown(objectSidePlan)
if (querySide == broadcastSide.get) {
// broadcast is on query side
return BroadcastQuerySideKNNJoinExec(
planLater(left),
planLater(right),
leftShape,
rightShape,
broadcastSide.get,
joinType,
k = distance.get,
useApproximate = false,
spatialPredicate,
isGeography,
condition = null,
extraCondition = None) :: Nil
} else {
// broadcast is on object side
return BroadcastObjectSideKNNJoinExec(
planLater(left),
planLater(right),
leftShape,
rightShape,
broadcastSide.get,
joinType,
k = distance.get,
useApproximate = false,
spatialPredicate,
isGeography,
condition = null,
extraCondition = None) :: Nil
}
}
}
val a = children.head
val b = children.tail.head
val isRasterPredicate =
a.dataType.isInstanceOf[RasterUDT] || b.dataType.isInstanceOf[RasterUDT]
val relationship =
(distance, spatialPredicate, isGeography, extraCondition, isRasterPredicate) match {
case (Some(_), SpatialPredicate.INTERSECTS, false, Some(ST_DWithin(Seq(_*))), false) =>
"ST_DWithin"
case (Some(_), SpatialPredicate.INTERSECTS, false, _, false) => "ST_Distance <="
case (Some(_), _, false, _, false) => "ST_Distance <"
case (Some(_), SpatialPredicate.INTERSECTS, true, Some(ST_DWithin(Seq(_*))), false) =>
"ST_DWithin(useSpheroid = true)"
case (Some(_), SpatialPredicate.INTERSECTS, true, _, false) =>
"ST_Distance (Geography) <="
case (Some(_), _, true, _, false) => "ST_Distance (Geography) <"
case (None, _, false, _, false) => s"ST_$spatialPredicate"
case (None, _, false, _, true) => s"RS_$spatialPredicate"
}
val (distanceOnIndexSide, distanceOnStreamSide) = distance
.map { distanceExpr =>
matchDistanceExpressionToJoinSide(distanceExpr, left, right) match {
case Some(side) =>
if (broadcastSide.get == side) (Some(distanceExpr), None)
else if (distanceExpr.references.isEmpty) (Some(distanceExpr), None)
else (None, Some(distanceExpr))
case _ =>
throw new IllegalArgumentException(
"Distance expression must be bound to one side of the join")
}
}
.getOrElse((None, None))
matchExpressionsToPlans(a, b, left, right) match {
case Some((_, _, swapped)) =>
logInfo(s"Planning spatial join for $relationship relationship")
val (leftPlan, rightPlan, streamShape, windowSide) = (broadcastSide.get, swapped) match {
case (LeftSide, false) => // Broadcast the left side, windows on the left
(
SpatialIndexExec(
planLater(left),
a,
indexType,
isRasterPredicate,
isGeography,
distanceOnIndexSide),
planLater(right),
b,
LeftSide)
case (LeftSide, true) => // Broadcast the left side, objects on the left
(
SpatialIndexExec(
planLater(left),
b,
indexType,
isRasterPredicate,
isGeography,
distanceOnIndexSide),
planLater(right),
a,
RightSide)
case (RightSide, false) => // Broadcast the right side, windows on the left
(
planLater(left),
SpatialIndexExec(
planLater(right),
b,
indexType,
isRasterPredicate,
isGeography,
distanceOnIndexSide),
a,
LeftSide)
case (RightSide, true) => // Broadcast the right side, objects on the left
(
planLater(left),
SpatialIndexExec(
planLater(right),
a,
indexType,
isRasterPredicate,
isGeography,
distanceOnIndexSide),
b,
RightSide)
}
BroadcastIndexJoinExec(
leftPlan,
rightPlan,
streamShape,
broadcastSide.get,
windowSide,
joinType,
spatialPredicate,
extraCondition,
distanceOnStreamSide) :: Nil
case None =>
logInfo(
s"Spatial join for $relationship with arguments not aligned " +
"with join relations is not supported")
Nil
}
}