in cpp-ch/local-engine/Parser/SerializedPlanParser.cpp [701:848]
ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG(
const substrait::Expression & rel, std::vector<String> & result_names, DB::ActionsDAGPtr actions_dag, bool keep_result, bool position)
{
if (!rel.has_scalar_function())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression should be a scalar function:\n {}", rel.DebugString());
const auto & scalar_function = rel.scalar_function();
auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference()));
auto function_name = getFunctionName(function_signature, scalar_function);
if (function_name != "arrayJoin")
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"Function parseArrayJoinWithDAG should only process arrayJoin function, but input is {}",
rel.ShortDebugString());
/// The argument number of arrayJoin(converted from Spark explode/posexplode) should be 1
if (scalar_function.arguments_size() != 1)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument number of arrayJoin should be 1 but is {}", scalar_function.arguments_size());
ActionsDAG::NodeRawConstPtrs args;
parseFunctionArguments(actions_dag, args, function_name, scalar_function);
auto arg_type = DB::removeNullable(args[0]->result_type);
/// array() or map()
const auto * empty_map_or_array_node
= addColumn(actions_dag, DB::removeNullable(args[0]->result_type), isMap(arg_type) ? Field(Map()) : Field(Array()));
/// ifNull(args[0], array() or map())
const auto * if_null_node = toFunctionNode(actions_dag, "ifNull", {args[0], empty_map_or_array_node});
/// assumeNotNull(ifNull(args[0], array() or map()))
const auto * arg_not_null = toFunctionNode(actions_dag, "assumeNotNull", {if_null_node});
/// Wrap with materalize function to make sure column input to ARRAY JOIN STEP is materaized
arg_not_null = &actions_dag->materializeNode(*arg_not_null);
/// arrayJoin(arg_not_null)
/// Note: Make sure result_name keep the same after applying arrayJoin function, which makes it much easier to transform arrayJoin function to ARRAY JOIN STEP
/// Otherwise an alias node must be appended after ARRAY JOIN STEP, which is not a graceful implementation.
auto array_join_name = arg_not_null->result_name;
const auto * array_join_node = &actions_dag->addArrayJoin(*arg_not_null, array_join_name);
auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context);
auto tuple_index_type = std::make_shared<DataTypeUInt32>();
auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node *
{
ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i)));
const auto * index_node = &actions_dag->addColumn(std::move(index_col));
auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")";
return &actions_dag->addFunction(tuple_element_builder, {tuple_node, index_node}, result_name);
};
/// Special process to keep compatiable with Spark
WhichDataType which(arg_type.get());
if (!position)
{
/// Spark: explode(array_or_map) -> CH: arrayJoin(array_or_map)
if (which.isMap())
{
/// In Spark: explode(map(k, v)) output 2 columns with default names "key" and "value"
/// In CH: arrayJoin(map(k, v)) output 1 column with Tuple Type.
/// So we must wrap arrayJoin with sparkTupleElement function for compatiability.
/// arrayJoin(arg_not_null).1
const auto * key_node = add_tuple_element(array_join_node, 1);
/// arrayJoin(arg_not_null).2
const auto * val_node = add_tuple_element(array_join_node, 2);
result_names.push_back(key_node->result_name);
result_names.push_back(val_node->result_name);
if (keep_result)
{
actions_dag->addOrReplaceInOutputs(*key_node);
actions_dag->addOrReplaceInOutputs(*val_node);
}
return {key_node, val_node};
}
else if (which.isArray())
{
result_names.push_back(array_join_name);
if (keep_result)
actions_dag->addOrReplaceInOutputs(*array_join_node);
return {array_join_node};
}
else
throw Exception(
ErrorCodes::BAD_ARGUMENTS,
"Argument type of arrayJoin converted from explode should be Array or Map but is {}",
arg_type->getName());
}
else
{
/// Spark: posexplode(array_or_map) -> CH: arrayJoin(map), in which map = mapFromArrays(range(length(array_or_map)), array_or_map)
if (which.isMap())
{
/// In Spark: posexplode(array_of_map) output 2 or 3 columns: (pos, col) or (pos, key, value)
/// In CH: arrayJoin(map(k, v)) output 1 column with Tuple Type.
/// So we must wrap arrayJoin with sparkTupleElement function for compatiability.
/// pos = arrayJoin(arg_not_null).1
const auto * pos_node = add_tuple_element(array_join_node, 1);
/// col = arrayJoin(arg_not_null).2 or (key, value) = arrayJoin(arg_not_null).2
const auto * item_node = add_tuple_element(array_join_node, 2);
/// It is a tricky but efficient way to get the original type of argument type in posexplode
if (endsWith(args[0]->result_name, "type_hint:map"))
{
/// key = arrayJoin(arg_not_null).2.1
const auto * item_key_node = add_tuple_element(item_node, 1);
/// value = arrayJoin(arg_not_null).2.2
const auto * item_value_node = add_tuple_element(item_node, 2);
result_names.push_back(pos_node->result_name);
result_names.push_back(item_key_node->result_name);
result_names.push_back(item_value_node->result_name);
if (keep_result)
{
actions_dag->addOrReplaceInOutputs(*pos_node);
actions_dag->addOrReplaceInOutputs(*item_key_node);
actions_dag->addOrReplaceInOutputs(*item_value_node);
}
return {pos_node, item_key_node, item_value_node};
}
else if (endsWith(args[0]->result_name, "type_hint:array"))
{
/// col = arrayJoin(arg_not_null).2
result_names.push_back(pos_node->result_name);
result_names.push_back(item_node->result_name);
if (keep_result)
{
actions_dag->addOrReplaceInOutputs(*pos_node);
actions_dag->addOrReplaceInOutputs(*item_node);
}
return {pos_node, item_node};
}
else
throw Exception(
ErrorCodes::BAD_ARGUMENTS, "The raw input of arrayJoin converted from posexplode should be Array or Map type");
}
else
throw Exception(
ErrorCodes::BAD_ARGUMENTS,
"Argument type of arrayJoin converted from posexplode should be Map but is {}",
arg_type->getName());
}
}