void SerializedPlanParser::parseFunctionArguments()

in cpp-ch/local-engine/Parser/SerializedPlanParser.cpp [1051:1192]


void SerializedPlanParser::parseFunctionArguments(
    DB::ActionsDAGPtr & actions_dag,
    ActionsDAG::NodeRawConstPtrs & parsed_args,
    std::string & function_name,
    const substrait::Expression_ScalarFunction & scalar_function)
{
    auto function_signature = function_mapping.at(std::to_string(scalar_function.function_reference()));
    const auto & args = scalar_function.arguments();
    parsed_args.reserve(args.size());

    // Some functions need to be handled specially.
    if (function_name == "JSONExtract")
    {
        parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
        auto data_type = TypeParser::parseType(scalar_function.output_type());
        parsed_args.emplace_back(addColumn(actions_dag, std::make_shared<DB::DataTypeString>(), data_type->getName()));
    }
    else if (function_name == "sparkTupleElement" || function_name == "tupleElement")
    {
        parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);

        if (!args[1].value().has_literal())
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's second argument must be a literal");

        auto [data_type, field] = parseLiteral(args[1].value().literal());
        if (data_type->getTypeId() != DB::TypeIndex::Int32)
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's second argument must be i32");

        // tuple indecies start from 1, in spark, start from 0
        Int32 field_index = static_cast<Int32>(field.get<Int32>() + 1);
        const auto * index_node = addColumn(actions_dag, std::make_shared<DB::DataTypeUInt32>(), field_index);
        parsed_args.emplace_back(index_node);
    }
    else if (function_name == "tuple")
    {
        // Arguments in the format, (<field name>, <value expression>[, <field name>, <value expression> ...])
        // We don't need to care the field names here.
        for (int index = 1; index < args.size(); index += 2)
            parseFunctionArgument(actions_dag, parsed_args, function_name, args[index]);
    }
    else if (function_name == "repeat")
    {
        // repeat. the field index must be unsigned integer in CH, cast the signed integer in substrait
        // which must be a positive value into unsigned integer here.
        parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
        const DB::ActionsDAG::Node * repeat_times_node = parseFunctionArgument(actions_dag, function_name, args[1]);
        DB::DataTypeNullable target_type(std::make_shared<DB::DataTypeUInt32>());
        repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag, repeat_times_node, target_type.getName());
        parsed_args.emplace_back(repeat_times_node);
    }
    else if (function_name == "isNaN")
    {
        // the result of isNaN(NULL) is NULL in CH, but false in Spark
        const DB::ActionsDAG::Node * arg_node = nullptr;
        if (args[0].value().has_cast())
        {
            arg_node = parseExpression(actions_dag, args[0].value().cast().input());
            const auto * res_type = arg_node->result_type.get();
            if (res_type->isNullable())
            {
                res_type = typeid_cast<const DB::DataTypeNullable *>(res_type)->getNestedType().get();
            }
            if (isString(*res_type))
            {
                DB::ActionsDAG::NodeRawConstPtrs cast_func_args = {arg_node};
                arg_node = toFunctionNode(actions_dag, "toFloat64OrZero", cast_func_args);
            }
            else
            {
                arg_node = parseFunctionArgument(actions_dag, function_name, args[0]);
            }
        }
        else
        {
            arg_node = parseFunctionArgument(actions_dag, function_name, args[0]);
        }

        DB::ActionsDAG::NodeRawConstPtrs ifnull_func_args = {arg_node, addColumn(actions_dag, std::make_shared<DataTypeInt32>(), 0)};
        parsed_args.emplace_back(toFunctionNode(actions_dag, "IfNull", ifnull_func_args));
    }
    else if (function_name == "positionUTF8Spark")
    {
        if (args.size() >= 2)
        {
            // In Spark: position(substr, str, Int32)
            // In CH:    position(str, subtr, UInt32)
            parseFunctionArgument(actions_dag, parsed_args, function_name, args[1]);
            parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
        }
        if (args.size() >= 3)
        {
            // add cast: cast(start_pos as UInt32)
            const auto * start_pos_node = parseFunctionArgument(actions_dag, function_name, args[2]);
            DB::DataTypeNullable target_type(std::make_shared<DB::DataTypeUInt32>());
            start_pos_node = ActionsDAGUtil::convertNodeType(actions_dag, start_pos_node, target_type.getName());
            parsed_args.emplace_back(start_pos_node);
        }
    }
    else if (function_name == "space")
    {
        // convert space function to repeat
        const DB::ActionsDAG::Node * repeat_times_node = parseFunctionArgument(actions_dag, "repeat", args[0]);
        const DB::ActionsDAG::Node * space_str_node = addColumn(actions_dag, std::make_shared<DataTypeString>(), " ");
        function_name = "repeat";
        parsed_args.emplace_back(space_str_node);
        parsed_args.emplace_back(repeat_times_node);
    }
    else if (function_name == "trimBothSpark" || function_name == "trimLeftSpark" || function_name == "trimRightSpark")
    {
        /// In substrait, the first arg is srcStr, the second arg is trimStr
        /// But in CH, the first arg is trimStr, the second arg is srcStr
        parseFunctionArgument(actions_dag, parsed_args, function_name, args[1]);
        parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
    }
    else if (startsWith(function_signature, "extract:"))
    {
        /// Skip the first arg of extract in substrait
        for (int i = 1; i < args.size(); i++)
            parseFunctionArgument(actions_dag, parsed_args, function_name, args[i]);

        /// Append extra mode argument for extract(WEEK_DAY from date) or extract(DAY_OF_WEEK from date) in substrait
        if (function_name == "toDayOfWeek" || function_name == "DAYOFWEEK")
        {
            UInt8 mode = function_name == "toDayOfWeek" ? 1 : 3;
            auto mode_type = std::make_shared<DataTypeUInt8>();
            ColumnWithTypeAndName mode_col(mode_type->createColumnConst(1, mode), mode_type, getUniqueName(std::to_string(mode)));
            const auto & mode_node = actions_dag->addColumn(std::move(mode_col));
            parsed_args.emplace_back(&mode_node);
        }
    }
    else if (startsWith(function_signature, "sha2:"))
    {
        for (int i = 0; i < args.size() - 1; i++)
            parseFunctionArgument(actions_dag, parsed_args, function_name, args[i]);
    }
    else
    {
        // Default handle
        for (const auto & arg : args)
            parseFunctionArgument(actions_dag, parsed_args, function_name, arg);
    }
}