in spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/TraitJoinQueryExec.scala [44:144]
override protected def doExecute(): RDD[InternalRow] = {
val boundLeftShape = BindReferences.bindReference(leftShape, left.output)
val boundRightShape = BindReferences.bindReference(rightShape, right.output)
val leftResultsRaw = left.execute().asInstanceOf[RDD[UnsafeRow]]
val rightResultsRaw = right.execute().asInstanceOf[RDD[UnsafeRow]]
val sedonaConf = SedonaConf.fromActiveSession
val (leftShapes, rightShapes) =
toSpatialRddPair(leftResultsRaw, boundLeftShape, rightResultsRaw, boundRightShape)
// Only do SpatialRDD analyze when the user doesn't know approximate total count of the spatial partitioning
// dominant side rdd
if (sedonaConf.getJoinApproximateTotalCount == -1) {
if (sedonaConf.getJoinSpartitionDominantSide == JoinSpartitionDominantSide.LEFT) {
leftShapes.analyze()
} else {
rightShapes.analyze()
}
}
log.info("[SedonaSQL] Number of partitions on the left: " + leftResultsRaw.partitions.size)
log.info("[SedonaSQL] Number of partitions on the right: " + rightResultsRaw.partitions.size)
var numPartitions = -1
try {
if (sedonaConf.getJoinSpartitionDominantSide == JoinSpartitionDominantSide.LEFT) {
if (sedonaConf.getFallbackPartitionNum != -1) {
numPartitions = sedonaConf.getFallbackPartitionNum
} else {
numPartitions = joinPartitionNumOptimizer(
leftShapes.rawSpatialRDD.partitions.size(),
rightShapes.rawSpatialRDD.partitions.size(),
leftShapes.approximateTotalCount)
}
doSpatialPartitioning(leftShapes, rightShapes, numPartitions, sedonaConf)
} else {
if (sedonaConf.getFallbackPartitionNum != -1) {
numPartitions = sedonaConf.getFallbackPartitionNum
} else {
numPartitions = joinPartitionNumOptimizer(
rightShapes.rawSpatialRDD.partitions.size(),
leftShapes.rawSpatialRDD.partitions.size(),
rightShapes.approximateTotalCount)
}
doSpatialPartitioning(rightShapes, leftShapes, numPartitions, sedonaConf)
}
} catch {
case e: IllegalArgumentException => {
print(e.getMessage)
// Partition number are not qualified
// Use fallback num partitions specified in SedonaConf
if (sedonaConf.getJoinSpartitionDominantSide == JoinSpartitionDominantSide.LEFT) {
numPartitions = sedonaConf.getFallbackPartitionNum
doSpatialPartitioning(leftShapes, rightShapes, numPartitions, sedonaConf)
} else {
numPartitions = sedonaConf.getFallbackPartitionNum
doSpatialPartitioning(rightShapes, leftShapes, numPartitions, sedonaConf)
}
}
}
val joinParams = new JoinParams(
sedonaConf.getUseIndex,
spatialPredicate,
sedonaConf.getIndexType,
sedonaConf.getJoinBuildSide)
// logInfo(s"leftShape count ${leftShapes.spatialPartitionedRDD.count()}")
// logInfo(s"rightShape count ${rightShapes.spatialPartitionedRDD.count()}")
val matchesRDD: RDD[(Geometry, Geometry)] =
(leftShapes.spatialPartitionedRDD, rightShapes.spatialPartitionedRDD) match {
case (null, null) =>
// Dominant side is empty, skipped creating partitioned RDDs. Result of join should also be empty.
sparkContext.parallelize(Seq[(Geometry, Geometry)]())
case _ => JoinQuery.spatialJoin(leftShapes, rightShapes, joinParams).rdd
}
logDebug(s"Join result has ${matchesRDD.count()} rows")
matchesRDD.mapPartitions { iter =>
val joinRow = {
val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
(l: UnsafeRow, r: UnsafeRow) => joiner.join(l, r)
}
val joined = iter.map { case (l, r) =>
val leftRow = l.getUserData.asInstanceOf[UnsafeRow]
val rightRow = r.getUserData.asInstanceOf[UnsafeRow]
joinRow(leftRow, rightRow)
}
extraCondition match {
case Some(condition) =>
val boundCondition = Predicate.create(condition, output)
joined.filter(row => boundCondition.eval(row))
case None => joined
}
}
}