cpp-ch/local-engine/Shuffle/SelectorBuilder.cpp (372 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "SelectorBuilder.h" #include <limits> #include <memory> #include <mutex> #include <Columns/ColumnArray.h> #include <Columns/ColumnMap.h> #include <Columns/ColumnNullable.h> #include <Columns/ColumnTuple.h> #include <DataTypes/DataTypeArray.h> #include <DataTypes/DataTypeMap.h> #include <DataTypes/DataTypeTuple.h> #include <DataTypes/DataTypesDecimal.h> #include <Functions/FunctionFactory.h> #include <Parser/SerializedPlanParser.h> #include <Parser/TypeParser.h> #include <Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h> #include <Processors/QueryPlan/QueryPlan.h> #include <Poco/Base64Decoder.h> #include <Poco/JSON/JSON.h> #include <Poco/JSON/Parser.h> #include <Poco/MemoryStream.h> #include <Poco/StreamCopier.h> #include <Common/CHUtil.h> #include <Common/Exception.h> namespace DB { namespace ErrorCodes { extern const int LOGICAL_ERROR; } } namespace local_engine { PartitionInfo PartitionInfo::fromSelector(DB::IColumn::Selector selector, size_t partition_num) { auto rows = selector.size(); std::vector<size_t> partition_row_idx_start_points(partition_num + 1, 0); IColumn::Selector partition_selector(rows, 0); for (size_t i = 0; i < rows; ++i) partition_row_idx_start_points[selector[i]]++; for (size_t i = 1; i <= partition_num; ++i) partition_row_idx_start_points[i] += partition_row_idx_start_points[i - 1]; for (size_t i = rows; i-- > 0;) { partition_selector[partition_row_idx_start_points[selector[i]] - 1] = i; partition_row_idx_start_points[selector[i]]--; } return PartitionInfo{ .partition_selector = std::move(partition_selector), .partition_start_points = partition_row_idx_start_points, .partition_num = partition_num}; } PartitionInfo RoundRobinSelectorBuilder::build(DB::Block & block) { DB::IColumn::Selector result; result.resize_fill(block.rows(), 0); for (auto & pid : result) { pid = pid_selection; pid_selection = (pid_selection + 1) % parts_num; } return PartitionInfo::fromSelector(std::move(result), parts_num); } HashSelectorBuilder::HashSelectorBuilder( UInt32 parts_num_, const std::vector<size_t> & exprs_index_, const std::string & hash_function_name_) : parts_num(parts_num_), exprs_index(exprs_index_), hash_function_name(hash_function_name_) { } PartitionInfo HashSelectorBuilder::build(DB::Block & block) { ColumnsWithTypeAndName args; for (size_t i = 0; i < exprs_index.size(); i++) args.emplace_back(block.safeGetByPosition(exprs_index.at(i))); auto flatten_block = BlockUtil::flattenBlock(DB::Block(args), BlockUtil::FLAT_STRUCT_FORCE | BlockUtil::FLAT_NESTED_TABLE, true); args = flatten_block.getColumnsWithTypeAndName(); if (!hash_function) [[unlikely]] { auto & factory = DB::FunctionFactory::instance(); auto function = factory.get(hash_function_name, local_engine::SerializedPlanParser::global_context); hash_function = function->build(args); } auto rows = block.rows(); DB::IColumn::Selector partition_ids; partition_ids.reserve(rows); auto result_type = hash_function->getResultType(); auto hash_column = hash_function->execute(args, result_type, rows, false); if (isNothing(removeNullable(result_type))) { /// TODO: implement new hash function sparkCityHash64 like sparkXxHash64 to process null literal as column more gracefully. /// Current implementation may cause partition skew. for (size_t i = 0; i < rows; i++) partition_ids.emplace_back(0); } else { if (hash_function_name == "sparkMurmurHash3_32") { /// sparkMurmurHash3_32 returns are all not null. auto parts_num_int32 = static_cast<Int32>(parts_num); for (size_t i = 0; i < rows; i++) { // cast to int32 to be the same as the data type of the vanilla spark auto hash_int32 = static_cast<Int32>(hash_column->get64(i)); auto res = hash_int32 % parts_num_int32; if (res < 0) { res += parts_num_int32; } partition_ids.emplace_back(static_cast<UInt64>(res)); } } else { if (hash_column->isNullable()) { const auto * null_col = typeid_cast<const ColumnNullable *>(hash_column->getPtr().get()); auto & null_map = null_col->getNullMapData(); for (size_t i = 0; i < rows; ++i) { auto hash_value = static_cast<UInt64>(hash_column->get64(i)) & static_cast<UInt64>(static_cast<Int64>(null_map[i]) - 1); partition_ids.emplace_back(static_cast<UInt64>(hash_value % parts_num)); } } else { for (size_t i = 0; i < rows; i++) partition_ids.emplace_back(static_cast<UInt64>(hash_column->get64(i) % parts_num)); } } } return PartitionInfo::fromSelector(std::move(partition_ids), parts_num); } static std::map<int, std::pair<int, int>> direction_map = {{1, {1, -1}}, {2, {1, 1}}, {3, {-1, 1}}, {4, {-1, -1}}}; RangeSelectorBuilder::RangeSelectorBuilder(const std::string & option, const size_t partition_num_) { Poco::JSON::Parser parser; auto info = parser.parse(option).extract<Poco::JSON::Object::Ptr>(); auto ordering_infos = info->get("ordering").extract<Poco::JSON::Array::Ptr>(); initSortInformation(ordering_infos); initRangeBlock(info->get("range_bounds").extract<Poco::JSON::Array::Ptr>()); partition_num = partition_num_; } PartitionInfo RangeSelectorBuilder::build(DB::Block & block) { DB::IColumn::Selector result; computePartitionIdByBinarySearch(block, result); return PartitionInfo::fromSelector(std::move(result), partition_num); } void RangeSelectorBuilder::initSortInformation(Poco::JSON::Array::Ptr orderings) { for (uint32_t i = 0; i < orderings->size(); ++i) { auto ordering = orderings->get(i).extract<Poco::JSON::Object::Ptr>(); auto col_pos = ordering->get("column_ref").convert<Int32>(); auto col_name = ordering->get("column_name").convert<String>(); auto sort_direction = ordering->get("direction").convert<int>(); auto d_iter = direction_map.find(sort_direction); if (d_iter == direction_map.end()) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupported sorting direction:{}", sort_direction); DB::SortColumnDescription ch_col_sort_descr(col_name, d_iter->second.first, d_iter->second.second); sort_descriptions.emplace_back(ch_col_sort_descr); auto type_name = ordering->get("data_type").convert<std::string>(); auto type = TypeParser::getCHTypeByName(type_name); SortFieldTypeInfo info; info.inner_type = type; info.is_nullable = ordering->get("is_nullable").convert<bool>(); sort_field_types.emplace_back(info); sorting_key_columns.emplace_back(col_pos); } } template <typename T> void RangeSelectorBuilder::safeInsertFloatValue(const Poco::Dynamic::Var & field_value, DB::MutableColumnPtr & col) { try { col->insert(field_value.convert<T>()); } catch (const std::exception &) { String val = Poco::toLower(field_value.convert<std::string>()); T res; if (val == "nan") res = std::numeric_limits<T>::quiet_NaN(); else if (val == "infinity") res = std::numeric_limits<T>::infinity(); else if (val == "-infinity") res = -std::numeric_limits<T>::infinity(); else throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupported value: {}", val); col->insert(res); } } void RangeSelectorBuilder::initRangeBlock(Poco::JSON::Array::Ptr range_bounds) { DB::ColumnsWithTypeAndName columns; for (uint32_t i = 0; i < sort_field_types.size(); ++i) { auto & type_info = sort_field_types[i]; auto inner_col = type_info.inner_type->createColumn(); auto data_type = type_info.inner_type; DB::MutableColumnPtr col = std::move(inner_col); if (type_info.is_nullable) { col = ColumnNullable::create(std::move(col), DB::ColumnUInt8::create(0, 0)); data_type = std::make_shared<DB::DataTypeNullable>(data_type); } for (uint32_t r = 0; r < range_bounds->size(); ++r) { auto row = range_bounds->get(r).extract<Poco::JSON::Array::Ptr>(); auto field_info = row->get(i).extract<Poco::JSON::Object::Ptr>(); if (field_info->get("is_null").convert<bool>()) { col->insertData(nullptr, 0); } else { const auto & type_name = type_info.inner_type->getName(); const auto & field_value = field_info->get("value"); if (type_name == "UInt8") { col->insert(static_cast<UInt8>(field_value.convert<Int16>())); } else if (type_name == "Int8") { col->insert(field_value.convert<Int8>()); } else if (type_name == "Int16") { col->insert(field_value.convert<Int16>()); } else if (type_name == "Int32") { col->insert(field_value.convert<Int32>()); } else if (type_name == "Int64") { col->insert(field_value.convert<Int64>()); } else if (type_name == "Float32") { safeInsertFloatValue<Float32>(field_value, col); } else if (type_name == "Float64") { safeInsertFloatValue<Float64>(field_value, col); } else if (type_name == "String") { col->insert(field_value.convert<std::string>()); } else if (type_name == "Date32") { int val = field_value.convert<Int32>(); col->insert(val); } else if (const auto * decimal32 = dynamic_cast<const DB::DataTypeDecimal<DB::Decimal32> *>(type_info.inner_type.get())) { auto value = decimal32->parseFromString(field_value.convert<std::string>()); col->insert(DB::DecimalField<DB::Decimal32>(value, decimal32->getScale())); } else if (const auto * decimal64 = dynamic_cast<const DB::DataTypeDecimal<DB::Decimal64> *>(type_info.inner_type.get())) { auto value = decimal64->parseFromString(field_value.convert<std::string>()); col->insert(DB::DecimalField<DB::Decimal64>(value, decimal64->getScale())); } else if (const auto * decimal128 = dynamic_cast<const DB::DataTypeDecimal<DB::Decimal128> *>(type_info.inner_type.get())) { auto value = decimal128->parseFromString(field_value.convert<std::string>()); col->insert(DB::DecimalField<DB::Decimal128>(value, decimal128->getScale())); } else { throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupported data type: {}", type_info.inner_type->getName()); } } } auto col_name = "sort_col_" + std::to_string(i); columns.emplace_back(std::move(col), data_type, col_name); } range_bounds_block = DB::Block(columns); } void RangeSelectorBuilder::initActionsDAG(const DB::Block & block) { std::lock_guard lock(actions_dag_mutex); if (has_init_actions_dag) return; SerializedPlanParser plan_parser(local_engine::SerializedPlanParser::global_context); plan_parser.parseExtensions(projection_plan_pb->extensions()); const auto & expressions = projection_plan_pb->relations().at(0).root().input().project().expressions(); std::vector<substrait::Expression> exprs; exprs.reserve(expressions.size()); for (const auto & expression : expressions) exprs.emplace_back(expression); auto projection_actions_dag = plan_parser.expressionsToActionsDAG(exprs, block, block); projection_expression_actions = std::make_unique<DB::ExpressionActions>(projection_actions_dag); has_init_actions_dag = true; } void RangeSelectorBuilder::computePartitionIdByBinarySearch(DB::Block & block, DB::IColumn::Selector & selector) { Chunks chunks; Chunk chunk(block.getColumns(), block.rows()); chunks.emplace_back(std::move(chunk)); selector.clear(); selector.reserve(block.rows()); auto input_columns = block.getColumns(); auto total_rows = block.rows(); const auto & bounds_columns = range_bounds_block.getColumns(); auto max_part = bounds_columns[0]->size(); for (size_t i = 0; i < bounds_columns.size(); i++) if (bounds_columns[i]->isNullable() && !input_columns[sorting_key_columns[i]]->isNullable()) input_columns[sorting_key_columns[i]] = makeNullable(input_columns[sorting_key_columns[i]]); for (size_t r = 0; r < total_rows; ++r) { size_t selected_partition = 0; auto ret = binarySearchBound(bounds_columns, 0, max_part - 1, input_columns, sorting_key_columns, r); if (ret >= 0) selected_partition = ret; else selected_partition = max_part; selector.emplace_back(selected_partition); } } int RangeSelectorBuilder::compareRow( const DB::Columns & columns, const std::vector<size_t> & required_columns, size_t row, const DB::Columns & bound_columns, size_t bound_row) { for (size_t i = 0, n = required_columns.size(); i < n; ++i) { auto lpos = required_columns[i]; auto rpos = i; auto res = columns[lpos]->compareAt(row, bound_row, *bound_columns[rpos], sort_descriptions[i].nulls_direction) * sort_descriptions[i].direction; if (res != 0) return res; } return 0; } // If there were elements in range[l,r] that are larger then the row // the return the min element's index. otherwise return -1 int RangeSelectorBuilder::binarySearchBound( const DB::Columns & bound_columns, Int64 l, Int64 r, const DB::Columns & columns, const std::vector<size_t> & used_cols, size_t row) { if (l > r) return -1; auto m = (l + r) >> 1; auto cmp_ret = compareRow(columns, used_cols, row, bound_columns, m); if (l == r) { if (cmp_ret <= 0) return static_cast<int>(m); else return -1; } if (cmp_ret == 0) return static_cast<int>(m); if (cmp_ret < 0) { cmp_ret = binarySearchBound(bound_columns, l, m - 1, columns, used_cols, row); if (cmp_ret < 0) { // m is the upper bound return static_cast<int>(m); } return cmp_ret; } else { cmp_ret = binarySearchBound(bound_columns, m + 1, r, columns, used_cols, row); if (cmp_ret < 0) return -1; else return cmp_ret; } __builtin_unreachable(); } }