in native/core/src/execution/planner.rs [813:892]
fn create_binary_expr_with_options(
&self,
left: &Expr,
right: &Expr,
return_type: Option<&spark_expression::DataType>,
op: DataFusionOperator,
input_schema: SchemaRef,
options: BinaryExprOptions,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let left = self.create_expr(left, Arc::clone(&input_schema))?;
let right = self.create_expr(right, Arc::clone(&input_schema))?;
match (
&op,
left.data_type(&input_schema),
right.data_type(&input_schema),
) {
(
DataFusionOperator::Plus
| DataFusionOperator::Minus
| DataFusionOperator::Multiply
| DataFusionOperator::Modulo,
Ok(DataType::Decimal128(p1, s1)),
Ok(DataType::Decimal128(p2, s2)),
) if ((op == DataFusionOperator::Plus || op == DataFusionOperator::Minus)
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
>= DECIMAL128_MAX_PRECISION)
|| (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION)
|| (op == DataFusionOperator::Modulo
&& max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
> DECIMAL128_MAX_PRECISION) =>
{
let data_type = return_type.map(to_arrow_datatype).unwrap();
// For some Decimal128 operations, we need wider internal digits.
// Cast left and right to Decimal256 and cast the result back to Decimal128
let left = Arc::new(Cast::new(
left,
DataType::Decimal256(p1, s1),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
));
let right = Arc::new(Cast::new(
right,
DataType::Decimal256(p2, s2),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
));
let child = Arc::new(BinaryExpr::new(left, op, right));
Ok(Arc::new(Cast::new(
child,
data_type,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
)))
}
(
DataFusionOperator::Divide,
Ok(DataType::Decimal128(_p1, _s1)),
Ok(DataType::Decimal128(_p2, _s2)),
) => {
let data_type = return_type.map(to_arrow_datatype).unwrap();
let func_name = if options.is_integral_div {
// Decimal256 division in Arrow may overflow, so we still need this variant of decimal_div.
// Otherwise, we may be able to reuse the previous case-match instead of here,
// see more: https://github.com/apache/datafusion-comet/pull/1428#discussion_r1972648463
"decimal_integral_div"
} else {
"decimal_div"
};
let fun_expr = create_comet_physical_fun(
func_name,
data_type.clone(),
&self.session_ctx.state(),
)?;
Ok(Arc::new(ScalarFunctionExpr::new(
func_name,
fun_expr,
vec![left, right],
data_type,
)))
}
_ => Ok(Arc::new(BinaryExpr::new(left, op, right))),
}
}