ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG()

in cpp-ch/local-engine/Parser/SerializedPlanParser.cpp [701:848]


ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG(
    const substrait::Expression & rel, std::vector<String> & result_names, DB::ActionsDAGPtr actions_dag, bool keep_result, bool position)
{
    if (!rel.has_scalar_function())
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression should be a scalar function:\n {}", rel.DebugString());

    const auto & scalar_function = rel.scalar_function();

    auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference()));
    auto function_name = getFunctionName(function_signature, scalar_function);
    if (function_name != "arrayJoin")
        throw Exception(
            ErrorCodes::LOGICAL_ERROR,
            "Function parseArrayJoinWithDAG should only process arrayJoin function, but input is {}",
            rel.ShortDebugString());

    /// The argument number of arrayJoin(converted from Spark explode/posexplode) should be 1
    if (scalar_function.arguments_size() != 1)
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument number of arrayJoin should be 1 but is {}", scalar_function.arguments_size());

    ActionsDAG::NodeRawConstPtrs args;
    parseFunctionArguments(actions_dag, args, function_name, scalar_function);

    auto arg_type = DB::removeNullable(args[0]->result_type);
    /// array() or map()
    const auto * empty_map_or_array_node
        = addColumn(actions_dag, DB::removeNullable(args[0]->result_type), isMap(arg_type) ? Field(Map()) : Field(Array()));
    /// ifNull(args[0], array() or map())
    const auto * if_null_node = toFunctionNode(actions_dag, "ifNull", {args[0], empty_map_or_array_node});
    /// assumeNotNull(ifNull(args[0], array() or map()))
    const auto * arg_not_null = toFunctionNode(actions_dag, "assumeNotNull", {if_null_node});
    /// Wrap with materalize function to make sure column input to ARRAY JOIN STEP is materaized
    arg_not_null = &actions_dag->materializeNode(*arg_not_null);

    /// arrayJoin(arg_not_null)
    /// Note: Make sure result_name keep the same after applying arrayJoin function, which makes it much easier to transform arrayJoin function to ARRAY JOIN STEP
    /// Otherwise an alias node must be appended after ARRAY JOIN STEP, which is not a graceful implementation.
    auto array_join_name = arg_not_null->result_name;
    const auto * array_join_node = &actions_dag->addArrayJoin(*arg_not_null, array_join_name);

    auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context);
    auto tuple_index_type = std::make_shared<DataTypeUInt32>();
    auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node *
    {
        ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i)));
        const auto * index_node = &actions_dag->addColumn(std::move(index_col));
        auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")";
        return &actions_dag->addFunction(tuple_element_builder, {tuple_node, index_node}, result_name);
    };

    /// Special process to keep compatiable with Spark
    WhichDataType which(arg_type.get());
    if (!position)
    {
        /// Spark: explode(array_or_map) -> CH: arrayJoin(array_or_map)
        if (which.isMap())
        {
            /// In Spark: explode(map(k, v)) output 2 columns with default names "key" and "value"
            /// In CH: arrayJoin(map(k, v)) output 1 column with Tuple Type.
            /// So we must wrap arrayJoin with sparkTupleElement function for compatiability.

            /// arrayJoin(arg_not_null).1
            const auto * key_node = add_tuple_element(array_join_node, 1);

            /// arrayJoin(arg_not_null).2
            const auto * val_node = add_tuple_element(array_join_node, 2);

            result_names.push_back(key_node->result_name);
            result_names.push_back(val_node->result_name);
            if (keep_result)
            {
                actions_dag->addOrReplaceInOutputs(*key_node);
                actions_dag->addOrReplaceInOutputs(*val_node);
            }
            return {key_node, val_node};
        }
        else if (which.isArray())
        {
            result_names.push_back(array_join_name);
            if (keep_result)
                actions_dag->addOrReplaceInOutputs(*array_join_node);
            return {array_join_node};
        }
        else
            throw Exception(
                ErrorCodes::BAD_ARGUMENTS,
                "Argument type of arrayJoin converted from explode should be Array or Map but is {}",
                arg_type->getName());
    }
    else
    {
        /// Spark: posexplode(array_or_map) -> CH: arrayJoin(map), in which map = mapFromArrays(range(length(array_or_map)), array_or_map)
        if (which.isMap())
        {
            /// In Spark: posexplode(array_of_map) output 2 or 3 columns: (pos, col) or (pos, key, value)
            /// In CH: arrayJoin(map(k, v)) output 1 column with Tuple Type.
            /// So we must wrap arrayJoin with sparkTupleElement function for compatiability.

            /// pos = arrayJoin(arg_not_null).1
            const auto * pos_node = add_tuple_element(array_join_node, 1);

            /// col = arrayJoin(arg_not_null).2 or (key, value) = arrayJoin(arg_not_null).2
            const auto * item_node = add_tuple_element(array_join_node, 2);

            /// It is a tricky but efficient way to get the original type of argument type in posexplode
            if (endsWith(args[0]->result_name, "type_hint:map"))
            {
                /// key = arrayJoin(arg_not_null).2.1
                const auto * item_key_node = add_tuple_element(item_node, 1);

                /// value = arrayJoin(arg_not_null).2.2
                const auto * item_value_node = add_tuple_element(item_node, 2);

                result_names.push_back(pos_node->result_name);
                result_names.push_back(item_key_node->result_name);
                result_names.push_back(item_value_node->result_name);
                if (keep_result)
                {
                    actions_dag->addOrReplaceInOutputs(*pos_node);
                    actions_dag->addOrReplaceInOutputs(*item_key_node);
                    actions_dag->addOrReplaceInOutputs(*item_value_node);
                }

                return {pos_node, item_key_node, item_value_node};
            }
            else if (endsWith(args[0]->result_name, "type_hint:array"))
            {
                /// col = arrayJoin(arg_not_null).2
                result_names.push_back(pos_node->result_name);
                result_names.push_back(item_node->result_name);
                if (keep_result)
                {
                    actions_dag->addOrReplaceInOutputs(*pos_node);
                    actions_dag->addOrReplaceInOutputs(*item_node);
                }
                return {pos_node, item_node};
            }
            else
                throw Exception(
                    ErrorCodes::BAD_ARGUMENTS, "The raw input of arrayJoin converted from posexplode should be Array or Map type");
        }
        else
            throw Exception(
                ErrorCodes::BAD_ARGUMENTS,
                "Argument type of arrayJoin converted from posexplode should be Map but is {}",
                arg_type->getName());
    }
}