/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "SerializedPlanParser.h"
#include <algorithm>
#include <memory>
#include <string>
#include <string_view>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <Columns/ColumnSet.h>
#include <Core/Block.h>
#include <Core/ColumnWithTypeAndName.h>
#include <Core/Names.h>
#include <Core/NamesAndTypes.h>
#include <Core/Types.h>
#include <Core/Field.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeDate32.h>
#include <DataTypes/DataTypeDateTime64.h>
#include <DataTypes/DataTypeFactory.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeNothing.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeSet.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesDecimal.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <DataTypes/getLeastSupertype.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionsConversion.h>
#include <Interpreters/ActionsDAG.h>
#include <Interpreters/ActionsVisitor.h>
#include <Interpreters/CollectJoinOnKeysVisitor.h>
#include <Interpreters/Context.h>
#include <Interpreters/PreparedSets.h>
#include <Interpreters/ProcessList.h>
#include <Interpreters/QueryPriorities.h>
#include <Join/StorageJoinFromReadBuffer.h>
#include <Operator/BlocksBufferPoolTransform.h>
#include <Parser/FunctionParser.h>
#include <Parser/JoinRelParser.h>
#include <Parser/RelParser.h>
#include <Parser/TypeParser.h>
#include <Parser/MergeTreeRelParser.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ExpressionListParsers.h>
#include <Processors/Executors/PullingAsyncPipelineExecutor.h>
#include <Processors/Formats/Impl/ArrowBlockOutputFormat.h>
#include <Processors/QueryPlan/AggregatingStep.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/FilterStep.h>
#include <Processors/QueryPlan/LimitStep.h>
#include <Processors/QueryPlan/MergingAggregatedStep.h>
#include <Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <Processors/QueryPlan/ReadFromPreparedSource.h>
#include <Processors/Transforms/AggregatingTransform.h>
#include <Processors/Transforms/MaterializingTransform.h>
#include <QueryPipeline/Pipe.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <QueryPipeline/printPipeline.h>
#include <Storages/CustomStorageMergeTree.h>
#include <Storages/IStorage.h>
#include <Storages/MergeTree/MergeTreeData.h>
#include <Storages/StorageMergeTreeFactory.h>
#include <Storages/SubstraitSource/SubstraitFileSource.h>
#include <Storages/SubstraitSource/SubstraitFileSourceStep.h>
#include <google/protobuf/util/json_util.h>
#include <google/protobuf/wrappers.pb.h>
#include <Poco/Util/MapConfiguration.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>
#include <Common/MergeTreeTool.h>
#include <Common/logger_useful.h>
#include <Common/typeid_cast.h>

namespace DB
{
namespace ErrorCodes
{
    extern const int LOGICAL_ERROR;
    extern const int UNKNOWN_TYPE;
    extern const int BAD_ARGUMENTS;
    extern const int NO_SUCH_DATA_PART;
    extern const int UNKNOWN_FUNCTION;
    extern const int CANNOT_PARSE_PROTOBUF_SCHEMA;
    extern const int ILLEGAL_TYPE_OF_ARGUMENT;
    extern const int INVALID_JOIN_ON_EXPRESSION;
}
}

namespace local_engine
{
using namespace DB;

std::string join(const ActionsDAG::NodeRawConstPtrs & v, char c)
{
    std::string res;
    for (size_t i = 0; i < v.size(); ++i)
    {
        if (i)
            res += c;
        res += v[i]->result_name;
    }
    return res;
}

void logDebugMessage(const google::protobuf::Message & message, const char * type)
{
    auto * logger = &Poco::Logger::get("SerializedPlanParser");
    if (logger->debug())
    {
        namespace pb_util = google::protobuf::util;
        pb_util::JsonOptions options;
        std::string json;
        auto s = pb_util::MessageToJsonString(message, &json, options);
        if (!s.ok())
            throw Exception(ErrorCodes::LOGICAL_ERROR, "Can not convert {} to Json", type);
        LOG_DEBUG(logger, "{}:\n{}", type, json);
    }
}

const ActionsDAG::Node * SerializedPlanParser::addColumn(DB::ActionsDAGPtr actions_dag, const DataTypePtr & type, const Field & field)
{
    return &actions_dag->addColumn(
        ColumnWithTypeAndName(type->createColumnConst(1, field), type, getUniqueName(toString(field).substr(0, 10))));
}


void SerializedPlanParser::parseExtensions(
    const ::google::protobuf::RepeatedPtrField<substrait::extensions::SimpleExtensionDeclaration> & extensions)
{
    for (const auto & extension : extensions)
    {
        if (extension.has_extension_function())
        {
            function_mapping.emplace(
                std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name());
        }
    }
}

std::shared_ptr<DB::ActionsDAG> SerializedPlanParser::expressionsToActionsDAG(
    const std::vector<substrait::Expression> & expressions, const DB::Block & header, const DB::Block & read_schema)
{
    auto actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header));
    NamesWithAliases required_columns;
    std::set<String> distinct_columns;

    for (const auto & expr : expressions)
    {
        if (expr.has_selection())
        {
            auto position = expr.selection().direct_reference().struct_field().field();
            auto col_name = read_schema.getByPosition(position).name;
            const ActionsDAG::Node * field = actions_dag->tryFindInOutputs(col_name);
            if (distinct_columns.contains(field->result_name))
            {
                auto unique_name = getUniqueName(field->result_name);
                required_columns.emplace_back(NameWithAlias(field->result_name, unique_name));
                distinct_columns.emplace(unique_name);
            }
            else
            {
                required_columns.emplace_back(NameWithAlias(field->result_name, field->result_name));
                distinct_columns.emplace(field->result_name);
            }
        }
        else if (expr.has_scalar_function())
        {
            const auto & scalar_function = expr.scalar_function();
            auto function_signature = function_mapping.at(std::to_string(scalar_function.function_reference()));

            std::vector<String> result_names;
            if (startsWith(function_signature, "explode:"))
                actions_dag = parseArrayJoin(header, expr, result_names, actions_dag, true, false);
            else if (startsWith(function_signature, "posexplode:"))
                actions_dag = parseArrayJoin(header, expr, result_names, actions_dag, true, true);
            else if (startsWith(function_signature, "json_tuple:"))
                actions_dag = parseJsonTuple(header, expr, result_names, actions_dag, true, false);
            else
            {
                result_names.resize(1);
                actions_dag = parseFunction(header, expr, result_names[0], actions_dag, true);
            }
            for (const auto & result_name : result_names)
            {
                if (result_name.empty())
                    continue;

                if (distinct_columns.contains(result_name))
                {
                    auto unique_name = getUniqueName(result_name);
                    required_columns.emplace_back(NameWithAlias(result_name, unique_name));
                    distinct_columns.emplace(unique_name);
                }
                else
                {
                    required_columns.emplace_back(NameWithAlias(result_name, result_name));
                    distinct_columns.emplace(result_name);
                }
            }
        }
        else if (expr.has_cast() || expr.has_if_then() || expr.has_literal())
        {
            const auto * node = parseExpression(actions_dag, expr);
            actions_dag->addOrReplaceInOutputs(*node);
            if (distinct_columns.contains(node->result_name))
            {
                auto unique_name = getUniqueName(node->result_name);
                required_columns.emplace_back(NameWithAlias(node->result_name, unique_name));
                distinct_columns.emplace(unique_name);
            }
            else
            {
                required_columns.emplace_back(NameWithAlias(node->result_name, node->result_name));
                distinct_columns.emplace(node->result_name);
            }
        }
        else
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case()));
    }
    actions_dag->project(required_columns);
    return actions_dag;
}

std::string getDecimalFunction(const substrait::Type_Decimal & decimal, bool null_on_overflow)
{
    std::string ch_function_name;
    UInt32 precision = decimal.precision();

    if (precision <= DataTypeDecimal32::maxPrecision())
        ch_function_name = "toDecimal32";
    else if (precision <= DataTypeDecimal64::maxPrecision())
        ch_function_name = "toDecimal64";
    else if (precision <= DataTypeDecimal128::maxPrecision())
        ch_function_name = "toDecimal128";
    else
        throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision);

    if (null_on_overflow)
        ch_function_name = ch_function_name + "OrNull";
    return ch_function_name;
}

bool SerializedPlanParser::isReadRelFromJava(const substrait::ReadRel & rel)
{
    return rel.has_local_files() && rel.local_files().items().size() == 1 && rel.local_files().items().at(0).uri_file().starts_with("iterator");
}

bool SerializedPlanParser::isReadFromMergeTree(const substrait::ReadRel & rel)
{
    assert(rel.has_advanced_extension());
    bool is_read_from_merge_tree;
    google::protobuf::StringValue optimization;
    optimization.ParseFromString(rel.advanced_extension().optimization().value());
    ReadBufferFromString in(optimization.value());
    assertString("isMergeTree=", in);
    readBoolText(is_read_from_merge_tree, in);
    assertChar('\n', in);
    return is_read_from_merge_tree;
}

QueryPlanStepPtr SerializedPlanParser::parseReadRealWithLocalFile(const substrait::ReadRel & rel)
{
    auto header = TypeParser::buildBlockFromNamedStruct(rel.base_schema());
    substrait::ReadRel::LocalFiles local_files;
    if (rel.has_local_files())
        local_files = rel.local_files();
    else
        local_files = parseLocalFiles(split_infos.at(nextSplitInfoIndex()));
    auto source = std::make_shared<SubstraitFileSource>(context, header, local_files);
    auto source_pipe = Pipe(source);
    auto source_step = std::make_unique<SubstraitFileSourceStep>(context, std::move(source_pipe), "substrait local files");
    source_step->setStepDescription("read local files");
    if (rel.has_filter())
    {
        const ActionsDAGPtr actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header));
        const ActionsDAG::Node * filter_node = parseExpression(actions_dag, rel.filter());
        actions_dag->addOrReplaceInOutputs(*filter_node);
        source_step->addFilter(actions_dag, filter_node);
    }
    return source_step;
}

QueryPlanStepPtr SerializedPlanParser::parseReadRealWithJavaIter(const substrait::ReadRel & rel)
{
    assert(rel.has_local_files());
    assert(rel.local_files().items().size() == 1);
    auto iter = rel.local_files().items().at(0).uri_file();
    auto pos = iter.find(':');
    auto iter_index = std::stoi(iter.substr(pos + 1, iter.size()));

    auto source = std::make_shared<SourceFromJavaIter>(
        context, TypeParser::buildBlockFromNamedStruct(rel.base_schema()), input_iters[iter_index], materialize_inputs[iter_index]);
    QueryPlanStepPtr source_step = std::make_unique<ReadFromPreparedSource>(Pipe(source));
    source_step->setStepDescription("Read From Java Iter");
    return source_step;
}

