ASTPtr ASTParser::parseArgumentToAST()

in cpp-ch/local-engine/Parser/SerializedPlanParser.cpp [1899:2023]


ASTPtr ASTParser::parseArgumentToAST(const Names & names, const substrait::Expression & rel)
{
    switch (rel.rex_type_case())
    {
        case substrait::Expression::RexTypeCase::kLiteral: {
            DataTypePtr type;
            Field field;
            std::tie(std::ignore, field) = SerializedPlanParser::parseLiteral(rel.literal());
            return std::make_shared<ASTLiteral>(field);
        }
        case substrait::Expression::RexTypeCase::kSelection: {
            if (!rel.selection().has_direct_reference() || !rel.selection().direct_reference().has_struct_field())
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "Can only have direct struct references in selections");

            const auto field = rel.selection().direct_reference().struct_field().field();
            return std::make_shared<ASTIdentifier>(names[field]);
        }
        case substrait::Expression::RexTypeCase::kCast: {
            if (!rel.cast().has_type() || !rel.cast().has_input())
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node.");

            /// Append input to asts
            ASTs args;
            args.emplace_back(parseArgumentToAST(names, rel.cast().input()));

            /// Append destination type to asts
            const auto & substrait_type = rel.cast().type();
            /// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
            if (substrait_type.has_binary())
                return makeASTFunction("reinterpretAsStringSpark", args);
            else
            {
                DataTypePtr ch_type = TypeParser::parseType(substrait_type);
                args.emplace_back(std::make_shared<ASTLiteral>(ch_type->getName()));

                return makeASTFunction("CAST", args);
            }
        }
        case substrait::Expression::RexTypeCase::kIfThen: {
            const auto & if_then = rel.if_then();
            auto condition_nums = if_then.ifs_size();
            std::string ch_function_name = condition_nums == 1 ? "if" : "multiIf";
            auto function_multi_if = DB::FunctionFactory::instance().get(ch_function_name, context);
            ASTs args;

            for (int i = 0; i < condition_nums; ++i)
            {
                const auto & ifs = if_then.ifs(i);
                auto if_node = parseArgumentToAST(names, ifs.if_());
                args.emplace_back(if_node);

                auto then_node = parseArgumentToAST(names, ifs.then());
                args.emplace_back(then_node);
            }

            auto else_node = parseArgumentToAST(names, if_then.else_());
            args.emplace_back(std::move(else_node));
            return makeASTFunction(ch_function_name, args);
        }
        case substrait::Expression::RexTypeCase::kScalarFunction: {
            return parseToAST(names, rel);
        }
        case substrait::Expression::RexTypeCase::kSingularOrList: {
            const auto & options = rel.singular_or_list().options();
            /// options is empty always return false
            if (options.empty())
                return std::make_shared<ASTLiteral>(0);
            /// options should be literals
            if (!options[0].has_literal())
                throw Exception(ErrorCodes::LOGICAL_ERROR, "Options of SingularOrList must have literal type");

            ASTs args;
            args.emplace_back(parseArgumentToAST(names, rel.singular_or_list().value()));

            bool nullable = false;
            size_t options_len = options.size();
            ASTs in_args;
            in_args.reserve(options_len);

            for (int i = 0; i < static_cast<int>(options_len); ++i)
            {
                if (!options[i].has_literal())
                    throw Exception(ErrorCodes::BAD_ARGUMENTS, "in expression values must be the literal!");
                if (!nullable)
                    nullable = options[i].literal().has_null();
            }

            auto elem_type_and_field = SerializedPlanParser::parseLiteral(options[0].literal());
            DataTypePtr elem_type = wrapNullableType(nullable, elem_type_and_field.first);
            for (int i = 0; i < static_cast<int>(options_len); ++i)
            {
                auto type_and_field = SerializedPlanParser::parseLiteral(options[i].literal());
                auto option_type = wrapNullableType(nullable, type_and_field.first);
                if (!elem_type->equals(*option_type))
                    throw Exception(
                        ErrorCodes::LOGICAL_ERROR,
                        "SingularOrList options type mismatch:{} and {}",
                        elem_type->getName(),
                        option_type->getName());

                in_args.emplace_back(std::make_shared<ASTLiteral>(type_and_field.second));
            }
            auto array_ast = makeASTFunction("array", in_args);
            args.emplace_back(array_ast);

            auto ast = makeASTFunction("in", args);
            if (nullable)
            {
                /// if sets has `null` and value not in sets
                /// In Spark: return `null`, is the standard behaviour from ANSI.(SPARK-37920)
                /// In CH: return `false`
                /// So we used if(a, b, c) cast `false` to `null` if sets has `null`
                ast = makeASTFunction("if", ast, std::make_shared<ASTLiteral>(true), std::make_shared<ASTLiteral>(Field()));
            }

            return ast;
        }
        default:
            throw Exception(
                ErrorCodes::UNKNOWN_TYPE,
                "Join on condition error. Unsupported spark expression type {} : {}",
                magic_enum::enum_name(rel.rex_type_case()),
                rel.DebugString());
    }
}