in cpp-ch/local-engine/Parser/ExpressionParser.cpp [284:507]
ExpressionParser::NodeRawConstPtr ExpressionParser::parseExpression(ActionsDAG & actions_dag, const substrait::Expression & rel) const
{
switch (rel.rex_type_case())
{
case substrait::Expression::RexTypeCase::kLiteral: {
DB::DataTypePtr type;
DB::Field field;
std::tie(type, field) = LiteralParser::parse(rel.literal());
return addConstColumn(actions_dag, type, field);
}
case substrait::Expression::RexTypeCase::kSelection: {
auto field_index = SubstraitParserUtils::getStructFieldIndex(rel);
if (!field_index)
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Can only have direct struct references in selections");
const auto * field = actions_dag.getInputs()[*field_index];
return field;
}
case substrait::Expression::RexTypeCase::kCast: {
if (!rel.cast().has_type() || !rel.cast().has_input())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node.");
ActionsDAG::NodeRawConstPtrs args;
const auto & input = rel.cast().input();
args.emplace_back(parseExpression(actions_dag, input));
const auto & substrait_type = rel.cast().type();
const auto & input_type = args[0]->result_type;
DataTypePtr denull_input_type = removeNullable(input_type);
DataTypePtr output_type = TypeParser::parseType(substrait_type);
DataTypePtr denull_output_type = removeNullable(output_type);
const ActionsDAG::Node * result_node = nullptr;
if (substrait_type.has_binary())
{
/// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
result_node = toFunctionNode(actions_dag, "reinterpretAsStringSpark", args);
}
else if (isString(denull_input_type) && isDate32(denull_output_type))
result_node = toFunctionNode(actions_dag, "sparkToDate", args);
else if (isString(denull_input_type) && isDateTime64(denull_output_type))
result_node = toFunctionNode(actions_dag, "sparkToDateTime", args);
else if (isDecimal(denull_input_type) && isString(denull_output_type))
{
/// Spark cast(x as STRING) if x is Decimal -> CH toDecimalString(x, scale)
UInt8 scale = getDecimalScale(*denull_input_type);
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeUInt8>(), Field(scale)));
result_node = toFunctionNode(actions_dag, "toDecimalString", args);
}
else if (isFloat(denull_input_type) && isInt(denull_output_type))
{
String function_name = "sparkCastFloatTo" + denull_output_type->getName();
result_node = toFunctionNode(actions_dag, function_name, args);
}
else if (isFloat(denull_input_type) && isString(denull_output_type))
result_node = toFunctionNode(actions_dag, "sparkCastFloatToString", args);
else if ((isDecimal(denull_input_type) || isNativeNumber(denull_input_type)) && substrait_type.has_decimal())
{
int precision = substrait_type.decimal().precision();
int scale = substrait_type.decimal().scale();
if (precision)
{
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), precision));
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), scale));
result_node = toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", args);
}
}
else if ((isMap(denull_input_type) || isArray(denull_input_type) || isTuple(denull_input_type)) && isString(denull_output_type))
{
/// https://github.com/apache/incubator-gluten/issues/9049
result_node = toFunctionNode(actions_dag, "sparkCastComplexTypesToString", args);
}
else if (isString(denull_input_type) && substrait_type.has_bool_())
{
/// cast(string to boolean)
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeString>(), output_type->getName()));
result_node = toFunctionNode(actions_dag, "accurateCastOrNull", args);
}
else if (isString(denull_input_type) && isInt(denull_output_type))
{
/// Spark cast(x as INT) if x is String -> CH cast(trim(x) as INT)
/// Refer to https://github.com/apache/incubator-gluten/issues/4956 and https://github.com/apache/incubator-gluten/issues/8598
const auto * trim_str_arg = addConstColumn(actions_dag, std::make_shared<DataTypeString>(), " \t\n\r\f");
args[0] = toFunctionNode(actions_dag, "trimBothSpark", {args[0], trim_str_arg});
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeString>(), output_type->getName()));
result_node = toFunctionNode(actions_dag, "CAST", args);
}
else
{
/// Common process: CAST(input, type)
args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeString>(), output_type->getName()));
result_node = toFunctionNode(actions_dag, "CAST", args);
}
actions_dag.addOrReplaceInOutputs(*result_node);
return result_node;
}
case substrait::Expression::RexTypeCase::kIfThen: {
const auto & if_then = rel.if_then();
DB::FunctionOverloadResolverPtr function_ptr = nullptr;
auto condition_nums = if_then.ifs_size();
if (condition_nums == 1)
function_ptr = DB::FunctionFactory::instance().get("if", context->queryContext());
else
function_ptr = FunctionFactory::instance().get("multiIf", context->queryContext());
DB::ActionsDAG::NodeRawConstPtrs args;
for (int i = 0; i < condition_nums; ++i)
{
const auto & ifs = if_then.ifs(i);
const auto * if_node = parseExpression(actions_dag, ifs.if_());
args.emplace_back(if_node);
const auto * then_node = parseExpression(actions_dag, ifs.then());
args.emplace_back(then_node);
}
const auto * else_node = parseExpression(actions_dag, if_then.else_());
args.emplace_back(else_node);
std::string args_name = join(args, ',');
std::string result_name;
if (condition_nums == 1)
result_name = "if(" + args_name + ")";
else
result_name = "multiIf(" + args_name + ")";
const auto * function_node = &actions_dag.addFunction(function_ptr, args, result_name);
actions_dag.addOrReplaceInOutputs(*function_node);
return function_node;
}
case substrait::Expression::RexTypeCase::kScalarFunction: {
return parseFunction(rel.scalar_function(), actions_dag);
}
case substrait::Expression::RexTypeCase::kSingularOrList: {
const auto & options = rel.singular_or_list().options();
/// options is empty always return false
if (options.empty())
return addConstColumn(actions_dag, std::make_shared<DB::DataTypeUInt8>(), 0);
/// options should be literals
if (!options[0].has_literal())
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Options of SingularOrList must have literal type");
DB::ActionsDAG::NodeRawConstPtrs args;
args.emplace_back(parseExpression(actions_dag, rel.singular_or_list().value()));
bool nullable = false;
int options_len = options.size();
for (int i = 0; i < options_len; ++i)
{
if (!options[i].has_literal())
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "in expression values must be the literal!");
if (!nullable)
nullable = options[i].literal().has_null();
}
DB::DataTypePtr elem_type;
std::vector<std::pair<DB::DataTypePtr, DB::Field>> options_type_and_field;
auto first_option = LiteralParser::parse(options[0].literal());
elem_type = wrapNullableType(nullable, first_option.first);
options_type_and_field.emplace_back(std::move(first_option));
for (int i = 1; i < options_len; ++i)
{
auto type_and_field = LiteralParser::parse(options[i].literal());
auto option_type = wrapNullableType(nullable, type_and_field.first);
if (!elem_type->equals(*option_type))
throw DB::Exception(
DB::ErrorCodes::LOGICAL_ERROR,
"SingularOrList options type mismatch:{} and {}",
elem_type->getName(),
option_type->getName());
options_type_and_field.emplace_back(std::move(type_and_field));
}
// check tuple internal types
if (isTuple(elem_type) && isTuple(args[0]->result_type))
{
// Spark guarantees that the types of tuples in the 'in' filter are completely consistent.
// See org.apache.spark.sql.types.DataType#equalsStructurally
// Additionally, the mapping from Spark types to ClickHouse types is one-to-one, See TypeParser.cpp
// So we can directly use the first tuple type as the type of the tuple to avoid nullable mismatch
elem_type = args[0]->result_type;
}
DB::MutableColumnPtr elem_column = elem_type->createColumn();
elem_column->reserve(options_len);
for (int i = 0; i < options_len; ++i)
elem_column->insert(options_type_and_field[i].second);
auto name = getUniqueName("__set");
ColumnWithTypeAndName elem_block{std::move(elem_column), elem_type, name};
PreparedSets prepared_sets;
FutureSet::Hash emptyKey;
auto future_set = prepared_sets.addFromTuple(emptyKey, nullptr, {elem_block}, context->queryContext()->getSettingsRef());
auto arg = DB::ColumnSet::create(1, std::move(future_set));
args.emplace_back(&actions_dag.addColumn(DB::ColumnWithTypeAndName(std::move(arg), std::make_shared<DB::DataTypeSet>(), name)));
const auto * function_node = toFunctionNode(actions_dag, "in", args);
actions_dag.addOrReplaceInOutputs(*function_node);
if (nullable)
{
/// if sets has `null` and value not in sets
/// In Spark: return `null`, is the standard behaviour from ANSI.(SPARK-37920)
/// In CH: return `false`
/// So we used if(a, b, c) cast `false` to `null` if sets has `null`
auto type = wrapNullableType(true, function_node->result_type);
DB::ActionsDAG::NodeRawConstPtrs cast_args(
{function_node, addConstColumn(actions_dag, type, true), addConstColumn(actions_dag, type, DB::Field())});
auto cast = DB::FunctionFactory::instance().get("if", context->queryContext());
function_node = toFunctionNode(actions_dag, "if", cast_args);
actions_dag.addOrReplaceInOutputs(*function_node);
}
return function_node;
}
default:
throw DB::Exception(
DB::ErrorCodes::UNKNOWN_TYPE,
"Unsupported spark expression type {} : {}",
magic_enum::enum_name(rel.rex_type_case()),
rel.DebugString());
}
}