IQueryPlanStep * SerializedPlanParser::addRemoveNullableStep(QueryPlan & plan, const std::set<String> & columns)
{
    if (columns.empty())
        return nullptr;

    auto remove_nullable_actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(plan.getCurrentDataStream().header));
    removeNullable(columns, remove_nullable_actions_dag);
    auto expression_step = std::make_unique<ExpressionStep>(plan.getCurrentDataStream(), remove_nullable_actions_dag);
    expression_step->setStepDescription("Remove nullable properties");
    auto * step_ptr = expression_step.get();
    plan.addStep(std::move(expression_step));
    return step_ptr;
}

PrewhereInfoPtr SerializedPlanParser::parsePreWhereInfo(const substrait::Expression & rel, Block & input)
{
    auto prewhere_info = std::make_shared<PrewhereInfo>();
    prewhere_info->prewhere_actions = std::make_shared<ActionsDAG>(input.getNamesAndTypesList());
    std::string filter_name;
    // for in function
    if (rel.has_singular_or_list())
    {
        const auto * in_node = parseExpression(prewhere_info->prewhere_actions, rel);
        prewhere_info->prewhere_actions->addOrReplaceInOutputs(*in_node);
        filter_name = in_node->result_name;
    }
    else
    {
        parseFunctionWithDAG(rel, filter_name, prewhere_info->prewhere_actions, true);
    }
    prewhere_info->prewhere_column_name = filter_name;
    prewhere_info->need_filter = true;
    prewhere_info->remove_prewhere_column = true;
    auto cols = prewhere_info->prewhere_actions->getRequiredColumnsNames();
    // Keep it the same as the input.
    prewhere_info->prewhere_actions->removeUnusedActions(Names{filter_name}, false, true);
    prewhere_info->prewhere_actions->projectInput(false);
    for (const auto & name : input.getNames())
    {
        prewhere_info->prewhere_actions->tryRestoreColumn(name);
    }
    return prewhere_info;
}

DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr nested_type)
{
    return wrapNullableType(nullable == substrait::Type_Nullability_NULLABILITY_NULLABLE, nested_type);
}

DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type)
{
    if (nullable && !nested_type->isNullable())
        return std::make_shared<DataTypeNullable>(nested_type);
    else
        return nested_type;
}

QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr<substrait::Plan> plan)
{
    logDebugMessage(*plan, "substrait plan");
    parseExtensions(plan->extensions());
    if (plan->relations_size() == 1)
    {
        auto root_rel = plan->relations().at(0);
        if (!root_rel.has_root())
        {
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "must have root rel!");
        }
        std::list<const substrait::Rel *> rel_stack;
        auto query_plan = parseOp(root_rel.root().input(), rel_stack);
        if (root_rel.root().names_size())
        {
            ActionsDAGPtr actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(query_plan->getCurrentDataStream().header));
            NamesWithAliases aliases;
            auto cols = query_plan->getCurrentDataStream().header.getNamesAndTypesList();
            if (cols.getNames().size() != static_cast<size_t>(root_rel.root().names_size()))
            {
                throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Missmatch result columns size.");
            }
            for (int i = 0; i < static_cast<int>(cols.getNames().size()); i++)
            {
                aliases.emplace_back(NameWithAlias(cols.getNames()[i], root_rel.root().names(i)));
            }
            actions_dag->project(aliases);
            auto expression_step = std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(), actions_dag);
            expression_step->setStepDescription("Rename Output");
            query_plan->addStep(std::move(expression_step));
        }

        // fixes: issue-1874, to keep the nullability as expected.
        const auto & output_schema = root_rel.root().output_schema();
        if (output_schema.types_size())
        {
            auto original_header = query_plan->getCurrentDataStream().header;
            const auto & original_cols = original_header.getColumnsWithTypeAndName();
            if (static_cast<size_t>(output_schema.types_size()) != original_cols.size())
            {
                throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Mismatch output schema");
            }
            bool need_final_project = false;
            DB::ColumnsWithTypeAndName final_cols;
            for (int i = 0; i < output_schema.types_size(); ++i)
            {
                const auto & col = original_cols[i];
                auto type = TypeParser::parseType(output_schema.types(i));
                // At present, we only check nullable mismatch.
                // intermediate aggregate data is special, no check here.
                if (type->isNullable() != col.type->isNullable() && !typeid_cast<const DB::DataTypeAggregateFunction *>(col.type.get()))
                {
                    if (type->isNullable())
                    {
                        final_cols.emplace_back(type->createColumn(), std::make_shared<DB::DataTypeNullable>(col.type), col.name);
                    }
                    else
                    {
                        final_cols.emplace_back(type->createColumn(), DB::removeNullable(col.type), col.name);
                    }
                    need_final_project = true;
                }
                else
                {
                    final_cols.push_back(col);
                }
            }
            if (need_final_project)
            {
                ActionsDAGPtr final_project
                    = ActionsDAG::makeConvertingActions(original_cols, final_cols, ActionsDAG::MatchColumnsMode::Position);
                QueryPlanStepPtr final_project_step = std::make_unique<ExpressionStep>(query_plan->getCurrentDataStream(), final_project);
                final_project_step->setStepDescription("Project for output schema");
                query_plan->addStep(std::move(final_project_step));
            }
        }
        return query_plan;
    }
    else
    {
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "too many relations found");
    }
}

QueryPlanPtr SerializedPlanParser::parseOp(const substrait::Rel & rel, std::list<const substrait::Rel *> & rel_stack)
{
    QueryPlanPtr query_plan;
    std::vector<IQueryPlanStep *> steps;
    switch (rel.rel_type_case())
    {
        case substrait::Rel::RelTypeCase::kFetch: {
            rel_stack.push_back(&rel);
            const auto & limit = rel.fetch();
            query_plan = parseOp(limit.input(), rel_stack);
            rel_stack.pop_back();
            auto limit_step = std::make_unique<LimitStep>(query_plan->getCurrentDataStream(), limit.count(), limit.offset());
            limit_step->setStepDescription("LIMIT");
            steps.emplace_back(limit_step.get());
            query_plan->addStep(std::move(limit_step));
            break;
        }
        case substrait::Rel::RelTypeCase::kRead: {
            const auto & read = rel.read();
            // TODO: We still maintain the old logic of parsing LocalFiles or ExtensionTable in RealRel
            // to be compatiable with some suites about metrics.
            // Remove this compatiability in later and then only java iter has local files in ReadRel.
            if (read.has_local_files() || (!read.has_extension_table() && !isReadFromMergeTree(read)))
            {
                assert(rel.has_base_schema());
                QueryPlanStepPtr step;
                if (isReadRelFromJava(read))
                    step = parseReadRealWithJavaIter(read);
                else
                    step = parseReadRealWithLocalFile(read);

                query_plan = std::make_unique<QueryPlan>();
                steps.emplace_back(step.get());
                query_plan->addStep(std::move(step));

                // Add a buffer after source, it try to preload data from source and reduce the
                // waiting time of downstream nodes.
                if (context->getSettingsRef().max_threads > 1)
                {
                    auto buffer_step = std::make_unique<BlocksBufferPoolStep>(query_plan->getCurrentDataStream());
                    steps.emplace_back(buffer_step.get());
                    query_plan->addStep(std::move(buffer_step));
                }
            }
            else
            {
                substrait::ReadRel::ExtensionTable extension_table;
                if (read.has_extension_table())
                    extension_table = read.extension_table();
                else
                    extension_table = parseExtensionTable(split_infos.at(nextSplitInfoIndex()));

                MergeTreeRelParser mergeTreeParser(this, context, query_context, global_context);
                std::list<const substrait::Rel *> stack;
                query_plan = mergeTreeParser.parseReadRel(std::make_unique<QueryPlan>(), read, extension_table, stack);
                steps = mergeTreeParser.getSteps();
            }
            break;
        }
        case substrait::Rel::RelTypeCase::kFilter:
        case substrait::Rel::RelTypeCase::kGenerate:
        case substrait::Rel::RelTypeCase::kProject:
        case substrait::Rel::RelTypeCase::kAggregate:
        case substrait::Rel::RelTypeCase::kSort:
        case substrait::Rel::RelTypeCase::kWindow:
        case substrait::Rel::RelTypeCase::kJoin:
        case substrait::Rel::RelTypeCase::kExpand: {
            auto op_parser = RelParserFactory::instance().getBuilder(rel.rel_type_case())(this);
            query_plan = op_parser->parseOp(rel, rel_stack);
            auto parser_steps = op_parser->getSteps();
            steps.insert(steps.end(), parser_steps.begin(), parser_steps.end());
            break;
        }
        default:
            throw Exception(ErrorCodes::UNKNOWN_TYPE, "doesn't support relation type: {}.\n{}", rel.rel_type_case(), rel.DebugString());
    }

    if (!context->getSettingsRef().query_plan_enable_optimizations)
    {
        if (rel.rel_type_case() == substrait::Rel::RelTypeCase::kRead)
        {
            size_t id = metrics.empty() ? 0 : metrics.back()->getId() + 1;
            metrics.emplace_back(std::make_shared<RelMetric>(id, String(magic_enum::enum_name(rel.rel_type_case())), steps));
        }
        else
            metrics = {std::make_shared<RelMetric>(String(magic_enum::enum_name(rel.rel_type_case())), metrics, steps)};
    }

    return query_plan;
}

NamesAndTypesList SerializedPlanParser::blockToNameAndTypeList(const Block & header)
{
    NamesAndTypesList types;
    for (const auto & name : header.getNames())
    {
        const auto * column = header.findByName(name);
        types.push_back(NameAndTypePair(column->name, column->type));
    }
    return types;
}

