def operator2Proto()

in spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala [2277:2811]


  def operator2Proto(op: SparkPlan, childOp: Operator*): Option[Operator] = {
    val conf = op.conf
    val result = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
    childOp.foreach(result.addChildren)

    op match {

      // Fully native scan for V1
      case scan: CometScanExec
          if CometConf.COMET_NATIVE_SCAN_IMPL.get(conf) == CometConf.SCAN_NATIVE_DATAFUSION =>
        val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder()
        nativeScanBuilder.setSource(op.simpleStringWithNodeId())

        val scanTypes = op.output.flatten { attr =>
          serializeDataType(attr.dataType)
        }

        if (scanTypes.length == op.output.length) {
          nativeScanBuilder.addAllFields(scanTypes.asJava)

          // Sink operators don't have children
          result.clearChildren()

          if (conf.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED)) {
            // TODO remove flatMap and add error handling for unsupported data filters
            val dataFilters = scan.dataFilters.flatMap(exprToProto(_, scan.output))
            nativeScanBuilder.addAllDataFilters(dataFilters.asJava)
          }

          // TODO: modify CometNativeScan to generate the file partitions without instantiating RDD.
          scan.inputRDD match {
            case rdd: DataSourceRDD =>
              val partitions = rdd.partitions
              partitions.foreach(p => {
                val inputPartitions = p.asInstanceOf[DataSourceRDDPartition].inputPartitions
                inputPartitions.foreach(partition => {
                  partition2Proto(
                    partition.asInstanceOf[FilePartition],
                    nativeScanBuilder,
                    scan.relation.partitionSchema)
                })
              })
            case rdd: FileScanRDD =>
              rdd.filePartitions.foreach(partition => {
                partition2Proto(partition, nativeScanBuilder, scan.relation.partitionSchema)
              })
            case _ =>
          }

          val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields)
          val requiredSchema = schema2Proto(scan.requiredSchema.fields)
          val dataSchema = schema2Proto(scan.relation.dataSchema.fields)

          val data_schema_idxs = scan.requiredSchema.fields.map(field => {
            scan.relation.dataSchema.fieldIndex(field.name)
          })
          val partition_schema_idxs = Array
            .range(
              scan.relation.dataSchema.fields.length,
              scan.relation.dataSchema.length + scan.relation.partitionSchema.fields.length)

          val projection_vector = (data_schema_idxs ++ partition_schema_idxs).map(idx =>
            idx.toLong.asInstanceOf[java.lang.Long])

          nativeScanBuilder.addAllProjectionVector(projection_vector.toIterable.asJava)

          // In `CometScanRule`, we ensure partitionSchema is supported.
          assert(partitionSchema.length == scan.relation.partitionSchema.fields.length)

          nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava)
          nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava)
          nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)
          nativeScanBuilder.setSessionTimezone(conf.getConfString("spark.sql.session.timeZone"))

          Some(result.setNativeScan(nativeScanBuilder).build())

        } else {
          // There are unsupported scan type
          val msg =
            s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above"
          emitWarning(msg)
          withInfo(op, msg)
          None
        }

      case ProjectExec(projectList, child) if CometConf.COMET_EXEC_PROJECT_ENABLED.get(conf) =>
        val exprs = projectList.map(exprToProto(_, child.output))

        if (exprs.forall(_.isDefined) && childOp.nonEmpty) {
          val projectBuilder = OperatorOuterClass.Projection
            .newBuilder()
            .addAllProjectList(exprs.map(_.get).asJava)
          Some(result.setProjection(projectBuilder).build())
        } else {
          withInfo(op, projectList: _*)
          None
        }

      case FilterExec(condition, child) if CometConf.COMET_EXEC_FILTER_ENABLED.get(conf) =>
        val cond = exprToProto(condition, child.output)

        if (cond.isDefined && childOp.nonEmpty) {
          val filterBuilder = OperatorOuterClass.Filter
            .newBuilder()
            .setPredicate(cond.get)
            .setUseDatafusionFilter(
              CometConf.COMET_NATIVE_SCAN_IMPL.get() == CometConf.SCAN_NATIVE_DATAFUSION ||
                CometConf.COMET_NATIVE_SCAN_IMPL.get() == CometConf.SCAN_NATIVE_ICEBERG_COMPAT)
          Some(result.setFilter(filterBuilder).build())
        } else {
          withInfo(op, condition, child)
          None
        }

      case SortExec(sortOrder, _, child, _) if CometConf.COMET_EXEC_SORT_ENABLED.get(conf) =>
        if (!supportedSortType(op, sortOrder)) {
          return None
        }

        val sortOrders = sortOrder.map(exprToProto(_, child.output))

        if (sortOrders.forall(_.isDefined) && childOp.nonEmpty) {
          val sortBuilder = OperatorOuterClass.Sort
            .newBuilder()
            .addAllSortOrders(sortOrders.map(_.get).asJava)
          Some(result.setSort(sortBuilder).build())
        } else {
          withInfo(op, "sort order not supported", sortOrder: _*)
          None
        }

      case LocalLimitExec(limit, _) if CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED.get(conf) =>
        if (childOp.nonEmpty) {
          // LocalLimit doesn't use offset, but it shares same operator serde class.
          // Just set it to zero.
          val limitBuilder = OperatorOuterClass.Limit
            .newBuilder()
            .setLimit(limit)
            .setOffset(0)
          Some(result.setLimit(limitBuilder).build())
        } else {
          withInfo(op, "No child operator")
          None
        }

      case globalLimitExec: GlobalLimitExec
          if CometConf.COMET_EXEC_GLOBAL_LIMIT_ENABLED.get(conf) =>
        // TODO: We don't support negative limit for now.
        if (childOp.nonEmpty && globalLimitExec.limit >= 0) {
          val limitBuilder = OperatorOuterClass.Limit.newBuilder()

          // TODO: Spark 3.3 might have negative limit (-1) for Offset usage.
          // When we upgrade to Spark 3.3., we need to address it here.
          limitBuilder.setLimit(globalLimitExec.limit)

          Some(result.setLimit(limitBuilder).build())
        } else {
          withInfo(op, "No child operator")
          None
        }

      case ExpandExec(projections, _, child) if CometConf.COMET_EXEC_EXPAND_ENABLED.get(conf) =>
        var allProjExprs: Seq[Expression] = Seq()
        val projExprs = projections.flatMap(_.map(e => {
          allProjExprs = allProjExprs :+ e
          exprToProto(e, child.output)
        }))

        if (projExprs.forall(_.isDefined) && childOp.nonEmpty) {
          val expandBuilder = OperatorOuterClass.Expand
            .newBuilder()
            .addAllProjectList(projExprs.map(_.get).asJava)
            .setNumExprPerProject(projections.head.size)
          Some(result.setExpand(expandBuilder).build())
        } else {
          withInfo(op, allProjExprs: _*)
          None
        }

      case WindowExec(windowExpression, partitionSpec, orderSpec, child)
          if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) =>
        val output = child.output

        val winExprs: Array[WindowExpression] = windowExpression.flatMap { expr =>
          expr match {
            case alias: Alias =>
              alias.child match {
                case winExpr: WindowExpression =>
                  Some(winExpr)
                case _ =>
                  None
              }
            case _ =>
              None
          }
        }.toArray

        if (winExprs.length != windowExpression.length) {
          withInfo(op, "Unsupported window expression(s)")
          return None
        }

        if (partitionSpec.nonEmpty && orderSpec.nonEmpty &&
          !validatePartitionAndSortSpecsForWindowFunc(partitionSpec, orderSpec, op)) {
          return None
        }

        val windowExprProto = winExprs.map(windowExprToProto(_, output, op.conf))
        val partitionExprs = partitionSpec.map(exprToProto(_, child.output))

        val sortOrders = orderSpec.map(exprToProto(_, child.output))

        if (windowExprProto.forall(_.isDefined) && partitionExprs.forall(_.isDefined)
          && sortOrders.forall(_.isDefined)) {
          val windowBuilder = OperatorOuterClass.Window.newBuilder()
          windowBuilder.addAllWindowExpr(windowExprProto.map(_.get).toIterable.asJava)
          windowBuilder.addAllPartitionByList(partitionExprs.map(_.get).asJava)
          windowBuilder.addAllOrderByList(sortOrders.map(_.get).asJava)
          Some(result.setWindow(windowBuilder).build())
        } else {
          None
        }

      case aggregate: BaseAggregateExec
          if (aggregate.isInstanceOf[HashAggregateExec] ||
            aggregate.isInstanceOf[ObjectHashAggregateExec]) &&
            CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) =>
        val groupingExpressions = aggregate.groupingExpressions
        val aggregateExpressions = aggregate.aggregateExpressions
        val aggregateAttributes = aggregate.aggregateAttributes
        val resultExpressions = aggregate.resultExpressions
        val child = aggregate.child

        if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
          withInfo(op, "No group by or aggregation")
          return None
        }

        // Aggregate expressions with filter are not supported yet.
        if (aggregateExpressions.exists(_.filter.isDefined)) {
          withInfo(op, "Aggregate expression with filter is not supported")
          return None
        }

        val groupingExprs = groupingExpressions.map(exprToProto(_, child.output))
        if (groupingExprs.exists(_.isEmpty)) {
          withInfo(op, "Not all grouping expressions are supported")
          return None
        }

        // In some of the cases, the aggregateExpressions could be empty.
        // For example, if the aggregate functions only have group by or if the aggregate
        // functions only have distinct aggregate functions:
        //
        // SELECT COUNT(distinct col2), col1 FROM test group by col1
        //  +- HashAggregate (keys =[col1# 6], functions =[count (distinct col2#7)] )
        //    +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, [plan_id = 36]
        //      +- HashAggregate (keys =[col1#6], functions =[partial_count (distinct col2#7)] )
        //        +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
        //          +- Exchange hashpartitioning (col1#6, col2#7, 10), ENSURE_REQUIREMENTS, ...
        //            +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
        //              +- FileScan parquet spark_catalog.default.test[col1#6, col2#7] ......
        // If the aggregateExpressions is empty, we only want to build groupingExpressions,
        // and skip processing of aggregateExpressions.
        if (aggregateExpressions.isEmpty) {
          val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
          hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
          val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
          val resultExprs = resultExpressions.map(exprToProto(_, attributes))
          if (resultExprs.exists(_.isEmpty)) {
            val msg = s"Unsupported result expressions found in: ${resultExpressions}"
            emitWarning(msg)
            withInfo(op, msg, resultExpressions: _*)
            return None
          }
          hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
          Some(result.setHashAgg(hashAggBuilder).build())
        } else {
          val modes = aggregateExpressions.map(_.mode).distinct

          if (modes.size != 1) {
            // This shouldn't happen as all aggregation expressions should share the same mode.
            // Fallback to Spark nevertheless here.
            withInfo(op, "All aggregate expressions do not have the same mode")
            return None
          }

          val mode = modes.head match {
            case Partial => CometAggregateMode.Partial
            case Final => CometAggregateMode.Final
            case _ =>
              withInfo(op, s"Unsupported aggregation mode ${modes.head}")
              return None
          }

          // In final mode, the aggregate expressions are bound to the output of the
          // child and partial aggregate expressions buffer attributes produced by partial
          // aggregation. This is done in Spark `HashAggregateExec` internally. In Comet,
          // we don't have to do this because we don't use the merging expression.
          val binding = mode != CometAggregateMode.Final
          // `output` is only used when `binding` is true (i.e., non-Final)
          val output = child.output

          val aggExprs =
            aggregateExpressions.map(aggExprToProto(_, output, binding, op.conf))
          if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
            aggExprs.forall(_.isDefined)) {
            val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
            hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
            hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
            if (mode == CometAggregateMode.Final) {
              val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
              val resultExprs = resultExpressions.map(exprToProto(_, attributes))
              if (resultExprs.exists(_.isEmpty)) {
                val msg = s"Unsupported result expressions found in: ${resultExpressions}"
                emitWarning(msg)
                withInfo(op, msg, resultExpressions: _*)
                return None
              }
              hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
            }
            hashAggBuilder.setModeValue(mode.getNumber)
            Some(result.setHashAgg(hashAggBuilder).build())
          } else {
            val allChildren: Seq[Expression] =
              groupingExpressions ++ aggregateExpressions ++ aggregateAttributes
            withInfo(op, allChildren: _*)
            None
          }
        }

      case join: HashJoin =>
        // `HashJoin` has only two implementations in Spark, but we check the type of the join to
        // make sure we are handling the correct join type.
        if (!(CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
            join.isInstanceOf[ShuffledHashJoinExec]) &&
          !(CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) &&
            join.isInstanceOf[BroadcastHashJoinExec])) {
          withInfo(join, s"Invalid hash join type ${join.nodeName}")
          return None
        }

        if (join.buildSide == BuildRight && join.joinType == LeftAnti) {
          withInfo(join, "BuildRight with LeftAnti is not supported")
          return None
        }

        val condition = join.condition.map { cond =>
          val condProto = exprToProto(cond, join.left.output ++ join.right.output)
          if (condProto.isEmpty) {
            withInfo(join, cond)
            return None
          }
          condProto.get
        }

        val joinType = join.joinType match {
          case Inner => JoinType.Inner
          case LeftOuter => JoinType.LeftOuter
          case RightOuter => JoinType.RightOuter
          case FullOuter => JoinType.FullOuter
          case LeftSemi => JoinType.LeftSemi
          case LeftAnti => JoinType.LeftAnti
          case _ =>
            // Spark doesn't support other join types
            withInfo(join, s"Unsupported join type ${join.joinType}")
            return None
        }

        val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
        val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))

        if (leftKeys.forall(_.isDefined) &&
          rightKeys.forall(_.isDefined) &&
          childOp.nonEmpty) {
          val joinBuilder = OperatorOuterClass.HashJoin
            .newBuilder()
            .setJoinType(joinType)
            .addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
            .addAllRightJoinKeys(rightKeys.map(_.get).asJava)
            .setBuildSide(
              if (join.buildSide == BuildLeft) BuildSide.BuildLeft else BuildSide.BuildRight)
          condition.foreach(joinBuilder.setCondition)
          Some(result.setHashJoin(joinBuilder).build())
        } else {
          val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
          withInfo(join, allExprs: _*)
          None
        }

      case join: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) =>
        // `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec.
        def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
          keys.map(SortOrder(_, Ascending))
        }

        def getKeyOrdering(
            keys: Seq[Expression],
            childOutputOrdering: Seq[SortOrder]): Seq[SortOrder] = {
          val requiredOrdering = requiredOrders(keys)
          if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) {
            keys.zip(childOutputOrdering).map { case (key, childOrder) =>
              val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key
              SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq)
            }
          } else {
            requiredOrdering
          }
        }

        if (join.condition.isDefined &&
          !CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED
            .get(conf)) {
          withInfo(
            join,
            s"${CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key} is not enabled",
            join.condition.get)
          return None
        }

        val condition = join.condition.map { cond =>
          val condProto = exprToProto(cond, join.left.output ++ join.right.output)
          if (condProto.isEmpty) {
            withInfo(join, cond)
            return None
          }
          condProto.get
        }

        val joinType = join.joinType match {
          case Inner => JoinType.Inner
          case LeftOuter => JoinType.LeftOuter
          case RightOuter => JoinType.RightOuter
          case FullOuter => JoinType.FullOuter
          case LeftSemi => JoinType.LeftSemi
          case LeftAnti => JoinType.LeftAnti
          case _ =>
            // Spark doesn't support other join types
            withInfo(op, s"Unsupported join type ${join.joinType}")
            return None
        }

        // Checks if the join keys are supported by DataFusion SortMergeJoin.
        val errorMsgs = join.leftKeys.flatMap { key =>
          if (!supportedSortMergeJoinEqualType(key.dataType)) {
            Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}")
          } else {
            None
          }
        }

        if (errorMsgs.nonEmpty) {
          withInfo(op, errorMsgs.flatten.mkString("\n"))
          return None
        }

        val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
        val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))

        val sortOptions = getKeyOrdering(join.leftKeys, join.left.outputOrdering)
          .map(exprToProto(_, join.left.output))

        if (sortOptions.forall(_.isDefined) &&
          leftKeys.forall(_.isDefined) &&
          rightKeys.forall(_.isDefined) &&
          childOp.nonEmpty) {
          val joinBuilder = OperatorOuterClass.SortMergeJoin
            .newBuilder()
            .setJoinType(joinType)
            .addAllSortOptions(sortOptions.map(_.get).asJava)
            .addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
            .addAllRightJoinKeys(rightKeys.map(_.get).asJava)
          condition.map(joinBuilder.setCondition)
          Some(result.setSortMergeJoin(joinBuilder).build())
        } else {
          val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
          withInfo(join, allExprs: _*)
          None
        }

      case join: SortMergeJoinExec if !CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) =>
        withInfo(join, "SortMergeJoin is not enabled")
        None

      case op
          if isCometSink(op) && op.output.forall(a =>
            supportedDataType(
              a.dataType,
              // Complex type supported if
              // - Native datafusion reader enabled (experimental) OR
              // - conversion from Parquet/JSON enabled
              allowComplex =
                usingDataFusionParquetExec(conf) || CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED
                  .get(conf) || CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf))) =>
        // These operators are source of Comet native execution chain
        val scanBuilder = OperatorOuterClass.Scan.newBuilder()
        val source = op.simpleStringWithNodeId()
        if (source.isEmpty) {
          scanBuilder.setSource(op.getClass.getSimpleName)
        } else {
          scanBuilder.setSource(source)
        }

        val scanTypes = op.output.flatten { attr =>
          serializeDataType(attr.dataType)
        }

        if (scanTypes.length == op.output.length) {
          scanBuilder.addAllFields(scanTypes.asJava)

          // Sink operators don't have children
          result.clearChildren()

          Some(result.setScan(scanBuilder).build())
        } else {
          // There are unsupported scan type
          val msg =
            s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above"
          emitWarning(msg)
          withInfo(op, msg)
          None
        }

      case op =>
        // Emit warning if:
        //  1. it is not Spark shuffle operator, which is handled separately
        //  2. it is not a Comet operator
        if (!op.nodeName.contains("Comet") && !op.isInstanceOf[ShuffleExchangeExec]) {
          val msg = s"unsupported Spark operator: ${op.nodeName}"
          emitWarning(msg)
          withInfo(op, msg)
        }
        None
    }
  }