datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs (211 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use crate::logical_plan::consumer::{from_substrait_func_args, SubstraitConsumer};
use datafusion::common::{not_impl_err, plan_err, substrait_err, DFSchema, ScalarValue};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{expr, BinaryExpr, Expr, Like, Operator};
use substrait::proto::expression::ScalarFunction;
use substrait::proto::function_argument::ArgType;
pub async fn from_scalar_function(
consumer: &impl SubstraitConsumer,
f: &ScalarFunction,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
let Some(fn_signature) = consumer
.get_extensions()
.functions
.get(&f.function_reference)
else {
return plan_err!(
"Scalar function not found: function reference = {:?}",
f.function_reference
);
};
let fn_name = substrait_fun_name(fn_signature);
let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;
// try to first match the requested function into registered udfs, then built-in ops
// and finally built-in expressions
if let Ok(func) = consumer.get_function_registry().udf(fn_name) {
Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
func.to_owned(),
args,
)))
} else if let Some(op) = name_to_op(fn_name) {
if f.arguments.len() < 2 {
return not_impl_err!(
"Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}",
f.arguments.len()
);
}
// Some expressions are binary in DataFusion but take in a variadic number of args in Substrait.
// In those cases we iterate through all the arguments, applying the binary expression against them all
let combined_expr = args
.into_iter()
.fold(None, |combined_expr: Option<Expr>, arg: Expr| {
Some(match combined_expr {
Some(expr) => Expr::BinaryExpr(BinaryExpr {
left: Box::new(expr),
op,
right: Box::new(arg),
}),
None => arg,
})
})
.unwrap();
Ok(combined_expr)
} else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) {
builder.build(consumer, f, input_schema).await
} else {
not_impl_err!("Unsupported function name: {fn_name:?}")
}
}
pub fn substrait_fun_name(name: &str) -> &str {
let name = match name.rsplit_once(':') {
// Since 0.32.0, Substrait requires the function names to be in a compound format
// https://substrait.io/extensions/#function-signature-compound-names
// for example, `add:i8_i8`.
// On the consumer side, we don't really care about the signature though, just the name.
Some((name, _)) => name,
None => name,
};
name
}
pub fn name_to_op(name: &str) -> Option<Operator> {
match name {
"equal" => Some(Operator::Eq),
"not_equal" => Some(Operator::NotEq),
"lt" => Some(Operator::Lt),
"lte" => Some(Operator::LtEq),
"gt" => Some(Operator::Gt),
"gte" => Some(Operator::GtEq),
"add" => Some(Operator::Plus),
"subtract" => Some(Operator::Minus),
"multiply" => Some(Operator::Multiply),
"divide" => Some(Operator::Divide),
"mod" => Some(Operator::Modulo),
"modulus" => Some(Operator::Modulo),
"and" => Some(Operator::And),
"or" => Some(Operator::Or),
"is_distinct_from" => Some(Operator::IsDistinctFrom),
"is_not_distinct_from" => Some(Operator::IsNotDistinctFrom),
"regex_match" => Some(Operator::RegexMatch),
"regex_imatch" => Some(Operator::RegexIMatch),
"regex_not_match" => Some(Operator::RegexNotMatch),
"regex_not_imatch" => Some(Operator::RegexNotIMatch),
"bitwise_and" => Some(Operator::BitwiseAnd),
"bitwise_or" => Some(Operator::BitwiseOr),
"str_concat" => Some(Operator::StringConcat),
"at_arrow" => Some(Operator::AtArrow),
"arrow_at" => Some(Operator::ArrowAt),
"bitwise_xor" => Some(Operator::BitwiseXor),
"bitwise_shift_right" => Some(Operator::BitwiseShiftRight),
"bitwise_shift_left" => Some(Operator::BitwiseShiftLeft),
_ => None,
}
}
/// Build [`Expr`] from its name and required inputs.
struct BuiltinExprBuilder {
expr_name: String,
}
impl BuiltinExprBuilder {
pub fn try_from_name(name: &str) -> Option<Self> {
match name {
"not" | "like" | "ilike" | "is_null" | "is_not_null" | "is_true"
| "is_false" | "is_not_true" | "is_not_false" | "is_unknown"
| "is_not_unknown" | "negative" | "negate" => Some(Self {
expr_name: name.to_string(),
}),
_ => None,
}
}
pub async fn build(
self,
consumer: &impl SubstraitConsumer,
f: &ScalarFunction,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
match self.expr_name.as_str() {
"like" => Self::build_like_expr(consumer, false, f, input_schema).await,
"ilike" => Self::build_like_expr(consumer, true, f, input_schema).await,
"not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true"
| "is_false" | "is_not_true" | "is_not_false" | "is_unknown"
| "is_not_unknown" => {
Self::build_unary_expr(consumer, &self.expr_name, f, input_schema).await
}
_ => {
not_impl_err!("Unsupported builtin expression: {}", self.expr_name)
}
}
}
async fn build_unary_expr(
consumer: &impl SubstraitConsumer,
fn_name: &str,
f: &ScalarFunction,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
if f.arguments.len() != 1 {
return substrait_err!("Expect one argument for {fn_name} expr");
}
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
return substrait_err!("Invalid arguments type for {fn_name} expr");
};
let arg = consumer
.consume_expression(expr_substrait, input_schema)
.await?;
let arg = Box::new(arg);
let expr = match fn_name {
"not" => Expr::Not(arg),
"negative" | "negate" => Expr::Negative(arg),
"is_null" => Expr::IsNull(arg),
"is_not_null" => Expr::IsNotNull(arg),
"is_true" => Expr::IsTrue(arg),
"is_false" => Expr::IsFalse(arg),
"is_not_true" => Expr::IsNotTrue(arg),
"is_not_false" => Expr::IsNotFalse(arg),
"is_unknown" => Expr::IsUnknown(arg),
"is_not_unknown" => Expr::IsNotUnknown(arg),
_ => return not_impl_err!("Unsupported builtin expression: {}", fn_name),
};
Ok(expr)
}
async fn build_like_expr(
consumer: &impl SubstraitConsumer,
case_insensitive: bool,
f: &ScalarFunction,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" };
if f.arguments.len() != 2 && f.arguments.len() != 3 {
return substrait_err!("Expect two or three arguments for `{fn_name}` expr");
}
let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else {
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
};
let expr = consumer
.consume_expression(expr_substrait, input_schema)
.await?;
let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else {
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
};
let pattern = consumer
.consume_expression(pattern_substrait, input_schema)
.await?;
// Default case: escape character is Literal(Utf8(None))
let escape_char = if f.arguments.len() == 3 {
let Some(ArgType::Value(escape_char_substrait)) = &f.arguments[2].arg_type
else {
return substrait_err!("Invalid arguments type for `{fn_name}` expr");
};
let escape_char_expr = consumer
.consume_expression(escape_char_substrait, input_schema)
.await?;
match escape_char_expr {
Expr::Literal(ScalarValue::Utf8(escape_char_string)) => {
// Convert Option<String> to Option<char>
escape_char_string.and_then(|s| s.chars().next())
}
_ => {
return substrait_err!(
"Expect Utf8 literal for escape char, but found {escape_char_expr:?}"
)
}
}
} else {
None
};
Ok(Expr::Like(Like {
negated: false,
expr: Box::new(expr),
pattern: Box::new(pattern),
escape_char,
case_insensitive,
}))
}
}