std::string
SerializedPlanParser::getFunctionName(const std::string & function_signature, const substrait::Expression_ScalarFunction & function)
{
    auto args = function.arguments();
    auto pos = function_signature.find(':');
    auto function_name = function_signature.substr(0, pos);
    if (!SCALAR_FUNCTIONS.contains(function_name))
        throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "Unsupported function {}", function_name);

    std::string ch_function_name;
    if (function_name == "trim")
        ch_function_name = args.size() == 1 ? "trimBoth" : "trimBothSpark";
    else if (function_name == "ltrim")
        ch_function_name = args.size() == 1 ? "trimLeft" : "trimLeftSpark";
    else if (function_name == "rtrim")
        ch_function_name = args.size() == 1 ? "trimRight" : "trimRightSpark";
    else if (function_name == "extract")
    {
        if (args.size() != 2)
            throw Exception(
                ErrorCodes::BAD_ARGUMENTS, "Spark function extract requires two args, function:{}", function.ShortDebugString());

        // Get the first arg: field
        const auto & extract_field = args.at(0);

        if (extract_field.value().has_literal())
        {
            const auto & field_value = extract_field.value().literal().string();
            if (field_value == "YEAR")
                ch_function_name = "toYear"; // spark: extract(YEAR FROM) or year
            else if (field_value == "YEAR_OF_WEEK")
                ch_function_name = "toISOYear"; // spark: extract(YEAROFWEEK FROM)
            else if (field_value == "QUARTER")
                ch_function_name = "toQuarter"; // spark: extract(QUARTER FROM) or quarter
            else if (field_value == "MONTH")
                ch_function_name = "toMonth"; // spark: extract(MONTH FROM) or month
            else if (field_value == "WEEK_OF_YEAR")
                ch_function_name = "toISOWeek"; // spark: extract(WEEK FROM) or weekofyear
            else if (field_value == "WEEK_DAY")
                /// Spark WeekDay(date) (0 = Monday, 1 = Tuesday, ..., 6 = Sunday)
                /// Substrait: extract(WEEK_DAY from date)
                /// CH: toDayOfWeek(date, 1)
                ch_function_name = "toDayOfWeek";
            else if (field_value == "DAY_OF_WEEK")
                /// Spark: DayOfWeek(date) (1 = Sunday, 2 = Monday, ..., 7 = Saturday)
                /// Substrait: extract(DAY_OF_WEEK from date)
                /// CH: toDayOfWeek(date, 3)
                /// DAYOFWEEK is alias of function toDayOfWeek.
                /// This trick is to distinguish between extract fields DAY_OF_WEEK and WEEK_DAY in latter codes
                ch_function_name = "DAYOFWEEK";
            else if (field_value == "DAY")
                ch_function_name = "toDayOfMonth"; // spark: extract(DAY FROM) or dayofmonth
            else if (field_value == "DAY_OF_YEAR")
                ch_function_name = "toDayOfYear"; // spark: extract(DOY FROM) or dayofyear
            else if (field_value == "HOUR")
                ch_function_name = "toHour"; // spark: extract(HOUR FROM) or hour
            else if (field_value == "MINUTE")
                ch_function_name = "toMinute"; // spark: extract(MINUTE FROM) or minute
            else if (field_value == "SECOND")
                ch_function_name = "toSecond"; // spark: extract(SECOND FROM) or secondwithfraction
            else
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "The first arg of spark extract function is wrong.");
        }
        else
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "The first arg of spark extract function is wrong.");
    }
    else if (function_name == "sha2")
    {
        if (args.size() != 2)
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "Spark function sha2 requires two args, function:{}", function.ShortDebugString());

        const auto & bit_length = args.at(1);
        if (!bit_length.value().has_literal() || !bit_length.value().literal().has_i32())
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "The second arg of spark sha2 function is wrong.");

        const auto & bit_length_value = bit_length.value().literal().i32();
        if (bit_length_value == 224)
            ch_function_name = "SHA224";
        else if (bit_length_value == 256 || bit_length_value == 0)
            ch_function_name = "SHA256";
        else if (bit_length_value == 384)
            ch_function_name = "SHA384";
        else if (bit_length_value == 512)
            ch_function_name = "SHA512";
        else
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "The second arg of spark sha2 function is wrong, value:{}", bit_length_value);
    }
    else if (function_name == "check_overflow")
    {
        if (args.size() < 2)
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires at least two args.");
        ch_function_name = SCALAR_FUNCTIONS.at(function_name);
        auto null_on_overflow = args.at(1).value().literal().boolean();
        if (null_on_overflow)
            ch_function_name = ch_function_name + "OrNull";
    }
    else if (function_name == "make_decimal")
    {
        if (args.size() < 2)
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "make_decimal function requires at least 2 args.");
        ch_function_name = SCALAR_FUNCTIONS.at(function_name);
        auto null_on_overflow = args.at(1).value().literal().boolean();
        if (null_on_overflow)
            ch_function_name = ch_function_name + "OrNull";
    }
    else if (function_name == "char_length")
    {
        /// In Spark
        /// char_length returns the number of bytes when input is binary type, corresponding to CH length function
        /// char_length returns the number of characters when input is string type, corresponding to CH char_length function
        ch_function_name = SCALAR_FUNCTIONS.at(function_name);
        if (function_signature.find("vbin") != std::string::npos)
            ch_function_name = "length";
    }
    else if (function_name == "reverse")
    {
        if (function.output_type().has_list())
            ch_function_name = "arrayReverse";
        else
            ch_function_name = "reverseUTF8";
    }
    else if (function_name == "concat")
    {
        /// 1. ConcatOverloadResolver cannot build arrayConcat for Nullable(Array) type which causes failures when using functions like concat(split()).
        ///    So we use arrayConcat directly if the output type is array.
        /// 2. CH ConcatImpl can only accept at least 2 arguments, but Spark concat can accept 1 argument, like concat('a')
        ///    in such case we use identity function
        if (function.output_type().has_list())
            ch_function_name = "arrayConcat";
        else if (args.size() == 1)
            ch_function_name = "identity";
        else
            ch_function_name = "concat";
    }
    else
        ch_function_name = SCALAR_FUNCTIONS.at(function_name);

    return ch_function_name;
}

ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG(
    const substrait::Expression & rel, std::vector<String> & result_names, DB::ActionsDAGPtr actions_dag, bool keep_result, bool position)
{
    if (!rel.has_scalar_function())
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression should be a scalar function:\n {}", rel.DebugString());

    const auto & scalar_function = rel.scalar_function();

    auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference()));
    auto function_name = getFunctionName(function_signature, scalar_function);
    if (function_name != "arrayJoin")
        throw Exception(
            ErrorCodes::LOGICAL_ERROR,
            "Function parseArrayJoinWithDAG should only process arrayJoin function, but input is {}",
            rel.ShortDebugString());

    /// The argument number of arrayJoin(converted from Spark explode/posexplode) should be 1
    if (scalar_function.arguments_size() != 1)
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "Argument number of arrayJoin should be 1 but is {}", scalar_function.arguments_size());

    ActionsDAG::NodeRawConstPtrs args;
    parseFunctionArguments(actions_dag, args, function_name, scalar_function);

    auto arg_type = DB::removeNullable(args[0]->result_type);
    /// array() or map()
    const auto * empty_map_or_array_node
        = addColumn(actions_dag, DB::removeNullable(args[0]->result_type), isMap(arg_type) ? Field(Map()) : Field(Array()));
    /// ifNull(args[0], array() or map())
    const auto * if_null_node = toFunctionNode(actions_dag, "ifNull", {args[0], empty_map_or_array_node});
    /// assumeNotNull(ifNull(args[0], array() or map()))
    const auto * arg_not_null = toFunctionNode(actions_dag, "assumeNotNull", {if_null_node});
    /// Wrap with materalize function to make sure column input to ARRAY JOIN STEP is materaized
    arg_not_null = &actions_dag->materializeNode(*arg_not_null);

    /// arrayJoin(arg_not_null)
    /// Note: Make sure result_name keep the same after applying arrayJoin function, which makes it much easier to transform arrayJoin function to ARRAY JOIN STEP
    /// Otherwise an alias node must be appended after ARRAY JOIN STEP, which is not a graceful implementation.
    auto array_join_name = arg_not_null->result_name;
    const auto * array_join_node = &actions_dag->addArrayJoin(*arg_not_null, array_join_name);

    auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context);
    auto tuple_index_type = std::make_shared<DataTypeUInt32>();
    auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node *
    {
        ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i)));
        const auto * index_node = &actions_dag->addColumn(std::move(index_col));
        auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")";
        return &actions_dag->addFunction(tuple_element_builder, {tuple_node, index_node}, result_name);
    };

    /// Special process to keep compatiable with Spark
    WhichDataType which(arg_type.get());
    if (!position)
    {
        /// Spark: explode(array_or_map) -> CH: arrayJoin(array_or_map)
        if (which.isMap())
        {
            /// In Spark: explode(map(k, v)) output 2 columns with default names "key" and "value"
            /// In CH: arrayJoin(map(k, v)) output 1 column with Tuple Type.
            /// So we must wrap arrayJoin with sparkTupleElement function for compatiability.

            /// arrayJoin(arg_not_null).1
            const auto * key_node = add_tuple_element(array_join_node, 1);

            /// arrayJoin(arg_not_null).2
            const auto * val_node = add_tuple_element(array_join_node, 2);

            result_names.push_back(key_node->result_name);
            result_names.push_back(val_node->result_name);
            if (keep_result)
            {
                actions_dag->addOrReplaceInOutputs(*key_node);
                actions_dag->addOrReplaceInOutputs(*val_node);
            }
            return {key_node, val_node};
        }
        else if (which.isArray())
        {
            result_names.push_back(array_join_name);
            if (keep_result)
                actions_dag->addOrReplaceInOutputs(*array_join_node);
            return {array_join_node};
        }
        else
            throw Exception(
                ErrorCodes::BAD_ARGUMENTS,
                "Argument type of arrayJoin converted from explode should be Array or Map but is {}",
                arg_type->getName());
    }
    else
    {
        /// Spark: posexplode(array_or_map) -> CH: arrayJoin(map), in which map = mapFromArrays(range(length(array_or_map)), array_or_map)
        if (which.isMap())
        {
            /// In Spark: posexplode(array_of_map) output 2 or 3 columns: (pos, col) or (pos, key, value)
            /// In CH: arrayJoin(map(k, v)) output 1 column with Tuple Type.
            /// So we must wrap arrayJoin with sparkTupleElement function for compatiability.

            /// pos = arrayJoin(arg_not_null).1
            const auto * pos_node = add_tuple_element(array_join_node, 1);

            /// col = arrayJoin(arg_not_null).2 or (key, value) = arrayJoin(arg_not_null).2
            const auto * item_node = add_tuple_element(array_join_node, 2);

            /// It is a tricky but efficient way to get the original type of argument type in posexplode
            if (endsWith(args[0]->result_name, "type_hint:map"))
            {
                /// key = arrayJoin(arg_not_null).2.1
                const auto * item_key_node = add_tuple_element(item_node, 1);

                /// value = arrayJoin(arg_not_null).2.2
                const auto * item_value_node = add_tuple_element(item_node, 2);

                result_names.push_back(pos_node->result_name);
                result_names.push_back(item_key_node->result_name);
                result_names.push_back(item_value_node->result_name);
                if (keep_result)
                {
                    actions_dag->addOrReplaceInOutputs(*pos_node);
                    actions_dag->addOrReplaceInOutputs(*item_key_node);
                    actions_dag->addOrReplaceInOutputs(*item_value_node);
                }

                return {pos_node, item_key_node, item_value_node};
            }
            else if (endsWith(args[0]->result_name, "type_hint:array"))
            {
                /// col = arrayJoin(arg_not_null).2
                result_names.push_back(pos_node->result_name);
                result_names.push_back(item_node->result_name);
                if (keep_result)
                {
                    actions_dag->addOrReplaceInOutputs(*pos_node);
                    actions_dag->addOrReplaceInOutputs(*item_node);
                }
                return {pos_node, item_node};
            }
            else
                throw Exception(
                    ErrorCodes::BAD_ARGUMENTS, "The raw input of arrayJoin converted from posexplode should be Array or Map type");
        }
        else
            throw Exception(
                ErrorCodes::BAD_ARGUMENTS,
                "Argument type of arrayJoin converted from posexplode should be Map but is {}",
                arg_type->getName());
    }
}

