in spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala [44:109]
def promote(
allowPrecisionLoss: Boolean,
expr: Expression,
nullOnOverflow: Boolean): Expression = {
expr.transformUp {
// This means the binary expression is already optimized with the rule in Spark. This can
// happen if the Spark version is < 3.4
case e: BinaryArithmetic if e.left.prettyName == "promote_precision" => e
case add @ Add(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultScale = max(s1, s2)
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(add, resultType, nullOnOverflow)
case sub @ Subtract(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultScale = max(s1, s2)
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
} else {
DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
}
CheckOverflow(sub, resultType, nullOnOverflow)
case mul @ Multiply(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
} else {
DecimalType.bounded(p1 + p2 + 1, s1 + s2)
}
CheckOverflow(mul, resultType, nullOnOverflow)
case div @ Divide(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
// Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
// Scale: max(6, s1 + p2 + 1)
val intDig = p1 - s1 + s2
val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1)
val prec = intDig + scale
DecimalType.adjustPrecisionScale(prec, scale)
} else {
var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
val diff = (intDig + decDig) - DecimalType.MAX_SCALE
if (diff > 0) {
decDig -= diff / 2 + 1
intDig = DecimalType.MAX_SCALE - decDig
}
DecimalType.bounded(intDig + decDig, decDig)
}
CheckOverflow(div, resultType, nullOnOverflow)
case rem @ Remainder(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
val resultType = if (allowPrecisionLoss) {
DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
} else {
DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
}
CheckOverflow(rem, resultType, nullOnOverflow)
case e => e
}
}