in cpp-ch/local-engine/Parser/SerializedPlanParser.cpp [1051:1192]
void SerializedPlanParser::parseFunctionArguments(
DB::ActionsDAGPtr & actions_dag,
ActionsDAG::NodeRawConstPtrs & parsed_args,
std::string & function_name,
const substrait::Expression_ScalarFunction & scalar_function)
{
auto function_signature = function_mapping.at(std::to_string(scalar_function.function_reference()));
const auto & args = scalar_function.arguments();
parsed_args.reserve(args.size());
// Some functions need to be handled specially.
if (function_name == "JSONExtract")
{
parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
auto data_type = TypeParser::parseType(scalar_function.output_type());
parsed_args.emplace_back(addColumn(actions_dag, std::make_shared<DB::DataTypeString>(), data_type->getName()));
}
else if (function_name == "sparkTupleElement" || function_name == "tupleElement")
{
parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
if (!args[1].value().has_literal())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's second argument must be a literal");
auto [data_type, field] = parseLiteral(args[1].value().literal());
if (data_type->getTypeId() != DB::TypeIndex::Int32)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's second argument must be i32");
// tuple indecies start from 1, in spark, start from 0
Int32 field_index = static_cast<Int32>(field.get<Int32>() + 1);
const auto * index_node = addColumn(actions_dag, std::make_shared<DB::DataTypeUInt32>(), field_index);
parsed_args.emplace_back(index_node);
}
else if (function_name == "tuple")
{
// Arguments in the format, (<field name>, <value expression>[, <field name>, <value expression> ...])
// We don't need to care the field names here.
for (int index = 1; index < args.size(); index += 2)
parseFunctionArgument(actions_dag, parsed_args, function_name, args[index]);
}
else if (function_name == "repeat")
{
// repeat. the field index must be unsigned integer in CH, cast the signed integer in substrait
// which must be a positive value into unsigned integer here.
parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
const DB::ActionsDAG::Node * repeat_times_node = parseFunctionArgument(actions_dag, function_name, args[1]);
DB::DataTypeNullable target_type(std::make_shared<DB::DataTypeUInt32>());
repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag, repeat_times_node, target_type.getName());
parsed_args.emplace_back(repeat_times_node);
}
else if (function_name == "isNaN")
{
// the result of isNaN(NULL) is NULL in CH, but false in Spark
const DB::ActionsDAG::Node * arg_node = nullptr;
if (args[0].value().has_cast())
{
arg_node = parseExpression(actions_dag, args[0].value().cast().input());
const auto * res_type = arg_node->result_type.get();
if (res_type->isNullable())
{
res_type = typeid_cast<const DB::DataTypeNullable *>(res_type)->getNestedType().get();
}
if (isString(*res_type))
{
DB::ActionsDAG::NodeRawConstPtrs cast_func_args = {arg_node};
arg_node = toFunctionNode(actions_dag, "toFloat64OrZero", cast_func_args);
}
else
{
arg_node = parseFunctionArgument(actions_dag, function_name, args[0]);
}
}
else
{
arg_node = parseFunctionArgument(actions_dag, function_name, args[0]);
}
DB::ActionsDAG::NodeRawConstPtrs ifnull_func_args = {arg_node, addColumn(actions_dag, std::make_shared<DataTypeInt32>(), 0)};
parsed_args.emplace_back(toFunctionNode(actions_dag, "IfNull", ifnull_func_args));
}
else if (function_name == "positionUTF8Spark")
{
if (args.size() >= 2)
{
// In Spark: position(substr, str, Int32)
// In CH: position(str, subtr, UInt32)
parseFunctionArgument(actions_dag, parsed_args, function_name, args[1]);
parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
}
if (args.size() >= 3)
{
// add cast: cast(start_pos as UInt32)
const auto * start_pos_node = parseFunctionArgument(actions_dag, function_name, args[2]);
DB::DataTypeNullable target_type(std::make_shared<DB::DataTypeUInt32>());
start_pos_node = ActionsDAGUtil::convertNodeType(actions_dag, start_pos_node, target_type.getName());
parsed_args.emplace_back(start_pos_node);
}
}
else if (function_name == "space")
{
// convert space function to repeat
const DB::ActionsDAG::Node * repeat_times_node = parseFunctionArgument(actions_dag, "repeat", args[0]);
const DB::ActionsDAG::Node * space_str_node = addColumn(actions_dag, std::make_shared<DataTypeString>(), " ");
function_name = "repeat";
parsed_args.emplace_back(space_str_node);
parsed_args.emplace_back(repeat_times_node);
}
else if (function_name == "trimBothSpark" || function_name == "trimLeftSpark" || function_name == "trimRightSpark")
{
/// In substrait, the first arg is srcStr, the second arg is trimStr
/// But in CH, the first arg is trimStr, the second arg is srcStr
parseFunctionArgument(actions_dag, parsed_args, function_name, args[1]);
parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
}
else if (startsWith(function_signature, "extract:"))
{
/// Skip the first arg of extract in substrait
for (int i = 1; i < args.size(); i++)
parseFunctionArgument(actions_dag, parsed_args, function_name, args[i]);
/// Append extra mode argument for extract(WEEK_DAY from date) or extract(DAY_OF_WEEK from date) in substrait
if (function_name == "toDayOfWeek" || function_name == "DAYOFWEEK")
{
UInt8 mode = function_name == "toDayOfWeek" ? 1 : 3;
auto mode_type = std::make_shared<DataTypeUInt8>();
ColumnWithTypeAndName mode_col(mode_type->createColumnConst(1, mode), mode_type, getUniqueName(std::to_string(mode)));
const auto & mode_node = actions_dag->addColumn(std::move(mode_col));
parsed_args.emplace_back(&mode_node);
}
}
else if (startsWith(function_signature, "sha2:"))
{
for (int i = 0; i < args.size() - 1; i++)
parseFunctionArgument(actions_dag, parsed_args, function_name, args[i]);
}
else
{
// Default handle
for (const auto & arg : args)
parseFunctionArgument(actions_dag, parsed_args, function_name, arg);
}
}