cpp-ch/local-engine/Storages/SubstraitSource/SubstraitFileSource.cpp (326 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <functional>
#include <memory>
#include <boost/algorithm/string/predicate.hpp>
#include <substrait/plan.pb.h>
#include <magic_enum.hpp>
#include <Poco/URI.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeDateTime64.h>
#include <DataTypes/DataTypeDecimalBase.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <Interpreters/castColumn.h>
#include <QueryPipeline/Pipe.h>
#include <Storages/SubstraitSource/FormatFile.h>
#include <Storages/SubstraitSource/SubstraitFileSource.h>
#include <Storages/SubstraitSource/SubstraitFileSourceStep.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>
#include <Common/StringUtils.h>
#include <Common/typeid_cast.h>
#include "DataTypes/DataTypesDecimal.h"
#include "IO/readDecimalText.h"
#include <boost/stacktrace.hpp>
namespace DB
{
namespace ErrorCodes
{
extern const int UNKNOWN_TYPE;
extern const int LOGICAL_ERROR;
}
}
namespace local_engine
{
// When run query "select count(*) from t", there is no any column to be read.
// The number of rows is the only needed information. To handle these cases, we
// build blocks with a const virtual column to indicate how many rows is in it.
static DB::Block getRealHeader(const DB::Block & header)
{
return header ? header : BlockUtil::buildRowCountHeader();
}
SubstraitFileSource::SubstraitFileSource(
DB::ContextPtr context_, const DB::Block & header_, const substrait::ReadRel::LocalFiles & file_infos)
: DB::SourceWithKeyCondition(getRealHeader(header_), false), context(context_), output_header(header_), to_read_header(output_header)
{
if (file_infos.items_size())
{
/// Initialize files
Poco::URI file_uri(file_infos.items().Get(0).uri_file());
read_buffer_builder = ReadBufferBuilderFactory::instance().createBuilder(file_uri.getScheme(), context);
for (const auto & item : file_infos.items())
files.emplace_back(FormatFileUtil::createFile(context, read_buffer_builder, item));
/// File partition keys are read from the file path
auto partition_keys = files[0]->getFilePartitionKeys();
for (const auto & key : partition_keys)
{
if (to_read_header.findByName(key))
to_read_header.erase(key);
}
}
}
void SubstraitFileSource::setKeyCondition(const DB::ActionsDAG::NodeRawConstPtrs & nodes, DB::ContextPtr context_)
{
const auto & keys = to_read_header;
std::unordered_map<std::string, DB::ColumnWithTypeAndName> node_name_to_input_column;
for (const auto & column : keys.getColumnsWithTypeAndName())
node_name_to_input_column.insert({column.name, column});
auto filter_actions_dag = DB::ActionsDAG::buildFilterActionsDAG(nodes, node_name_to_input_column);
key_condition = std::make_shared<const DB::KeyCondition>(
filter_actions_dag,
context_,
keys.getNames(),
std::make_shared<DB::ExpressionActions>(std::make_shared<DB::ActionsDAG>(keys.getColumnsWithTypeAndName())));
}
DB::Chunk SubstraitFileSource::generate()
{
while (true)
{
if (!tryPrepareReader())
{
/// all files finished
return {};
}
DB::Chunk chunk;
if (file_reader->pull(chunk))
return chunk;
/// try to read from next file
file_reader.reset();
}
}
bool SubstraitFileSource::tryPrepareReader()
{
if (file_reader)
return true;
if (current_file_index >= files.size())
return false;
auto current_file = files[current_file_index];
current_file_index += 1;
if (!current_file->supportSplit() && current_file->getStartOffset())
{
/// For the files do not support split strategy, the task with not 0 offset will generate empty data
file_reader = std::make_unique<EmptyFileReader>(current_file);
return true;
}
if (!to_read_header)
{
auto total_rows = current_file->getTotalRows();
if (total_rows.has_value())
file_reader = std::make_unique<ConstColumnsFileReader>(current_file, context, output_header, *total_rows);
else
{
/// For text/json format file, we can't get total rows from file metadata.
/// So we add a dummy column to indicate the number of rows.
file_reader
= std::make_unique<NormalFileReader>(current_file, context, getRealHeader(to_read_header), getRealHeader(output_header));
}
}
else
file_reader = std::make_unique<NormalFileReader>(current_file, context, to_read_header, output_header);
file_reader->applyKeyCondition(key_condition);
return true;
}
DB::ColumnPtr FileReaderWrapper::createConstColumn(DB::DataTypePtr data_type, const DB::Field & field, size_t rows)
{
auto nested_type = DB::removeNullable(data_type);
auto column = nested_type->createColumnConst(rows, field);
if (data_type->isNullable())
column = DB::ColumnNullable::create(column, DB::ColumnUInt8::create(rows, 0));
return column;
}
DB::ColumnPtr FileReaderWrapper::createColumn(const String & value, DB::DataTypePtr type, size_t rows)
{
if (StringUtils::isNullPartitionValue(value))
{
if (!type->isNullable())
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Partition column is null value,but column data type is not nullable.");
}
auto nested_type = static_cast<const DB::DataTypeNullable &>(*type).getNestedType();
auto column = nested_type->createColumnConstWithDefaultValue(rows);
return DB::ColumnNullable::create(column, DB::ColumnUInt8::create(rows, 1));
}
else
{
auto field = buildFieldFromString(value, type);
return createConstColumn(type, field, rows);
}
}
#define BUILD_INT_FIELD(type) \
[](DB::ReadBuffer & in, const String &) \
{ \
type value = 0; \
DB::readIntText(value, in); \
return DB::Field(value); \
}
#define BUILD_FP_FIELD(type) \
[](DB::ReadBuffer & in, const String &) \
{ \
type value = 0.0; \
DB::readFloatText(value, in); \
return DB::Field(value); \
}
DB::Field FileReaderWrapper::buildFieldFromString(const String & str_value, DB::DataTypePtr type)
{
using FieldBuilder = std::function<DB::Field(DB::ReadBuffer &, const String &)>;
static std::map<std::string, FieldBuilder> field_builders
= {{"Int8", BUILD_INT_FIELD(Int8)},
{"Int16", BUILD_INT_FIELD(Int16)},
{"Int32", BUILD_INT_FIELD(Int32)},
{"Int64", BUILD_INT_FIELD(Int64)},
{"Float32", BUILD_FP_FIELD(Float32)},
{"Float64", BUILD_FP_FIELD(Float64)},
{"String", [](DB::ReadBuffer &, const String & val) { return DB::Field(val); }},
{"Date",
[](DB::ReadBuffer & in, const String &)
{
DayNum value;
readDateText(value, in);
return DB::Field(value);
}},
{"Date32",
[](DB::ReadBuffer & in, const String &)
{
ExtendedDayNum value;
readDateText(value, in);
return DB::Field(value.toUnderType());
}},
{"Bool",
[](DB::ReadBuffer & in, const String &)
{
bool value;
readBoolTextWord(value, in, true);
return DB::Field(value);
}},
{"DateTime64(6)",
[](DB::ReadBuffer &, const String & s)
{
std::string decoded; // s: "2023-07-12 05%3A05%3A33.798" (spark encoded it) => decoded: "2023-07-12 05:05:33.798"
Poco::URI::decode(s, decoded);
std::string to_read;
if (decoded.length() > 23) // we see cases when spark mistakely? encode the URI twice, so we need to decode twice
Poco::URI::decode(decoded, to_read);
else
to_read = decoded;
DB::ReadBufferFromString read_buffer(to_read);
DB::DateTime64 value;
DB::readDateTime64Text(value, 6, read_buffer);
return DB::Field(value);
}}
};
auto nested_type = DB::removeNullable(type);
DB::ReadBufferFromString read_buffer(str_value);
auto it = field_builders.find(nested_type->getName());
if (it == field_builders.end())
{
DB::WhichDataType which(nested_type->getTypeId());
if (which.isDecimal32())
{
auto & dataTypeDecimal = static_cast<const DB::DataTypeDecimal<DB::Decimal32> &>(*nested_type);
DB::Decimal32 value = dataTypeDecimal.parseFromString(str_value);
return DB::DecimalField<DB::Decimal32>(value, dataTypeDecimal.getScale());
}
else if (which.isDecimal64())
{
auto & dataTypeDecimal = static_cast<const DB::DataTypeDecimal<DB::Decimal64> &>(*nested_type);
DB::Decimal64 value = dataTypeDecimal.parseFromString(str_value);
return DB::DecimalField<DB::Decimal64>(value, dataTypeDecimal.getScale());
}
else if (which.isDecimal128())
{
auto & dataTypeDecimal = static_cast<const DB::DataTypeDecimal<DB::Decimal128> &>(*nested_type);
DB::Decimal128 value = dataTypeDecimal.parseFromString(str_value);
return DB::DecimalField<DB::Decimal128>(value, dataTypeDecimal.getScale());
}
else if (which.isDecimal256())
{
auto & dataTypeDecimal = static_cast<const DB::DataTypeDecimal<DB::Decimal256> &>(*nested_type);
DB::Decimal256 value = dataTypeDecimal.parseFromString(str_value);
return DB::DecimalField<DB::Decimal256>(value, dataTypeDecimal.getScale());
}
throw DB::Exception(DB::ErrorCodes::UNKNOWN_TYPE, "Unsupported data type {}", nested_type->getName());
}
return it->second(read_buffer, str_value);
}
ConstColumnsFileReader::ConstColumnsFileReader(FormatFilePtr file_, DB::ContextPtr context_, const DB::Block & header_, size_t block_size_)
: FileReaderWrapper(file_), context(context_), header(header_), remained_rows(0), block_size(block_size_)
{
auto rows = file->getTotalRows();
if (!rows)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Cannot get total rows number from file : {}", file->getURIPath());
remained_rows = *rows;
}
bool ConstColumnsFileReader::pull(DB::Chunk & chunk)
{
if (!remained_rows) [[unlikely]]
return false;
size_t to_read_rows = 0;
if (remained_rows < block_size)
{
to_read_rows = remained_rows;
remained_rows = 0;
}
else
{
to_read_rows = block_size;
remained_rows -= block_size;
}
DB::Columns res_columns;
size_t col_num = header.columns();
if (col_num)
{
res_columns.reserve(col_num);
const auto & partition_values = file->getFilePartitionValues();
for (size_t pos = 0; pos < col_num; ++pos)
{
auto col_with_name_and_type = header.getByPosition(pos);
auto type = col_with_name_and_type.type;
const auto & name = col_with_name_and_type.name;
auto it = partition_values.find(name);
if (it == partition_values.end()) [[unlikely]]
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow partition column : {}", name);
}
res_columns.emplace_back(createColumn(it->second, type, to_read_rows));
}
}
else
{
// the original header is empty, build a block to represent the row count.
res_columns = BlockUtil::buildRowCountChunk(to_read_rows).detachColumns();
}
chunk = DB::Chunk(std::move(res_columns), to_read_rows);
return true;
}
NormalFileReader::NormalFileReader(
FormatFilePtr file_, DB::ContextPtr context_, const DB::Block & to_read_header_, const DB::Block & output_header_)
: FileReaderWrapper(file_), context(context_), to_read_header(to_read_header_), output_header(output_header_)
{
input_format = file->createInputFormat(to_read_header);
}
bool NormalFileReader::pull(DB::Chunk & chunk)
{
DB::Chunk raw_chunk = input_format->input->generate();
auto rows = raw_chunk.getNumRows();
if (!rows)
return false;
auto read_columns = raw_chunk.detachColumns();
auto columns_with_name_and_type = output_header.getColumnsWithTypeAndName();
auto partition_values = file->getFilePartitionValues();
DB::Columns res_columns;
res_columns.reserve(columns_with_name_and_type.size());
for (auto & column : columns_with_name_and_type)
{
if (to_read_header.has(column.name))
{
auto pos = to_read_header.getPositionByName(column.name);
res_columns.push_back(read_columns[pos]);
}
else
{
auto it = partition_values.find(column.name);
if (it == partition_values.end())
{
throw DB::Exception(
DB::ErrorCodes::LOGICAL_ERROR, "Not found column({}) from file({}) partition keys.", column.name, file->getURIPath());
}
res_columns.push_back(createColumn(it->second, column.type, rows));
}
}
chunk = DB::Chunk(std::move(res_columns), rows);
return true;
}
}