const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG(
    const substrait::Expression & rel, std::string & result_name, DB::ActionsDAGPtr actions_dag, bool keep_result)
{
    if (!rel.has_scalar_function())
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression should be a scalar function:\n {}", rel.DebugString());

    const auto & scalar_function = rel.scalar_function();
    auto function_signature = function_mapping.at(std::to_string(scalar_function.function_reference()));

    /// If the substrait function name is registered in FunctionParserFactory, use it to parse the function, and return result directly
    auto pos = function_signature.find(':');
    auto func_name = function_signature.substr(0, pos);

    auto func_parser = FunctionParserFactory::instance().tryGet(func_name, this);
    if (func_parser)
    {
        LOG_DEBUG(
            &Poco::Logger::get("SerializedPlanParser"), "parse function {} by function parser: {}", func_name, func_parser->getName());
        const auto * result_node = func_parser->parse(scalar_function, actions_dag);
        if (keep_result)
            actions_dag->addOrReplaceInOutputs(*result_node);

        result_name = result_node->result_name;
        return result_node;
    }

    auto ch_func_name = getFunctionName(function_signature, scalar_function);
    ActionsDAG::NodeRawConstPtrs args;
    parseFunctionArguments(actions_dag, args, ch_func_name, scalar_function);

    /// If the first argument of function formatDateTimeInJodaSyntax is integer, replace formatDateTimeInJodaSyntax with fromUnixTimestampInJodaSyntax
    /// to avoid exception
    if (ch_func_name == "formatDateTimeInJodaSyntax")
    {
        if (args.size() > 1 && isInteger(DB::removeNullable(args[0]->result_type)))
            ch_func_name = "fromUnixTimestampInJodaSyntax";
    }

    const ActionsDAG::Node * result_node;
    if (ch_func_name == "alias")
    {
        result_name = args[0]->result_name;
        actions_dag->addOrReplaceInOutputs(*args[0]);
        result_node = &actions_dag->addAlias(actions_dag->findInOutputs(result_name), result_name);
    }
    else
    {
        if (ch_func_name == "splitByRegexp")
        {
            if (args.size() >= 2)
            {
                /// In Spark: split(str, regex [, limit] )
                /// In CH: splitByRegexp(regexp, str [, limit])
                std::swap(args[0], args[1]);
            }
        }

        if (function_signature.find("check_overflow:", 0) != function_signature.npos)
        {
            if (scalar_function.arguments().size() < 2)
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "check_overflow function requires at least two args.");

            ActionsDAG::NodeRawConstPtrs new_args;
            new_args.reserve(3);
            new_args.emplace_back(args[0]);

            UInt32 precision = rel.scalar_function().output_type().decimal().precision();
            UInt32 scale = rel.scalar_function().output_type().decimal().scale();
            auto uint32_type = std::make_shared<DataTypeUInt32>();
            new_args.emplace_back(&actions_dag->addColumn(
                ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision)))));
            new_args.emplace_back(&actions_dag->addColumn(
                ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale)))));
            args = std::move(new_args);
        }
        else if (startsWith(function_signature, "make_decimal:"))
        {
            if (scalar_function.arguments().size() < 2)
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "make_decimal function requires at least 2 args.");

            ActionsDAG::NodeRawConstPtrs new_args;
            new_args.reserve(3);
            new_args.emplace_back(args[0]);

            UInt32 precision = rel.scalar_function().output_type().decimal().precision();
            UInt32 scale = rel.scalar_function().output_type().decimal().scale();
            auto uint32_type = std::make_shared<DataTypeUInt32>();
            new_args.emplace_back(&actions_dag->addColumn(
                ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision)))));
            new_args.emplace_back(&actions_dag->addColumn(
                ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale)))));
            args = std::move(new_args);
        }

        bool converted_decimal_args = convertBinaryArithmeticFunDecimalArgs(actions_dag, args, scalar_function);
        auto function_builder = FunctionFactory::instance().get(ch_func_name, context);
        std::string args_name = join(args, ',');
        result_name = ch_func_name + "(" + args_name + ")";
        const auto * function_node = &actions_dag->addFunction(function_builder, args, result_name);
        result_node = function_node;
        if (!TypeParser::isTypeMatched(rel.scalar_function().output_type(), function_node->result_type) && !converted_decimal_args)
        {
            auto result_type = TypeParser::parseType(rel.scalar_function().output_type());
            if (isDecimalOrNullableDecimal(result_type))
            {
                result_node = ActionsDAGUtil::convertNodeType(
                    actions_dag,
                    function_node,
                    // as stated in isTypeMatched， currently we don't change nullability of the result type
                    function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName()
                                                             : local_engine::removeNullable(result_type)->getName(),
                    function_node->result_name,
                    DB::CastType::accurateOrNull);
            }
            else
            {
                result_node = ActionsDAGUtil::convertNodeType(
                    actions_dag,
                    function_node,
                    // as stated in isTypeMatched， currently we don't change nullability of the result type
                    function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName()
                                                             : local_engine::removeNullable(result_type)->getName(),
                    function_node->result_name);
            }
        }

        if (ch_func_name == "JSON_VALUE")
            result_node->function->setResolver(function_builder);

        if (keep_result)
            actions_dag->addOrReplaceInOutputs(*result_node);
    }
    return result_node;
}

bool SerializedPlanParser::convertBinaryArithmeticFunDecimalArgs(
    ActionsDAGPtr actions_dag, ActionsDAG::NodeRawConstPtrs & args, const substrait::Expression_ScalarFunction & arithmeticFun)
{
    auto function_signature = function_mapping.at(std::to_string(arithmeticFun.function_reference()));
    auto pos = function_signature.find(':');
    auto func_name = function_signature.substr(0, pos);

    if (func_name == "divide" || func_name == "multiply" || func_name == "plus" || func_name == "minus")
    {
        /// for divide/plus/minus, we need to convert first arg to result precision and scale
        /// for multiply, we need to convert first arg to result precision, but keep scale
        if (isDecimalOrNullableDecimal(args[0]->result_type) && isDecimalOrNullableDecimal(args[1]->result_type))
        {
            UInt32 p1 = getDecimalPrecision(*DB::removeNullable(args[0]->result_type));
            UInt32 s1 = getDecimalScale(*DB::removeNullable(args[0]->result_type));
            UInt32 p2 = getDecimalPrecision(*DB::removeNullable(args[1]->result_type));
            UInt32 s2 = getDecimalScale(*DB::removeNullable(args[1]->result_type));

            UInt32 precision;
            UInt32 scale;

            if (func_name == "plus" || func_name == "minus")
            {
                scale = s1;
                precision = scale + std::max(p1 - s1, p2 - s2) + 1;
            }
            else if (func_name == "divide")
            {
                scale = std::max(static_cast<UInt32>(6), s1 + p2 + 1);
                precision = p1 - s1 + s2 + scale;
            }
            else // multiply
            {
                scale = s1;
                precision = p1 + p2 + 1;
            }

            UInt32 maxPrecision = DataTypeDecimal256::maxPrecision();
            UInt32 maxScale = DataTypeDecimal128::maxPrecision();
            precision = std::min(precision, maxPrecision);
            scale = std::min(scale, maxScale);

            ActionsDAG::NodeRawConstPtrs new_args;
            new_args.reserve(args.size());

            ActionsDAG::NodeRawConstPtrs cast_args;
            cast_args.reserve(2);
            cast_args.emplace_back(args[0]);
            DataTypePtr ch_type = createDecimal<DataTypeDecimal>(precision, scale);
            ch_type = wrapNullableType(arithmeticFun.output_type().decimal().nullability(), ch_type);
            String type_name = ch_type->getName();
            DataTypePtr str_type = std::make_shared<DataTypeString>();
            const ActionsDAG::Node * type_node = &actions_dag->addColumn(
                ColumnWithTypeAndName(str_type->createColumnConst(1, type_name), str_type, getUniqueName(type_name)));
            cast_args.emplace_back(type_node);
            const ActionsDAG::Node * cast_node = toFunctionNode(actions_dag, "CAST", cast_args);
            actions_dag->addOrReplaceInOutputs(*cast_node);
            new_args.emplace_back(cast_node);
            new_args.emplace_back(args[1]);
            args = std::move(new_args);
            return true;
        }
    }
    return false;
}

