ExpressionParser::NodeRawConstPtr ExpressionParser::parseExpression()

in cpp-ch/local-engine/Parser/ExpressionParser.cpp [284:507]


ExpressionParser::NodeRawConstPtr ExpressionParser::parseExpression(ActionsDAG & actions_dag, const substrait::Expression & rel) const
{
    switch (rel.rex_type_case())
    {
        case substrait::Expression::RexTypeCase::kLiteral: {
            DB::DataTypePtr type;
            DB::Field field;
            std::tie(type, field) = LiteralParser::parse(rel.literal());
            return addConstColumn(actions_dag, type, field);
        }

        case substrait::Expression::RexTypeCase::kSelection: {
            auto field_index = SubstraitParserUtils::getStructFieldIndex(rel);
            if (!field_index)
                throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Can only have direct struct references in selections");

            const auto * field = actions_dag.getInputs()[*field_index];
            return field;
        }

        case substrait::Expression::RexTypeCase::kCast: {
            if (!rel.cast().has_type() || !rel.cast().has_input())
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node.");
            ActionsDAG::NodeRawConstPtrs args;

            const auto & input = rel.cast().input();
            args.emplace_back(parseExpression(actions_dag, input));

            const auto & substrait_type = rel.cast().type();
            const auto & input_type = args[0]->result_type;
            DataTypePtr denull_input_type = removeNullable(input_type);
            DataTypePtr output_type = TypeParser::parseType(substrait_type);
            DataTypePtr denull_output_type = removeNullable(output_type);
            const ActionsDAG::Node * result_node = nullptr;
            if (substrait_type.has_binary())
            {
                /// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
                result_node = toFunctionNode(actions_dag, "reinterpretAsStringSpark", args);
            }
            else if (isString(denull_input_type) && isDate32(denull_output_type))
                result_node = toFunctionNode(actions_dag, "sparkToDate", args);
            else if (isString(denull_input_type) && isDateTime64(denull_output_type))
                result_node = toFunctionNode(actions_dag, "sparkToDateTime", args);
            else if (isDecimal(denull_input_type) && isString(denull_output_type))
            {
                /// Spark cast(x as STRING) if x is Decimal -> CH toDecimalString(x, scale)
                UInt8 scale = getDecimalScale(*denull_input_type);
                args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeUInt8>(), Field(scale)));
                result_node = toFunctionNode(actions_dag, "toDecimalString", args);
            }
            else if (isFloat(denull_input_type) && isInt(denull_output_type))
            {
                String function_name = "sparkCastFloatTo" + denull_output_type->getName();
                result_node = toFunctionNode(actions_dag, function_name, args);
            }
            else if (isFloat(denull_input_type) && isString(denull_output_type))
                result_node = toFunctionNode(actions_dag, "sparkCastFloatToString", args);
            else if ((isDecimal(denull_input_type) || isNativeNumber(denull_input_type)) && substrait_type.has_decimal())
            {
                int precision = substrait_type.decimal().precision();
                int scale = substrait_type.decimal().scale();
                if (precision)
                {
                    args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), precision));
                    args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeInt32>(), scale));
                    result_node = toFunctionNode(actions_dag, "checkDecimalOverflowSparkOrNull", args);
                }
            }
            else if ((isMap(denull_input_type) || isArray(denull_input_type) || isTuple(denull_input_type)) && isString(denull_output_type))
            {
                /// https://github.com/apache/incubator-gluten/issues/9049
                result_node = toFunctionNode(actions_dag, "sparkCastComplexTypesToString", args);
            }
            else if (isString(denull_input_type) && substrait_type.has_bool_())
            {
                /// cast(string to boolean)
                args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeString>(), output_type->getName()));
                result_node = toFunctionNode(actions_dag, "accurateCastOrNull", args);
            }
            else if (isString(denull_input_type) && isInt(denull_output_type))
            {
                /// Spark cast(x as INT) if x is String -> CH cast(trim(x) as INT)
                /// Refer to https://github.com/apache/incubator-gluten/issues/4956 and https://github.com/apache/incubator-gluten/issues/8598
                const auto * trim_str_arg = addConstColumn(actions_dag, std::make_shared<DataTypeString>(), " \t\n\r\f");
                args[0] = toFunctionNode(actions_dag, "trimBothSpark", {args[0], trim_str_arg});
                args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeString>(), output_type->getName()));
                result_node = toFunctionNode(actions_dag, "CAST", args);
            }
            else
            {
                /// Common process: CAST(input, type)
                args.emplace_back(addConstColumn(actions_dag, std::make_shared<DataTypeString>(), output_type->getName()));
                result_node = toFunctionNode(actions_dag, "CAST", args);
            }

            actions_dag.addOrReplaceInOutputs(*result_node);
            return result_node;
        }

        case substrait::Expression::RexTypeCase::kIfThen: {
            const auto & if_then = rel.if_then();
            DB::FunctionOverloadResolverPtr function_ptr = nullptr;
            auto condition_nums = if_then.ifs_size();
            if (condition_nums == 1)
                function_ptr = DB::FunctionFactory::instance().get("if", context->queryContext());
            else
                function_ptr = FunctionFactory::instance().get("multiIf", context->queryContext());
            DB::ActionsDAG::NodeRawConstPtrs args;

            for (int i = 0; i < condition_nums; ++i)
            {
                const auto & ifs = if_then.ifs(i);
                const auto * if_node = parseExpression(actions_dag, ifs.if_());
                args.emplace_back(if_node);

                const auto * then_node = parseExpression(actions_dag, ifs.then());
                args.emplace_back(then_node);
            }

            const auto * else_node = parseExpression(actions_dag, if_then.else_());
            args.emplace_back(else_node);
            std::string args_name = join(args, ',');
            std::string result_name;
            if (condition_nums == 1)
                result_name = "if(" + args_name + ")";
            else
                result_name = "multiIf(" + args_name + ")";
            const auto * function_node = &actions_dag.addFunction(function_ptr, args, result_name);
            actions_dag.addOrReplaceInOutputs(*function_node);
            return function_node;
        }

        case substrait::Expression::RexTypeCase::kScalarFunction: {
            return parseFunction(rel.scalar_function(), actions_dag);
        }

        case substrait::Expression::RexTypeCase::kSingularOrList: {
            const auto & options = rel.singular_or_list().options();
            /// options is empty always return false
            if (options.empty())
                return addConstColumn(actions_dag, std::make_shared<DB::DataTypeUInt8>(), 0);
            /// options should be literals
            if (!options[0].has_literal())
                throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Options of SingularOrList must have literal type");

            DB::ActionsDAG::NodeRawConstPtrs args;
            args.emplace_back(parseExpression(actions_dag, rel.singular_or_list().value()));

            bool nullable = false;
            int options_len = options.size();
            for (int i = 0; i < options_len; ++i)
            {
                if (!options[i].has_literal())
                    throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "in expression values must be the literal!");
                if (!nullable)
                    nullable = options[i].literal().has_null();
            }

            DB::DataTypePtr elem_type;
            std::vector<std::pair<DB::DataTypePtr, DB::Field>> options_type_and_field;
            auto first_option = LiteralParser::parse(options[0].literal());
            elem_type = wrapNullableType(nullable, first_option.first);
            options_type_and_field.emplace_back(std::move(first_option));
            for (int i = 1; i < options_len; ++i)
            {
                auto type_and_field = LiteralParser::parse(options[i].literal());
                auto option_type = wrapNullableType(nullable, type_and_field.first);
                if (!elem_type->equals(*option_type))
                    throw DB::Exception(
                        DB::ErrorCodes::LOGICAL_ERROR,
                        "SingularOrList options type mismatch:{} and {}",
                        elem_type->getName(),
                        option_type->getName());
                options_type_and_field.emplace_back(std::move(type_and_field));
            }

            // check tuple internal types
            if (isTuple(elem_type) && isTuple(args[0]->result_type))
            {
                // Spark guarantees that the types of tuples in the 'in' filter are completely consistent.
                // See org.apache.spark.sql.types.DataType#equalsStructurally
                // Additionally, the mapping from Spark types to ClickHouse types is one-to-one, See TypeParser.cpp
                // So we can directly use the first tuple type as the type of the tuple to avoid nullable mismatch
                elem_type = args[0]->result_type;
            }
            DB::MutableColumnPtr elem_column = elem_type->createColumn();
            elem_column->reserve(options_len);
            for (int i = 0; i < options_len; ++i)
                elem_column->insert(options_type_and_field[i].second);
            auto name = getUniqueName("__set");
            ColumnWithTypeAndName elem_block{std::move(elem_column), elem_type, name};

            PreparedSets prepared_sets;
            FutureSet::Hash emptyKey;
            auto future_set = prepared_sets.addFromTuple(emptyKey, nullptr, {elem_block}, context->queryContext()->getSettingsRef());
            auto arg = DB::ColumnSet::create(1, std::move(future_set));
            args.emplace_back(&actions_dag.addColumn(DB::ColumnWithTypeAndName(std::move(arg), std::make_shared<DB::DataTypeSet>(), name)));

            const auto * function_node = toFunctionNode(actions_dag, "in", args);
            actions_dag.addOrReplaceInOutputs(*function_node);
            if (nullable)
            {
                /// if sets has `null` and value not in sets
                /// In Spark: return `null`, is the standard behaviour from ANSI.(SPARK-37920)
                /// In CH: return `false`
                /// So we used if(a, b, c) cast `false` to `null` if sets has `null`
                auto type = wrapNullableType(true, function_node->result_type);
                DB::ActionsDAG::NodeRawConstPtrs cast_args(
                    {function_node, addConstColumn(actions_dag, type, true), addConstColumn(actions_dag, type, DB::Field())});
                auto cast = DB::FunctionFactory::instance().get("if", context->queryContext());
                function_node = toFunctionNode(actions_dag, "if", cast_args);
                actions_dag.addOrReplaceInOutputs(*function_node);
            }
            return function_node;
        }

        default:
            throw DB::Exception(
                DB::ErrorCodes::UNKNOWN_TYPE,
                "Unsupported spark expression type {} : {}",
                magic_enum::enum_name(rel.rex_type_case()),
                rel.DebugString());
    }
}