cpp/velox/substrait/SubstraitParser.cc (358 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "SubstraitParser.h"
#include "TypeUtils.h"
#include "velox/common/base/Exceptions.h"
#include "VeloxSubstraitSignature.h"
namespace gluten {
TypePtr SubstraitParser::parseType(const ::substrait::Type& substraitType, bool asLowerCase) {
switch (substraitType.kind_case()) {
case ::substrait::Type::KindCase::kBool:
return BOOLEAN();
case ::substrait::Type::KindCase::kI8:
return TINYINT();
case ::substrait::Type::KindCase::kI16:
return SMALLINT();
case ::substrait::Type::KindCase::kI32:
return INTEGER();
case ::substrait::Type::KindCase::kI64:
return BIGINT();
case ::substrait::Type::KindCase::kFp32:
return REAL();
case ::substrait::Type::KindCase::kFp64:
return DOUBLE();
case ::substrait::Type::KindCase::kString:
return VARCHAR();
case ::substrait::Type::KindCase::kBinary:
return VARBINARY();
case ::substrait::Type::KindCase::kStruct: {
const auto& substraitStruct = substraitType.struct_();
const auto& structTypes = substraitStruct.types();
const auto& structNames = substraitStruct.names();
bool nameProvided = structTypes.size() == structNames.size();
std::vector<TypePtr> types;
std::vector<std::string> names;
for (int i = 0; i < structTypes.size(); i++) {
types.emplace_back(parseType(structTypes[i], asLowerCase));
std::string fieldName = nameProvided ? structNames[i] : "col_" + std::to_string(i);
if (asLowerCase) {
folly::toLowerAscii(fieldName);
}
names.emplace_back(fieldName);
}
return ROW(std::move(names), std::move(types));
}
case ::substrait::Type::KindCase::kList: {
const auto& fieldType = substraitType.list().type();
return ARRAY(parseType(fieldType, asLowerCase));
}
case ::substrait::Type::KindCase::kMap: {
const auto& sMap = substraitType.map();
const auto& keyType = sMap.key();
const auto& valueType = sMap.value();
return MAP(parseType(keyType, asLowerCase), parseType(valueType, asLowerCase));
}
case ::substrait::Type::KindCase::kUserDefined:
// We only support UNKNOWN type to handle the null literal whose type is
// not known.
return UNKNOWN();
case ::substrait::Type::KindCase::kDate:
return DATE();
case ::substrait::Type::KindCase::kTimestamp:
return TIMESTAMP();
case ::substrait::Type::KindCase::kDecimal: {
auto precision = substraitType.decimal().precision();
auto scale = substraitType.decimal().scale();
return DECIMAL(precision, scale);
}
case ::substrait::Type::KindCase::kIntervalYear: {
return INTERVAL_YEAR_MONTH();
}
case ::substrait::Type::KindCase::kNothing:
return UNKNOWN();
default:
VELOX_NYI("Parsing for Substrait type not supported: {}", substraitType.DebugString());
}
}
std::vector<TypePtr> SubstraitParser::parseNamedStruct(const ::substrait::NamedStruct& namedStruct, bool asLowerCase) {
// Note that "names" are not used.
// Parse Struct.
const auto& substraitStruct = namedStruct.struct_();
const auto& substraitTypes = substraitStruct.types();
std::vector<TypePtr> typeList;
typeList.reserve(substraitTypes.size());
for (const auto& type : substraitTypes) {
typeList.emplace_back(parseType(type, asLowerCase));
}
return typeList;
}
void SubstraitParser::parseColumnTypes(
const ::substrait::NamedStruct& namedStruct,
std::vector<ColumnType>& columnTypes) {
const auto& columnsTypes = namedStruct.column_types();
if (columnsTypes.size() == 0) {
// Regard all columns as regular columns.
columnTypes.resize(namedStruct.names().size(), ColumnType::kRegular);
return;
} else {
VELOX_CHECK_EQ(columnsTypes.size(), namedStruct.names().size(), "Wrong size for column types and column names.");
}
columnTypes.reserve(columnsTypes.size());
for (const auto& columnType : columnsTypes) {
switch (columnType) {
case ::substrait::NamedStruct::NORMAL_COL:
columnTypes.push_back(ColumnType::kRegular);
break;
case ::substrait::NamedStruct::PARTITION_COL:
columnTypes.push_back(ColumnType::kPartitionKey);
break;
case ::substrait::NamedStruct::METADATA_COL:
columnTypes.push_back(ColumnType::kSynthesized);
break;
case ::substrait::NamedStruct::ROWINDEX_COL:
columnTypes.push_back(ColumnType::kRowIndex);
break;
default:
VELOX_FAIL("Unspecified column type.");
}
}
return;
}
bool SubstraitParser::parseReferenceSegment(
const ::substrait::Expression::ReferenceSegment& refSegment,
uint32_t& fieldIndex) {
auto typeCase = refSegment.reference_type_case();
switch (typeCase) {
case ::substrait::Expression::ReferenceSegment::ReferenceTypeCase::kStructField: {
if (refSegment.struct_field().has_child()) {
// To parse subfield index is not supported.
return false;
}
fieldIndex = refSegment.struct_field().field();
if (fieldIndex < 0) {
return false;
}
return true;
}
default:
VELOX_NYI("Substrait conversion not supported for ReferenceSegment '{}'", std::to_string(typeCase));
}
}
std::vector<std::string> SubstraitParser::makeNames(const std::string& prefix, int size) {
std::vector<std::string> names;
names.reserve(size);
for (int i = 0; i < size; i++) {
names.emplace_back(fmt::format("{}_{}", prefix, i));
}
return names;
}
std::string SubstraitParser::makeNodeName(int nodeId, int colIdx) {
return fmt::format("n{}_{}", nodeId, colIdx);
}
int SubstraitParser::getIdxFromNodeName(const std::string& nodeName) {
// Get the position of "_" in the function name.
std::size_t pos = nodeName.find("_");
if (pos == std::string::npos) {
VELOX_FAIL("Invalid node name.");
}
if (pos == nodeName.size() - 1) {
VELOX_FAIL("Invalid node name.");
}
// Get the column index.
std::string colIdx = nodeName.substr(pos + 1);
try {
return stoi(colIdx);
} catch (const std::exception& err) {
VELOX_FAIL(err.what());
}
}
std::string SubstraitParser::findFunctionSpec(
const std::unordered_map<uint64_t, std::string>& functionMap,
uint64_t id) {
auto x = functionMap.find(id);
if (x == functionMap.end()) {
VELOX_FAIL("Could not find function id {} in function map.", id);
}
return x->second;
}
// TODO Refactor using Bison.
std::string SubstraitParser::getNameBeforeDelimiter(const std::string& signature, const std::string& delimiter) {
std::size_t pos = signature.find(delimiter);
if (pos == std::string::npos) {
return signature;
}
return signature.substr(0, pos);
}
std::vector<std::string> SubstraitParser::getSubFunctionTypes(const std::string& substraitFunction) {
// Get the position of ":" in the function name.
size_t pos = substraitFunction.find(":");
// Get the parameter types.
std::vector<std::string> types;
if (pos == std::string::npos || pos == substraitFunction.size() - 1) {
return types;
}
// Extract input types with delimiter.
for (;;) {
const size_t endPos = substraitFunction.find("_", pos + 1);
if (endPos == std::string::npos) {
std::string typeName = substraitFunction.substr(pos + 1);
if (typeName != "opt" && typeName != "req") {
types.emplace_back(typeName);
}
break;
}
const std::string typeName = substraitFunction.substr(pos + 1, endPos - pos - 1);
if (typeName != "opt" && typeName != "req") {
types.emplace_back(typeName);
}
pos = endPos;
}
return types;
}
std::string SubstraitParser::findVeloxFunction(
const std::unordered_map<uint64_t, std::string>& functionMap,
uint64_t id) {
std::string funcSpec = findFunctionSpec(functionMap, id);
std::string funcName = getNameBeforeDelimiter(funcSpec);
std::vector<std::string> types = getSubFunctionTypes(funcSpec);
bool isDecimal = false;
for (const auto& type : types) {
if (type.find("dec") != std::string::npos) {
isDecimal = true;
break;
}
}
return mapToVeloxFunction(funcName, isDecimal);
}
std::string SubstraitParser::mapToVeloxFunction(const std::string& substraitFunction, bool isDecimal) {
auto it = substraitVeloxFunctionMap_.find(substraitFunction);
if (isDecimal) {
if (substraitFunction == "lt" || substraitFunction == "lte" || substraitFunction == "gt" ||
substraitFunction == "gte" || substraitFunction == "equal") {
return "decimal_" + it->second;
}
if (substraitFunction == "round") {
return "decimal_round";
}
}
if (it != substraitVeloxFunctionMap_.end()) {
return it->second;
}
// If not finding the mapping from Substrait function name to Velox function
// name, the original Substrait function name will be used.
return substraitFunction;
}
bool SubstraitParser::configSetInOptimization(
const ::substrait::extensions::AdvancedExtension& extension,
const std::string& config) {
if (extension.has_optimization()) {
google::protobuf::StringValue msg;
extension.optimization().UnpackTo(&msg);
std::size_t pos = msg.value().find(config);
if ((pos != std::string::npos) && (msg.value().substr(pos + config.size(), 1) == "1")) {
return true;
}
}
return false;
}
std::vector<TypePtr> SubstraitParser::sigToTypes(const std::string& signature) {
std::vector<std::string> typeStrs = SubstraitParser::getSubFunctionTypes(signature);
std::vector<TypePtr> types;
types.reserve(typeStrs.size());
for (const auto& typeStr : typeStrs) {
types.emplace_back(VeloxSubstraitSignature::fromSubstraitSignature(typeStr));
}
return types;
}
template <typename T>
T SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& /* literal */) {
VELOX_NYI();
}
template <>
std::shared_ptr<void> gluten::SubstraitParser::getLiteralValue(const substrait::Expression_Literal& literal) {
return nullptr;
}
template <>
facebook::velox::UnknownValue gluten::SubstraitParser::getLiteralValue(const substrait::Expression_Literal& literal) {
return UnknownValue();
}
template <>
int8_t SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& literal) {
return static_cast<int8_t>(literal.i8());
}
template <>
int16_t SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& literal) {
return static_cast<int16_t>(literal.i16());
}
template <>
int32_t SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& literal) {
if (literal.has_date()) {
return static_cast<int32_t>(literal.date());
}
return literal.i32();
}
template <>
int64_t SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& literal) {
if (literal.has_decimal()) {
auto decimal = literal.decimal().value();
int128_t decimalValue;
memcpy(&decimalValue, decimal.c_str(), 16);
return static_cast<int64_t>(decimalValue);
}
return literal.i64();
}
template <>
int128_t SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& literal) {
auto decimal = literal.decimal().value();
int128_t decimalValue;
memcpy(&decimalValue, decimal.c_str(), 16);
return HugeInt::build(static_cast<uint64_t>(decimalValue >> 64), static_cast<uint64_t>(decimalValue));
}
template <>
double SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& literal) {
return literal.fp64();
}
template <>
float SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& literal) {
return literal.fp32();
}
template <>
bool SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& literal) {
return literal.boolean();
}
template <>
Timestamp SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& literal) {
return Timestamp::fromMicros(literal.timestamp());
}
template <>
StringView SubstraitParser::getLiteralValue(const ::substrait::Expression::Literal& literal) {
if (literal.has_string()) {
return StringView(literal.string());
} else if (literal.has_var_char()) {
return StringView(literal.var_char().value());
} else if (literal.has_binary()) {
return StringView(literal.binary());
} else {
VELOX_FAIL("Unexpected string or binary literal");
}
}
std::unordered_map<std::string, std::string> SubstraitParser::substraitVeloxFunctionMap_ = {
{"is_not_null", "isnotnull"}, /*Spark functions.*/
{"is_null", "isnull"},
{"equal", "equalto"},
{"equal_null_safe", "equalnullsafe"},
{"lt", "lessthan"},
{"lte", "lessthanorequal"},
{"gt", "greaterthan"},
{"gte", "greaterthanorequal"},
{"char_length", "length"},
{"strpos", "instr"},
{"ends_with", "endswith"},
{"starts_with", "startswith"},
{"named_struct", "row_constructor"},
{"bit_or", "bitwise_or_agg"},
{"bit_and", "bitwise_and_agg"},
{"murmur3hash", "hash_with_seed"},
{"xxhash64", "xxhash64_with_seed"},
{"modulus", "remainder"},
{"negative", "unaryminus"},
{"get_array_item", "get"}};
const std::unordered_map<std::string, std::string> SubstraitParser::typeMap_ = {
{"bool", "BOOLEAN"},
{"i8", "TINYINT"},
{"i16", "SMALLINT"},
{"i32", "INTEGER"},
{"i64", "BIGINT"},
{"fp32", "REAL"},
{"fp64", "DOUBLE"},
{"date", "DATE"},
{"ts", "TIMESTAMP"},
{"str", "VARCHAR"},
{"vbin", "VARBINARY"},
{"decShort", "SHORT_DECIMAL"},
{"decLong", "HUGEINT"}};
} // namespace gluten