override def doExecuteColumnar()

in spark/src/main/scala/org/apache/spark/sql/comet/operators.scala [189:312]


  override def doExecuteColumnar(): RDD[ColumnarBatch] = {
    serializedPlanOpt.plan match {
      case None =>
        // This is in the middle of a native execution, it should not be executed directly.
        throw new CometRuntimeException(
          s"CometNativeExec should not be executed directly without a serialized plan: $this")
      case Some(serializedPlan) =>
        // Switch to use Decimal128 regardless of precision, since Arrow native execution
        // doesn't support Decimal32 and Decimal64 yet.
        SQLConf.get.setConfString(CometConf.COMET_USE_DECIMAL_128.key, "true")

        val serializedPlanCopy = serializedPlan
        // TODO: support native metrics for all operators.
        val nativeMetrics = CometMetricNode.fromCometPlan(this)

        def createCometExecIter(
            inputs: Seq[Iterator[ColumnarBatch]],
            numParts: Int,
            partitionIndex: Int): CometExecIterator = {
          val it = new CometExecIterator(
            CometExec.newIterId,
            inputs,
            output.length,
            serializedPlanCopy,
            nativeMetrics,
            numParts,
            partitionIndex)

          setSubqueries(it.id, this)

          Option(TaskContext.get()).foreach { context =>
            context.addTaskCompletionListener[Unit] { _ =>
              it.close()
              cleanSubqueries(it.id, this)
            }
          }

          it
        }

        // Collect the input ColumnarBatches from the child operators and create a CometExecIterator
        // to execute the native plan.
        val sparkPlans = ArrayBuffer.empty[SparkPlan]
        val inputs = ArrayBuffer.empty[RDD[ColumnarBatch]]

        foreachUntilCometInput(this)(sparkPlans += _)

        // Find the first non broadcast plan
        val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find {
          case (_: CometBroadcastExchangeExec, _) => false
          case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), _) => false
          case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => false
          case _ => true
        }

        val containsBroadcastInput = sparkPlans.exists {
          case _: CometBroadcastExchangeExec => true
          case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
          case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => true
          case _ => false
        }

        // If the first non broadcast plan is not found, it means all the plans are broadcast plans.
        // This is not expected, so throw an exception.
        if (containsBroadcastInput && firstNonBroadcastPlan.isEmpty) {
          throw new CometRuntimeException(s"Cannot find the first non broadcast plan: $this")
        }

        // If the first non broadcast plan is found, we need to adjust the partition number of
        // the broadcast plans to make sure they have the same partition number as the first non
        // broadcast plan.
        val (firstNonBroadcastPlanRDD, firstNonBroadcastPlanNumPartitions) =
          firstNonBroadcastPlan.get._1 match {
            case plan: CometNativeExec =>
              (null, plan.outputPartitioning.numPartitions)
            case plan =>
              val rdd = plan.executeColumnar()
              (rdd, rdd.getNumPartitions)
          }

        // Spark doesn't need to zip Broadcast RDDs, so it doesn't schedule Broadcast RDDs with
        // same partition number. But for Comet, we need to zip them so we need to adjust the
        // partition number of Broadcast RDDs to make sure they have the same partition number.
        sparkPlans.zipWithIndex.foreach { case (plan, idx) =>
          plan match {
            case c: CometBroadcastExchangeExec =>
              inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
            case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) =>
              inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
            case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) =>
              inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
            case BroadcastQueryStageExec(
                  _,
                  ReusedExchangeExec(_, c: CometBroadcastExchangeExec),
                  _) =>
              inputs += c.setNumPartitions(firstNonBroadcastPlanNumPartitions).executeColumnar()
            case _: CometNativeExec =>
            // no-op
            case _ if idx == firstNonBroadcastPlan.get._2 =>
              inputs += firstNonBroadcastPlanRDD
            case _ =>
              val rdd = plan.executeColumnar()
              if (rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) {
                throw new CometRuntimeException(
                  s"Partition number mismatch: ${rdd.getNumPartitions} != " +
                    s"$firstNonBroadcastPlanNumPartitions")
              } else {
                inputs += rdd
              }
          }
        }

        if (inputs.isEmpty && !sparkPlans.forall(_.isInstanceOf[CometNativeExec])) {
          throw new CometRuntimeException(s"No input for CometNativeExec:\n $this")
        }

        if (inputs.nonEmpty) {
          ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter)
        } else {
          val partitionNum = firstNonBroadcastPlanNumPartitions
          CometExecRDD(sparkContext, partitionNum)(createCometExecIter)
        }
    }
  }