fn get_valid_types()

in datafusion/expr/src/type_coercion/functions.rs [354:691]


fn get_valid_types(
    function_name: &str,
    signature: &TypeSignature,
    current_types: &[DataType],
) -> Result<Vec<Vec<DataType>>> {
    fn array_valid_types(
        function_name: &str,
        current_types: &[DataType],
        arguments: &[ArrayFunctionArgument],
        array_coercion: Option<&ListCoercion>,
    ) -> Result<Vec<Vec<DataType>>> {
        if current_types.len() != arguments.len() {
            return Ok(vec![vec![]]);
        }

        let mut large_list = false;
        let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
        let mut list_sizes = Vec::with_capacity(arguments.len());
        let mut element_types = Vec::with_capacity(arguments.len());
        for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
            match argument {
                ArrayFunctionArgument::Index | ArrayFunctionArgument::String => (),
                ArrayFunctionArgument::Element => {
                    element_types.push(current_type.clone())
                }
                ArrayFunctionArgument::Array => match current_type {
                    DataType::Null => element_types.push(DataType::Null),
                    DataType::List(field) => {
                        element_types.push(field.data_type().clone());
                        fixed_size = false;
                    }
                    DataType::LargeList(field) => {
                        element_types.push(field.data_type().clone());
                        large_list = true;
                        fixed_size = false;
                    }
                    DataType::FixedSizeList(field, size) => {
                        element_types.push(field.data_type().clone());
                        list_sizes.push(*size)
                    }
                    arg_type => {
                        plan_err!("{function_name} does not support type {arg_type}")?
                    }
                },
            }
        }

        let Some(element_type) = type_union_resolution(&element_types) else {
            return Ok(vec![vec![]]);
        };

        if !fixed_size {
            list_sizes.clear()
        }

        let mut list_sizes = list_sizes.into_iter();
        let valid_types = arguments.iter().zip(current_types.iter()).map(
            |(argument_type, current_type)| match argument_type {
                ArrayFunctionArgument::Index => DataType::Int64,
                ArrayFunctionArgument::String => DataType::Utf8,
                ArrayFunctionArgument::Element => element_type.clone(),
                ArrayFunctionArgument::Array => {
                    if current_type.is_null() {
                        DataType::Null
                    } else if large_list {
                        DataType::new_large_list(element_type.clone(), true)
                    } else if let Some(size) = list_sizes.next() {
                        DataType::new_fixed_size_list(element_type.clone(), size, true)
                    } else {
                        DataType::new_list(element_type.clone(), true)
                    }
                }
            },
        );

        Ok(vec![valid_types.collect()])
    }

    fn recursive_array(array_type: &DataType) -> Option<DataType> {
        match array_type {
            DataType::List(_)
            | DataType::LargeList(_)
            | DataType::FixedSizeList(_, _) => {
                let array_type = coerced_fixed_size_list_to_list(array_type);
                Some(array_type)
            }
            _ => None,
        }
    }

    fn function_length_check(
        function_name: &str,
        length: usize,
        expected_length: usize,
    ) -> Result<()> {
        if length != expected_length {
            return plan_err!(
                "Function '{function_name}' expects {expected_length} arguments but received {length}"
            );
        }
        Ok(())
    }

    let valid_types = match signature {
        TypeSignature::Variadic(valid_types) => valid_types
            .iter()
            .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
            .collect(),
        TypeSignature::String(number) => {
            function_length_check(function_name, current_types.len(), *number)?;

            let mut new_types = Vec::with_capacity(current_types.len());
            for data_type in current_types.iter() {
                let logical_data_type: NativeType = data_type.into();
                if logical_data_type == NativeType::String {
                    new_types.push(data_type.to_owned());
                } else if logical_data_type == NativeType::Null {
                    // TODO: Switch to Utf8View if all the string functions supports Utf8View
                    new_types.push(DataType::Utf8);
                } else {
                    return plan_err!(
                        "Function '{function_name}' expects NativeType::String but received {logical_data_type}"
                    );
                }
            }

            // Find the common string type for the given types
            fn find_common_type(
                function_name: &str,
                lhs_type: &DataType,
                rhs_type: &DataType,
            ) -> Result<DataType> {
                match (lhs_type, rhs_type) {
                    (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
                        find_common_type(function_name, lhs, rhs)
                    }
                    (DataType::Dictionary(_, v), other)
                    | (other, DataType::Dictionary(_, v)) => {
                        find_common_type(function_name, v, other)
                    }
                    _ => {
                        if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
                            Ok(coerced_type)
                        } else {
                            plan_err!(
                                "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
                            )
                        }
                    }
                }
            }

            // Length checked above, safe to unwrap
            let mut coerced_type = new_types.first().unwrap().to_owned();
            for t in new_types.iter().skip(1) {
                coerced_type = find_common_type(function_name, &coerced_type, t)?;
            }

            fn base_type_or_default_type(data_type: &DataType) -> DataType {
                if let DataType::Dictionary(_, v) = data_type {
                    base_type_or_default_type(v)
                } else {
                    data_type.to_owned()
                }
            }

            vec![vec![base_type_or_default_type(&coerced_type); *number]]
        }
        TypeSignature::Numeric(number) => {
            function_length_check(function_name, current_types.len(), *number)?;

            // Find common numeric type among given types except string
            let mut valid_type = current_types.first().unwrap().to_owned();
            for t in current_types.iter().skip(1) {
                let logical_data_type: NativeType = t.into();
                if logical_data_type == NativeType::Null {
                    continue;
                }

                if !logical_data_type.is_numeric() {
                    return plan_err!(
                        "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
                    );
                }

                if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
                    valid_type = coerced_type;
                } else {
                    return plan_err!(
                        "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
                    );
                }
            }

            let logical_data_type: NativeType = valid_type.clone().into();
            // Fallback to default type if we don't know which type to coerced to
            // f64 is chosen since most of the math functions utilize Signature::numeric,
            // and their default type is double precision
            if logical_data_type == NativeType::Null {
                valid_type = DataType::Float64;
            } else if !logical_data_type.is_numeric() {
                return plan_err!(
                    "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
                );
            }

            vec![vec![valid_type; *number]]
        }
        TypeSignature::Comparable(num) => {
            function_length_check(function_name, current_types.len(), *num)?;
            let mut target_type = current_types[0].to_owned();
            for data_type in current_types.iter().skip(1) {
                if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
                    target_type = dt;
                } else {
                    return plan_err!("For function '{function_name}' {target_type} and {data_type} is not comparable");
                }
            }
            // Convert null to String type.
            if target_type.is_null() {
                vec![vec![DataType::Utf8View; *num]]
            } else {
                vec![vec![target_type; *num]]
            }
        }
        TypeSignature::Coercible(param_types) => {
            function_length_check(function_name, current_types.len(), param_types.len())?;

            let mut new_types = Vec::with_capacity(current_types.len());
            for (current_type, param) in current_types.iter().zip(param_types.iter()) {
                let current_native_type: NativeType = current_type.into();

                if param.desired_type().matches_native_type(&current_native_type) {
                    let casted_type = param.desired_type().default_casted_type(
                        &current_native_type,
                        current_type,
                    )?;

                    new_types.push(casted_type);
                } else if param
                .allowed_source_types()
                .iter()
                .any(|t| t.matches_native_type(&current_native_type)) {
                    // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap
                    let default_casted_type = param.default_casted_type().unwrap();
                    let casted_type = default_casted_type.default_cast_for(current_type)?;
                    new_types.push(casted_type);
                } else {
                    return internal_err!(
                        "Expect {} but received {}, DataType: {}",
                        param.desired_type(),
                        current_native_type,
                        current_type
                    );
                }
            }

            vec![new_types]
        }
        TypeSignature::Uniform(number, valid_types) => {
            if *number == 0 {
                return plan_err!("The function '{function_name}' expected at least one argument");
            }

            valid_types
                .iter()
                .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
                .collect()
        }
        TypeSignature::UserDefined => {
            return internal_err!(
                "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
            )
        }
        TypeSignature::VariadicAny => {
            if current_types.is_empty() {
                return plan_err!(
                    "Function '{function_name}' expected at least one argument but received 0"
                );
            }
            vec![current_types.to_vec()]
        }
        TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
        TypeSignature::ArraySignature(ref function_signature) => match function_signature {
            ArrayFunctionSignature::Array { arguments, array_coercion, } => {
                array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())?
            }
            ArrayFunctionSignature::RecursiveArray => {
                if current_types.len() != 1 {
                    return Ok(vec![vec![]]);
                }
                recursive_array(&current_types[0])
                    .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
            }
            ArrayFunctionSignature::MapArray => {
                if current_types.len() != 1 {
                    return Ok(vec![vec![]]);
                }

                match &current_types[0] {
                    DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
                    _ => vec![vec![]],
                }
            }
        },
        TypeSignature::Nullary => {
            if !current_types.is_empty() {
                return plan_err!(
                    "The function '{function_name}' expected zero argument but received {}",
                    current_types.len()
                );
            }
            vec![vec![]]
        }
        TypeSignature::Any(number) => {
            if current_types.is_empty() {
                return plan_err!(
                    "The function '{function_name}' expected at least one argument but received 0"
                );
            }

            if current_types.len() != *number {
                return plan_err!(
                    "The function '{function_name}' expected {number} arguments but received {}",
                    current_types.len()
                );
            }
            vec![(0..*number).map(|i| current_types[i].clone()).collect()]
        }
        TypeSignature::OneOf(types) => types
            .iter()
            .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
            .flatten()
            .collect::<Vec<_>>(),
    };

    Ok(valid_types)
}