cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp (302 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <google/protobuf/wrappers.pb.h>
#include <Parser/FunctionParser.h>
#include <Parser/TypeParser.h>
#include <Storages/StorageMergeTreeFactory.h>
#include <Common/CHUtil.h>
#include <Common/MergeTreeTool.h>
#include "MergeTreeRelParser.h"
namespace DB
{
namespace ErrorCodes
{
extern const int NO_SUCH_DATA_PART;
extern const int LOGICAL_ERROR;
extern const int UNKNOWN_FUNCTION;
extern const int UNKNOWN_TYPE;
}
}
namespace local_engine
{
using namespace DB;
/// Find minimal position of any of the column in primary key.
static Int64 findMinPosition(const NameSet & condition_table_columns, const NameToIndexMap & primary_key_positions)
{
Int64 min_position = std::numeric_limits<Int64>::max() - 1;
for (const auto & column : condition_table_columns)
{
auto it = primary_key_positions.find(column);
if (it != primary_key_positions.end())
min_position = std::min(min_position, static_cast<Int64>(it->second));
}
return min_position;
}
CustomStorageMergeTreePtr MergeTreeRelParser::parseStorage(
const substrait::Rel & rel_,
const substrait::ReadRel::ExtensionTable & extension_table,
ContextMutablePtr context)
{
const auto & rel = rel_.read();
google::protobuf::StringValue table;
table.ParseFromString(extension_table.detail().value());
auto merge_tree_table = local_engine::parseMergeTreeTableString(table.value());
DB::Block header;
chassert(rel.has_base_schema());
header = TypeParser::buildBlockFromNamedStruct(rel.base_schema());
auto names_and_types_list = header.getNamesAndTypesList();
auto storage_factory = StorageMergeTreeFactory::instance();
auto metadata = buildMetaData(names_and_types_list, context, merge_tree_table);
auto storage = storage_factory.getStorage(
StorageID(merge_tree_table.database, merge_tree_table.table),
metadata->getColumns(),
[&]() -> CustomStorageMergeTreePtr
{
auto custom_storage_merge_tree = std::make_shared<CustomStorageMergeTree>(
StorageID(merge_tree_table.database, merge_tree_table.table),
merge_tree_table.relative_path,
*metadata,
false,
context,
"",
MergeTreeData::MergingParams(),
buildMergeTreeSettings());
custom_storage_merge_tree->loadDataParts(false, std::nullopt);
return custom_storage_merge_tree;
});
return storage;
}
DB::QueryPlanPtr
MergeTreeRelParser::parseReadRel(
DB::QueryPlanPtr query_plan,
const substrait::ReadRel & rel,
const substrait::ReadRel::ExtensionTable & extension_table,
std::list<const substrait::Rel *> & /*rel_stack_*/)
{
google::protobuf::StringValue table;
table.ParseFromString(extension_table.detail().value());
auto merge_tree_table = local_engine::parseMergeTreeTableString(table.value());
DB::Block header;
header = TypeParser::buildBlockFromNamedStruct(merge_tree_table.schema);
DB::Block input;
if (rel.has_base_schema() && rel.base_schema().names_size())
{
input = TypeParser::buildBlockFromNamedStruct(rel.base_schema());
}
else
{
NamesAndTypesList one_column_name_type;
one_column_name_type.push_back(header.getNamesAndTypesList().front());
input = BlockUtil::buildHeader(one_column_name_type);
LOG_DEBUG(&Poco::Logger::get("SerializedPlanParser"), "Try to read ({}) instead of empty header", header.dumpNames());
}
auto storage_factory = StorageMergeTreeFactory::instance();
auto metadata = buildMetaData(header.getNamesAndTypesList(), context, merge_tree_table);
query_context.metadata = metadata;
StorageID table_id(merge_tree_table.database, merge_tree_table.table);
auto storage = storage_factory.getStorage(
table_id,
metadata->getColumns(),
[&]() -> CustomStorageMergeTreePtr
{
auto custom_storage_merge_tree = std::make_shared<CustomStorageMergeTree>(
StorageID(merge_tree_table.database, merge_tree_table.table),
merge_tree_table.relative_path,
*metadata,
false,
global_context,
"",
MergeTreeData::MergingParams(),
buildMergeTreeSettings());
return custom_storage_merge_tree;
});
for (const auto & [name, sizes] : storage->getColumnSizes())
column_sizes[name] = sizes.data_compressed;
query_context.storage_snapshot = std::make_shared<StorageSnapshot>(*storage, metadata);
query_context.custom_storage_merge_tree = storage;
auto names_and_types_list = input.getNamesAndTypesList();
auto query_info = buildQueryInfo(names_and_types_list);
std::set<String> non_nullable_columns;
if (rel.has_filter())
{
NonNullableColumnsResolver non_nullable_columns_resolver(input, *getPlanParser(), rel.filter());
non_nullable_columns = non_nullable_columns_resolver.resolve();
query_info->prewhere_info = parsePreWhereInfo(rel.filter(), input);
}
std::vector<DataPartPtr> selected_parts = storage_factory.getDataParts(table_id, merge_tree_table.getPartNames());
auto ranges = merge_tree_table.extractRange(selected_parts);
if (selected_parts.empty())
throw Exception(ErrorCodes::NO_SUCH_DATA_PART, "no data part found.");
auto read_step = query_context.custom_storage_merge_tree->reader.readFromParts(
selected_parts,
/* alter_conversions = */ {},
names_and_types_list.getNames(),
query_context.storage_snapshot,
*query_info,
context,
context->getSettingsRef().max_block_size,
1);
query_context.custom_storage_merge_tree->wrapRangesInDataParts(*reinterpret_cast<ReadFromMergeTree *>(read_step.get()), ranges);
steps.emplace_back(read_step.get());
query_plan->addStep(std::move(read_step));
if (!non_nullable_columns.empty())
{
auto input_header = query_plan->getCurrentDataStream().header;
std::erase_if(non_nullable_columns, [input_header](auto item) -> bool { return !input_header.has(item); });
auto * remove_null_step = getPlanParser()->addRemoveNullableStep(*query_plan, non_nullable_columns);
if (remove_null_step)
steps.emplace_back(remove_null_step);
}
return query_plan;
}
PrewhereInfoPtr MergeTreeRelParser::parsePreWhereInfo(const substrait::Expression & rel, Block & input)
{
std::string filter_name;
auto prewhere_info = std::make_shared<PrewhereInfo>();
prewhere_info->prewhere_actions = optimizePrewhereAction(rel, filter_name, input);
prewhere_info->prewhere_column_name = filter_name;
prewhere_info->need_filter = true;
prewhere_info->remove_prewhere_column = true;
prewhere_info->prewhere_actions->projectInput(false);
for (const auto & name : input.getNames())
prewhere_info->prewhere_actions->tryRestoreColumn(name);
return prewhere_info;
}
DB::ActionsDAGPtr MergeTreeRelParser::optimizePrewhereAction(const substrait::Expression & rel, std::string & filter_name, Block & block)
{
Conditions res;
std::set<Int64> pk_positions;
analyzeExpressions(res, rel, pk_positions, block);
Int64 min_valid_pk_pos = -1;
for (auto pk_pos : pk_positions)
{
if (pk_pos != min_valid_pk_pos + 1)
break;
min_valid_pk_pos = pk_pos;
}
// TODO need to test
for (auto & cond : res)
if (cond.min_position_in_primary_key > min_valid_pk_pos)
cond.min_position_in_primary_key = std::numeric_limits<Int64>::max() - 1;
// filter less size column first
res.sort();
auto filter_action = std::make_shared<ActionsDAG>(block.getNamesAndTypesList());
if (res.size() == 1)
{
parseToAction(filter_action, res.back().node, filter_name);
}
else
{
DB::ActionsDAG::NodeRawConstPtrs args;
for (Condition cond : res)
{
String ignore;
parseToAction(filter_action, cond.node, ignore);
args.emplace_back(&filter_action->getNodes().back());
}
auto function_builder = FunctionFactory::instance().get("and", context);
std::string args_name = join(args, ',');
filter_name = +"and(" + args_name + ")";
const auto * and_function = &filter_action->addFunction(function_builder, args, filter_name);
filter_action->addOrReplaceInOutputs(*and_function);
}
filter_action->removeUnusedActions(Names{filter_name}, false, true);
return filter_action;
}
void MergeTreeRelParser::parseToAction(ActionsDAGPtr & filter_action, const substrait::Expression & rel, std::string & filter_name)
{
if (rel.has_scalar_function())
getPlanParser()->parseFunctionWithDAG(rel, filter_name, filter_action, true);
else
{
const auto * in_node = parseExpression(filter_action, rel);
filter_action->addOrReplaceInOutputs(*in_node);
filter_name = in_node->result_name;
}
}
void MergeTreeRelParser::analyzeExpressions(
Conditions & res, const substrait::Expression & rel, std::set<Int64> & pk_positions, Block & block)
{
if (rel.has_scalar_function() && getCHFunctionName(rel.scalar_function()) == "and")
{
int arguments_size = rel.scalar_function().arguments_size();
for (int i = 0; i < arguments_size; ++i)
{
auto argument = rel.scalar_function().arguments(i);
analyzeExpressions(res, argument.value(), pk_positions, block);
}
}
else
{
Condition cond(rel);
collectColumns(rel, cond.table_columns, block);
cond.columns_size = getColumnsSize(cond.table_columns);
// TODO: get primary_key_names
const NameToIndexMap primary_key_names_positions;
cond.min_position_in_primary_key = findMinPosition(cond.table_columns, primary_key_names_positions);
pk_positions.emplace(cond.min_position_in_primary_key);
res.emplace_back(std::move(cond));
}
}
UInt64 MergeTreeRelParser::getColumnsSize(const NameSet & columns)
{
UInt64 size = 0;
for (const auto & column : columns)
if (column_sizes.contains(column))
size += column_sizes[column];
return size;
}
void MergeTreeRelParser::collectColumns(const substrait::Expression & rel, NameSet & columns, Block & block)
{
switch (rel.rex_type_case())
{
case substrait::Expression::RexTypeCase::kLiteral: {
return;
}
case substrait::Expression::RexTypeCase::kSelection: {
const size_t idx = rel.selection().direct_reference().struct_field().field();
if (const Names names = block.getNames(); names.size() > idx)
columns.insert(names[idx]);
return;
}
case substrait::Expression::RexTypeCase::kCast: {
const auto & input = rel.cast().input();
collectColumns(input, columns, block);
return;
}
case substrait::Expression::RexTypeCase::kIfThen: {
const auto & if_then = rel.if_then();
auto condition_nums = if_then.ifs_size();
for (int i = 0; i < condition_nums; ++i)
{
const auto & ifs = if_then.ifs(i);
collectColumns(ifs.if_(), columns, block);
collectColumns(ifs.then(), columns, block);
}
return;
}
case substrait::Expression::RexTypeCase::kScalarFunction: {
for (const auto & arg : rel.scalar_function().arguments())
collectColumns(arg.value(), columns, block);
return;
}
case substrait::Expression::RexTypeCase::kSingularOrList: {
const auto & options = rel.singular_or_list().options();
/// options is empty always return false
if (options.empty())
return;
collectColumns(rel.singular_or_list().value(), columns, block);
return;
}
default:
throw Exception(
ErrorCodes::UNKNOWN_TYPE,
"Unsupported spark expression type {} : {}",
magic_enum::enum_name(rel.rex_type_case()),
rel.DebugString());
}
}
String MergeTreeRelParser::getCHFunctionName(const substrait::Expression_ScalarFunction & substrait_func)
{
auto func_signature = getPlanParser()->function_mapping.at(std::to_string(substrait_func.function_reference()));
auto pos = func_signature.find(':');
auto func_name = func_signature.substr(0, pos);
auto it = SCALAR_FUNCTIONS.find(func_name);
if (it == SCALAR_FUNCTIONS.end())
throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unsupported substrait function on mergetree prewhere parser: {}", func_name);
return it->second;
}
}