def promote()

in spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala [44:109]


  def promote(
      allowPrecisionLoss: Boolean,
      expr: Expression,
      nullOnOverflow: Boolean): Expression = {
    expr.transformUp {
      // This means the binary expression is already optimized with the rule in Spark. This can
      // happen if the Spark version is < 3.4
      case e: BinaryArithmetic if e.left.prettyName == "promote_precision" => e

      case add @ Add(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
        val resultScale = max(s1, s2)
        val resultType = if (allowPrecisionLoss) {
          DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
        } else {
          DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
        }
        CheckOverflow(add, resultType, nullOnOverflow)

      case sub @ Subtract(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
        val resultScale = max(s1, s2)
        val resultType = if (allowPrecisionLoss) {
          DecimalType.adjustPrecisionScale(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
        } else {
          DecimalType.bounded(max(p1 - s1, p2 - s2) + resultScale + 1, resultScale)
        }
        CheckOverflow(sub, resultType, nullOnOverflow)

      case mul @ Multiply(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
        val resultType = if (allowPrecisionLoss) {
          DecimalType.adjustPrecisionScale(p1 + p2 + 1, s1 + s2)
        } else {
          DecimalType.bounded(p1 + p2 + 1, s1 + s2)
        }
        CheckOverflow(mul, resultType, nullOnOverflow)

      case div @ Divide(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
        val resultType = if (allowPrecisionLoss) {
          // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
          // Scale: max(6, s1 + p2 + 1)
          val intDig = p1 - s1 + s2
          val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1)
          val prec = intDig + scale
          DecimalType.adjustPrecisionScale(prec, scale)
        } else {
          var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
          var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
          val diff = (intDig + decDig) - DecimalType.MAX_SCALE
          if (diff > 0) {
            decDig -= diff / 2 + 1
            intDig = DecimalType.MAX_SCALE - decDig
          }
          DecimalType.bounded(intDig + decDig, decDig)
        }
        CheckOverflow(div, resultType, nullOnOverflow)

      case rem @ Remainder(DecimalExpression(p1, s1), DecimalExpression(p2, s2), _) =>
        val resultType = if (allowPrecisionLoss) {
          DecimalType.adjustPrecisionScale(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
        } else {
          DecimalType.bounded(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
        }
        CheckOverflow(rem, resultType, nullOnOverflow)

      case e => e
    }
  }