void SerializedPlanParser::parseFunctionArguments(
    DB::ActionsDAGPtr & actions_dag,
    ActionsDAG::NodeRawConstPtrs & parsed_args,
    std::string & function_name,
    const substrait::Expression_ScalarFunction & scalar_function)
{
    auto function_signature = function_mapping.at(std::to_string(scalar_function.function_reference()));
    const auto & args = scalar_function.arguments();
    parsed_args.reserve(args.size());

    // Some functions need to be handled specially.
    if (function_name == "JSONExtract")
    {
        parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
        auto data_type = TypeParser::parseType(scalar_function.output_type());
        parsed_args.emplace_back(addColumn(actions_dag, std::make_shared<DB::DataTypeString>(), data_type->getName()));
    }
    else if (function_name == "sparkTupleElement" || function_name == "tupleElement")
    {
        parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);

        if (!args[1].value().has_literal())
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's second argument must be a literal");

        auto [data_type, field] = parseLiteral(args[1].value().literal());
        if (data_type->getTypeId() != DB::TypeIndex::Int32)
            throw Exception(ErrorCodes::BAD_ARGUMENTS, "get_struct_field's second argument must be i32");

        // tuple indecies start from 1, in spark, start from 0
        Int32 field_index = static_cast<Int32>(field.get<Int32>() + 1);
        const auto * index_node = addColumn(actions_dag, std::make_shared<DB::DataTypeUInt32>(), field_index);
        parsed_args.emplace_back(index_node);
    }
    else if (function_name == "tuple")
    {
        // Arguments in the format, (<field name>, <value expression>[, <field name>, <value expression> ...])
        // We don't need to care the field names here.
        for (int index = 1; index < args.size(); index += 2)
            parseFunctionArgument(actions_dag, parsed_args, function_name, args[index]);
    }
    else if (function_name == "repeat")
    {
        // repeat. the field index must be unsigned integer in CH, cast the signed integer in substrait
        // which must be a positive value into unsigned integer here.
        parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
        const DB::ActionsDAG::Node * repeat_times_node = parseFunctionArgument(actions_dag, function_name, args[1]);
        DB::DataTypeNullable target_type(std::make_shared<DB::DataTypeUInt32>());
        repeat_times_node = ActionsDAGUtil::convertNodeType(actions_dag, repeat_times_node, target_type.getName());
        parsed_args.emplace_back(repeat_times_node);
    }
    else if (function_name == "isNaN")
    {
        // the result of isNaN(NULL) is NULL in CH, but false in Spark
        const DB::ActionsDAG::Node * arg_node = nullptr;
        if (args[0].value().has_cast())
        {
            arg_node = parseExpression(actions_dag, args[0].value().cast().input());
            const auto * res_type = arg_node->result_type.get();
            if (res_type->isNullable())
            {
                res_type = typeid_cast<const DB::DataTypeNullable *>(res_type)->getNestedType().get();
            }
            if (isString(*res_type))
            {
                DB::ActionsDAG::NodeRawConstPtrs cast_func_args = {arg_node};
                arg_node = toFunctionNode(actions_dag, "toFloat64OrZero", cast_func_args);
            }
            else
            {
                arg_node = parseFunctionArgument(actions_dag, function_name, args[0]);
            }
        }
        else
        {
            arg_node = parseFunctionArgument(actions_dag, function_name, args[0]);
        }

        DB::ActionsDAG::NodeRawConstPtrs ifnull_func_args = {arg_node, addColumn(actions_dag, std::make_shared<DataTypeInt32>(), 0)};
        parsed_args.emplace_back(toFunctionNode(actions_dag, "IfNull", ifnull_func_args));
    }
    else if (function_name == "positionUTF8Spark")
    {
        if (args.size() >= 2)
        {
            // In Spark: position(substr, str, Int32)
            // In CH:    position(str, subtr, UInt32)
            parseFunctionArgument(actions_dag, parsed_args, function_name, args[1]);
            parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
        }
        if (args.size() >= 3)
        {
            // add cast: cast(start_pos as UInt32)
            const auto * start_pos_node = parseFunctionArgument(actions_dag, function_name, args[2]);
            DB::DataTypeNullable target_type(std::make_shared<DB::DataTypeUInt32>());
            start_pos_node = ActionsDAGUtil::convertNodeType(actions_dag, start_pos_node, target_type.getName());
            parsed_args.emplace_back(start_pos_node);
        }
    }
    else if (function_name == "space")
    {
        // convert space function to repeat
        const DB::ActionsDAG::Node * repeat_times_node = parseFunctionArgument(actions_dag, "repeat", args[0]);
        const DB::ActionsDAG::Node * space_str_node = addColumn(actions_dag, std::make_shared<DataTypeString>(), " ");
        function_name = "repeat";
        parsed_args.emplace_back(space_str_node);
        parsed_args.emplace_back(repeat_times_node);
    }
    else if (function_name == "trimBothSpark" || function_name == "trimLeftSpark" || function_name == "trimRightSpark")
    {
        /// In substrait, the first arg is srcStr, the second arg is trimStr
        /// But in CH, the first arg is trimStr, the second arg is srcStr
        parseFunctionArgument(actions_dag, parsed_args, function_name, args[1]);
        parseFunctionArgument(actions_dag, parsed_args, function_name, args[0]);
    }
    else if (startsWith(function_signature, "extract:"))
    {
        /// Skip the first arg of extract in substrait
        for (int i = 1; i < args.size(); i++)
            parseFunctionArgument(actions_dag, parsed_args, function_name, args[i]);

        /// Append extra mode argument for extract(WEEK_DAY from date) or extract(DAY_OF_WEEK from date) in substrait
        if (function_name == "toDayOfWeek" || function_name == "DAYOFWEEK")
        {
            UInt8 mode = function_name == "toDayOfWeek" ? 1 : 3;
            auto mode_type = std::make_shared<DataTypeUInt8>();
            ColumnWithTypeAndName mode_col(mode_type->createColumnConst(1, mode), mode_type, getUniqueName(std::to_string(mode)));
            const auto & mode_node = actions_dag->addColumn(std::move(mode_col));
            parsed_args.emplace_back(&mode_node);
        }
    }
    else if (startsWith(function_signature, "sha2:"))
    {
        for (int i = 0; i < args.size() - 1; i++)
            parseFunctionArgument(actions_dag, parsed_args, function_name, args[i]);
    }
    else
    {
        // Default handle
        for (const auto & arg : args)
            parseFunctionArgument(actions_dag, parsed_args, function_name, arg);
    }
}

void SerializedPlanParser::parseFunctionArgument(
    DB::ActionsDAGPtr & actions_dag,
    ActionsDAG::NodeRawConstPtrs & parsed_args,
    const std::string & function_name,
    const substrait::FunctionArgument & arg)
{
    parsed_args.emplace_back(parseFunctionArgument(actions_dag, function_name, arg));
}

const DB::ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument(
    DB::ActionsDAGPtr & actions_dag, const std::string & function_name, const substrait::FunctionArgument & arg)
{
    const DB::ActionsDAG::Node * res;
    if (arg.value().has_scalar_function())
    {
        std::string arg_name;
        bool keep_arg = FUNCTION_NEED_KEEP_ARGUMENTS.contains(function_name);
        parseFunctionWithDAG(arg.value(), arg_name, actions_dag, keep_arg);
        res = &actions_dag->getNodes().back();
    }
    else
    {
        res = parseExpression(actions_dag, arg.value());
    }
    return res;
}

// Convert signed integer index into unsigned integer index
std::pair<DB::DataTypePtr, DB::Field> SerializedPlanParser::convertStructFieldType(const DB::DataTypePtr & type, const DB::Field & field)
{
// For tupelElement, field index starts from 1, but int substrait plan, it starts from 0.
#define UINT_CONVERT(type_ptr, field, type_name) \
    if ((type_ptr)->getTypeId() == DB::TypeIndex::type_name) \
    { \
        return {std::make_shared<DB::DataTypeU##type_name>(), static_cast<U##type_name>((field).get<type_name>()) + 1}; \
    }

    auto type_id = type->getTypeId();
    if (type_id == DB::TypeIndex::UInt8 || type_id == DB::TypeIndex::UInt16 || type_id == DB::TypeIndex::UInt32
        || type_id == DB::TypeIndex::UInt64)
    {
        return {type, field};
    }
    UINT_CONVERT(type, field, Int8)
    UINT_CONVERT(type, field, Int16)
    UINT_CONVERT(type, field, Int32)
    UINT_CONVERT(type, field, Int64)
    throw DB::Exception(DB::ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Not valid interger type: {}", type->getName());
#undef UINT_CONVERT
}

ActionsDAGPtr SerializedPlanParser::parseFunction(
    const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result)
{
    if (!actions_dag)
        actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header));

    parseFunctionWithDAG(rel, result_name, actions_dag, keep_result);
    return actions_dag;
}

ActionsDAGPtr SerializedPlanParser::parseFunctionOrExpression(
    const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result)
{
    if (!actions_dag)
        actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(header));

    if (rel.has_scalar_function())
        parseFunctionWithDAG(rel, result_name, actions_dag, keep_result);
    else
    {
        const auto * result_node = parseExpression(actions_dag, rel);
        result_name = result_node->result_name;
    }

    return actions_dag;
}

ActionsDAGPtr SerializedPlanParser::parseArrayJoin(
    const Block & input,
    const substrait::Expression & rel,
    std::vector<String> & result_names,
    ActionsDAGPtr actions_dag,
    bool keep_result,
    bool position)
{
    if (!actions_dag)
        actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(input));

    parseArrayJoinWithDAG(rel, result_names, actions_dag, keep_result, position);
    return actions_dag;
}

ActionsDAGPtr SerializedPlanParser::parseJsonTuple(
    const Block & input,
    const substrait::Expression & rel,
    std::vector<String> & result_names,
    ActionsDAGPtr actions_dag,
    bool keep_result,
    bool)
{
    if (!actions_dag)
    {
        actions_dag = std::make_shared<ActionsDAG>(blockToNameAndTypeList(input));
    }

    const auto & scalar_function = rel.scalar_function();
    auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference()));
    auto function_name = getFunctionName(function_signature, scalar_function);
    auto args = scalar_function.arguments();
    if (args.size() < 2)
    {
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "The function json_tuple should has at least 2 arguments.");
    }
    auto first_arg = args[0].value();
    const DB::ActionsDAG::Node * json_expr_node = parseExpression(actions_dag, first_arg);
    std::string extract_expr = "Tuple(";
    for (int i = 1; i < args.size(); i++)
    {
        auto arg_value = args[i].value();
        if (!arg_value.has_literal())
        {
            throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "The arguments of function {} must be string literal", function_name);
        }
        DB::Field f = arg_value.literal().string();
        std::string s;
        if (f.tryGet(s))
        {
            extract_expr.append(s).append(" Nullable(String)");
            if (i != args.size() - 1)
            {
                extract_expr.append(",");
            }
        }
    }
    extract_expr.append(")");
    const DB::ActionsDAG::Node * extract_expr_node = addColumn(actions_dag, std::make_shared<DataTypeString>(), extract_expr);
    auto json_extract_builder = FunctionFactory::instance().get("JSONExtract", context);
    auto json_extract_result_name = "JSONExtract(" + json_expr_node->result_name + "," + extract_expr_node->result_name + ")";
    const ActionsDAG::Node * json_extract_node
        = &actions_dag->addFunction(json_extract_builder, {json_expr_node, extract_expr_node}, json_extract_result_name);
    auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context);
    auto tuple_index_type = std::make_shared<DataTypeUInt32>();
    auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node *
    {
        ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i)));
        const auto * index_node = &actions_dag->addColumn(std::move(index_col));
        auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")";
        return &actions_dag->addFunction(tuple_element_builder, {tuple_node, index_node}, result_name);
    };
    for (int i = 1; i < args.size(); i++)
    {
        const ActionsDAG::Node * tuple_node = add_tuple_element(json_extract_node, i);
        if (keep_result)
        {
            actions_dag->addOrReplaceInOutputs(*tuple_node);
            result_names.push_back(tuple_node->result_name);
        }
    }
    return actions_dag;
}

const ActionsDAG::Node *
SerializedPlanParser::toFunctionNode(ActionsDAGPtr actions_dag, const String & function, const DB::ActionsDAG::NodeRawConstPtrs & args)
{
    auto function_builder = DB::FunctionFactory::instance().get(function, context);
    std::string args_name = join(args, ',');
    auto result_name = function + "(" + args_name + ")";
    const auto * function_node = &actions_dag->addFunction(function_builder, args, result_name);
    return function_node;
}

