DB::QueryPlanPtr JoinRelParser::parseJoin()

in cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp [201:364]


DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right)
{
    auto join_config = JoinConfig::loadFromContext(getContext());
    google::protobuf::StringValue optimization_info;
    optimization_info.ParseFromString(join.advanced_extension().optimization().value());
    auto join_opt_info = JoinOptimizationInfo::parse(optimization_info.value());
    LOG_DEBUG(getLogger("JoinRelParser"), "optimization info:{}", optimization_info.value());
    auto storage_join = join_opt_info.is_broadcast ? BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key) : nullptr;
    if (storage_join)
        renamePlanColumns(*left, *right, *storage_join);

    auto table_join = createDefaultTableJoin(join.type(), join_opt_info, context);
    DB::Block right_header_before_convert_step = right->getCurrentHeader();
    addConvertStep(*table_join, *left, *right);

    // Add a check to find error easily.
    if (storage_join)
    {
        bool is_col_names_changed = false;
        const auto & current_right_header = right->getCurrentHeader();
        if (right_header_before_convert_step.columns() != current_right_header.columns())
            is_col_names_changed = true;
        if (!is_col_names_changed)
        {
            for (size_t i = 0; i < right_header_before_convert_step.columns(); i++)
            {
                if (right_header_before_convert_step.getByPosition(i).name != current_right_header.getByPosition(i).name)
                {
                    is_col_names_changed = true;
                    break;
                }
            }
        }
        if (is_col_names_changed)
        {
            throw DB::Exception(
                DB::ErrorCodes::LOGICAL_ERROR,
                "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}",
                left->getCurrentHeader().dumpStructure(),
                right_header_before_convert_step.dumpStructure(),
                right->getCurrentHeader().dumpStructure());
        }
    }

    Names after_join_names;
    auto left_names = left->getCurrentHeader().getNames();
    after_join_names.insert(after_join_names.end(), left_names.begin(), left_names.end());
    auto right_name = table_join->columnsFromJoinedTable().getNames();
    after_join_names.insert(after_join_names.end(), right_name.begin(), right_name.end());

    auto left_header = left->getCurrentHeader();
    auto right_header = right->getCurrentHeader();

    QueryPlanPtr query_plan;

    /// some examples to explain when the post_join_filter is not empty
    /// - on t1.key = t2.key and t1.v1 > 1 and t2.v1 > 1, 't1.v1> 1' is in the  post filter. but 't2.v1 > 1'
    ///   will be pushed down into right table by spark and is not in the post filter. 't1.key = t2.key ' is
    ///   in JoinRel::expression.
    /// - on t1.key = t2. key and t1.v1 > t2.v2, 't1.v1 > t2.v2' is in the post filter.
    collectJoinKeys(*table_join, join, left_header, right_header);

    if (storage_join)
    {
        if (join_opt_info.is_null_aware_anti_join && join.type() == substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_ANTI)
        {
            if (storage_join->has_null_key_value)
            {
                // if there is a null key value on the build side, it will return the empty result
                auto empty_step = std::make_unique<EarlyStopStep>(left->getCurrentHeader());
                left->addStep(std::move(empty_step));
            }
            else if (!storage_join->is_empty_hash_table)
            {
                auto input_header = left->getCurrentHeader();
                DB::ActionsDAG filter_is_not_null_dag{input_header.getColumnsWithTypeAndName()};
                // when is_null_aware_anti_join is true, there is only one join key
                auto field_index = SubstraitParserUtils::getStructFieldIndex(join.expression().scalar_function().arguments(0).value());
                if (!field_index)
                    throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The join key is not found in the expression.");
                const auto * key_field = filter_is_not_null_dag.getInputs()[*field_index];

                auto result_node = filter_is_not_null_dag.tryFindInOutputs(key_field->result_name);
                // add a function isNotNull to filter the null key on the left side
                const auto * cond_node = buildFunctionNode(filter_is_not_null_dag, "isNotNull", {result_node});
                filter_is_not_null_dag.addOrReplaceInOutputs(*cond_node);
                auto filter_step = std::make_unique<FilterStep>(
                    left->getCurrentHeader(), std::move(filter_is_not_null_dag), cond_node->result_name, true);
                left->addStep(std::move(filter_step));
            }
            // other case: is_empty_hash_table, don't need to handle
        }
        applyJoinFilter(*table_join, join, *left, *right, true);
        auto broadcast_hash_join = storage_join->getJoinLocked(table_join, context);

        QueryPlanStepPtr join_step = std::make_unique<FilledJoinStep>(left->getCurrentHeader(), broadcast_hash_join, 8192);

        join_step->setStepDescription("STORAGE_JOIN");
        steps.emplace_back(join_step.get());
        left->addStep(std::move(join_step));
        query_plan = std::move(left);
        /// hold right plan for profile
        extra_plan_holder.emplace_back(std::move(right));
    }
    else if (join_opt_info.is_smj)
    {
        bool need_post_filter = !applyJoinFilter(*table_join, join, *left, *right, false);

        /// If applyJoinFilter returns false, it means there are mixed conditions in the post_join_filter.
        /// It should be a inner join.
        /// TODO: make smj support mixed conditions
        if (need_post_filter && table_join->kind() != DB::JoinKind::Inner)
            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Sort merge join doesn't support mixed join conditions, except inner join.");

        JoinPtr smj_join = std::make_shared<FullSortingMergeJoin>(table_join, right->getCurrentHeader().cloneEmpty(), -1);
        MultiEnum<DB::JoinAlgorithm> join_algorithm = context->getSettingsRef()[Setting::join_algorithm];
        QueryPlanStepPtr join_step = std::make_unique<DB::JoinStep>(
            left->getCurrentHeader(),
            right->getCurrentHeader(),
            smj_join,
            context->getSettingsRef()[Setting::max_block_size],
            context->getSettingsRef()[Setting::min_joined_block_size_bytes],
            1,
            /* required_output_ = */ NameSet{},
            false,
            /* use_new_analyzer_ = */ false);

        join_step->setStepDescription("SORT_MERGE_JOIN");
        steps.emplace_back(join_step.get());
        std::vector<QueryPlanPtr> plans;
        plans.emplace_back(std::move(left));
        plans.emplace_back(std::move(right));

        query_plan = std::make_unique<QueryPlan>();
        query_plan->unitePlans(std::move(join_step), {std::move(plans)});
        if (need_post_filter)
            addPostFilter(*query_plan, join);
    }
    else
    {
        std::vector<DB::TableJoin::JoinOnClause> join_on_clauses;
        if (table_join->getClauses().empty())
            table_join->addDisjunct();
        bool is_multi_join_on_clauses
            = couldRewriteToMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header);
        if (is_multi_join_on_clauses && join_config.prefer_multi_join_on_clauses && join_opt_info.right_table_rows > 0
            && join_opt_info.partitions_num > 0
            && join_opt_info.right_table_rows / join_opt_info.partitions_num < join_config.multi_join_on_clauses_build_side_rows_limit)
        {
            query_plan = buildMultiOnClauseHashJoin(table_join, std::move(left), std::move(right), join_on_clauses);
        }
        else
        {
            query_plan = buildSingleOnClauseHashJoin(join, table_join, std::move(left), std::move(right));
        }
    }

    JoinUtil::reorderJoinOutput(*query_plan, after_join_names);
    /// Need to project the right table column into boolean type
    if (join_opt_info.is_existence_join)
        existenceJoinPostProject(*query_plan, left_names);

    return query_plan;
}