in spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala [2277:2811]
def operator2Proto(op: SparkPlan, childOp: Operator*): Option[Operator] = {
val conf = op.conf
val result = OperatorOuterClass.Operator.newBuilder().setPlanId(op.id)
childOp.foreach(result.addChildren)
op match {
// Fully native scan for V1
case scan: CometScanExec
if CometConf.COMET_NATIVE_SCAN_IMPL.get(conf) == CometConf.SCAN_NATIVE_DATAFUSION =>
val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder()
nativeScanBuilder.setSource(op.simpleStringWithNodeId())
val scanTypes = op.output.flatten { attr =>
serializeDataType(attr.dataType)
}
if (scanTypes.length == op.output.length) {
nativeScanBuilder.addAllFields(scanTypes.asJava)
// Sink operators don't have children
result.clearChildren()
if (conf.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED)) {
// TODO remove flatMap and add error handling for unsupported data filters
val dataFilters = scan.dataFilters.flatMap(exprToProto(_, scan.output))
nativeScanBuilder.addAllDataFilters(dataFilters.asJava)
}
// TODO: modify CometNativeScan to generate the file partitions without instantiating RDD.
scan.inputRDD match {
case rdd: DataSourceRDD =>
val partitions = rdd.partitions
partitions.foreach(p => {
val inputPartitions = p.asInstanceOf[DataSourceRDDPartition].inputPartitions
inputPartitions.foreach(partition => {
partition2Proto(
partition.asInstanceOf[FilePartition],
nativeScanBuilder,
scan.relation.partitionSchema)
})
})
case rdd: FileScanRDD =>
rdd.filePartitions.foreach(partition => {
partition2Proto(partition, nativeScanBuilder, scan.relation.partitionSchema)
})
case _ =>
}
val partitionSchema = schema2Proto(scan.relation.partitionSchema.fields)
val requiredSchema = schema2Proto(scan.requiredSchema.fields)
val dataSchema = schema2Proto(scan.relation.dataSchema.fields)
val data_schema_idxs = scan.requiredSchema.fields.map(field => {
scan.relation.dataSchema.fieldIndex(field.name)
})
val partition_schema_idxs = Array
.range(
scan.relation.dataSchema.fields.length,
scan.relation.dataSchema.length + scan.relation.partitionSchema.fields.length)
val projection_vector = (data_schema_idxs ++ partition_schema_idxs).map(idx =>
idx.toLong.asInstanceOf[java.lang.Long])
nativeScanBuilder.addAllProjectionVector(projection_vector.toIterable.asJava)
// In `CometScanRule`, we ensure partitionSchema is supported.
assert(partitionSchema.length == scan.relation.partitionSchema.fields.length)
nativeScanBuilder.addAllDataSchema(dataSchema.toIterable.asJava)
nativeScanBuilder.addAllRequiredSchema(requiredSchema.toIterable.asJava)
nativeScanBuilder.addAllPartitionSchema(partitionSchema.toIterable.asJava)
nativeScanBuilder.setSessionTimezone(conf.getConfString("spark.sql.session.timeZone"))
Some(result.setNativeScan(nativeScanBuilder).build())
} else {
// There are unsupported scan type
val msg =
s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above"
emitWarning(msg)
withInfo(op, msg)
None
}
case ProjectExec(projectList, child) if CometConf.COMET_EXEC_PROJECT_ENABLED.get(conf) =>
val exprs = projectList.map(exprToProto(_, child.output))
if (exprs.forall(_.isDefined) && childOp.nonEmpty) {
val projectBuilder = OperatorOuterClass.Projection
.newBuilder()
.addAllProjectList(exprs.map(_.get).asJava)
Some(result.setProjection(projectBuilder).build())
} else {
withInfo(op, projectList: _*)
None
}
case FilterExec(condition, child) if CometConf.COMET_EXEC_FILTER_ENABLED.get(conf) =>
val cond = exprToProto(condition, child.output)
if (cond.isDefined && childOp.nonEmpty) {
val filterBuilder = OperatorOuterClass.Filter
.newBuilder()
.setPredicate(cond.get)
.setUseDatafusionFilter(
CometConf.COMET_NATIVE_SCAN_IMPL.get() == CometConf.SCAN_NATIVE_DATAFUSION ||
CometConf.COMET_NATIVE_SCAN_IMPL.get() == CometConf.SCAN_NATIVE_ICEBERG_COMPAT)
Some(result.setFilter(filterBuilder).build())
} else {
withInfo(op, condition, child)
None
}
case SortExec(sortOrder, _, child, _) if CometConf.COMET_EXEC_SORT_ENABLED.get(conf) =>
if (!supportedSortType(op, sortOrder)) {
return None
}
val sortOrders = sortOrder.map(exprToProto(_, child.output))
if (sortOrders.forall(_.isDefined) && childOp.nonEmpty) {
val sortBuilder = OperatorOuterClass.Sort
.newBuilder()
.addAllSortOrders(sortOrders.map(_.get).asJava)
Some(result.setSort(sortBuilder).build())
} else {
withInfo(op, "sort order not supported", sortOrder: _*)
None
}
case LocalLimitExec(limit, _) if CometConf.COMET_EXEC_LOCAL_LIMIT_ENABLED.get(conf) =>
if (childOp.nonEmpty) {
// LocalLimit doesn't use offset, but it shares same operator serde class.
// Just set it to zero.
val limitBuilder = OperatorOuterClass.Limit
.newBuilder()
.setLimit(limit)
.setOffset(0)
Some(result.setLimit(limitBuilder).build())
} else {
withInfo(op, "No child operator")
None
}
case globalLimitExec: GlobalLimitExec
if CometConf.COMET_EXEC_GLOBAL_LIMIT_ENABLED.get(conf) =>
// TODO: We don't support negative limit for now.
if (childOp.nonEmpty && globalLimitExec.limit >= 0) {
val limitBuilder = OperatorOuterClass.Limit.newBuilder()
// TODO: Spark 3.3 might have negative limit (-1) for Offset usage.
// When we upgrade to Spark 3.3., we need to address it here.
limitBuilder.setLimit(globalLimitExec.limit)
Some(result.setLimit(limitBuilder).build())
} else {
withInfo(op, "No child operator")
None
}
case ExpandExec(projections, _, child) if CometConf.COMET_EXEC_EXPAND_ENABLED.get(conf) =>
var allProjExprs: Seq[Expression] = Seq()
val projExprs = projections.flatMap(_.map(e => {
allProjExprs = allProjExprs :+ e
exprToProto(e, child.output)
}))
if (projExprs.forall(_.isDefined) && childOp.nonEmpty) {
val expandBuilder = OperatorOuterClass.Expand
.newBuilder()
.addAllProjectList(projExprs.map(_.get).asJava)
.setNumExprPerProject(projections.head.size)
Some(result.setExpand(expandBuilder).build())
} else {
withInfo(op, allProjExprs: _*)
None
}
case WindowExec(windowExpression, partitionSpec, orderSpec, child)
if CometConf.COMET_EXEC_WINDOW_ENABLED.get(conf) =>
val output = child.output
val winExprs: Array[WindowExpression] = windowExpression.flatMap { expr =>
expr match {
case alias: Alias =>
alias.child match {
case winExpr: WindowExpression =>
Some(winExpr)
case _ =>
None
}
case _ =>
None
}
}.toArray
if (winExprs.length != windowExpression.length) {
withInfo(op, "Unsupported window expression(s)")
return None
}
if (partitionSpec.nonEmpty && orderSpec.nonEmpty &&
!validatePartitionAndSortSpecsForWindowFunc(partitionSpec, orderSpec, op)) {
return None
}
val windowExprProto = winExprs.map(windowExprToProto(_, output, op.conf))
val partitionExprs = partitionSpec.map(exprToProto(_, child.output))
val sortOrders = orderSpec.map(exprToProto(_, child.output))
if (windowExprProto.forall(_.isDefined) && partitionExprs.forall(_.isDefined)
&& sortOrders.forall(_.isDefined)) {
val windowBuilder = OperatorOuterClass.Window.newBuilder()
windowBuilder.addAllWindowExpr(windowExprProto.map(_.get).toIterable.asJava)
windowBuilder.addAllPartitionByList(partitionExprs.map(_.get).asJava)
windowBuilder.addAllOrderByList(sortOrders.map(_.get).asJava)
Some(result.setWindow(windowBuilder).build())
} else {
None
}
case aggregate: BaseAggregateExec
if (aggregate.isInstanceOf[HashAggregateExec] ||
aggregate.isInstanceOf[ObjectHashAggregateExec]) &&
CometConf.COMET_EXEC_AGGREGATE_ENABLED.get(conf) =>
val groupingExpressions = aggregate.groupingExpressions
val aggregateExpressions = aggregate.aggregateExpressions
val aggregateAttributes = aggregate.aggregateAttributes
val resultExpressions = aggregate.resultExpressions
val child = aggregate.child
if (groupingExpressions.isEmpty && aggregateExpressions.isEmpty) {
withInfo(op, "No group by or aggregation")
return None
}
// Aggregate expressions with filter are not supported yet.
if (aggregateExpressions.exists(_.filter.isDefined)) {
withInfo(op, "Aggregate expression with filter is not supported")
return None
}
val groupingExprs = groupingExpressions.map(exprToProto(_, child.output))
if (groupingExprs.exists(_.isEmpty)) {
withInfo(op, "Not all grouping expressions are supported")
return None
}
// In some of the cases, the aggregateExpressions could be empty.
// For example, if the aggregate functions only have group by or if the aggregate
// functions only have distinct aggregate functions:
//
// SELECT COUNT(distinct col2), col1 FROM test group by col1
// +- HashAggregate (keys =[col1# 6], functions =[count (distinct col2#7)] )
// +- Exchange hashpartitioning (col1#6, 10), ENSURE_REQUIREMENTS, [plan_id = 36]
// +- HashAggregate (keys =[col1#6], functions =[partial_count (distinct col2#7)] )
// +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
// +- Exchange hashpartitioning (col1#6, col2#7, 10), ENSURE_REQUIREMENTS, ...
// +- HashAggregate (keys =[col1#6, col2#7], functions =[] )
// +- FileScan parquet spark_catalog.default.test[col1#6, col2#7] ......
// If the aggregateExpressions is empty, we only want to build groupingExpressions,
// and skip processing of aggregateExpressions.
if (aggregateExpressions.isEmpty) {
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
val resultExprs = resultExpressions.map(exprToProto(_, attributes))
if (resultExprs.exists(_.isEmpty)) {
val msg = s"Unsupported result expressions found in: ${resultExpressions}"
emitWarning(msg)
withInfo(op, msg, resultExpressions: _*)
return None
}
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
Some(result.setHashAgg(hashAggBuilder).build())
} else {
val modes = aggregateExpressions.map(_.mode).distinct
if (modes.size != 1) {
// This shouldn't happen as all aggregation expressions should share the same mode.
// Fallback to Spark nevertheless here.
withInfo(op, "All aggregate expressions do not have the same mode")
return None
}
val mode = modes.head match {
case Partial => CometAggregateMode.Partial
case Final => CometAggregateMode.Final
case _ =>
withInfo(op, s"Unsupported aggregation mode ${modes.head}")
return None
}
// In final mode, the aggregate expressions are bound to the output of the
// child and partial aggregate expressions buffer attributes produced by partial
// aggregation. This is done in Spark `HashAggregateExec` internally. In Comet,
// we don't have to do this because we don't use the merging expression.
val binding = mode != CometAggregateMode.Final
// `output` is only used when `binding` is true (i.e., non-Final)
val output = child.output
val aggExprs =
aggregateExpressions.map(aggExprToProto(_, output, binding, op.conf))
if (childOp.nonEmpty && groupingExprs.forall(_.isDefined) &&
aggExprs.forall(_.isDefined)) {
val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder()
hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava)
hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava)
if (mode == CometAggregateMode.Final) {
val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes
val resultExprs = resultExpressions.map(exprToProto(_, attributes))
if (resultExprs.exists(_.isEmpty)) {
val msg = s"Unsupported result expressions found in: ${resultExpressions}"
emitWarning(msg)
withInfo(op, msg, resultExpressions: _*)
return None
}
hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava)
}
hashAggBuilder.setModeValue(mode.getNumber)
Some(result.setHashAgg(hashAggBuilder).build())
} else {
val allChildren: Seq[Expression] =
groupingExpressions ++ aggregateExpressions ++ aggregateAttributes
withInfo(op, allChildren: _*)
None
}
}
case join: HashJoin =>
// `HashJoin` has only two implementations in Spark, but we check the type of the join to
// make sure we are handling the correct join type.
if (!(CometConf.COMET_EXEC_HASH_JOIN_ENABLED.get(conf) &&
join.isInstanceOf[ShuffledHashJoinExec]) &&
!(CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.get(conf) &&
join.isInstanceOf[BroadcastHashJoinExec])) {
withInfo(join, s"Invalid hash join type ${join.nodeName}")
return None
}
if (join.buildSide == BuildRight && join.joinType == LeftAnti) {
withInfo(join, "BuildRight with LeftAnti is not supported")
return None
}
val condition = join.condition.map { cond =>
val condProto = exprToProto(cond, join.left.output ++ join.right.output)
if (condProto.isEmpty) {
withInfo(join, cond)
return None
}
condProto.get
}
val joinType = join.joinType match {
case Inner => JoinType.Inner
case LeftOuter => JoinType.LeftOuter
case RightOuter => JoinType.RightOuter
case FullOuter => JoinType.FullOuter
case LeftSemi => JoinType.LeftSemi
case LeftAnti => JoinType.LeftAnti
case _ =>
// Spark doesn't support other join types
withInfo(join, s"Unsupported join type ${join.joinType}")
return None
}
val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))
if (leftKeys.forall(_.isDefined) &&
rightKeys.forall(_.isDefined) &&
childOp.nonEmpty) {
val joinBuilder = OperatorOuterClass.HashJoin
.newBuilder()
.setJoinType(joinType)
.addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
.addAllRightJoinKeys(rightKeys.map(_.get).asJava)
.setBuildSide(
if (join.buildSide == BuildLeft) BuildSide.BuildLeft else BuildSide.BuildRight)
condition.foreach(joinBuilder.setCondition)
Some(result.setHashJoin(joinBuilder).build())
} else {
val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
withInfo(join, allExprs: _*)
None
}
case join: SortMergeJoinExec if CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) =>
// `requiredOrders` and `getKeyOrdering` are copied from Spark's SortMergeJoinExec.
def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
keys.map(SortOrder(_, Ascending))
}
def getKeyOrdering(
keys: Seq[Expression],
childOutputOrdering: Seq[SortOrder]): Seq[SortOrder] = {
val requiredOrdering = requiredOrders(keys)
if (SortOrder.orderingSatisfies(childOutputOrdering, requiredOrdering)) {
keys.zip(childOutputOrdering).map { case (key, childOrder) =>
val sameOrderExpressionsSet = ExpressionSet(childOrder.children) - key
SortOrder(key, Ascending, sameOrderExpressionsSet.toSeq)
}
} else {
requiredOrdering
}
}
if (join.condition.isDefined &&
!CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED
.get(conf)) {
withInfo(
join,
s"${CometConf.COMET_EXEC_SORT_MERGE_JOIN_WITH_JOIN_FILTER_ENABLED.key} is not enabled",
join.condition.get)
return None
}
val condition = join.condition.map { cond =>
val condProto = exprToProto(cond, join.left.output ++ join.right.output)
if (condProto.isEmpty) {
withInfo(join, cond)
return None
}
condProto.get
}
val joinType = join.joinType match {
case Inner => JoinType.Inner
case LeftOuter => JoinType.LeftOuter
case RightOuter => JoinType.RightOuter
case FullOuter => JoinType.FullOuter
case LeftSemi => JoinType.LeftSemi
case LeftAnti => JoinType.LeftAnti
case _ =>
// Spark doesn't support other join types
withInfo(op, s"Unsupported join type ${join.joinType}")
return None
}
// Checks if the join keys are supported by DataFusion SortMergeJoin.
val errorMsgs = join.leftKeys.flatMap { key =>
if (!supportedSortMergeJoinEqualType(key.dataType)) {
Some(s"Unsupported join key type ${key.dataType} on key: ${key.sql}")
} else {
None
}
}
if (errorMsgs.nonEmpty) {
withInfo(op, errorMsgs.flatten.mkString("\n"))
return None
}
val leftKeys = join.leftKeys.map(exprToProto(_, join.left.output))
val rightKeys = join.rightKeys.map(exprToProto(_, join.right.output))
val sortOptions = getKeyOrdering(join.leftKeys, join.left.outputOrdering)
.map(exprToProto(_, join.left.output))
if (sortOptions.forall(_.isDefined) &&
leftKeys.forall(_.isDefined) &&
rightKeys.forall(_.isDefined) &&
childOp.nonEmpty) {
val joinBuilder = OperatorOuterClass.SortMergeJoin
.newBuilder()
.setJoinType(joinType)
.addAllSortOptions(sortOptions.map(_.get).asJava)
.addAllLeftJoinKeys(leftKeys.map(_.get).asJava)
.addAllRightJoinKeys(rightKeys.map(_.get).asJava)
condition.map(joinBuilder.setCondition)
Some(result.setSortMergeJoin(joinBuilder).build())
} else {
val allExprs: Seq[Expression] = join.leftKeys ++ join.rightKeys
withInfo(join, allExprs: _*)
None
}
case join: SortMergeJoinExec if !CometConf.COMET_EXEC_SORT_MERGE_JOIN_ENABLED.get(conf) =>
withInfo(join, "SortMergeJoin is not enabled")
None
case op
if isCometSink(op) && op.output.forall(a =>
supportedDataType(
a.dataType,
// Complex type supported if
// - Native datafusion reader enabled (experimental) OR
// - conversion from Parquet/JSON enabled
allowComplex =
usingDataFusionParquetExec(conf) || CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED
.get(conf) || CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf))) =>
// These operators are source of Comet native execution chain
val scanBuilder = OperatorOuterClass.Scan.newBuilder()
val source = op.simpleStringWithNodeId()
if (source.isEmpty) {
scanBuilder.setSource(op.getClass.getSimpleName)
} else {
scanBuilder.setSource(source)
}
val scanTypes = op.output.flatten { attr =>
serializeDataType(attr.dataType)
}
if (scanTypes.length == op.output.length) {
scanBuilder.addAllFields(scanTypes.asJava)
// Sink operators don't have children
result.clearChildren()
Some(result.setScan(scanBuilder).build())
} else {
// There are unsupported scan type
val msg =
s"unsupported Comet operator: ${op.nodeName}, due to unsupported data types above"
emitWarning(msg)
withInfo(op, msg)
None
}
case op =>
// Emit warning if:
// 1. it is not Spark shuffle operator, which is handled separately
// 2. it is not a Comet operator
if (!op.nodeName.contains("Comet") && !op.isInstanceOf[ShuffleExchangeExec]) {
val msg = s"unsupported Spark operator: ${op.nodeName}"
emitWarning(msg)
withInfo(op, msg)
}
None
}
}