std::pair<DataTypePtr, Field> SerializedPlanParser::parseLiteral(const substrait::Expression_Literal & literal)
{
    DataTypePtr type;
    Field field;

    switch (literal.literal_type_case())
    {
        case substrait::Expression_Literal::kFp64: {
            type = std::make_shared<DataTypeFloat64>();
            field = literal.fp64();
            break;
        }
        case substrait::Expression_Literal::kFp32: {
            type = std::make_shared<DataTypeFloat32>();
            field = literal.fp32();
            break;
        }
        case substrait::Expression_Literal::kString: {
            type = std::make_shared<DataTypeString>();
            field = literal.string();
            break;
        }
        case substrait::Expression_Literal::kBinary: {
            type = std::make_shared<DataTypeString>();
            field = literal.binary();
            break;
        }
        case substrait::Expression_Literal::kI64: {
            type = std::make_shared<DataTypeInt64>();
            field = literal.i64();
            break;
        }
        case substrait::Expression_Literal::kI32: {
            type = std::make_shared<DataTypeInt32>();
            field = literal.i32();
            break;
        }
        case substrait::Expression_Literal::kBoolean: {
            type = std::make_shared<DataTypeUInt8>();
            field = literal.boolean() ? UInt8(1) : UInt8(0);
            break;
        }
        case substrait::Expression_Literal::kI16: {
            type = std::make_shared<DataTypeInt16>();
            field = literal.i16();
            break;
        }
        case substrait::Expression_Literal::kI8: {
            type = std::make_shared<DataTypeInt8>();
            field = literal.i8();
            break;
        }
        case substrait::Expression_Literal::kDate: {
            type = std::make_shared<DataTypeDate32>();
            field = literal.date();
            break;
        }
        case substrait::Expression_Literal::kTimestamp: {
            type = std::make_shared<DataTypeDateTime64>(6);
            field = DecimalField<DateTime64>(literal.timestamp(), 6);
            break;
        }
        case substrait::Expression_Literal::kDecimal: {
            UInt32 precision = literal.decimal().precision();
            UInt32 scale = literal.decimal().scale();
            const auto & bytes = literal.decimal().value();

            if (precision <= DataTypeDecimal32::maxPrecision())
            {
                type = std::make_shared<DataTypeDecimal32>(precision, scale);
                auto value = *reinterpret_cast<const Int32 *>(bytes.data());
                field = DecimalField<Decimal32>(value, scale);
            }
            else if (precision <= DataTypeDecimal64::maxPrecision())
            {
                type = std::make_shared<DataTypeDecimal64>(precision, scale);
                auto value = *reinterpret_cast<const Int64 *>(bytes.data());
                field = DecimalField<Decimal64>(value, scale);
            }
            else if (precision <= DataTypeDecimal128::maxPrecision())
            {
                type = std::make_shared<DataTypeDecimal128>(precision, scale);
                String bytes_copy(bytes);
                auto value = *reinterpret_cast<Decimal128 *>(bytes_copy.data());
                field = DecimalField<Decimal128>(value, scale);
            }
            else
                throw Exception(ErrorCodes::UNKNOWN_TYPE, "Spark doesn't support decimal type with precision {}", precision);
            break;
        }
        case substrait::Expression_Literal::kList: {
            const auto & values = literal.list().values();
            if (values.empty())
            {
                type = std::make_shared<DataTypeArray>(std::make_shared<DataTypeNothing>());
                field = Array();
                break;
            }

            DataTypePtr common_type;
            std::tie(common_type, std::ignore) = parseLiteral(values[0]);
            size_t list_len = values.size();
            Array array(list_len);
            for (int i = 0; i < static_cast<int>(list_len); ++i)
            {
                auto type_and_field = parseLiteral(values[i]);
                common_type = getLeastSupertype(DataTypes{common_type, type_and_field.first});
                array[i] = std::move(type_and_field.second);
            }

            type = std::make_shared<DataTypeArray>(common_type);
            field = std::move(array);
            break;
        }
        case substrait::Expression_Literal::kMap: {
            const auto & key_values = literal.map().key_values();
            if (key_values.empty())
            {
                type = std::make_shared<DataTypeMap>(std::make_shared<DataTypeNothing>(), std::make_shared<DataTypeNothing>());
                field = Map();
                break;
            }

            const auto & first_key_value = key_values[0];

            DataTypePtr common_key_type;
            std::tie(common_key_type, std::ignore) = parseLiteral(first_key_value.key());

            DataTypePtr common_value_type;
            std::tie(common_value_type, std::ignore) = parseLiteral(first_key_value.value());

            Map map;
            map.reserve(key_values.size());
            for (const auto & key_value : key_values)
            {
                Tuple tuple(2);

                DataTypePtr key_type;
                std::tie(key_type, tuple[0]) = parseLiteral(key_value.key());
                /// Each key should has the same type
                if (!common_key_type->equals(*key_type))
                    throw Exception(
                        ErrorCodes::LOGICAL_ERROR,
                        "Literal map key type mismatch:{} and {}",
                        common_key_type->getName(),
                        key_type->getName());

                DataTypePtr value_type;
                std::tie(value_type, tuple[1]) = parseLiteral(key_value.value());
                /// Each value should has least super type for all of them
                common_value_type = getLeastSupertype(DataTypes{common_value_type, value_type});

                map.emplace_back(std::move(tuple));
            }

            type = std::make_shared<DataTypeMap>(common_key_type, common_value_type);
            field = std::move(map);
            break;
        }
        case substrait::Expression_Literal::kStruct: {
            const auto & fields = literal.struct_().fields();

            DataTypes types;
            types.reserve(fields.size());
            Tuple tuple;
            tuple.reserve(fields.size());
            for (const auto & f : fields)
            {
                DataTypePtr field_type;
                Field field_value;
                std::tie(field_type, field_value) = parseLiteral(f);

                types.emplace_back(std::move(field_type));
                tuple.emplace_back(std::move(field_value));
            }

            type = std::make_shared<DataTypeTuple>(types);
            field = std::move(tuple);
            break;
        }
        case substrait::Expression_Literal::kNull: {
            type = TypeParser::parseType(literal.null());
            field = Field{};
            break;
        }
        default: {
            throw Exception(
                ErrorCodes::UNKNOWN_TYPE, "Unsupported spark literal type {}", magic_enum::enum_name(literal.literal_type_case()));
        }
    }
    return std::make_pair(std::move(type), std::move(field));
}

const ActionsDAG::Node * SerializedPlanParser::parseExpression(ActionsDAGPtr actions_dag, const substrait::Expression & rel)
{
    switch (rel.rex_type_case())
    {
        case substrait::Expression::RexTypeCase::kLiteral: {
            DataTypePtr type;
            Field field;
            std::tie(type, field) = parseLiteral(rel.literal());
            return addColumn(actions_dag, type, field);
        }

        case substrait::Expression::RexTypeCase::kSelection: {
            if (!rel.selection().has_direct_reference() || !rel.selection().direct_reference().has_struct_field())
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "Can only have direct struct references in selections");

            const auto * field = actions_dag->getInputs()[rel.selection().direct_reference().struct_field().field()];
            return actions_dag->tryFindInOutputs(field->result_name);
        }

        case substrait::Expression::RexTypeCase::kCast: {
            if (!rel.cast().has_type() || !rel.cast().has_input())
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node.");

            DB::ActionsDAG::NodeRawConstPtrs args;

            const auto & input = rel.cast().input();
            args.emplace_back(parseExpression(actions_dag, input));

            const auto & substrait_type = rel.cast().type();
            const ActionsDAG::Node * function_node = nullptr;
            if (DB::isString(DB::removeNullable(args.back()->result_type)) && substrait_type.has_date())
            {
                function_node = toFunctionNode(actions_dag, "spark_to_date", args);
            }
            else if (substrait_type.has_binary())
            {
                // Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
                function_node = toFunctionNode(actions_dag, "reinterpretAsStringSpark", args);
            }
            else
            {
                DataTypePtr ch_type = TypeParser::parseType(substrait_type);
                if(DB::isString(DB::removeNullable(ch_type)) && isDecimalOrNullableDecimal(args[0]->result_type))
                {
                    UInt8 scale = getDecimalScale(*DB::removeNullable(args[0]->result_type));
                    args.emplace_back(addColumn(actions_dag, std::make_shared<DataTypeUInt8>(), Field(scale)));
                    function_node = toFunctionNode(actions_dag, "toDecimalString", args);
                }
                else
                {
                    if (isFloat(DB::removeNullable(args[0]->result_type)) && isInt(DB::removeNullable(ch_type)))
                    {
                        String function_name = "sparkCastFloatTo" + DB::removeNullable(ch_type)->getName();
                        function_node = toFunctionNode(actions_dag, function_name, args);
                    }
                    else
                    {
                        args.emplace_back(addColumn(actions_dag, std::make_shared<DataTypeString>(), ch_type->getName()));
                        function_node = toFunctionNode(actions_dag, "CAST", args);
                    }
                }
            }

            actions_dag->addOrReplaceInOutputs(*function_node);
            return function_node;
        }

        case substrait::Expression::RexTypeCase::kIfThen: {
            const auto & if_then = rel.if_then();
            DB::FunctionOverloadResolverPtr function_ptr = nullptr;
            auto condition_nums = if_then.ifs_size();
            if (condition_nums == 1)
                function_ptr = DB::FunctionFactory::instance().get("if", context);
            else
                function_ptr = DB::FunctionFactory::instance().get("multiIf", context);
            DB::ActionsDAG::NodeRawConstPtrs args;

            for (int i = 0; i < condition_nums; ++i)
            {
                const auto & ifs = if_then.ifs(i);
                const auto * if_node = parseExpression(actions_dag, ifs.if_());
                args.emplace_back(if_node);

                const auto * then_node = parseExpression(actions_dag, ifs.then());
                args.emplace_back(then_node);
            }

            const auto * else_node = parseExpression(actions_dag, if_then.else_());
            args.emplace_back(else_node);
            std::string args_name = join(args, ',');
            std::string result_name;
            if (condition_nums == 1)
                result_name = "if(" + args_name + ")";
            else
                result_name = "multiIf(" + args_name + ")";
            const auto * function_node = &actions_dag->addFunction(function_ptr, args, result_name);
            actions_dag->addOrReplaceInOutputs(*function_node);
            return function_node;
        }

        case substrait::Expression::RexTypeCase::kScalarFunction: {
            std::string result;
            return parseFunctionWithDAG(rel, result, actions_dag, false);
        }

        case substrait::Expression::RexTypeCase::kSingularOrList: {
            const auto & options = rel.singular_or_list().options();
            /// options is empty always return false
            if (options.empty())
                return addColumn(actions_dag, std::make_shared<DataTypeUInt8>(), 0);
            /// options should be literals
            if (!options[0].has_literal())
                throw Exception(ErrorCodes::LOGICAL_ERROR, "Options of SingularOrList must have literal type");

            DB::ActionsDAG::NodeRawConstPtrs args;
            args.emplace_back(parseExpression(actions_dag, rel.singular_or_list().value()));

            bool nullable = false;
            int options_len = static_cast<int>(options.size());
            for (int i = 0; i < options_len; ++i)
            {
                if (!options[i].has_literal())
                    throw Exception(ErrorCodes::BAD_ARGUMENTS, "in expression values must be the literal!");
                if (!nullable)
                    nullable = options[i].literal().has_null();
            }

            DataTypePtr elem_type;
            std::tie(elem_type, std::ignore) = parseLiteral(options[0].literal());
            elem_type = wrapNullableType(nullable, elem_type);

            MutableColumnPtr elem_column = elem_type->createColumn();
            elem_column->reserve(options_len);
            for (int i = 0; i < options_len; ++i)
            {
                auto type_and_field = parseLiteral(options[i].literal());
                auto option_type = wrapNullableType(nullable, type_and_field.first);
                if (!elem_type->equals(*option_type))
                    throw Exception(
                        ErrorCodes::LOGICAL_ERROR,
                        "SingularOrList options type mismatch:{} and {}",
                        elem_type->getName(),
                        option_type->getName());

                elem_column->insert(type_and_field.second);
            }

            MutableColumns elem_columns;
            elem_columns.emplace_back(std::move(elem_column));

            auto name = getUniqueName("__set");
            Block elem_block;
            elem_block.insert(ColumnWithTypeAndName(nullptr, elem_type, name));
            elem_block.setColumns(std::move(elem_columns));

            SizeLimits limit;
            auto elem_set = std::make_shared<Set>(limit, true, false);
            elem_set->setHeader(elem_block.getColumnsWithTypeAndName());
            elem_set->insertFromBlock(elem_block.getColumnsWithTypeAndName());
            elem_set->finishInsert();

            auto future_set = std::make_shared<FutureSetFromStorage>(std::move(elem_set));
            auto arg = ColumnSet::create(1, std::move(future_set));
            args.emplace_back(&actions_dag->addColumn(ColumnWithTypeAndName(std::move(arg), std::make_shared<DataTypeSet>(), name)));

            const auto * function_node = toFunctionNode(actions_dag, "in", args);
            actions_dag->addOrReplaceInOutputs(*function_node);
            if (nullable)
            {
                /// if sets has `null` and value not in sets
                /// In Spark: return `null`, is the standard behaviour from ANSI.(SPARK-37920)
                /// In CH: return `false`
                /// So we used if(a, b, c) cast `false` to `null` if sets has `null`
                auto type = wrapNullableType(true, function_node->result_type);
                DB::ActionsDAG::NodeRawConstPtrs cast_args(
                    {function_node, addColumn(actions_dag, type, true), addColumn(actions_dag, type, Field())});
                auto cast = FunctionFactory::instance().get("if", context);
                function_node = toFunctionNode(actions_dag, "if", cast_args);
                actions_dag->addOrReplaceInOutputs(*function_node);
            }
            return function_node;
        }

        default:
            throw Exception(
                ErrorCodes::UNKNOWN_TYPE,
                "Unsupported spark expression type {} : {}",
                magic_enum::enum_name(rel.rex_type_case()),
                rel.DebugString());
    }
}

