/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "NativeReader.h"

#include <IO/ReadHelpers.h>
#include <IO/VarInt.h>
#include <DataTypes/DataTypeFactory.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Common/Arena.h>
#include <Storages/IO/NativeWriter.h>
#include <Storages/IO/AggregateSerializationUtils.h>

namespace DB
{
namespace ErrorCodes
{
    extern const int INCORRECT_INDEX;
    extern const int LOGICAL_ERROR;
    extern const int CANNOT_READ_ALL_DATA;
    extern const int INCORRECT_DATA;
    extern const int TOO_LARGE_ARRAY_SIZE;
}
}

using namespace DB;

namespace local_engine
{

Block NativeReader::getHeader() const
{
    return header;
}

DB::Block NativeReader::read()
{
    DB::Block result_block;
    if (istr.eof())
        return result_block;
    if (columns_parse_util.empty())
    {
        result_block = prepareByFirstBlock();
        if (!result_block)
            return {};
    }
    else
    {
        result_block = header.cloneEmpty();
    }

    /// Append small blocks into a large one, could reduce memory allocations overhead.
    while (result_block.rows() < max_block_size && result_block.bytes() < max_block_bytes)
    {
        if (!appendNextBlock(result_block))
            break;
    }

    if (result_block.rows())
    {
        for (size_t i = 0; i < result_block.columns(); ++i)
        {
            auto & column_parse_util = columns_parse_util[i];
            auto & column = result_block.getByPosition(i);
            column_parse_util.avg_value_size_hint = column.column->byteSize() / column.column->size();
        }
    }
    return result_block;
}

static void readFixedSizeAggregateData(DB::ReadBuffer &in, DB::ColumnPtr & column, size_t rows, NativeReader::ColumnParseUtil & column_parse_util)
{
    ColumnAggregateFunction & real_column = typeid_cast<ColumnAggregateFunction &>(*column->assumeMutable());
    auto & arena = real_column.createOrGetArena();
    ColumnAggregateFunction::Container & vec = real_column.getData();
    size_t initial_size = vec.size();
    vec.reserve(initial_size + rows);
    for (size_t i = 0; i < rows; ++i)
    {
        AggregateDataPtr place = arena.alignedAlloc(column_parse_util.aggregate_state_size, column_parse_util.aggregate_state_align);
        column_parse_util.aggregate_function->create(place);
        auto n = in.read(place, column_parse_util.aggregate_state_size);
        chassert(n == column_parse_util.aggregate_state_size);
        vec.push_back(place);
    }
}

static void readVarSizeAggregateData(DB::ReadBuffer &in, DB::ColumnPtr & column, size_t rows, NativeReader::ColumnParseUtil & column_parse_util)
{
    ColumnAggregateFunction & real_column = typeid_cast<ColumnAggregateFunction &>(*column->assumeMutable());
    auto & arena = real_column.createOrGetArena();
    ColumnAggregateFunction::Container & vec = real_column.getData();
    size_t initial_size = vec.size();
    vec.reserve(initial_size + rows);
    for (size_t i = 0; i < rows; ++i)
    {
        AggregateDataPtr place = arena.alignedAlloc(column_parse_util.aggregate_state_size, column_parse_util.aggregate_state_align);
        column_parse_util.aggregate_function->create(place);
        column_parse_util.aggregate_function->deserialize(place, in, std::nullopt, &arena);
        in.ignore();
        vec.push_back(place);
    }
}

static void readNormalData(DB::ReadBuffer &in, DB::ColumnPtr & column, size_t rows, NativeReader::ColumnParseUtil & column_parse_util)
{
    ISerialization::DeserializeBinaryBulkSettings settings;
    settings.getter = [&](ISerialization::SubstreamPath) -> ReadBuffer * { return &in; };
    settings.avg_value_size_hint = column_parse_util.avg_value_size_hint;
    settings.position_independent_encoding = false;
    settings.native_format = true;

    ISerialization::DeserializeBinaryBulkStatePtr state;

    column_parse_util.serializer->deserializeBinaryBulkStatePrefix(settings, state);
    column_parse_util.serializer->deserializeBinaryBulkWithMultipleStreams(column, rows, settings, state, nullptr);
}

DB::Block NativeReader::prepareByFirstBlock()
{
    if (istr.eof())
        return {};
    const DataTypeFactory & data_type_factory = DataTypeFactory::instance();
    DB::Block result_block;

    size_t columns = 0;
    size_t rows = 0;
    readVarUInt(columns, istr);
    readVarUInt(rows, istr);

    if (columns > 1'000'000uz)
        throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Suspiciously many columns in Native format: {}", columns);
    if (rows > 1'000'000'000'000uz)
        throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Suspiciously many rows in Native format: {}", rows);

    if (columns == 0 && !header && rows != 0)
        throw Exception(ErrorCodes::INCORRECT_DATA, "Zero columns but {} rows in Native format.", rows);

    for (size_t i = 0; i < columns; ++i)
    {
        ColumnWithTypeAndName column;
        ColumnParseUtil column_parse_util;

        column.name = "col_" + std::to_string(i);

        /// Type
        String type_name;
        readBinary(type_name, istr);
        bool agg_opt_column = false;
        String real_type_name = type_name;
        if (type_name.ends_with(NativeWriter::AGG_STATE_SUFFIX))
        {
            agg_opt_column = true;
            real_type_name = type_name.substr(0, type_name.length() - NativeWriter::AGG_STATE_SUFFIX.length());
        }
        column.type = data_type_factory.get(real_type_name);
        bool is_agg_state_type = WhichDataType(column.type).isAggregateFunction();
        SerializationPtr serialization = column.type->getDefaultSerialization();

        column_parse_util.type = column.type;
        column_parse_util.name = column.name;
        column_parse_util.serializer = serialization;

        /// Data
        ColumnPtr read_column = column.type->createColumn(*serialization);

        if (rows)    /// If no rows, nothing to read.
        {
            if (is_agg_state_type && agg_opt_column)
            {
                const DataTypeAggregateFunction * agg_type = checkAndGetDataType<DataTypeAggregateFunction>(column.type.get());
                auto aggregate_function = agg_type->getFunction();

                column_parse_util.aggregate_function = aggregate_function;
                column_parse_util.aggregate_state_size = aggregate_function->sizeOfData();
                column_parse_util.aggregate_state_align = aggregate_function->alignOfData();

                bool fixed = isFixedSizeAggregateFunction(aggregate_function);
                if (fixed)
                {
                    readFixedSizeAggregateData(istr, read_column, rows, column_parse_util);
                    column_parse_util.parse = readFixedSizeAggregateData;
                }
                else
                {
                    readVarSizeAggregateData(istr, read_column, rows, column_parse_util);
                    column_parse_util.parse = readVarSizeAggregateData;
                }
            }
            else
            {
                readNormalData(istr, read_column, rows, column_parse_util);
                column_parse_util.parse = readNormalData;
            }
        }
        column.column = std::move(read_column);
        result_block.insert(std::move(column));
        columns_parse_util.emplace_back(column_parse_util);
    }

    header = result_block.cloneEmpty();
    return result_block;
}

bool NativeReader::appendNextBlock(DB::Block & result_block)
{
    if (istr.eof())
        return false;

    size_t columns = 0;
    size_t rows = 0;
    readVarUInt(columns, istr);
    readVarUInt(rows, istr);
    for (size_t i = 0; i < columns; ++i)
    {
        // Not actually read type name.
        size_t type_name_len = 0;
        readVarUInt(type_name_len, istr);
        istr.ignore(type_name_len);

        if (!rows) [[unlikely]]
            continue;
        auto & column_parse_util = columns_parse_util[i];
        auto & column = result_block.getByPosition(i);
        column_parse_util.parse(istr, column.column, rows, column_parse_util);
    }

    return true;
}

}
