cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp (686 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "JoinRelParser.h"
#include <optional>
#include <Core/Block.h>
#include <Core/Settings.h>
#include <Functions/FunctionFactory.h>
#include <Interpreters/CollectJoinOnKeysVisitor.h>
#include <Interpreters/ExpressionActions.h>
#include <Interpreters/FullSortingMergeJoin.h>
#include <Interpreters/GraceHashJoin.h>
#include <Interpreters/HashJoin/HashJoin.h>
#include <Interpreters/TableJoin.h>
#include <Join/BroadCastJoinBuilder.h>
#include <Join/StorageJoinFromReadBuffer.h>
#include <Operator/EarlyStopStep.h>
#include <Parser/AdvancedParametersParseUtil.h>
#include <Parser/ExpressionParser.h>
#include <Parsers/ASTIdentifier.h>
#include <Parser/SubstraitParserUtils.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/FilterStep.h>
#include <Processors/QueryPlan/JoinStep.h>
#include <google/protobuf/wrappers.pb.h>
#include <Common/CHUtil.h>
#include <Common/GlutenConfig.h>
#include <Common/logger_useful.h>
namespace DB
{
namespace Setting
{
extern const SettingsJoinAlgorithm join_algorithm;
extern const SettingsUInt64 max_block_size;
extern const SettingsUInt64 min_joined_block_size_bytes;
extern const SettingsNonZeroUInt64 grace_hash_join_initial_buckets;
extern const SettingsNonZeroUInt64 grace_hash_join_max_buckets;
}
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int UNKNOWN_TYPE;
extern const int BAD_ARGUMENTS;
}
}
using namespace DB;
namespace local_engine
{
std::shared_ptr<DB::TableJoin> createDefaultTableJoin(substrait::JoinRel_JoinType join_type, const JoinOptimizationInfo & join_opt_info, ContextPtr & context)
{
auto table_join
= std::make_shared<TableJoin>(context->getSettingsRef(), context->getGlobalTemporaryVolume(), context->getTempDataOnDisk());
std::pair<DB::JoinKind, DB::JoinStrictness> kind_and_strictness = JoinUtil::getJoinKindAndStrictness(join_type, join_opt_info.is_existence_join);
table_join->setKind(kind_and_strictness.first);
if (!join_opt_info.is_any_join)
table_join->setStrictness(kind_and_strictness.second);
else
table_join->setStrictness(DB::JoinStrictness::Any);
return table_join;
}
JoinRelParser::JoinRelParser(ParserContextPtr parser_context_) : RelParser(parser_context_), context(parser_context_->queryContext())
{
}
DB::QueryPlanPtr
JoinRelParser::parse(DB::QueryPlanPtr /*query_plan*/, const substrait::Rel & /*rel*/, std::list<const substrait::Rel *> & /*rel_stack_*/)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "join node has 2 inputs, can't call parse().");
}
std::vector<const substrait::Rel *> JoinRelParser::getInputs(const substrait::Rel & rel)
{
const auto & join = rel.join();
if (!join.has_left() || !join.has_right())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "left table or right table is missing.");
return {&join.left(), &join.right()};
}
std::optional<const substrait::Rel *> JoinRelParser::getSingleInput(const substrait::Rel & /*rel*/)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "join node has 2 inputs, can't call getSingleInput().");
}
DB::QueryPlanPtr JoinRelParser::parse(
std::vector<DB::QueryPlanPtr> & input_plans_, const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack_)
{
assert(input_plans_.size() == 2);
const auto & join = rel.join();
return parseJoin(join, std::move(input_plans_[0]), std::move(input_plans_[1]));
}
std::unordered_set<DB::JoinTableSide> JoinRelParser::extractTableSidesFromExpression(
const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header)
{
std::unordered_set<DB::JoinTableSide> table_sides;
if (expr.has_scalar_function())
{
for (const auto & arg : expr.scalar_function().arguments())
{
auto table_sides_from_arg = extractTableSidesFromExpression(arg.value(), left_header, right_header);
table_sides.insert(table_sides_from_arg.begin(), table_sides_from_arg.end());
}
}
else if (auto field = SubstraitParserUtils::getStructFieldIndex(expr))
{
if (*field < left_header.columns())
table_sides.insert(DB::JoinTableSide::Left);
else
table_sides.insert(DB::JoinTableSide::Right);
}
else if (expr.has_singular_or_list())
{
auto child_table_sides = extractTableSidesFromExpression(expr.singular_or_list().value(), left_header, right_header);
table_sides.insert(child_table_sides.begin(), child_table_sides.end());
for (const auto & option : expr.singular_or_list().options())
{
child_table_sides = extractTableSidesFromExpression(option, left_header, right_header);
table_sides.insert(child_table_sides.begin(), child_table_sides.end());
}
}
else if (expr.has_cast())
{
auto child_table_sides = extractTableSidesFromExpression(expr.cast().input(), left_header, right_header);
table_sides.insert(child_table_sides.begin(), child_table_sides.end());
}
else if (expr.has_if_then())
{
for (const auto & if_child : expr.if_then().ifs())
{
auto child_table_sides = extractTableSidesFromExpression(if_child.if_(), left_header, right_header);
table_sides.insert(child_table_sides.begin(), child_table_sides.end());
child_table_sides = extractTableSidesFromExpression(if_child.then(), left_header, right_header);
table_sides.insert(child_table_sides.begin(), child_table_sides.end());
}
auto child_table_sides = extractTableSidesFromExpression(expr.if_then().else_(), left_header, right_header);
table_sides.insert(child_table_sides.begin(), child_table_sides.end());
}
else if (expr.has_literal())
{
// nothing
}
else
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Illegal expression '{}'", expr.DebugString());
}
return table_sides;
}
void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & right, const StorageJoinFromReadBuffer & storage_join)
{
/// To support mixed join conditions, we must make sure that the column names in the right be the same as
/// storage_join's right sample block.
ActionsDAG right_project = ActionsDAG::makeConvertingActions(
right.getCurrentHeader().getColumnsWithTypeAndName(),
storage_join.getRightSampleBlock().getColumnsWithTypeAndName(),
ActionsDAG::MatchColumnsMode::Position);
QueryPlanStepPtr right_project_step = std::make_unique<ExpressionStep>(right.getCurrentHeader(), std::move(right_project));
right_project_step->setStepDescription("Rename Broadcast Table Name");
steps.emplace_back(right_project_step.get());
right.addStep(std::move(right_project_step));
/// If the columns name in right table is duplicated with left table, we need to rename the left table's columns,
/// avoid the columns name in the right table be changed in `addConvertStep`.
/// This could happen in tpc-ds q44.
DB::ColumnsWithTypeAndName new_left_cols;
const auto & right_header = right.getCurrentHeader();
auto left_prefix = getUniqueName("left");
for (const auto & col : left.getCurrentHeader())
if (right_header.has(col.name))
new_left_cols.emplace_back(col.column, col.type, left_prefix + col.name);
else
new_left_cols.emplace_back(col.column, col.type, col.name);
ActionsDAG left_project = ActionsDAG::makeConvertingActions(
left.getCurrentHeader().getColumnsWithTypeAndName(), new_left_cols, ActionsDAG::MatchColumnsMode::Position);
QueryPlanStepPtr left_project_step = std::make_unique<ExpressionStep>(left.getCurrentHeader(), std::move(left_project));
left_project_step->setStepDescription("Rename Left Table Name for broadcast join");
steps.emplace_back(left_project_step.get());
left.addStep(std::move(left_project_step));
}
DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right)
{
auto join_config = JoinConfig::loadFromContext(getContext());
google::protobuf::StringValue optimization_info;
optimization_info.ParseFromString(join.advanced_extension().optimization().value());
auto join_opt_info = JoinOptimizationInfo::parse(optimization_info.value());
LOG_DEBUG(getLogger("JoinRelParser"), "optimization info:{}", optimization_info.value());
auto storage_join = join_opt_info.is_broadcast ? BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key) : nullptr;
if (storage_join)
renamePlanColumns(*left, *right, *storage_join);
auto table_join = createDefaultTableJoin(join.type(), join_opt_info, context);
DB::Block right_header_before_convert_step = right->getCurrentHeader();
addConvertStep(*table_join, *left, *right);
// Add a check to find error easily.
if (storage_join)
{
bool is_col_names_changed = false;
const auto & current_right_header = right->getCurrentHeader();
if (right_header_before_convert_step.columns() != current_right_header.columns())
is_col_names_changed = true;
if (!is_col_names_changed)
{
for (size_t i = 0; i < right_header_before_convert_step.columns(); i++)
{
if (right_header_before_convert_step.getByPosition(i).name != current_right_header.getByPosition(i).name)
{
is_col_names_changed = true;
break;
}
}
}
if (is_col_names_changed)
{
throw DB::Exception(
DB::ErrorCodes::LOGICAL_ERROR,
"For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}",
left->getCurrentHeader().dumpStructure(),
right_header_before_convert_step.dumpStructure(),
right->getCurrentHeader().dumpStructure());
}
}
Names after_join_names;
auto left_names = left->getCurrentHeader().getNames();
after_join_names.insert(after_join_names.end(), left_names.begin(), left_names.end());
auto right_name = table_join->columnsFromJoinedTable().getNames();
after_join_names.insert(after_join_names.end(), right_name.begin(), right_name.end());
auto left_header = left->getCurrentHeader();
auto right_header = right->getCurrentHeader();
QueryPlanPtr query_plan;
/// some examples to explain when the post_join_filter is not empty
/// - on t1.key = t2.key and t1.v1 > 1 and t2.v1 > 1, 't1.v1> 1' is in the post filter. but 't2.v1 > 1'
/// will be pushed down into right table by spark and is not in the post filter. 't1.key = t2.key ' is
/// in JoinRel::expression.
/// - on t1.key = t2. key and t1.v1 > t2.v2, 't1.v1 > t2.v2' is in the post filter.
collectJoinKeys(*table_join, join, left_header, right_header);
if (storage_join)
{
if (join_opt_info.is_null_aware_anti_join && join.type() == substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_ANTI)
{
if (storage_join->has_null_key_value)
{
// if there is a null key value on the build side, it will return the empty result
auto empty_step = std::make_unique<EarlyStopStep>(left->getCurrentHeader());
left->addStep(std::move(empty_step));
}
else if (!storage_join->is_empty_hash_table)
{
auto input_header = left->getCurrentHeader();
DB::ActionsDAG filter_is_not_null_dag{input_header.getColumnsWithTypeAndName()};
// when is_null_aware_anti_join is true, there is only one join key
auto field_index = SubstraitParserUtils::getStructFieldIndex(join.expression().scalar_function().arguments(0).value());
if (!field_index)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The join key is not found in the expression.");
const auto * key_field = filter_is_not_null_dag.getInputs()[*field_index];
auto result_node = filter_is_not_null_dag.tryFindInOutputs(key_field->result_name);
// add a function isNotNull to filter the null key on the left side
const auto * cond_node = buildFunctionNode(filter_is_not_null_dag, "isNotNull", {result_node});
filter_is_not_null_dag.addOrReplaceInOutputs(*cond_node);
auto filter_step = std::make_unique<FilterStep>(
left->getCurrentHeader(), std::move(filter_is_not_null_dag), cond_node->result_name, true);
left->addStep(std::move(filter_step));
}
// other case: is_empty_hash_table, don't need to handle
}
applyJoinFilter(*table_join, join, *left, *right, true);
auto broadcast_hash_join = storage_join->getJoinLocked(table_join, context);
QueryPlanStepPtr join_step = std::make_unique<FilledJoinStep>(left->getCurrentHeader(), broadcast_hash_join, 8192);
join_step->setStepDescription("STORAGE_JOIN");
steps.emplace_back(join_step.get());
left->addStep(std::move(join_step));
query_plan = std::move(left);
/// hold right plan for profile
extra_plan_holder.emplace_back(std::move(right));
}
else if (join_opt_info.is_smj)
{
bool need_post_filter = !applyJoinFilter(*table_join, join, *left, *right, false);
/// If applyJoinFilter returns false, it means there are mixed conditions in the post_join_filter.
/// It should be a inner join.
/// TODO: make smj support mixed conditions
if (need_post_filter && table_join->kind() != DB::JoinKind::Inner)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Sort merge join doesn't support mixed join conditions, except inner join.");
JoinPtr smj_join = std::make_shared<FullSortingMergeJoin>(table_join, right->getCurrentHeader().cloneEmpty(), -1);
MultiEnum<DB::JoinAlgorithm> join_algorithm = context->getSettingsRef()[Setting::join_algorithm];
QueryPlanStepPtr join_step = std::make_unique<DB::JoinStep>(
left->getCurrentHeader(),
right->getCurrentHeader(),
smj_join,
context->getSettingsRef()[Setting::max_block_size],
context->getSettingsRef()[Setting::min_joined_block_size_bytes],
1,
/* required_output_ = */ NameSet{},
false,
/* use_new_analyzer_ = */ false);
join_step->setStepDescription("SORT_MERGE_JOIN");
steps.emplace_back(join_step.get());
std::vector<QueryPlanPtr> plans;
plans.emplace_back(std::move(left));
plans.emplace_back(std::move(right));
query_plan = std::make_unique<QueryPlan>();
query_plan->unitePlans(std::move(join_step), {std::move(plans)});
if (need_post_filter)
addPostFilter(*query_plan, join);
}
else
{
std::vector<DB::TableJoin::JoinOnClause> join_on_clauses;
if (table_join->getClauses().empty())
table_join->addDisjunct();
bool is_multi_join_on_clauses
= couldRewriteToMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header);
if (is_multi_join_on_clauses && join_config.prefer_multi_join_on_clauses && join_opt_info.right_table_rows > 0
&& join_opt_info.partitions_num > 0
&& join_opt_info.right_table_rows / join_opt_info.partitions_num < join_config.multi_join_on_clauses_build_side_rows_limit)
{
query_plan = buildMultiOnClauseHashJoin(table_join, std::move(left), std::move(right), join_on_clauses);
}
else
{
query_plan = buildSingleOnClauseHashJoin(join, table_join, std::move(left), std::move(right));
}
}
JoinUtil::reorderJoinOutput(*query_plan, after_join_names);
/// Need to project the right table column into boolean type
if (join_opt_info.is_existence_join)
existenceJoinPostProject(*query_plan, left_names);
return query_plan;
}
/// We use left any join to implement ExistenceJoin.
/// The result columns of ExistenceJoin are left table columns + one flag column.
/// The flag column indicates whether a left row is matched or not. We build the flag column here.
/// The input plan's header is left table columns + right table columns. If one row in the right row is null,
/// we mark the flag 0, otherwise mark it 1.
void JoinRelParser::existenceJoinPostProject(DB::QueryPlan & plan, const DB::Names & left_input_cols)
{
DB::ActionsDAG actions_dag{plan.getCurrentHeader().getColumnsWithTypeAndName()};
const auto * right_col_node = actions_dag.getInputs().back();
auto function_builder = DB::FunctionFactory::instance().get("isNotNull", getContext());
const auto * not_null_node = &actions_dag.addFunction(function_builder, {right_col_node}, right_col_node->result_name);
actions_dag.addOrReplaceInOutputs(*not_null_node);
DB::Names required_cols = left_input_cols;
required_cols.emplace_back(not_null_node->result_name);
actions_dag.removeUnusedActions(required_cols);
auto project_step = std::make_unique<DB::ExpressionStep>(plan.getCurrentHeader(), std::move(actions_dag));
project_step->setStepDescription("ExistenceJoin Post Project");
steps.emplace_back(project_step.get());
plan.addStep(std::move(project_step));
}
void JoinRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right)
{
/// If the columns name in right table is duplicated with left table, we need to rename the right table's columns.
NameSet left_columns_set;
for (const auto & col : left.getCurrentHeader().getNames())
left_columns_set.emplace(col);
table_join.setColumnsFromJoinedTable(
right.getCurrentHeader().getNamesAndTypesList(),
left_columns_set,
getUniqueName("right") + ".",
left.getCurrentHeader().getNamesAndTypesList());
// fix right table key duplicate
NamesWithAliases right_table_alias;
for (size_t idx = 0; idx < table_join.columnsFromJoinedTable().size(); idx++)
{
auto origin_name = right.getCurrentHeader().getByPosition(idx).name;
auto dedup_name = table_join.columnsFromJoinedTable().getNames().at(idx);
if (origin_name != dedup_name)
right_table_alias.emplace_back(NameWithAlias(origin_name, dedup_name));
}
if (!right_table_alias.empty())
{
ActionsDAG rename_dag{right.getCurrentHeader().getNamesAndTypesList()};
auto original_right_columns = right.getCurrentHeader();
for (const auto & column_alias : right_table_alias)
{
if (original_right_columns.has(column_alias.first))
{
auto pos = original_right_columns.getPositionByName(column_alias.first);
const auto & alias = rename_dag.addAlias(*rename_dag.getInputs()[pos], column_alias.second);
rename_dag.getOutputs()[pos] = &alias;
}
}
QueryPlanStepPtr project_step = std::make_unique<ExpressionStep>(right.getCurrentHeader(), std::move(rename_dag));
project_step->setStepDescription("Right Table Rename");
steps.emplace_back(project_step.get());
right.addStep(std::move(project_step));
}
for (const auto & column : table_join.columnsFromJoinedTable())
table_join.addJoinedColumn(column);
std::optional<ActionsDAG> left_convert_actions;
std::optional<ActionsDAG> right_convert_actions;
std::tie(left_convert_actions, right_convert_actions) = table_join.createConvertingActions(
left.getCurrentHeader().getColumnsWithTypeAndName(), right.getCurrentHeader().getColumnsWithTypeAndName());
if (right_convert_actions)
{
auto converting_step = std::make_unique<ExpressionStep>(right.getCurrentHeader(), std::move(*right_convert_actions));
converting_step->setStepDescription("Convert joined columns");
steps.emplace_back(converting_step.get());
right.addStep(std::move(converting_step));
}
if (left_convert_actions)
{
auto converting_step = std::make_unique<ExpressionStep>(left.getCurrentHeader(), std::move(*left_convert_actions));
converting_step->setStepDescription("Convert joined columns");
steps.emplace_back(converting_step.get());
left.addStep(std::move(converting_step));
}
}
/// Join keys are collected from substrait::JoinRel::expression() which only contains the equal join conditions.
void JoinRelParser::collectJoinKeys(
TableJoin & table_join, const substrait::JoinRel & join_rel, const DB::Block & left_header, const DB::Block & right_header)
{
if (!join_rel.has_expression())
return;
/// Support only one join clause.
table_join.addDisjunct();
const auto & expr = join_rel.expression();
auto & join_clause = table_join.getClauses().back();
std::list<const substrait::Expression *> expressions_stack;
expressions_stack.push_back(&expr);
while (!expressions_stack.empty())
{
/// Must handle the expressions in DF order. It matters in sort merge join.
const auto * current_expr = expressions_stack.back();
expressions_stack.pop_back();
if (!current_expr->has_scalar_function())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Function expression is expected");
auto function_name = parseFunctionName(current_expr->scalar_function());
if (!function_name)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid function expression");
if (*function_name == "equals")
{
String left_key, right_key;
size_t left_pos = 0, right_pos = 0;
for (const auto & arg : current_expr->scalar_function().arguments())
{
auto field_index = SubstraitParserUtils::getStructFieldIndex(arg.value());
if (!field_index)
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A column reference is expected");
}
auto col_pos_ref = *field_index;
if (col_pos_ref < left_header.columns())
{
left_pos = col_pos_ref;
left_key = left_header.getByPosition(col_pos_ref).name;
}
else
{
right_pos = col_pos_ref - left_header.columns();
right_key = right_header.getByPosition(col_pos_ref - left_header.columns()).name;
}
}
if (left_key.empty() || right_key.empty())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid key equal join condition");
join_clause.addKey(left_key, right_key, false);
}
else if (*function_name == "and")
{
expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(1).value());
expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(0).value());
}
else
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow function: {}", *function_name);
}
}
}
bool JoinRelParser::applyJoinFilter(
DB::TableJoin & table_join,
const substrait::JoinRel & join_rel,
DB::QueryPlan & left,
DB::QueryPlan & right,
bool allow_mixed_condition)
{
if (!join_rel.has_post_join_filter())
return true;
const auto & expr = join_rel.post_join_filter();
const auto & left_header = left.getCurrentHeader();
const auto & right_header = right.getCurrentHeader();
ColumnsWithTypeAndName mixed_columns;
std::unordered_set<String> added_column_name;
for (const auto & col : left_header.getColumnsWithTypeAndName())
{
mixed_columns.emplace_back(col);
added_column_name.insert(col.name);
}
for (const auto & col : right_header.getColumnsWithTypeAndName())
{
const auto & renamed_col_name = table_join.renamedRightColumnNameWithAlias(col.name);
if (added_column_name.find(col.name) != added_column_name.end())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Right column's name conflict with left column: {}", col.name);
mixed_columns.emplace_back(col);
added_column_name.insert(col.name);
}
DB::Block mixed_header(mixed_columns);
auto table_sides = extractTableSidesFromExpression(expr, left_header, right_header);
auto get_input_expressions = [](const DB::Block & header)
{
std::vector<substrait::Expression> exprs;
for (size_t i = 0; i < header.columns(); ++i)
{
substrait::Expression expr = SubstraitParserUtils::buildStructFieldExpression(i);
exprs.emplace_back(expr);
}
return exprs;
};
/// If the columns in the expression are all from one table, use analyzer_left_filter_condition_column_name
/// and analyzer_left_filter_condition_column_name to filt the join result data. It requires to build the filter
/// column at first.
/// If the columns in the expression are from both tables, use mixed_join_expression to filt the join result data.
/// the filter columns will be built inner the join step.
if (table_sides.size() == 1)
{
auto table_side = *table_sides.begin();
if (table_side == DB::JoinTableSide::Left)
{
auto input_exprs = get_input_expressions(left_header);
input_exprs.push_back(expr);
auto actions_dag = expressionsToActionsDAG(input_exprs, left_header);
table_join.getClauses().back().analyzer_left_filter_condition_column_name = actions_dag.getOutputs().back()->result_name;
QueryPlanStepPtr before_join_step = std::make_unique<ExpressionStep>(left.getCurrentHeader(), std::move(actions_dag));
before_join_step->setStepDescription("Before JOIN LEFT");
steps.emplace_back(before_join_step.get());
left.addStep(std::move(before_join_step));
}
else
{
/// since the field reference in expr is the index of left_header ++ right_header, so we use
/// mixed_header to build the actions_dag
auto input_exprs = get_input_expressions(mixed_header);
input_exprs.push_back(expr);
auto actions_dag = expressionsToActionsDAG(input_exprs, mixed_header);
/// clear unused columns in actions_dag
for (const auto & col : left_header.getColumnsWithTypeAndName())
actions_dag.removeUnusedResult(col.name);
actions_dag.removeUnusedActions();
table_join.getClauses().back().analyzer_right_filter_condition_column_name = actions_dag.getOutputs().back()->result_name;
QueryPlanStepPtr before_join_step = std::make_unique<ExpressionStep>(right.getCurrentHeader(), std::move(actions_dag));
before_join_step->setStepDescription("Before JOIN RIGHT");
steps.emplace_back(before_join_step.get());
right.addStep(std::move(before_join_step));
}
}
else if (table_sides.size() == 2)
{
if (!allow_mixed_condition)
return false;
auto mixed_join_expressions_actions = expressionsToActionsDAG({expr}, mixed_header);
mixed_join_expressions_actions.removeUnusedActions();
table_join.getMixedJoinExpression()
= std::make_shared<DB::ExpressionActions>(std::move(mixed_join_expressions_actions), ExpressionActionsSettings(context));
}
else
{
throw DB::Exception(
DB::ErrorCodes::LOGICAL_ERROR, "Not any table column is used in the join condition.\n{}", join_rel.DebugString());
}
return true;
}
void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::JoinRel & join)
{
std::string filter_name;
ActionsDAG actions_dag{query_plan.getCurrentHeader().getColumnsWithTypeAndName()};
if (!join.post_join_filter().has_scalar_function())
{
// It may be singular_or_list
const auto * in_node = expression_parser->parseExpression(actions_dag, join.post_join_filter());
filter_name = in_node->result_name;
}
else
{
const auto * func_node = expression_parser->parseFunction(join.post_join_filter().scalar_function(), actions_dag, true);
filter_name = func_node->result_name;
}
auto filter_step = std::make_unique<FilterStep>(query_plan.getCurrentHeader(), std::move(actions_dag), filter_name, true);
filter_step->setStepDescription("Post Join Filter");
steps.emplace_back(filter_step.get());
query_plan.addStep(std::move(filter_step));
}
/// Only support following pattern: a1 = b1 or a2 = b2 or (a3 = b3 and a4 = b4)
bool JoinRelParser::couldRewriteToMultiJoinOnClauses(
const DB::TableJoin::JoinOnClause & prefix_clause,
std::vector<DB::TableJoin::JoinOnClause> & clauses,
const substrait::JoinRel & join_rel,
const DB::Block & left_header,
const DB::Block & right_header)
{
if (!join_rel.has_post_join_filter())
return false;
const auto & filter_expr = join_rel.post_join_filter();
auto check_function = [&](const String function_name_, const substrait::Expression & e)
{
if (!e.has_scalar_function())
return false;
auto function_name = parseFunctionName(e.scalar_function());
return function_name.has_value() && *function_name == function_name_;
};
std::function<void(std::vector<const substrait::Expression *> &, const substrait::Expression &)> dfs_visit_or_expr
= [&](std::vector<const substrait::Expression *> & or_exprs, const substrait::Expression & e) -> void
{
if (!check_function("or", e))
{
or_exprs.push_back(&e);
return;
}
const auto & args = e.scalar_function().arguments();
dfs_visit_or_expr(or_exprs, args[0].value());
dfs_visit_or_expr(or_exprs, args[1].value());
};
std::function<void(std::vector<const substrait::Expression *> &, const substrait::Expression &)> dfs_visit_and_expr
= [&](std::vector<const substrait::Expression *> & and_exprs, const substrait::Expression & e) -> void
{
if (!check_function("and", e))
{
and_exprs.push_back(&e);
return;
}
const auto & args = e.scalar_function().arguments();
dfs_visit_and_expr(and_exprs, args[0].value());
dfs_visit_and_expr(and_exprs, args[1].value());
};
auto visit_equal_expr = [&](const substrait::Expression & e) -> std::optional<std::pair<String, String>>
{
if (!check_function("equals", e))
return {};
const auto & args = e.scalar_function().arguments();
auto l_field_ref = SubstraitParserUtils::getStructFieldIndex(args[0].value());
auto r_field_ref = SubstraitParserUtils::getStructFieldIndex(args[1].value());
if (!l_field_ref.has_value() || !r_field_ref.has_value())
return {};
size_t l_pos = *l_field_ref;
size_t r_pos = *r_field_ref;
size_t l_cols = left_header.columns();
size_t total_cols = l_cols + right_header.columns();
if (l_pos < l_cols && r_pos >= l_cols && r_pos < total_cols)
return std::make_pair(left_header.getByPosition(l_pos).name, right_header.getByPosition(r_pos - l_cols).name);
else if (r_pos < l_cols && l_pos >= l_cols && l_pos < total_cols)
return std::make_pair(left_header.getByPosition(r_pos).name, right_header.getByPosition(l_pos - l_cols).name);
return {};
};
std::vector<const substrait::Expression *> or_exprs;
dfs_visit_or_expr(or_exprs, filter_expr);
if (or_exprs.empty())
return false;
for (const auto * or_expr : or_exprs)
{
DB::TableJoin::JoinOnClause new_clause = prefix_clause;
clauses.push_back(new_clause);
auto & current_clause = clauses.back();
std::vector<const substrait::Expression *> and_exprs;
dfs_visit_and_expr(and_exprs, *or_expr);
for (const auto * and_expr : and_exprs)
{
auto join_keys = visit_equal_expr(*and_expr);
if (!join_keys)
return false;
current_clause.addKey(join_keys->first, join_keys->second, false);
}
}
return true;
}
DB::QueryPlanPtr JoinRelParser::buildMultiOnClauseHashJoin(
std::shared_ptr<DB::TableJoin> table_join,
DB::QueryPlanPtr left_plan,
DB::QueryPlanPtr right_plan,
const std::vector<DB::TableJoin::JoinOnClause> & join_on_clauses)
{
DB::TableJoin::JoinOnClause & base_join_on_clause = table_join->getOnlyClause();
base_join_on_clause = join_on_clauses[0];
for (size_t i = 1; i < join_on_clauses.size(); ++i)
{
table_join->addDisjunct();
auto & join_on_clause = table_join->getClauses().back();
join_on_clause = join_on_clauses[i];
}
LOG_INFO(getLogger("JoinRelParser"), "multi join on clauses:\n{}", DB::TableJoin::formatClauses(table_join->getClauses()));
JoinPtr hash_join = std::make_shared<HashJoin>(table_join, right_plan->getCurrentHeader());
QueryPlanStepPtr join_step = std::make_unique<DB::JoinStep>(
left_plan->getCurrentHeader(),
right_plan->getCurrentHeader(),
hash_join,
context->getSettingsRef()[Setting::max_block_size],
context->getSettingsRef()[Setting::min_joined_block_size_bytes],
1,
/* required_output_ = */ NameSet{},
false,
/* use_new_analyzer_ = */ false);
join_step->setStepDescription("Multi join on clause hash join");
steps.emplace_back(join_step.get());
std::vector<QueryPlanPtr> plans;
plans.emplace_back(std::move(left_plan));
plans.emplace_back(std::move(right_plan));
auto query_plan = std::make_unique<QueryPlan>();
query_plan->unitePlans(std::move(join_step), {std::move(plans)});
return query_plan;
}
DB::QueryPlanPtr JoinRelParser::buildSingleOnClauseHashJoin(
const substrait::JoinRel & join_rel, std::shared_ptr<DB::TableJoin> table_join, DB::QueryPlanPtr left_plan, DB::QueryPlanPtr right_plan)
{
applyJoinFilter(*table_join, join_rel, *left_plan, *right_plan, true);
/// Following is some configurations for grace hash join.
/// - spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm=grace_hash. This will
/// enable grace hash join.
/// - spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join=3145728. This setup
/// the memory limitation fro grace hash join. If the memory consumption exceeds the limitation,
/// data will be spilled to disk. Don't set the limitation too small, otherwise the buckets number
/// will be too large and the performance will be bad.
JoinPtr hash_join = nullptr;
MultiEnum<DB::JoinAlgorithm> join_algorithm = context->getSettingsRef()[Setting::join_algorithm];
if (join_algorithm.isSet(DB::JoinAlgorithm::GRACE_HASH))
{
hash_join = std::make_shared<GraceHashJoin>(
context->getSettingsRef()[Setting::grace_hash_join_initial_buckets],
context->getSettingsRef()[Setting::grace_hash_join_max_buckets],
table_join, left_plan->getCurrentHeader(), right_plan->getCurrentHeader(), context->getTempDataOnDisk());
}
else
{
hash_join = std::make_shared<HashJoin>(table_join, right_plan->getCurrentHeader().cloneEmpty());
}
QueryPlanStepPtr join_step = std::make_unique<DB::JoinStep>(
left_plan->getCurrentHeader(),
right_plan->getCurrentHeader(),
hash_join,
context->getSettingsRef()[Setting::max_block_size],
context->getSettingsRef()[Setting::min_joined_block_size_bytes],
1,
/* required_output_ = */ NameSet{},
false,
/* use_new_analyzer_ = */ false);
join_step->setStepDescription("HASH_JOIN");
steps.emplace_back(join_step.get());
std::vector<QueryPlanPtr> plans;
plans.emplace_back(std::move(left_plan));
plans.emplace_back(std::move(right_plan));
auto query_plan = std::make_unique<QueryPlan>();
query_plan->unitePlans(std::move(join_step), {std::move(plans)});
return query_plan;
}
void registerJoinRelParser(RelParserFactory & factory)
{
auto builder = [](ParserContextPtr parser_context) { return std::make_shared<JoinRelParser>(parser_context); };
factory.registerBuilder(substrait::Rel::RelTypeCase::kJoin, builder);
}
}