in datafusion/expr/src/type_coercion/functions.rs [354:691]
fn get_valid_types(
function_name: &str,
signature: &TypeSignature,
current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
fn array_valid_types(
function_name: &str,
current_types: &[DataType],
arguments: &[ArrayFunctionArgument],
array_coercion: Option<&ListCoercion>,
) -> Result<Vec<Vec<DataType>>> {
if current_types.len() != arguments.len() {
return Ok(vec![vec![]]);
}
let mut large_list = false;
let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
let mut list_sizes = Vec::with_capacity(arguments.len());
let mut element_types = Vec::with_capacity(arguments.len());
for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
match argument {
ArrayFunctionArgument::Index | ArrayFunctionArgument::String => (),
ArrayFunctionArgument::Element => {
element_types.push(current_type.clone())
}
ArrayFunctionArgument::Array => match current_type {
DataType::Null => element_types.push(DataType::Null),
DataType::List(field) => {
element_types.push(field.data_type().clone());
fixed_size = false;
}
DataType::LargeList(field) => {
element_types.push(field.data_type().clone());
large_list = true;
fixed_size = false;
}
DataType::FixedSizeList(field, size) => {
element_types.push(field.data_type().clone());
list_sizes.push(*size)
}
arg_type => {
plan_err!("{function_name} does not support type {arg_type}")?
}
},
}
}
let Some(element_type) = type_union_resolution(&element_types) else {
return Ok(vec![vec![]]);
};
if !fixed_size {
list_sizes.clear()
}
let mut list_sizes = list_sizes.into_iter();
let valid_types = arguments.iter().zip(current_types.iter()).map(
|(argument_type, current_type)| match argument_type {
ArrayFunctionArgument::Index => DataType::Int64,
ArrayFunctionArgument::String => DataType::Utf8,
ArrayFunctionArgument::Element => element_type.clone(),
ArrayFunctionArgument::Array => {
if current_type.is_null() {
DataType::Null
} else if large_list {
DataType::new_large_list(element_type.clone(), true)
} else if let Some(size) = list_sizes.next() {
DataType::new_fixed_size_list(element_type.clone(), size, true)
} else {
DataType::new_list(element_type.clone(), true)
}
}
},
);
Ok(vec![valid_types.collect()])
}
fn recursive_array(array_type: &DataType) -> Option<DataType> {
match array_type {
DataType::List(_)
| DataType::LargeList(_)
| DataType::FixedSizeList(_, _) => {
let array_type = coerced_fixed_size_list_to_list(array_type);
Some(array_type)
}
_ => None,
}
}
fn function_length_check(
function_name: &str,
length: usize,
expected_length: usize,
) -> Result<()> {
if length != expected_length {
return plan_err!(
"Function '{function_name}' expects {expected_length} arguments but received {length}"
);
}
Ok(())
}
let valid_types = match signature {
TypeSignature::Variadic(valid_types) => valid_types
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::String(number) => {
function_length_check(function_name, current_types.len(), *number)?;
let mut new_types = Vec::with_capacity(current_types.len());
for data_type in current_types.iter() {
let logical_data_type: NativeType = data_type.into();
if logical_data_type == NativeType::String {
new_types.push(data_type.to_owned());
} else if logical_data_type == NativeType::Null {
// TODO: Switch to Utf8View if all the string functions supports Utf8View
new_types.push(DataType::Utf8);
} else {
return plan_err!(
"Function '{function_name}' expects NativeType::String but received {logical_data_type}"
);
}
}
// Find the common string type for the given types
fn find_common_type(
function_name: &str,
lhs_type: &DataType,
rhs_type: &DataType,
) -> Result<DataType> {
match (lhs_type, rhs_type) {
(DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
find_common_type(function_name, lhs, rhs)
}
(DataType::Dictionary(_, v), other)
| (other, DataType::Dictionary(_, v)) => {
find_common_type(function_name, v, other)
}
_ => {
if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
Ok(coerced_type)
} else {
plan_err!(
"Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
)
}
}
}
}
// Length checked above, safe to unwrap
let mut coerced_type = new_types.first().unwrap().to_owned();
for t in new_types.iter().skip(1) {
coerced_type = find_common_type(function_name, &coerced_type, t)?;
}
fn base_type_or_default_type(data_type: &DataType) -> DataType {
if let DataType::Dictionary(_, v) = data_type {
base_type_or_default_type(v)
} else {
data_type.to_owned()
}
}
vec![vec![base_type_or_default_type(&coerced_type); *number]]
}
TypeSignature::Numeric(number) => {
function_length_check(function_name, current_types.len(), *number)?;
// Find common numeric type among given types except string
let mut valid_type = current_types.first().unwrap().to_owned();
for t in current_types.iter().skip(1) {
let logical_data_type: NativeType = t.into();
if logical_data_type == NativeType::Null {
continue;
}
if !logical_data_type.is_numeric() {
return plan_err!(
"Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
);
}
if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
valid_type = coerced_type;
} else {
return plan_err!(
"For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
);
}
}
let logical_data_type: NativeType = valid_type.clone().into();
// Fallback to default type if we don't know which type to coerced to
// f64 is chosen since most of the math functions utilize Signature::numeric,
// and their default type is double precision
if logical_data_type == NativeType::Null {
valid_type = DataType::Float64;
} else if !logical_data_type.is_numeric() {
return plan_err!(
"Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
);
}
vec![vec![valid_type; *number]]
}
TypeSignature::Comparable(num) => {
function_length_check(function_name, current_types.len(), *num)?;
let mut target_type = current_types[0].to_owned();
for data_type in current_types.iter().skip(1) {
if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
target_type = dt;
} else {
return plan_err!("For function '{function_name}' {target_type} and {data_type} is not comparable");
}
}
// Convert null to String type.
if target_type.is_null() {
vec![vec![DataType::Utf8View; *num]]
} else {
vec![vec![target_type; *num]]
}
}
TypeSignature::Coercible(param_types) => {
function_length_check(function_name, current_types.len(), param_types.len())?;
let mut new_types = Vec::with_capacity(current_types.len());
for (current_type, param) in current_types.iter().zip(param_types.iter()) {
let current_native_type: NativeType = current_type.into();
if param.desired_type().matches_native_type(¤t_native_type) {
let casted_type = param.desired_type().default_casted_type(
¤t_native_type,
current_type,
)?;
new_types.push(casted_type);
} else if param
.allowed_source_types()
.iter()
.any(|t| t.matches_native_type(¤t_native_type)) {
// If the condition is met which means `implicit coercion`` is provided so we can safely unwrap
let default_casted_type = param.default_casted_type().unwrap();
let casted_type = default_casted_type.default_cast_for(current_type)?;
new_types.push(casted_type);
} else {
return internal_err!(
"Expect {} but received {}, DataType: {}",
param.desired_type(),
current_native_type,
current_type
);
}
}
vec![new_types]
}
TypeSignature::Uniform(number, valid_types) => {
if *number == 0 {
return plan_err!("The function '{function_name}' expected at least one argument");
}
valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
.collect()
}
TypeSignature::UserDefined => {
return internal_err!(
"Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
)
}
TypeSignature::VariadicAny => {
if current_types.is_empty() {
return plan_err!(
"Function '{function_name}' expected at least one argument but received 0"
);
}
vec![current_types.to_vec()]
}
TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
TypeSignature::ArraySignature(ref function_signature) => match function_signature {
ArrayFunctionSignature::Array { arguments, array_coercion, } => {
array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())?
}
ArrayFunctionSignature::RecursiveArray => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
}
recursive_array(¤t_types[0])
.map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
}
ArrayFunctionSignature::MapArray => {
if current_types.len() != 1 {
return Ok(vec![vec![]]);
}
match ¤t_types[0] {
DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
_ => vec![vec![]],
}
}
},
TypeSignature::Nullary => {
if !current_types.is_empty() {
return plan_err!(
"The function '{function_name}' expected zero argument but received {}",
current_types.len()
);
}
vec![vec![]]
}
TypeSignature::Any(number) => {
if current_types.is_empty() {
return plan_err!(
"The function '{function_name}' expected at least one argument but received 0"
);
}
if current_types.len() != *number {
return plan_err!(
"The function '{function_name}' expected {number} arguments but received {}",
current_types.len()
);
}
vec![(0..*number).map(|i| current_types[i].clone()).collect()]
}
TypeSignature::OneOf(types) => types
.iter()
.filter_map(|t| get_valid_types(function_name, t, current_types).ok())
.flatten()
.collect::<Vec<_>>(),
};
Ok(valid_types)
}