in gluten-core/src/main/scala/io/glutenproject/expression/ExpressionConverter.scala [98:515]
private def replaceWithExpressionTransformerInternal(
expr: Expression,
attributeSeq: Seq[Attribute],
expressionsMap: Map[Class[_], String]): ExpressionTransformer = {
logDebug(
s"replaceWithExpressionTransformer expr: $expr class: ${expr.getClass} " +
s"name: ${expr.prettyName}")
expr match {
case p: PythonUDF =>
return replacePythonUDFWithExpressionTransformer(p, attributeSeq, expressionsMap)
case s: ScalaUDF =>
return replaceScalaUDFWithExpressionTransformer(s, attributeSeq, expressionsMap)
case _ if HiveSimpleUDFTransformer.isHiveSimpleUDF(expr) =>
return HiveSimpleUDFTransformer.replaceWithExpressionTransformer(expr, attributeSeq)
case _ =>
}
TestStats.addExpressionClassName(expr.getClass.getName)
// Check whether Gluten supports this expression
val substraitExprNameOpt = expressionsMap.get(expr.getClass)
if (substraitExprNameOpt.isEmpty) {
throw new UnsupportedOperationException(s"Not supported: $expr. ${expr.getClass}")
}
val substraitExprName = substraitExprNameOpt.get
// Check whether each backend supports this expression
if (!BackendsApiManager.getValidatorApiInstance.doExprValidate(substraitExprName, expr)) {
throw new UnsupportedOperationException(s"Not supported: $expr.")
}
expr match {
case extendedExpr
if ExpressionMappings.expressionExtensionTransformer.extensionExpressionsMapping.contains(
extendedExpr.getClass) =>
// Use extended expression transformer to replace custom expression first
ExpressionMappings.expressionExtensionTransformer
.replaceWithExtensionExpressionTransformer(substraitExprName, extendedExpr, attributeSeq)
case c: CreateArray =>
val children =
c.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap))
CreateArrayTransformer(substraitExprName, children, true, c)
case g: GetArrayItem =>
GetArrayItemTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(g.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(g.right, attributeSeq, expressionsMap),
g.failOnError,
g
)
case c: CreateMap =>
val children =
c.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap))
CreateMapTransformer(substraitExprName, children, c.useStringTypeWhenEmpty, c)
case g: GetMapValue =>
BackendsApiManager.getSparkPlanExecApiInstance.genGetMapValueTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(g.child, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(g.key, attributeSeq, expressionsMap),
g
)
case e: Explode =>
ExplodeTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(e.child, attributeSeq, expressionsMap),
e)
case p: PosExplode =>
BackendsApiManager.getSparkPlanExecApiInstance.genPosExplodeTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(p.child, attributeSeq, expressionsMap),
p,
attributeSeq)
case a: Alias =>
BackendsApiManager.getSparkPlanExecApiInstance.genAliasTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(a.child, attributeSeq, expressionsMap),
a)
case a: AttributeReference =>
if (attributeSeq == null) {
throw new UnsupportedOperationException(s"attributeSeq should not be null.")
}
try {
val bindReference =
BindReferences.bindReference(expr, attributeSeq, allowFailures = false)
val b = bindReference.asInstanceOf[BoundReference]
AttributeReferenceTransformer(
a.name,
b.ordinal,
a.dataType,
b.nullable,
a.exprId,
a.qualifier,
a.metadata)
} catch {
case e: IllegalStateException =>
// This situation may need developers to fix, although we just throw the below
// exception to let the corresponding operator fall back.
throw new UnsupportedOperationException(
s"Failed to bind reference for $expr: ${e.getMessage}")
}
case b: BoundReference =>
BoundReferenceTransformer(b.ordinal, b.dataType, b.nullable)
case l: Literal =>
LiteralTransformer(l)
case d: DateDiff =>
DateDiffTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(d.endDate, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(d.startDate, attributeSeq, expressionsMap),
d
)
case t: ToUnixTimestamp =>
BackendsApiManager.getSparkPlanExecApiInstance.genUnixTimestampTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(t.timeExp, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(t.format, attributeSeq, expressionsMap),
t
)
case u: UnixTimestamp =>
BackendsApiManager.getSparkPlanExecApiInstance.genUnixTimestampTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(u.timeExp, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(u.format, attributeSeq, expressionsMap),
ToUnixTimestamp(u.timeExp, u.format, u.timeZoneId, u.failOnError)
)
case t: TruncTimestamp =>
BackendsApiManager.getSparkPlanExecApiInstance.genTruncTimestampTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(t.format, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(t.timestamp, attributeSeq, expressionsMap),
t.timeZoneId,
t
)
case m: MonthsBetween =>
MonthsBetweenTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(m.date1, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(m.date2, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(m.roundOff, attributeSeq, expressionsMap),
m.timeZoneId,
m
)
case i: If =>
IfTransformer(
replaceWithExpressionTransformerInternal(i.predicate, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(i.trueValue, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(i.falseValue, attributeSeq, expressionsMap),
i
)
case cw: CaseWhen =>
CaseWhenTransformer(
cw.branches.map {
expr =>
{
(
replaceWithExpressionTransformerInternal(expr._1, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(expr._2, attributeSeq, expressionsMap))
}
},
cw.elseValue.map {
expr =>
{
replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap)
}
},
cw
)
case i: In =>
InTransformer(
replaceWithExpressionTransformerInternal(i.value, attributeSeq, expressionsMap),
i.list,
i.value.dataType,
i)
case i: InSet =>
InSetTransformer(
replaceWithExpressionTransformerInternal(i.child, attributeSeq, expressionsMap),
i.hset,
i.child.dataType,
i)
case s: org.apache.spark.sql.execution.ScalarSubquery =>
ScalarSubqueryTransformer(s.plan, s.exprId, s)
case c: Cast =>
// Add trim node, as necessary.
val newCast =
BackendsApiManager.getSparkPlanExecApiInstance.genCastWithNewChild(c)
CastTransformer(
replaceWithExpressionTransformerInternal(newCast.child, attributeSeq, expressionsMap),
newCast.dataType,
newCast.timeZoneId,
newCast)
case s: String2TrimExpression =>
val (srcStr, trimStr) = s match {
case StringTrim(srcStr, trimStr) => (srcStr, trimStr)
case StringTrimLeft(srcStr, trimStr) => (srcStr, trimStr)
case StringTrimRight(srcStr, trimStr) => (srcStr, trimStr)
}
String2TrimExpressionTransformer(
substraitExprName,
trimStr.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
replaceWithExpressionTransformerInternal(srcStr, attributeSeq, expressionsMap),
s
)
case m: HashExpression[_] =>
BackendsApiManager.getSparkPlanExecApiInstance.genHashExpressionTransformer(
substraitExprName,
m.children.map(
expr => replaceWithExpressionTransformerInternal(expr, attributeSeq, expressionsMap)),
m)
case getStructField: GetStructField =>
// Different backends may have different result.
BackendsApiManager.getSparkPlanExecApiInstance.genGetStructFieldTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(
getStructField.child,
attributeSeq,
expressionsMap),
getStructField.ordinal,
getStructField)
case getArrayStructFields: GetArrayStructFields =>
GetArrayStructFieldsTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(
getArrayStructFields.child,
attributeSeq,
expressionsMap),
getArrayStructFields.ordinal,
getArrayStructFields.numFields,
getArrayStructFields.containsNull,
getArrayStructFields
)
case t: StringTranslate =>
BackendsApiManager.getSparkPlanExecApiInstance.genStringTranslateTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(t.srcExpr, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(t.matchingExpr, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(t.replaceExpr, attributeSeq, expressionsMap),
t
)
case l: StringLocate =>
BackendsApiManager.getSparkPlanExecApiInstance.genStringLocateTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(l.first, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(l.second, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(l.third, attributeSeq, expressionsMap),
l
)
case s: StringSplit =>
BackendsApiManager.getSparkPlanExecApiInstance.genStringSplitTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(s.str, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(s.regex, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(s.limit, attributeSeq, expressionsMap),
s
)
case r: RegExpReplace =>
RegExpReplaceTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(r.subject, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(r.regexp, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(r.rep, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(r.pos, attributeSeq, expressionsMap),
r
)
case equal: EqualNullSafe =>
BackendsApiManager.getSparkPlanExecApiInstance.genEqualNullSafeTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(equal.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(equal.right, attributeSeq, expressionsMap),
equal
)
case md5: Md5 =>
BackendsApiManager.getSparkPlanExecApiInstance.genMd5Transformer(
substraitExprName,
replaceWithExpressionTransformerInternal(md5.child, attributeSeq, expressionsMap),
md5)
case sha1: Sha1 =>
BackendsApiManager.getSparkPlanExecApiInstance.genSha1Transformer(
substraitExprName,
replaceWithExpressionTransformerInternal(sha1.child, attributeSeq, expressionsMap),
sha1)
case sha2: Sha2 =>
BackendsApiManager.getSparkPlanExecApiInstance.genSha2Transformer(
substraitExprName,
replaceWithExpressionTransformerInternal(sha2.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(sha2.right, attributeSeq, expressionsMap),
sha2
)
case size: Size =>
BackendsApiManager.getSparkPlanExecApiInstance.genSizeExpressionTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(size.child, attributeSeq, expressionsMap),
size)
case namedStruct: CreateNamedStruct =>
BackendsApiManager.getSparkPlanExecApiInstance.genNamedStructTransformer(
substraitExprName,
namedStruct.children.map(
replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
namedStruct,
attributeSeq)
case namedLambdaVariable: NamedLambdaVariable =>
NamedLambdaVariableTransformer(
substraitExprName,
name = namedLambdaVariable.name,
dataType = namedLambdaVariable.dataType,
nullable = namedLambdaVariable.nullable,
exprId = namedLambdaVariable.exprId
)
case lambdaFunction: LambdaFunction =>
LambdaFunctionTransformer(
substraitExprName,
function = replaceWithExpressionTransformerInternal(
lambdaFunction.function,
attributeSeq,
expressionsMap),
arguments = lambdaFunction.arguments.map(
replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
hidden = false,
original = lambdaFunction
)
case j: JsonTuple =>
val children =
j.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap))
JsonTupleExpressionTransformer(substraitExprName, children, j)
case l: Like =>
LikeTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(l.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(l.right, attributeSeq, expressionsMap),
l
)
case c: CheckOverflow =>
CheckOverflowTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(c.child, attributeSeq, expressionsMap),
c)
case m: MakeDecimal =>
MakeDecimalTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(m.child, attributeSeq, expressionsMap),
m)
case rand: Rand =>
BackendsApiManager.getSparkPlanExecApiInstance.genRandTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(rand.child, attributeSeq, expressionsMap),
rand)
case _: KnownFloatingPointNormalized | _: NormalizeNaNAndZero | _: PromotePrecision =>
ChildTransformer(
replaceWithExpressionTransformerInternal(expr.children.head, attributeSeq, expressionsMap)
)
case _: GetDateField | _: GetTimeField =>
ExtractDateTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(
expr.children.head,
attributeSeq,
expressionsMap),
expr)
case _: StringToMap =>
BackendsApiManager.getSparkPlanExecApiInstance.genStringToMapTransformer(
substraitExprName,
expr.children.map(
replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
expr)
case b: BinaryArithmetic if DecimalArithmeticUtil.isDecimalArithmetic(b) =>
// PrecisionLoss=true: velox support / ch not support
// PrecisionLoss=false: velox not support / ch support
// TODO ch support PrecisionLoss=true
if (!BackendsApiManager.getSettings.allowDecimalArithmetic) {
throw new UnsupportedOperationException(
s"Not support ${SQLConf.DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key} " +
s"${conf.decimalOperationsAllowPrecisionLoss} mode")
}
val rescaleBinary = if (BackendsApiManager.getSettings.rescaleDecimalLiteral) {
DecimalArithmeticUtil.rescaleLiteral(b)
} else {
b
}
val (left, right) = DecimalArithmeticUtil.rescaleCastForDecimal(
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.left),
DecimalArithmeticUtil.removeCastForDecimal(rescaleBinary.right))
val leftChild = replaceWithExpressionTransformerInternal(left, attributeSeq, expressionsMap)
val rightChild =
replaceWithExpressionTransformerInternal(right, attributeSeq, expressionsMap)
val resultType = DecimalArithmeticUtil.getResultTypeForOperation(
DecimalArithmeticUtil.getOperationType(b),
DecimalArithmeticUtil
.getResultType(leftChild)
.getOrElse(left.dataType.asInstanceOf[DecimalType]),
DecimalArithmeticUtil
.getResultType(rightChild)
.getOrElse(right.dataType.asInstanceOf[DecimalType])
)
DecimalArithmeticExpressionTransformer(
substraitExprName,
leftChild,
rightChild,
resultType,
b)
case n: NaNvl =>
BackendsApiManager.getSparkPlanExecApiInstance.genNaNvlTransformer(
substraitExprName,
replaceWithExpressionTransformerInternal(n.left, attributeSeq, expressionsMap),
replaceWithExpressionTransformerInternal(n.right, attributeSeq, expressionsMap),
n
)
case e: Transformable =>
val childrenTransformers =
e.children.map(replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap))
e.getTransformer(childrenTransformers)
case expr =>
GenericExpressionTransformer(
substraitExprName,
expr.children.map(
replaceWithExpressionTransformerInternal(_, attributeSeq, expressionsMap)),
expr
)
}
}