private def replaceWithExpressionTransformerInternal()

in gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala [98:515]


  private def replaceWithExpressionTransformerInternal(
      expr: Expression,
      attributeSeq: Seq[Attribute],
      expressionsMap: Map[Class[_], String]): ExpressionTransformer = {
    logDebug(
      s"replaceWithExpressionTransformer expr: $expr class: ${expr.getClass} " +
        s"name: ${expr.prettyName}")

    expr match {
      case p: PythonUDF =>
        return replacePythonUDFWithExpressionTransformer(p, attributeSeq, expressionsMap)
      case s: ScalaUDF =>
        return replaceScalaUDFWithExpressionTransformer(s, attributeSeq, expressionsMap)
      case _ if HiveSimpleUDFTransformer.isHiveSimpleUDF(expr) =>
        return HiveSimpleUDFTransformer.replaceWithExpressionTransformer(expr, attributeSeq)
      case _ =>
    }

    TestStats.addExpressionClassName(expr.getClass.getName)
    // Check whether Gluten supports this expression
    val substraitExprNameOpt = expressionsMap.get(expr.getClass)
    if (substraitExprNameOpt.isEmpty) {
      throw new UnsupportedOperationException(s"Not supported: $expr. ${expr.getClass}")
    }
    val substraitExprName = substraitExprNameOpt.get

    // Check whether each backend supports this expression
    if (!BackendsApiManager.getValidatorApiInstance.doExprValidate(substraitExprName, expr)) {
      throw new UnsupportedOperationException(s"Not supported: $expr.")
    }
    expr match {
      case extendedExpr
          if ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
            extendedExpr.getClass) =>
        // Use extended expression transformer to replace custom expression first
        ExpressionMappings.expressionExtensionTransformer
          .replaceWithExtensionExpressionTransformer(substraitExprName, extendedExpr, attributeSeq)
      case c: CreateArray =>
        val children =
          c.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap))
        CreateArrayTransformer(substraitExprName, children, true, c)
      case g: GetArrayItem =>
        GetArrayItemTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(g.left, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(g.right, attributeSeq, expressionsMap),
          g.failOnError,
          g
        )
      case c: CreateMap =>
        val children =
          c.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap))
        CreateMapTransformer(substraitExprName, children, c.useStringTypeWhenEmpty, c)
      case g: GetMapValue =>
        BackendsApiManager.getSparkPlanExecApiInstance.genGetMapValueTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(g.child, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(g.key, attributeSeq, expressionsMap),
          g
        )
      case e: Explode =>
        ExplodeTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(e.child, attributeSeq, expressionsMap),
          e)
      case p: PosExplode =>
        BackendsApiManager.getSparkPlanExecApiInstance.genPosExplodeTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(p.child, attributeSeq, expressionsMap),
          p,
          attributeSeq)
      case a: Alias =>
        BackendsApiManager.getSparkPlanExecApiInstance.genAliasTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(a.child, attributeSeq, expressionsMap),
          a)
      case a: AttributeReference =>
        if (attributeSeq == null) {
          throw new UnsupportedOperationException(s"attributeSeq should not be null.")
        }
        try {
          val bindReference =
            BindReferences.bindReference(expr, attributeSeq, allowFailures = false)
          val b = bindReference.asInstanceOf[BoundReference]
          AttributeReferenceTransformer(
            a.name,
            b.ordinal,
            a.dataType,
            b.nullable,
            a.exprId,
            a.qualifier,
            a.metadata)
        } catch {
          case e: IllegalStateException =>
            // This situation may need developers to fix, although we just throw the below
            // exception to let the corresponding operator fall back.
            throw new UnsupportedOperationException(
              s"Failed to bind reference for $expr: ${e.getMessage}")
        }
      case b: BoundReference =>
        BoundReferenceTransformer(b.ordinal, b.dataType, b.nullable)
      case l: Literal =>
        LiteralTransformer(l)
      case d: DateDiff =>
        DateDiffTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(d.endDate, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(d.startDate, attributeSeq, expressionsMap),
          d
        )
      case t: ToUnixTimestamp =>
        BackendsApiManager.getSparkPlanExecApiInstance.genUnixTimestampTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(t.timeExp, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(t.format, attributeSeq, expressionsMap),
          t
        )
      case u: UnixTimestamp =>
        BackendsApiManager.getSparkPlanExecApiInstance.genUnixTimestampTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(u.timeExp, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(u.format, attributeSeq, expressionsMap),
          ToUnixTimestamp(u.timeExp, u.format, u.timeZoneId, u.failOnError)
        )
      case t: TruncTimestamp =>
        BackendsApiManager.getSparkPlanExecApiInstance.genTruncTimestampTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(t.format, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(t.timestamp, attributeSeq, expressionsMap),
          t.timeZoneId,
          t
        )
      case m: MonthsBetween =>
        MonthsBetweenTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(m.date1, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(m.date2, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(m.roundOff, attributeSeq, expressionsMap),
          m.timeZoneId,
          m
        )
      case i: If =>
        IfTransformer(
          replaceWithExpressionTransformerInternal(i.predicate, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(i.trueValue, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(i.falseValue, attributeSeq, expressionsMap),
          i
        )
      case cw: CaseWhen =>
        CaseWhenTransformer(
          cw.branches.map {
            expr =>
              {
                (
                  replaceWithExpressionTransformerInternal(expr._1, attributeSeq, expressionsMap),
                  replaceWithExpressionTransformerInternal(expr._2, attributeSeq, expressionsMap))
              }
          },
          cw.elseValue.map {
            expr =>
              {
                replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap)
              }
          },
          cw
        )
      case i: In =>
        InTransformer(
          replaceWithExpressionTransformerInternal(i.value, attributeSeq, expressionsMap),
          i.list,
          i.value.dataType,
          i)
      case i: InSet =>
        InSetTransformer(
          replaceWithExpressionTransformerInternal(i.child, attributeSeq, expressionsMap),
          i.hset,
          i.child.dataType,
          i)
      case s: org.apache.spark.sql.execution.ScalarSubquery =>
        ScalarSubqueryTransformer(s.plan, s.exprId, s)
      case c: Cast =>
        // Add trim node, as necessary.
        val newCast =
          BackendsApiManager.getSparkPlanExecApiInstance.genCastWithNewChild(c)
        CastTransformer(
          replaceWithExpressionTransformerInternal(newCast.child, attributeSeq, expressionsMap),
          newCast.dataType,
          newCast.timeZoneId,
          newCast)
      case s: String2TrimExpression =>
        val (srcStr, trimStr) = s match {
          case StringTrim(srcStr, trimStr) => (srcStr, trimStr)
          case StringTrimLeft(srcStr, trimStr) => (srcStr, trimStr)
          case StringTrimRight(srcStr, trimStr) => (srcStr, trimStr)
        }
        String2TrimExpressionTransformer(
          substraitExprName,
          trimStr.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
          replaceWithExpressionTransformerInternal(srcStr, attributeSeq, expressionsMap),
          s
        )
      case m: HashExpression[_] =>
        BackendsApiManager.getSparkPlanExecApiInstance.genHashExpressionTransformer(
          substraitExprName,
          m.children.map(
            expr => replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap)),
          m)
      case getStructField: GetStructField =>
        // Different backends may have different result.
        BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(
            getStructField.child,
            attributeSeq,
            expressionsMap),
          getStructField.ordinal,
          getStructField)
      case getArrayStructFields: GetArrayStructFields =>
        GetArrayStructFieldsTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(
            getArrayStructFields.child,
            attributeSeq,
            expressionsMap),
          getArrayStructFields.ordinal,
          getArrayStructFields.numFields,
          getArrayStructFields.containsNull,
          getArrayStructFields
        )
      case t: StringTranslate =>
        BackendsApiManager.getSparkPlanExecApiInstance.genStringTranslateTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(t.srcExpr, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(t.matchingExpr, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(t.replaceExpr, attributeSeq, expressionsMap),
          t
        )
      case l: StringLocate =>
        BackendsApiManager.getSparkPlanExecApiInstance.genStringLocateTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(l.first, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(l.second, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(l.third, attributeSeq, expressionsMap),
          l
        )
      case s: StringSplit =>
        BackendsApiManager.getSparkPlanExecApiInstance.genStringSplitTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(s.str, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(s.regex, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(s.limit, attributeSeq, expressionsMap),
          s
        )
      case r: RegExpReplace =>
        RegExpReplaceTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(r.subject, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(r.regexp, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(r.rep, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(r.pos, attributeSeq, expressionsMap),
          r
        )
      case equal: EqualNullSafe =>
        BackendsApiManager.getSparkPlanExecApiInstance.genEqualNullSafeTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(equal.left, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(equal.right, attributeSeq, expressionsMap),
          equal
        )
      case md5: Md5 =>
        BackendsApiManager.getSparkPlanExecApiInstance.genMd5Transformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(md5.child, attributeSeq, expressionsMap),
          md5)
      case sha1: Sha1 =>
        BackendsApiManager.getSparkPlanExecApiInstance.genSha1Transformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(sha1.child, attributeSeq, expressionsMap),
          sha1)
      case sha2: Sha2 =>
        BackendsApiManager.getSparkPlanExecApiInstance.genSha2Transformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(sha2.left, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(sha2.right, attributeSeq, expressionsMap),
          sha2
        )
      case size: Size =>
        BackendsApiManager.getSparkPlanExecApiInstance.genSizeExpressionTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(size.child, attributeSeq, expressionsMap),
          size)
      case namedStruct: CreateNamedStruct =>
        BackendsApiManager.getSparkPlanExecApiInstance.genNamedStructTransformer(
          substraitExprName,
          namedStruct.children.map(
            replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
          namedStruct,
          attributeSeq)
      case namedLambdaVariable: NamedLambdaVariable =>
        NamedLambdaVariableTransformer(
          substraitExprName,
          name = namedLambdaVariable.name,
          dataType = namedLambdaVariable.dataType,
          nullable = namedLambdaVariable.nullable,
          exprId = namedLambdaVariable.exprId
        )
      case lambdaFunction: LambdaFunction =>
        LambdaFunctionTransformer(
          substraitExprName,
          function = replaceWithExpressionTransformerInternal(
            lambdaFunction.function,
            attributeSeq,
            expressionsMap),
          arguments = lambdaFunction.arguments.map(
            replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
          hidden = false,
          original = lambdaFunction
        )
      case j: JsonTuple =>
        val children =
          j.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap))
        JsonTupleExpressionTransformer(substraitExprName, children, j)
      case l: Like =>
        LikeTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(l.left, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(l.right, attributeSeq, expressionsMap),
          l
        )
      case c: CheckOverflow =>
        CheckOverflowTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap),
          c)
      case m: MakeDecimal =>
        MakeDecimalTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(m.child, attributeSeq, expressionsMap),
          m)
      case rand: Rand =>
        BackendsApiManager.getSparkPlanExecApiInstance.genRandTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(rand.child, attributeSeq, expressionsMap),
          rand)
      case _: KnownFloatingPointNormalized | _: NormalizeNaNAndZero | _: PromotePrecision =>
        ChildTransformer(
          replaceWithExpressionTransformerInternal(expr.children.head, attributeSeq, expressionsMap)
        )
      case _: GetDateField | _: GetTimeField =>
        ExtractDateTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(
            expr.children.head,
            attributeSeq,
            expressionsMap),
          expr)
      case _: StringToMap =>
        BackendsApiManager.getSparkPlanExecApiInstance.genStringToMapTransformer(
          substraitExprName,
          expr.children.map(
            replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
          expr)
      case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) =>
        // PrecisionLoss=true: velox support / ch not support
        // PrecisionLoss=false: velox not support / ch support
        // TODO ch support PrecisionLoss=true
        if (!BackendsApiManager.getSettings.allowDecimalArithmetic) {
          throw new UnsupportedOperationException(
            s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} " +
              s"${conf.decimalOperationsAllowPrecisionLoss} mode")
        }
        val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) {
          DecimalArithmeticUtil.rescaleLiteral(b)
        } else {
          b
        }
        val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
          DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
          DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
        val leftChild = replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap)
        val rightChild =
          replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap)

        val resultType = DecimalArithmeticUtil.getResultTypeForOperation(
          DecimalArithmeticUtil.getOperationType(b),
          DecimalArithmeticUtil
            .getResultType(leftChild)
            .getOrElse(left.dataType.asInstanceOf[DecimalType]),
          DecimalArithmeticUtil
            .getResultType(rightChild)
            .getOrElse(right.dataType.asInstanceOf[DecimalType])
        )
        DecimalArithmeticExpressionTransformer(
          substraitExprName,
          leftChild,
          rightChild,
          resultType,
          b)
      case n: NaNvl =>
        BackendsApiManager.getSparkPlanExecApiInstance.genNaNvlTransformer(
          substraitExprName,
          replaceWithExpressionTransformerInternal(n.left, attributeSeq, expressionsMap),
          replaceWithExpressionTransformerInternal(n.right, attributeSeq, expressionsMap),
          n
        )
      case e: Transformable =>
        val childrenTransformers =
          e.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap))
        e.getTransformer(childrenTransformers)
      case expr =>
        GenericExpressionTransformer(
          substraitExprName,
          expr.children.map(
            replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
          expr
        )
    }
  }