in spark/src/main/scala/org/apache/spark/sql/comet/operators.scala [189:312]
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
serializedPlanOpt.plan match {
case None =>
// This is in the middle of a native execution, it should not be executed directly.
throw new CometRuntimeException(
s"CometNativeExec should not be executed directly without a serialized plan: $this")
case Some(serializedPlan) =>
// Switch to use Decimal128 regardless of precision, since Arrow native execution
// doesn't support Decimal32 and Decimal64 yet.
SQLConf.get.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")
val serializedPlanCopy = serializedPlan
// TODO: support native metrics for all operators.
val nativeMetrics = CometMetricNode.fromCometPlan(this)
def createCometExecIter(
inputs: Seq[Iterator[ColumnarBatch]],
numParts: Int,
partitionIndex: Int): CometExecIterator = {
val it = new CometExecIterator(
CometExec.newIterId,
inputs,
output.length,
serializedPlanCopy,
nativeMetrics,
numParts,
partitionIndex)
setSubqueries(it.id, this)
Option(TaskContext.get()).foreach { context =>
context.addTaskCompletionListener[Unit] { _ =>
it.close()
cleanSubqueries(it.id, this)
}
}
it
}
// Collect the input ColumnarBatches from the child operators and create a CometExecIterator
// to execute the native plan.
val sparkPlans = ArrayBuffer.empty[SparkPlan]
val inputs = ArrayBuffer.empty[RDD[ColumnarBatch]]
foreachUntilCometInput(this)(sparkPlans += _)
// Find the first non broadcast plan
val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find {
case (_: CometBroadcastExchangeExec, _) => false
case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), _) => false
case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => false
case _ => true
}
val containsBroadcastInput = sparkPlans.exists {
case _: CometBroadcastExchangeExec => true
case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => true
case _ => false
}
// If the first non broadcast plan is not found, it means all the plans are broadcast plans.
// This is not expected, so throw an exception.
if (containsBroadcastInput && firstNonBroadcastPlan.isEmpty) {
throw new CometRuntimeException(s"Cannot find the first non broadcast plan: $this")
}
// If the first non broadcast plan is found, we need to adjust the partition number of
// the broadcast plans to make sure they have the same partition number as the first non
// broadcast plan.
val (firstNonBroadcastPlanRDD, firstNonBroadcastPlanNumPartitions) =
firstNonBroadcastPlan.get._1 match {
case plan: CometNativeExec =>
(null, plan.outputPartitioning.numPartitions)
case plan =>
val rdd = plan.executeColumnar()
(rdd, rdd.getNumPartitions)
}
// Spark doesn't need to zip Broadcast RDDs, so it doesn't schedule Broadcast RDDs with
// same partition number. But for Comet, we need to zip them so we need to adjust the
// partition number of Broadcast RDDs to make sure they have the same partition number.
sparkPlans.zipWithIndex.foreach { case (plan, idx) =>
plan match {
case c: CometBroadcastExchangeExec =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case BroadcastQueryStageExec(
_,
ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
_) =>
inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
case _: CometNativeExec =>
// no-op
case _ if idx == firstNonBroadcastPlan.get._2 =>
inputs += firstNonBroadcastPlanRDD
case _ =>
val rdd = plan.executeColumnar()
if (rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) {
throw new CometRuntimeException(
s"Partition number mismatch: ${rdd.getNumPartitions} != " +
s"$firstNonBroadcastPlanNumPartitions")
} else {
inputs += rdd
}
}
}
if (inputs.isEmpty && !sparkPlans.forall(_.isInstanceOf[CometNativeExec])) {
throw new CometRuntimeException(s"No input for CometNativeExec:\n $this")
}
if (inputs.nonEmpty) {
ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter)
} else {
val partitionNum = firstNonBroadcastPlanNumPartitions
CometExecRDD(sparkContext, partitionNum)(createCometExecIter)
}
}
}