substrait::ReadRel::ExtensionTable SerializedPlanParser::parseExtensionTable(const std::string & split_info)
{
    substrait::ReadRel::ExtensionTable extension_table;
    google::protobuf::io::CodedInputStream coded_in(reinterpret_cast<const uint8_t *>(split_info.data()), static_cast<int>(split_info.size()));
    coded_in.SetRecursionLimit(100000);

    auto ok = extension_table.ParseFromCodedStream(&coded_in);
    if (!ok)
        throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::ReadRel::ExtensionTable from string failed");
    logDebugMessage(extension_table, "extension_table");
    return extension_table;
}

substrait::ReadRel::LocalFiles SerializedPlanParser::parseLocalFiles(const std::string & split_info)
{
    substrait::ReadRel::LocalFiles local_files;
    google::protobuf::io::CodedInputStream coded_in(reinterpret_cast<const uint8_t *>(split_info.data()), static_cast<int>(split_info.size()));
    coded_in.SetRecursionLimit(100000);

    auto ok = local_files.ParseFromCodedStream(&coded_in);
    if (!ok)
        throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::ReadRel::LocalFiles from string failed");
    logDebugMessage(local_files, "local_files");
    return local_files;
}


QueryPlanPtr SerializedPlanParser::parse(const std::string & plan)
{
    auto plan_ptr = std::make_unique<substrait::Plan>();
    /// https://stackoverflow.com/questions/52028583/getting-error-parsing-protobuf-data
    /// Parsing may fail when the number of recursive layers is large.
    /// Here, set a limit large enough to avoid this problem.
    /// Once this problem occurs, it is difficult to troubleshoot, because the pb of c++ will not provide any valid information
    google::protobuf::io::CodedInputStream coded_in(reinterpret_cast<const uint8_t *>(plan.data()), static_cast<int>(plan.size()));
    coded_in.SetRecursionLimit(100000);

    auto ok = plan_ptr->ParseFromCodedStream(&coded_in);
    if (!ok)
        throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from string failed");

    auto res = parse(std::move(plan_ptr));

    auto * logger = &Poco::Logger::get("SerializedPlanParser");
    if (logger->debug())
    {
        auto out = PlanUtil::explainPlan(*res);
        LOG_DEBUG(logger, "clickhouse plan:\n{}", out);
    }
    return res;
}

QueryPlanPtr SerializedPlanParser::parseJson(const std::string & json_plan)
{
    auto plan_ptr = std::make_unique<substrait::Plan>();
    auto s = google::protobuf::util::JsonStringToMessage(absl::string_view(json_plan.c_str()), plan_ptr.get());
    if (!s.ok())
        throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from json string failed: {}", s.ToString());
    return parse(std::move(plan_ptr));
}

SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) : context(context_)
{
}

ContextMutablePtr SerializedPlanParser::global_context = nullptr;

Context::ConfigurationPtr SerializedPlanParser::config = nullptr;

void SerializedPlanParser::collectJoinKeys(
    const substrait::Expression & condition, std::vector<std::pair<int32_t, int32_t>> & join_keys, int32_t right_key_start)
{
    auto condition_name = getFunctionName(
        function_mapping.at(std::to_string(condition.scalar_function().function_reference())), condition.scalar_function());
    if (condition_name == "and")
    {
        collectJoinKeys(condition.scalar_function().arguments(0).value(), join_keys, right_key_start);
        collectJoinKeys(condition.scalar_function().arguments(1).value(), join_keys, right_key_start);
    }
    else if (condition_name == "equals")
    {
        const auto & function = condition.scalar_function();
        auto left_key_idx = function.arguments(0).value().selection().direct_reference().struct_field().field();
        auto right_key_idx = function.arguments(1).value().selection().direct_reference().struct_field().field() - right_key_start;
        join_keys.emplace_back(std::pair(left_key_idx, right_key_idx));
    }
    else
    {
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "doesn't support condition {}", condition_name);
    }
}

ActionsDAGPtr ASTParser::convertToActions(const NamesAndTypesList & name_and_types, const ASTPtr & ast)
{
    NamesAndTypesList aggregation_keys;
    ColumnNumbersList aggregation_keys_indexes_list;
    AggregationKeysInfo info(aggregation_keys, aggregation_keys_indexes_list, GroupByKind::NONE);
    SizeLimits size_limits_for_set;
    ActionsMatcher::Data visitor_data(
        context,
        size_limits_for_set,
        size_t(0),
        name_and_types,
        std::make_shared<ActionsDAG>(name_and_types),
        std::make_shared<PreparedSets>(),
        false /* no_subqueries */,
        false /* no_makeset */,
        false /* only_consts */,
        info);
    ActionsVisitor(visitor_data).visit(ast);
    return visitor_data.getActions();
}

ASTPtr ASTParser::parseToAST(const Names & names, const substrait::Expression & rel)
{
    LOG_DEBUG(&Poco::Logger::get("ASTParser"), "substrait plan:\n{}", rel.DebugString());
    if (rel.has_singular_or_list())
        return parseArgumentToAST(names, rel);
    if (!rel.has_scalar_function())
        throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression should be a scalar function:\n {}", rel.DebugString());

    const auto & scalar_function = rel.scalar_function();
    auto function_signature = function_mapping.at(std::to_string(rel.scalar_function().function_reference()));
    auto function_name = SerializedPlanParser::getFunctionName(function_signature, scalar_function);
    ASTs ast_args;
    parseFunctionArgumentsToAST(names, scalar_function, ast_args);

    return makeASTFunction(function_name, ast_args);
}

void ASTParser::parseFunctionArgumentsToAST(
    const Names & names, const substrait::Expression_ScalarFunction & scalar_function, ASTs & ast_args)
{
    const auto & args = scalar_function.arguments();

    for (const auto & arg : args)
    {
        if (arg.value().has_scalar_function())
        {
            ast_args.emplace_back(parseToAST(names, arg.value()));
        }
        else
        {
            ast_args.emplace_back(parseArgumentToAST(names, arg.value()));
        }
    }
}

ASTPtr ASTParser::parseArgumentToAST(const Names & names, const substrait::Expression & rel)
{
    switch (rel.rex_type_case())
    {
        case substrait::Expression::RexTypeCase::kLiteral: {
            DataTypePtr type;
            Field field;
            std::tie(std::ignore, field) = SerializedPlanParser::parseLiteral(rel.literal());
            return std::make_shared<ASTLiteral>(field);
        }
        case substrait::Expression::RexTypeCase::kSelection: {
            if (!rel.selection().has_direct_reference() || !rel.selection().direct_reference().has_struct_field())
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "Can only have direct struct references in selections");

            const auto field = rel.selection().direct_reference().struct_field().field();
            return std::make_shared<ASTIdentifier>(names[field]);
        }
        case substrait::Expression::RexTypeCase::kCast: {
            if (!rel.cast().has_type() || !rel.cast().has_input())
                throw Exception(ErrorCodes::BAD_ARGUMENTS, "Doesn't have type or input in cast node.");

            /// Append input to asts
            ASTs args;
            args.emplace_back(parseArgumentToAST(names, rel.cast().input()));

            /// Append destination type to asts
            const auto & substrait_type = rel.cast().type();
            /// Spark cast(x as BINARY) -> CH reinterpretAsStringSpark(x)
            if (substrait_type.has_binary())
                return makeASTFunction("reinterpretAsStringSpark", args);
            else
            {
                DataTypePtr ch_type = TypeParser::parseType(substrait_type);
                args.emplace_back(std::make_shared<ASTLiteral>(ch_type->getName()));

                return makeASTFunction("CAST", args);
            }
        }
        case substrait::Expression::RexTypeCase::kIfThen: {
            const auto & if_then = rel.if_then();
            auto condition_nums = if_then.ifs_size();
            std::string ch_function_name = condition_nums == 1 ? "if" : "multiIf";
            auto function_multi_if = DB::FunctionFactory::instance().get(ch_function_name, context);
            ASTs args;

            for (int i = 0; i < condition_nums; ++i)
            {
                const auto & ifs = if_then.ifs(i);
                auto if_node = parseArgumentToAST(names, ifs.if_());
                args.emplace_back(if_node);

                auto then_node = parseArgumentToAST(names, ifs.then());
                args.emplace_back(then_node);
            }

            auto else_node = parseArgumentToAST(names, if_then.else_());
            args.emplace_back(std::move(else_node));
            return makeASTFunction(ch_function_name, args);
        }
        case substrait::Expression::RexTypeCase::kScalarFunction: {
            return parseToAST(names, rel);
        }
        case substrait::Expression::RexTypeCase::kSingularOrList: {
            const auto & options = rel.singular_or_list().options();
            /// options is empty always return false
            if (options.empty())
                return std::make_shared<ASTLiteral>(0);
            /// options should be literals
            if (!options[0].has_literal())
                throw Exception(ErrorCodes::LOGICAL_ERROR, "Options of SingularOrList must have literal type");

            ASTs args;
            args.emplace_back(parseArgumentToAST(names, rel.singular_or_list().value()));

            bool nullable = false;
            size_t options_len = options.size();
            ASTs in_args;
            in_args.reserve(options_len);

            for (int i = 0; i < static_cast<int>(options_len); ++i)
            {
                if (!options[i].has_literal())
                    throw Exception(ErrorCodes::BAD_ARGUMENTS, "in expression values must be the literal!");
                if (!nullable)
                    nullable = options[i].literal().has_null();
            }

            auto elem_type_and_field = SerializedPlanParser::parseLiteral(options[0].literal());
            DataTypePtr elem_type = wrapNullableType(nullable, elem_type_and_field.first);
            for (int i = 0; i < static_cast<int>(options_len); ++i)
            {
                auto type_and_field = SerializedPlanParser::parseLiteral(options[i].literal());
                auto option_type = wrapNullableType(nullable, type_and_field.first);
                if (!elem_type->equals(*option_type))
                    throw Exception(
                        ErrorCodes::LOGICAL_ERROR,
                        "SingularOrList options type mismatch:{} and {}",
                        elem_type->getName(),
                        option_type->getName());

                in_args.emplace_back(std::make_shared<ASTLiteral>(type_and_field.second));
            }
            auto array_ast = makeASTFunction("array", in_args);
            args.emplace_back(array_ast);

            auto ast = makeASTFunction("in", args);
            if (nullable)
            {
                /// if sets has `null` and value not in sets
                /// In Spark: return `null`, is the standard behaviour from ANSI.(SPARK-37920)
                /// In CH: return `false`
                /// So we used if(a, b, c) cast `false` to `null` if sets has `null`
                ast = makeASTFunction("if", ast, std::make_shared<ASTLiteral>(true), std::make_shared<ASTLiteral>(Field()));
            }

            return ast;
        }
        default:
            throw Exception(
                ErrorCodes::UNKNOWN_TYPE,
                "Join on condition error. Unsupported spark expression type {} : {}",
                magic_enum::enum_name(rel.rex_type_case()),
                rel.DebugString());
    }
}

