/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "SelectorBuilder.h"
#include <limits>
#include <memory>
#include <mutex>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnMap.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnTuple.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Functions/FunctionFactory.h>
#include <Parser/SerializedPlanParser.h>
#include <Parser/TypeParser.h>
#include <Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <Poco/Base64Decoder.h>
#include <Poco/JSON/JSON.h>
#include <Poco/JSON/Parser.h>
#include <Poco/MemoryStream.h>
#include <Poco/StreamCopier.h>
#include <Common/CHUtil.h>
#include <Common/Exception.h>

namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
}
namespace local_engine
{
PartitionInfo PartitionInfo::fromSelector(DB::IColumn::Selector selector, size_t partition_num)
{
    auto rows = selector.size();
    std::vector<size_t> partition_row_idx_start_points(partition_num + 1, 0);
    IColumn::Selector partition_selector(rows, 0);
    for (size_t i = 0; i < rows; ++i)
        partition_row_idx_start_points[selector[i]]++;

    for (size_t i = 1; i <= partition_num; ++i)
        partition_row_idx_start_points[i] += partition_row_idx_start_points[i - 1];
    for (size_t i = rows; i-- > 0;)
    {
        partition_selector[partition_row_idx_start_points[selector[i]] - 1] = i;
        partition_row_idx_start_points[selector[i]]--;
    }
    return PartitionInfo{
        .partition_selector = std::move(partition_selector),
        .partition_start_points = partition_row_idx_start_points,
        .partition_num = partition_num};
}

PartitionInfo RoundRobinSelectorBuilder::build(DB::Block & block)
{
    DB::IColumn::Selector result;
    result.resize_fill(block.rows(), 0);
    for (auto & pid : result)
    {
        pid = pid_selection;
        pid_selection = (pid_selection + 1) % parts_num;
    }
    return PartitionInfo::fromSelector(std::move(result), parts_num);
}

HashSelectorBuilder::HashSelectorBuilder(
    UInt32 parts_num_, const std::vector<size_t> & exprs_index_, const std::string & hash_function_name_)
    : parts_num(parts_num_), exprs_index(exprs_index_), hash_function_name(hash_function_name_)
{
}

PartitionInfo HashSelectorBuilder::build(DB::Block & block)
{
    ColumnsWithTypeAndName args;
    for (size_t i = 0; i < exprs_index.size(); i++)
        args.emplace_back(block.safeGetByPosition(exprs_index.at(i)));
    auto flatten_block = BlockUtil::flattenBlock(DB::Block(args), BlockUtil::FLAT_STRUCT_FORCE | BlockUtil::FLAT_NESTED_TABLE, true);
    args = flatten_block.getColumnsWithTypeAndName();

    if (!hash_function) [[unlikely]]
    {
        auto & factory = DB::FunctionFactory::instance();
        auto function = factory.get(hash_function_name, local_engine::SerializedPlanParser::global_context);

        hash_function = function->build(args);
    }

    auto rows = block.rows();
    DB::IColumn::Selector partition_ids;
    partition_ids.reserve(rows);
    auto result_type = hash_function->getResultType();
    auto hash_column = hash_function->execute(args, result_type, rows, false);

    if (isNothing(removeNullable(result_type)))
    {
        /// TODO: implement new hash function sparkCityHash64 like sparkXxHash64 to process null literal as column more gracefully.
        /// Current implementation may cause partition skew.
        for (size_t i = 0; i < rows; i++)
            partition_ids.emplace_back(0);
    }
    else
    {
        if (hash_function_name == "sparkMurmurHash3_32")
        {
            /// sparkMurmurHash3_32 returns are all not null.
            auto parts_num_int32 = static_cast<Int32>(parts_num);
            for (size_t i = 0; i < rows; i++)
            {
                // cast to int32 to be the same as the data type of the vanilla spark
                auto hash_int32 = static_cast<Int32>(hash_column->get64(i));
                auto res = hash_int32 % parts_num_int32;
                if (res < 0)
                {
                    res += parts_num_int32;
                }
                partition_ids.emplace_back(static_cast<UInt64>(res));
            }
        }
        else
        {
            if (hash_column->isNullable())
            {
                const auto * null_col = typeid_cast<const ColumnNullable *>(hash_column->getPtr().get());
                auto & null_map = null_col->getNullMapData();
                for (size_t i = 0; i < rows; ++i)
                {
                    auto hash_value = static_cast<UInt64>(hash_column->get64(i)) & static_cast<UInt64>(static_cast<Int64>(null_map[i]) - 1);
                    partition_ids.emplace_back(static_cast<UInt64>(hash_value % parts_num));
                }
            }
            else
            {
                for (size_t i = 0; i < rows; i++)
                    partition_ids.emplace_back(static_cast<UInt64>(hash_column->get64(i) % parts_num));
            }
        }
    }
    return PartitionInfo::fromSelector(std::move(partition_ids), parts_num);
}


static std::map<int, std::pair<int, int>> direction_map = {{1, {1, -1}}, {2, {1, 1}}, {3, {-1, 1}}, {4, {-1, -1}}};

RangeSelectorBuilder::RangeSelectorBuilder(const std::string & option, const size_t partition_num_)
{
    Poco::JSON::Parser parser;
    auto info = parser.parse(option).extract<Poco::JSON::Object::Ptr>();
    auto ordering_infos = info->get("ordering").extract<Poco::JSON::Array::Ptr>();
    initSortInformation(ordering_infos);
    initRangeBlock(info->get("range_bounds").extract<Poco::JSON::Array::Ptr>());
    partition_num = partition_num_;
}

PartitionInfo RangeSelectorBuilder::build(DB::Block & block)
{
    DB::IColumn::Selector result;
    computePartitionIdByBinarySearch(block, result);
    return PartitionInfo::fromSelector(std::move(result), partition_num);
}

void RangeSelectorBuilder::initSortInformation(Poco::JSON::Array::Ptr orderings)
{
    for (uint32_t i = 0; i < orderings->size(); ++i)
    {
        auto ordering = orderings->get(i).extract<Poco::JSON::Object::Ptr>();
        auto col_pos = ordering->get("column_ref").convert<Int32>();
        auto col_name = ordering->get("column_name").convert<String>();

        auto sort_direction = ordering->get("direction").convert<int>();
        auto d_iter = direction_map.find(sort_direction);
        if (d_iter == direction_map.end())
            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupported sorting direction:{}", sort_direction);
        DB::SortColumnDescription ch_col_sort_descr(col_name, d_iter->second.first, d_iter->second.second);
        sort_descriptions.emplace_back(ch_col_sort_descr);

        auto type_name = ordering->get("data_type").convert<std::string>();
        auto type = TypeParser::getCHTypeByName(type_name);
        SortFieldTypeInfo info;
        info.inner_type = type;
        info.is_nullable = ordering->get("is_nullable").convert<bool>();
        sort_field_types.emplace_back(info);
        sorting_key_columns.emplace_back(col_pos);
    }
}

template <typename T>
void RangeSelectorBuilder::safeInsertFloatValue(const Poco::Dynamic::Var & field_value, DB::MutableColumnPtr & col)
{
    try
    {
        col->insert(field_value.convert<T>());
    }
    catch (const std::exception &)
    {
        String val = Poco::toLower(field_value.convert<std::string>());
        T res;
        if (val == "nan")
            res = std::numeric_limits<T>::quiet_NaN();
        else if (val == "infinity")
            res = std::numeric_limits<T>::infinity();
        else if (val == "-infinity")
            res = -std::numeric_limits<T>::infinity();
        else
            throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupported value: {}", val);
        col->insert(res);
    }
}

void RangeSelectorBuilder::initRangeBlock(Poco::JSON::Array::Ptr range_bounds)
{
    DB::ColumnsWithTypeAndName columns;
    for (uint32_t i = 0; i < sort_field_types.size(); ++i)
    {
        auto & type_info = sort_field_types[i];
        auto inner_col = type_info.inner_type->createColumn();
        auto data_type = type_info.inner_type;
        DB::MutableColumnPtr col = std::move(inner_col);
        if (type_info.is_nullable)
        {
            col = ColumnNullable::create(std::move(col), DB::ColumnUInt8::create(0, 0));
            data_type = std::make_shared<DB::DataTypeNullable>(data_type);
        }
        for (uint32_t r = 0; r < range_bounds->size(); ++r)
        {
            auto row = range_bounds->get(r).extract<Poco::JSON::Array::Ptr>();
            auto field_info = row->get(i).extract<Poco::JSON::Object::Ptr>();
            if (field_info->get("is_null").convert<bool>())
            {
                col->insertData(nullptr, 0);
            }
            else
            {
                const auto & type_name = type_info.inner_type->getName();
                const auto & field_value = field_info->get("value");
                if (type_name == "UInt8")
                {
                    col->insert(static_cast<UInt8>(field_value.convert<Int16>()));
                }
                else if (type_name == "Int8")
                {
                    col->insert(field_value.convert<Int8>());
                }
                else if (type_name == "Int16")
                {
                    col->insert(field_value.convert<Int16>());
                }
                else if (type_name == "Int32")
                {
                    col->insert(field_value.convert<Int32>());
                }
                else if (type_name == "Int64")
                {
                    col->insert(field_value.convert<Int64>());
                }
                else if (type_name == "Float32")
                {
                    safeInsertFloatValue<Float32>(field_value, col);
                }
                else if (type_name == "Float64")
                {
                    safeInsertFloatValue<Float64>(field_value, col);
                }
                else if (type_name == "String")
                {
                    col->insert(field_value.convert<std::string>());
                }
                else if (type_name == "Date32")
                {
                    int val = field_value.convert<Int32>();
                    col->insert(val);
                }
                else if (const auto * decimal32 = dynamic_cast<const DB::DataTypeDecimal<DB::Decimal32> *>(type_info.inner_type.get()))
                {
                    auto value = decimal32->parseFromString(field_value.convert<std::string>());
                    col->insert(DB::DecimalField<DB::Decimal32>(value, decimal32->getScale()));
                }
                else if (const auto * decimal64 = dynamic_cast<const DB::DataTypeDecimal<DB::Decimal64> *>(type_info.inner_type.get()))
                {
                    auto value = decimal64->parseFromString(field_value.convert<std::string>());
                    col->insert(DB::DecimalField<DB::Decimal64>(value, decimal64->getScale()));
                }
                else if (const auto * decimal128 = dynamic_cast<const DB::DataTypeDecimal<DB::Decimal128> *>(type_info.inner_type.get()))
                {
                    auto value = decimal128->parseFromString(field_value.convert<std::string>());
                    col->insert(DB::DecimalField<DB::Decimal128>(value, decimal128->getScale()));
                }
                else
                {
                    throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupported data type: {}", type_info.inner_type->getName());
                }
            }
        }
        auto col_name = "sort_col_" + std::to_string(i);
        columns.emplace_back(std::move(col), data_type, col_name);
    }
    range_bounds_block = DB::Block(columns);
}

void RangeSelectorBuilder::initActionsDAG(const DB::Block & block)
{
    std::lock_guard lock(actions_dag_mutex);
    if (has_init_actions_dag)
        return;
    SerializedPlanParser plan_parser(local_engine::SerializedPlanParser::global_context);
    plan_parser.parseExtensions(projection_plan_pb->extensions());

    const auto & expressions = projection_plan_pb->relations().at(0).root().input().project().expressions();
    std::vector<substrait::Expression> exprs;
    exprs.reserve(expressions.size());
    for (const auto & expression : expressions)
        exprs.emplace_back(expression);

    auto projection_actions_dag = plan_parser.expressionsToActionsDAG(exprs, block, block);
    projection_expression_actions = std::make_unique<DB::ExpressionActions>(projection_actions_dag);
    has_init_actions_dag = true;
}

void RangeSelectorBuilder::computePartitionIdByBinarySearch(DB::Block & block, DB::IColumn::Selector & selector)
{
    Chunks chunks;
    Chunk chunk(block.getColumns(), block.rows());
    chunks.emplace_back(std::move(chunk));
    selector.clear();
    selector.reserve(block.rows());
    auto input_columns = block.getColumns();
    auto total_rows = block.rows();
    const auto & bounds_columns = range_bounds_block.getColumns();
    auto max_part = bounds_columns[0]->size();
    for (size_t i = 0; i < bounds_columns.size(); i++)
        if (bounds_columns[i]->isNullable() && !input_columns[sorting_key_columns[i]]->isNullable())
            input_columns[sorting_key_columns[i]] = makeNullable(input_columns[sorting_key_columns[i]]);
    for (size_t r = 0; r < total_rows; ++r)
    {
        size_t selected_partition = 0;
        auto ret = binarySearchBound(bounds_columns, 0, max_part - 1, input_columns, sorting_key_columns, r);
        if (ret >= 0)
            selected_partition = ret;
        else
            selected_partition = max_part;
        selector.emplace_back(selected_partition);
    }
}

int RangeSelectorBuilder::compareRow(
    const DB::Columns & columns,
    const std::vector<size_t> & required_columns,
    size_t row,
    const DB::Columns & bound_columns,
    size_t bound_row)
{
    for (size_t i = 0, n = required_columns.size(); i < n; ++i)
    {
        auto lpos = required_columns[i];
        auto rpos = i;
        auto res = columns[lpos]->compareAt(row, bound_row, *bound_columns[rpos], sort_descriptions[i].nulls_direction)
            * sort_descriptions[i].direction;
        if (res != 0)
            return res;
    }
    return 0;
}

// If there were elements in range[l,r] that are larger then the row
// the return the min element's index. otherwise return -1
int RangeSelectorBuilder::binarySearchBound(
    const DB::Columns & bound_columns, Int64 l, Int64 r, const DB::Columns & columns, const std::vector<size_t> & used_cols, size_t row)
{
    if (l > r)
        return -1;
    auto m = (l + r) >> 1;
    auto cmp_ret = compareRow(columns, used_cols, row, bound_columns, m);
    if (l == r)
    {
        if (cmp_ret <= 0)
            return static_cast<int>(m);
        else
            return -1;
    }

    if (cmp_ret == 0)
        return static_cast<int>(m);
    if (cmp_ret < 0)
    {
        cmp_ret = binarySearchBound(bound_columns, l, m - 1, columns, used_cols, row);
        if (cmp_ret < 0)
        {
            // m is the upper bound
            return static_cast<int>(m);
        }
        return cmp_ret;
    }
    else
    {
        cmp_ret = binarySearchBound(bound_columns, m + 1, r, columns, used_cols, row);
        if (cmp_ret < 0)
            return -1;
        else
            return cmp_ret;
    }
    __builtin_unreachable();
}
}
