private def addTransformableTag()

in gluten-core/src/main/scala/io/glutenproject/extension/columnar/TransformHintRule.scala [382:799]


  private def addTransformableTag(plan: SparkPlan): Unit = {
    if (TransformHints.isAlreadyTagged(plan)) {
      logDebug(
        s"Skip adding transformable tag, since plan already tagged as " +
          s"${TransformHints.getHint(plan)}: ${plan.toString()}")
      return
    }
    try {
      plan match {
        case plan: BatchScanExec =>
          if (!enableColumnarBatchScan) {
            TransformHints.tagNotTransformable(plan, "columnar BatchScan is disabled")
          } else {
            // IF filter expressions aren't empty, we need to transform the inner operators.
            if (plan.runtimeFilters.nonEmpty) {
              TransformHints.tagTransformable(plan)
            } else {
              val transformer =
                ScanTransformerFactory
                  .createBatchScanTransformer(plan, reuseSubquery = false, validation = true)
                  .asInstanceOf[BatchScanExecTransformer]
              TransformHints.tag(plan, transformer.doValidate().toTransformHint)
            }
          }
        case plan: FileSourceScanExec =>
          if (!enableColumnarFileScan) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar FileScan is not enabled in FileSourceScanExec")
          } else {
            // IF filter expressions aren't empty, we need to transform the inner operators.
            if (plan.partitionFilters.nonEmpty) {
              TransformHints.tagTransformable(plan)
            } else {
              val transformer =
                ScanTransformerFactory.createFileSourceScanTransformer(
                  plan,
                  reuseSubquery = false,
                  validation = true)
              TransformHints.tag(plan, transformer.doValidate().toTransformHint)
            }
          }
        case plan if HiveTableScanExecTransformer.isHiveTableScan(plan) =>
          if (!enableColumnarHiveTableScan) {
            TransformHints.tagNotTransformable(plan, "columnar hive table scan is disabled")
          } else {
            TransformHints.tag(plan, HiveTableScanExecTransformer.validate(plan).toTransformHint)
          }
        case plan: ProjectExec =>
          if (!enableColumnarProject) {
            TransformHints.tagNotTransformable(plan, "columnar project is disabled")
          } else {
            val transformer = ProjectExecTransformer(plan.projectList, plan.child)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: FilterExec =>
          val childIsScan = plan.child.isInstanceOf[FileSourceScanExec] ||
            plan.child.isInstanceOf[BatchScanExec]
          if (!enableColumnarFilter) {
            TransformHints.tagNotTransformable(plan, "columnar Filter is not enabled in FilterExec")
          } else if (scanOnly && !childIsScan) {
            // When scanOnly is enabled, filter after scan will be offloaded.
            TransformHints.tagNotTransformable(
              plan,
              "ScanOnly enabled and plan child is not Scan in FilterExec")
          } else {
            val transformer = BackendsApiManager.getSparkPlanExecApiInstance
              .genFilterExecTransformer(plan.condition, plan.child)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: HashAggregateExec =>
          if (!enableColumnarHashAgg) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar HashAggregate is not enabled in HashAggregateExec")
          } else {
            val rewrittenAgg = RewriteMultiChildrenCount.applyForValidation(plan)
            val transformer = BackendsApiManager.getSparkPlanExecApiInstance
              .genHashAggregateExecTransformer(
                rewrittenAgg.requiredChildDistributionExpressions,
                rewrittenAgg.groupingExpressions,
                rewrittenAgg.aggregateExpressions,
                rewrittenAgg.aggregateAttributes,
                rewrittenAgg.initialInputBufferOffset,
                rewrittenAgg.resultExpressions,
                rewrittenAgg.child
              )
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: SortAggregateExec =>
          if (!BackendsApiManager.getSettings.replaceSortAggWithHashAgg) {
            TransformHints.tagNotTransformable(plan, "replaceSortAggWithHashAgg is not enabled")
          }
          if (!enableColumnarHashAgg) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar HashAgg is not enabled in SortAggregateExec")
          }
          val rewrittenAgg = RewriteMultiChildrenCount.applyForValidation(plan)
          val transformer = BackendsApiManager.getSparkPlanExecApiInstance
            .genHashAggregateExecTransformer(
              rewrittenAgg.requiredChildDistributionExpressions,
              rewrittenAgg.groupingExpressions,
              rewrittenAgg.aggregateExpressions,
              rewrittenAgg.aggregateAttributes,
              rewrittenAgg.initialInputBufferOffset,
              rewrittenAgg.resultExpressions,
              rewrittenAgg.child
            )
          TransformHints.tag(plan, transformer.doValidate().toTransformHint)
        case plan: ObjectHashAggregateExec =>
          if (!enableColumnarHashAgg) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar HashAgg is not enabled in ObjectHashAggregateExec")
          } else {
            val rewrittenAgg = RewriteMultiChildrenCount.applyForValidation(plan)
            val transformer = BackendsApiManager.getSparkPlanExecApiInstance
              .genHashAggregateExecTransformer(
                rewrittenAgg.requiredChildDistributionExpressions,
                rewrittenAgg.groupingExpressions,
                rewrittenAgg.aggregateExpressions,
                rewrittenAgg.aggregateAttributes,
                rewrittenAgg.initialInputBufferOffset,
                rewrittenAgg.resultExpressions,
                rewrittenAgg.child
              )
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: UnionExec =>
          if (!enableColumnarUnion) {
            TransformHints.tagNotTransformable(plan, "columnar Union is not enabled in UnionExec")
          } else {
            val transformer = ColumnarUnionExec(plan.children)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: ExpandExec =>
          if (!enableColumnarExpand) {
            TransformHints.tagNotTransformable(plan, "columnar Expand is not enabled in ExpandExec")
          } else {
            val transformer = ExpandExecTransformer(plan.projections, plan.output, plan.child)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }

        case plan: WriteFilesExec =>
          if (!enableColumnarWrite || !BackendsApiManager.getSettings.supportTransformWriteFiles) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar Write is not enabled in WriteFilesExec")
          } else {
            val transformer = WriteFilesExecTransformer(
              plan.child,
              plan.fileFormat,
              plan.partitionColumns,
              plan.bucketSpec,
              plan.options,
              plan.staticPartitions)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: SortExec =>
          if (!enableColumnarSort) {
            TransformHints.tagNotTransformable(plan, "columnar Sort is not enabled in SortExec")
          } else {
            val transformer =
              SortExecTransformer(plan.sortOrder, plan.global, plan.child, plan.testSpillFrequency)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: ShuffleExchangeExec =>
          if (!enableColumnarShuffle) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar Shuffle is not enabled in ShuffleExchangeExec")
          } else {
            val transformer = ColumnarShuffleExchangeExec(
              plan.outputPartitioning,
              plan.child,
              plan.shuffleOrigin,
              plan.child.output)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: ShuffledHashJoinExec =>
          if (!enableColumnarShuffledHashJoin) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar shufflehashjoin is not enabled in ShuffledHashJoinExec")
          } else {
            val transformer = BackendsApiManager.getSparkPlanExecApiInstance
              .genShuffledHashJoinExecTransformer(
                plan.leftKeys,
                plan.rightKeys,
                plan.joinType,
                plan.buildSide,
                plan.condition,
                plan.left,
                plan.right,
                plan.isSkewJoin)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: BroadcastExchangeExec =>
          // columnar broadcast is enabled only when columnar bhj is enabled.
          if (!enableColumnarBroadcastExchange) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar BroadcastExchange is not enabled in BroadcastExchangeExec")
          } else {
            val transformer = ColumnarBroadcastExchangeExec(plan.mode, plan.child)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case bhj: BroadcastHashJoinExec =>
          // FIXME Hongze: In following codes we perform a lot of if-else conditions to
          //  make sure the broadcast exchange and broadcast hash-join are of same type,
          //  either vanilla or columnar. In order to simplify the codes we have to do
          //  some tricks around C2R and R2C to make them adapt to columnar broadcast.
          //  Currently their doBroadcast() methods just propagate child's broadcast
          //  payloads which is not right in speaking of columnar.
          if (!enableColumnarBroadcastJoin) {
            TransformHints.tagNotTransformable(
              bhj,
              "columnar BroadcastJoin is not enabled in BroadcastHashJoinExec")
          } else {
            val isBhjTransformable: ValidationResult = {
              val transformer = BackendsApiManager.getSparkPlanExecApiInstance
                .genBroadcastHashJoinExecTransformer(
                  bhj.leftKeys,
                  bhj.rightKeys,
                  bhj.joinType,
                  bhj.buildSide,
                  bhj.condition,
                  bhj.left,
                  bhj.right,
                  isNullAwareAntiJoin = bhj.isNullAwareAntiJoin)
              transformer.doValidate()
            }
            val buildSidePlan = bhj.buildSide match {
              case BuildLeft => bhj.left
              case BuildRight => bhj.right
            }

            val maybeExchange = buildSidePlan
              .find {
                case BroadcastExchangeExec(_, _) => true
                case _ => false
              }
              .map(_.asInstanceOf[BroadcastExchangeExec])

            maybeExchange match {
              case Some(exchange @ BroadcastExchangeExec(mode, child)) =>
                TransformHints.tag(bhj, isBhjTransformable.toTransformHint)
                if (!isBhjTransformable.isValid) {
                  TransformHints.tagNotTransformable(exchange, isBhjTransformable)
                }
              case None =>
                // we are in AQE, find the hidden exchange
                // FIXME did we consider the case that AQE: OFF && Reuse: ON ?
                var maybeHiddenExchange: Option[BroadcastExchangeLike] = None
                breakable {
                  buildSidePlan.foreach {
                    case e: BroadcastExchangeLike =>
                      maybeHiddenExchange = Some(e)
                      break
                    case t: BroadcastQueryStageExec =>
                      t.plan.foreach {
                        case e2: BroadcastExchangeLike =>
                          maybeHiddenExchange = Some(e2)
                          break
                        case r: ReusedExchangeExec =>
                          r.child match {
                            case e2: BroadcastExchangeLike =>
                              maybeHiddenExchange = Some(e2)
                              break
                            case _ =>
                          }
                        case _ =>
                      }
                    case _ =>
                  }
                }
                // restriction to force the hidden exchange to be found
                val exchange = maybeHiddenExchange.get
                // to conform to the underlying exchange's type, columnar or vanilla
                exchange match {
                  case BroadcastExchangeExec(mode, child) =>
                    TransformHints.tagNotTransformable(
                      bhj,
                      "it's a materialized broadcast exchange or reused broadcast exchange")
                  case ColumnarBroadcastExchangeExec(mode, child) =>
                    if (!isBhjTransformable.isValid) {
                      throw new IllegalStateException(
                        s"BroadcastExchange has already been" +
                          s" transformed to columnar version but BHJ is determined as" +
                          s" non-transformable: ${bhj.toString()}")
                    }
                    TransformHints.tagTransformable(bhj)
                }
            }
          }
        case plan: SortMergeJoinExec =>
          if (!enableColumnarSortMergeJoin || plan.joinType == FullOuter) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar sort merge join is not enabled or join type is FullOuter")
          } else {
            val transformer = SortMergeJoinExecTransformer(
              plan.leftKeys,
              plan.rightKeys,
              plan.joinType,
              plan.condition,
              plan.left,
              plan.right,
              plan.isSkewJoin)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: CartesianProductExec =>
          if (!enableCartesianProduct) {
            TransformHints.tagNotTransformable(
              plan,
              "conversion to CartesianProductTransformer is not enabled.")
          }
          val transformer = CartesianProductExecTransformer(plan.left, plan.right, plan.condition)
          TransformHints.tag(plan, transformer.doValidate().toTransformHint)
        case plan: WindowExec =>
          if (!enableColumnarWindow) {
            TransformHints.tagNotTransformable(plan, "columnar window is not enabled in WindowExec")
          } else {
            val transformer = WindowExecTransformer(
              plan.windowExpression,
              plan.partitionSpec,
              plan.orderSpec,
              plan.child)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: CoalesceExec =>
          if (!enableColumnarCoalesce) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar coalesce is not enabled in CoalesceExec")
          } else {
            val transformer = CoalesceExecTransformer(plan.numPartitions, plan.child)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: GlobalLimitExec =>
          if (!enableColumnarLimit) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar limit is not enabled in GlobalLimitExec")
          } else {
            val transformer = LimitTransformer(plan.child, 0L, plan.limit)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: LocalLimitExec =>
          if (!enableColumnarLimit) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar limit is not enabled in GlobalLimitExec")
          } else {
            val transformer = LimitTransformer(plan.child, 0L, plan.limit)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: GenerateExec =>
          if (!enableColumnarGenerate) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar generate is not enabled in GenerateExec")
          } else {
            val transformer = GenerateExecTransformer(
              plan.generator,
              plan.requiredChildOutput,
              plan.outer,
              plan.generatorOutput,
              plan.child)
            TransformHints.tag(plan, transformer.doValidate().toTransformHint)
          }
        case plan: EvalPythonExec =>
          val transformer = EvalPythonExecTransformer(plan.udfs, plan.resultAttrs, plan.child)
          TransformHints.tag(plan, transformer.doValidate().toTransformHint)
        case _: AQEShuffleReadExec =>
          TransformHints.tagTransformable(plan)
        case plan: TakeOrderedAndProjectExec =>
          if (!enableTakeOrderedAndProject) {
            TransformHints.tagNotTransformable(
              plan,
              "columnar topK is not enabled in TakeOrderedAndProjectExec")
          } else {
            var tagged: ValidationResult = null
            val orderingSatisfies =
              SortOrder.orderingSatisfies(plan.child.outputOrdering, plan.sortOrder)
            if (orderingSatisfies) {
              val limitPlan = LimitTransformer(plan.child, 0, plan.limit)
              tagged = limitPlan.doValidate()
            } else {
              // Here we are validating sort + limit which is a kind of whole stage transformer,
              // because we would call sort.doTransform in limit.
              // So, we should add adapter to make it work.
              val inputTransformer =
                ColumnarCollapseTransformStages.wrapInputIteratorTransformer(plan.child)
              val sortPlan = SortExecTransformer(plan.sortOrder, false, inputTransformer)
              val limitPlan = LimitTransformer(sortPlan, 0, plan.limit)
              tagged = limitPlan.doValidate()
            }

            if (tagged.isValid) {
              val projectPlan = ProjectExecTransformer(plan.projectList, plan.child)
              tagged = projectPlan.doValidate()
            }
            TransformHints.tag(plan, tagged.toTransformHint)
          }
        case _ =>
          // currently we assume a plan to be transformable by default
          TransformHints.tagTransformable(plan)
      }
    } catch {
      case e: UnsupportedOperationException =>
        TransformHints.tagNotTransformable(
          plan,
          s"${e.getMessage}, original sparkplan is " +
            s"${plan.getClass}(${plan.children.toList.map(_.getClass)})")
    }
  }