void SerializedPlanParser::removeNullable(const std::set<String> & require_columns, ActionsDAGPtr actions_dag)
{
    for (const auto & item : require_columns)
    {
        const auto * require_node = actions_dag->tryFindInOutputs(item);
        if (require_node)
        {
            auto function_builder = FunctionFactory::instance().get("assumeNotNull", context);
            ActionsDAG::NodeRawConstPtrs args = {require_node};
            const auto & node = actions_dag->addFunction(function_builder, args, item);
            actions_dag->addOrReplaceInOutputs(node);
        }
    }
}

void SerializedPlanParser::wrapNullable(
    const std::vector<String> & columns, ActionsDAGPtr actions_dag, std::map<std::string, std::string> & nullable_measure_names)
{
    for (const auto & item : columns)
    {
        ActionsDAG::NodeRawConstPtrs args;
        args.emplace_back(&actions_dag->findInOutputs(item));
        const auto * node = toFunctionNode(actions_dag, "toNullable", args);
        actions_dag->addOrReplaceInOutputs(*node);
        nullable_measure_names[item] = node->result_name;
    }
}

SharedContextHolder SerializedPlanParser::shared_context;

LocalExecutor::~LocalExecutor()
{
    if (context->getConfigRef().getBool("dump_pipeline", false))
        LOG_INFO(&Poco::Logger::get("LocalExecutor"), "Dump pipeline:\n{}", dumpPipeline());
    if (spark_buffer)
    {
        ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size);
        spark_buffer.reset();
    }
}


void LocalExecutor::execute(QueryPlanPtr query_plan)
{
    Stopwatch stopwatch;

    const Settings & settings = context->getSettingsRef();
    current_query_plan = std::move(query_plan);
    auto * logger = &Poco::Logger::get("LocalExecutor");

    DB::QueryPriorities priorities;
    auto query_status = std::make_shared<DB::QueryStatus>(
        context,
        "",
        context->getClientInfo(),
        priorities.insert(static_cast<int>(settings.priority)),
        DB::CurrentThread::getGroup(),
        DB::IAST::QueryKind::Select,
        settings,
        0);

    QueryPlanOptimizationSettings optimization_settings{.optimize_plan = settings.query_plan_enable_optimizations};
    auto pipeline_builder = current_query_plan->buildQueryPipeline(
        optimization_settings,
        BuildQueryPipelineSettings{
            .actions_settings
            = ExpressionActionsSettings{.can_compile_expressions = true, .min_count_to_compile_expression = 3, .compile_expressions = CompileExpressions::yes},
            .process_list_element = query_status});

    LOG_DEBUG(logger, "clickhouse plan after optimization:\n{}", PlanUtil::explainPlan(*current_query_plan));
    query_pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder));
    LOG_DEBUG(logger, "clickhouse pipeline:\n{}", QueryPipelineUtil::explainPipeline(query_pipeline));
    auto t_pipeline = stopwatch.elapsedMicroseconds();

    executor = std::make_unique<PullingPipelineExecutor>(query_pipeline);
    auto t_executor = stopwatch.elapsedMicroseconds() - t_pipeline;
    stopwatch.stop();
    LOG_INFO(
        logger,
        "build pipeline {} ms; create executor {} ms;",
        t_pipeline / 1000.0,
        t_executor / 1000.0);

    header = current_query_plan->getCurrentDataStream().header.cloneEmpty();
    ch_column_to_spark_row = std::make_unique<CHColumnToSparkRow>();
}

std::unique_ptr<SparkRowInfo> LocalExecutor::writeBlockToSparkRow(Block & block)
{
    return ch_column_to_spark_row->convertCHColumnToSparkRow(block);
}

bool LocalExecutor::hasNext()
{
    bool has_next;
    try
    {
        size_t columns = currentBlock().columns();
        if (columns == 0 || isConsumed())
        {
            auto empty_block = header.cloneEmpty();
            setCurrentBlock(empty_block);
            has_next = executor->pull(currentBlock());
            produce();
        }
        else
        {
            has_next = true;
        }
    }
    catch (DB::Exception & e)
    {
        LOG_ERROR(
            &Poco::Logger::get("LocalExecutor"),
            "LocalExecutor run query plan failed with message: {}. Plan Explained: \n{}",
            e.message(),
            PlanUtil::explainPlan(*current_query_plan));
        throw;
    }
    return has_next;
}

SparkRowInfoPtr LocalExecutor::next()
{
    checkNextValid();
    SparkRowInfoPtr row_info = writeBlockToSparkRow(currentBlock());
    consume();
    if (spark_buffer)
    {
        ch_column_to_spark_row->freeMem(spark_buffer->address, spark_buffer->size);
        spark_buffer.reset();
    }
    spark_buffer = std::make_unique<SparkBuffer>();
    spark_buffer->address = row_info->getBufferAddress();
    spark_buffer->size = row_info->getTotalBytes();
    return row_info;
}

Block * LocalExecutor::nextColumnar()
{
    checkNextValid();
    Block * columnar_batch;
    if (currentBlock().columns() > 0)
    {
        columnar_batch = &currentBlock();
    }
    else
    {
        auto empty_block = header.cloneEmpty();
        setCurrentBlock(empty_block);
        columnar_batch = &currentBlock();
    }
    consume();
    return columnar_batch;
}

Block & LocalExecutor::getHeader()
{
    return header;
}

LocalExecutor::LocalExecutor(QueryContext & _query_context, ContextPtr context_)
    : query_context(_query_context), context(context_)
{
}

std::string LocalExecutor::dumpPipeline()
{
    const auto & processors = query_pipeline.getProcessors();
    for (auto & processor : processors)
    {
        DB::WriteBufferFromOwnString buffer;
        auto data_stats = processor->getProcessorDataStats();
        buffer << "(";
        buffer << "\nexcution time: " << processor->getElapsedUs() << " us.";
        buffer << "\ninput wait time: " << processor->getInputWaitElapsedUs() << " us.";
        buffer << "\noutput wait time: " << processor->getOutputWaitElapsedUs() << " us.";
        buffer << "\ninput rows: " << data_stats.input_rows;
        buffer << "\ninput bytes: " << data_stats.input_bytes;
        buffer << "\noutput rows: " << data_stats.output_rows;
        buffer << "\noutput bytes: " << data_stats.output_bytes;
        buffer << ")";
        processor->setDescription(buffer.str());
    }
    DB::WriteBufferFromOwnString out;
    DB::printPipeline(processors, out);
    return out.str();
}

NonNullableColumnsResolver::NonNullableColumnsResolver(
    const DB::Block & header_, SerializedPlanParser & parser_, const substrait::Expression & cond_rel_)
    : header(header_), parser(parser_), cond_rel(cond_rel_)
{
}

// make it simple at present, if the condition contains or, return empty for both side.
std::set<String> NonNullableColumnsResolver::resolve()
{
    collected_columns.clear();
    visit(cond_rel);
    return collected_columns;
}

// TODO: make it the same as spark, it's too simple at present.
void NonNullableColumnsResolver::visit(const substrait::Expression & expr)
{
    if (!expr.has_scalar_function())
        return;

    const auto & scalar_function = expr.scalar_function();
    auto function_signature = parser.function_mapping.at(std::to_string(scalar_function.function_reference()));
    auto function_name = safeGetFunctionName(function_signature, scalar_function);

    // Only some special functions are used to judge whether the column is non-nullable.
    if (function_name == "and")
    {
        visit(scalar_function.arguments(0).value());
        visit(scalar_function.arguments(1).value());
    }
    else if (function_name == "greaterOrEquals" || function_name == "greater")
    {
        // If it's the case, a > x, what ever x or a is, a and x are non-nullable.
        // a or x may be a column, or a simple expression like plus etc.
        visitNonNullable(scalar_function.arguments(0).value());
        visitNonNullable(scalar_function.arguments(1).value());
    }
    else if (function_name == "lessOrEquals" || function_name == "less")
    {
        // same as gt, gte.
        visitNonNullable(scalar_function.arguments(0).value());
        visitNonNullable(scalar_function.arguments(1).value());
    }
    else if (function_name == "isNotNull")
    {
        visitNonNullable(scalar_function.arguments(0).value());
    }
    // else do nothing
}

void NonNullableColumnsResolver::visitNonNullable(const substrait::Expression & expr)
{
    if (expr.has_scalar_function())
    {
        const auto & scalar_function = expr.scalar_function();
        auto function_signature = parser.function_mapping.at(std::to_string(scalar_function.function_reference()));
        auto function_name = safeGetFunctionName(function_signature, scalar_function);
        if (function_name == "plus" || function_name == "minus" || function_name == "multiply" || function_name == "divide")
        {
            visitNonNullable(scalar_function.arguments(0).value());
            visitNonNullable(scalar_function.arguments(1).value());
        }
    }
    else if (expr.has_selection())
    {
        const auto & selection = expr.selection();
        auto column_pos = selection.direct_reference().struct_field().field();
        auto column_name = header.getByPosition(column_pos).name;
        collected_columns.insert(column_name);
    }
    // else, do nothing.
}

std::string NonNullableColumnsResolver::safeGetFunctionName(
    const std::string & function_signature, const substrait::Expression_ScalarFunction & function)
{
    try
    {
        return parser.getFunctionName(function_signature, function);
    }
    catch (const Exception &)
    {
        return "";
    }
}
}
