/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <memory>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/IDataType.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Common/assert_cast.h>
#include "Columns/ColumnNullable.h"


namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int NOT_FOUND_COLUMN_IN_BLOCK;
}
}

namespace local_engine
{

using namespace DB;

namespace
{

/** Extract element of tuple by constant index or name. The operation is essentially free.
  * Also the function looks through Arrays: you can get Array of tuple elements from Array of Tuples.
  * The difference between this function and tupleElement is that this function supports nullable tuples/arrays as input.
  */
class SparkFunctionTupleElement : public IFunction
{
public:
    static constexpr auto name = "sparkTupleElement";

    static FunctionPtr create(ContextPtr) { return std::make_shared<SparkFunctionTupleElement>(); }
    String getName() const override { return name; }
    bool isVariadic() const override { return true; }
    size_t getNumberOfArguments() const override { return 0; }
    bool useDefaultImplementationForConstants() const override { return true; }
    ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1}; }
    bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }

    DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
    {
        const size_t number_of_arguments = arguments.size();

        if (number_of_arguments < 2 || number_of_arguments > 3)
            throw Exception(
                ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
                "Number of arguments for function {} doesn't match: passed {}, should be 2 or 3",
                getName(),
                number_of_arguments);

        std::vector<bool> arrays_is_nullable;
        DataTypePtr input_type = arguments[0].type;
        while (const DataTypeArray * array = checkAndGetDataType<DataTypeArray>(removeNullable(input_type).get()))
        {
            arrays_is_nullable.push_back(input_type->isNullable());
            input_type = array->getNestedType();
        }

        const DataTypeTuple * tuple = checkAndGetDataType<DataTypeTuple>(removeNullable(input_type).get());
        if (!tuple)
            throw Exception(
                ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
                "First argument for function {} must be tuple or array of tuple. Actual {}",
                getName(),
                arguments[0].type->getName());

        std::optional<size_t> index = getElementIndex(arguments[1].column, *tuple, number_of_arguments);
        if (index.has_value())
        {
            DataTypePtr return_type = tuple->getElements()[index.value()];

            /// Tuple may be wrapped in Nullable
            if (input_type->isNullable())
                return_type = makeNullable(return_type);

            /// Array may be wrapped in Nullable
            for (auto it = arrays_is_nullable.rbegin(); it != arrays_is_nullable.rend(); ++it)
            {
                return_type = std::make_shared<DataTypeArray>(return_type);
                if (*it)
                    return_type = makeNullable(return_type);
            }

            // std::cout << "return_type:" << return_type->getName() << std::endl;

            return return_type;
        }
        else
            return arguments[2].type;
    }

    ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
    {
        const auto & input_arg = arguments[0];
        DataTypePtr input_type = input_arg.type;
        const IColumn * input_col = input_arg.column.get();

        bool input_arg_is_const = false;
        if (typeid_cast<const ColumnConst *>(input_col))
        {
            input_col = assert_cast<const ColumnConst *>(input_col)->getDataColumnPtr().get();
            input_arg_is_const = true;
        }

        Columns array_offsets;
        Columns null_maps;
        while (const DataTypeArray * array_type = checkAndGetDataType<DataTypeArray>(removeNullable(input_type).get()))
        {
            const ColumnNullable * nullable_array_col = input_type->isNullable() ? checkAndGetColumn<ColumnNullable>(input_col) : nullptr;
            const ColumnArray * array_col = nullable_array_col ? checkAndGetColumn<ColumnArray>(&nullable_array_col->getNestedColumn())
                                                               : checkAndGetColumn<ColumnArray>(input_col);

            array_offsets.push_back(array_col->getOffsetsPtr());
            null_maps.push_back(nullable_array_col ? nullable_array_col->getNullMapColumnPtr() : nullptr);
            input_type = array_type->getNestedType();
            input_col = &array_col->getData();
        }

        const DataTypeTuple * input_type_as_tuple = checkAndGetDataType<DataTypeTuple>(removeNullable(input_type).get());
        if (!input_type_as_tuple)
            throw Exception(
                ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
                "First argument for function {} must be tuple or array of tuple. Actual {}",
                getName(),
                input_arg.type->getName());

        const ColumnNullable * input_col_as_nullable_tuple
            = input_type->isNullable() ? checkAndGetColumn<ColumnNullable>(input_col) : nullptr;
        const ColumnTuple * input_col_as_tuple = input_col_as_nullable_tuple
            ? checkAndGetColumn<ColumnTuple>(&input_col_as_nullable_tuple->getNestedColumn())
            : checkAndGetColumn<ColumnTuple>(input_col);

        std::optional<size_t> index = getElementIndex(arguments[1].column, *input_type_as_tuple, arguments.size());
        if (!index.has_value())
            return arguments[2].column;

        ColumnPtr res = input_col_as_tuple->getColumns()[index.value()];

        /// Wrap into Nullable if needed
        if (input_col_as_nullable_tuple)
        {
            auto res_type = input_type_as_tuple->getElements()[index.value()];
            ColumnPtr res_null_map = input_col_as_nullable_tuple->getNullMapColumnPtr();
            if (res_type->isNullable())
            {
                MutableColumnPtr mutable_res_null_map = IColumn::mutate(std::move(res_null_map));

                NullMap & res_null_map_data = assert_cast<ColumnUInt8 &>(*mutable_res_null_map).getData();
                const NullMap & src_null_map = assert_cast<const ColumnNullable &>(*res).getNullMapData();

                for (size_t i = 0, size = res_null_map_data.size(); i < size; ++i)
                    res_null_map_data[i] |= src_null_map[i];

                res_null_map = std::move(mutable_res_null_map);
                res = ColumnNullable::create(assert_cast<const ColumnNullable &>(*res).getNestedColumnPtr(), res_null_map);
            }
            else
                res = ColumnNullable::create(res, res_null_map);
        }

        /// Wrap into Arrays
        for (ssize_t i = array_offsets.size() - 1; i >= 0; --i)
        {
            res = ColumnArray::create(res, array_offsets[i]);

            /// Wrap into Nullable if needed
            if (null_maps[i])
                res = ColumnNullable::create(res, null_maps[i]);
        }

        if (input_arg_is_const)
            res = ColumnConst::create(res, input_rows_count);

        return res;
    }

private:
    std::optional<size_t> getElementIndex(const ColumnPtr & index_column, const DataTypeTuple & tuple, size_t argument_size) const
    {
        if (checkAndGetColumnConst<ColumnUInt8>(index_column.get()) || checkAndGetColumnConst<ColumnUInt16>(index_column.get())
            || checkAndGetColumnConst<ColumnUInt32>(index_column.get()) || checkAndGetColumnConst<ColumnUInt64>(index_column.get())
            || checkAndGetColumnConst<ColumnInt8>(index_column.get()) || checkAndGetColumnConst<ColumnInt16>(index_column.get())
            || checkAndGetColumnConst<ColumnInt32>(index_column.get()) || checkAndGetColumnConst<ColumnInt64>(index_column.get()))
        {
            const ssize_t index = index_column->getInt(0);
            if (index > 0 && index <= static_cast<ssize_t>(tuple.getElements().size()))
                return {index - 1};
            else
            {
                if (argument_size == 2)
                    throw Exception(ErrorCodes::NOT_FOUND_COLUMN_IN_BLOCK, "Tuple {} doesn't have element with index '{}'", tuple.getName(), index);
                return std::nullopt;
            }
        }
        else if (const auto * name_col = checkAndGetColumnConst<ColumnString>(index_column.get()))
        {
            std::optional<size_t> index = tuple.tryGetPositionByName(name_col->getValue<String>());

            if (index.has_value())
                return index;
            else
            {
                if (argument_size == 2)
                    throw Exception(
                        ErrorCodes::NOT_FOUND_COLUMN_IN_BLOCK, "Tuple doesn't have element with name '{}'", name_col->getValue<String>());
                return std::nullopt;
            }
        }
        else
            throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Second argument to {} must be a constant UInt or String", getName());
    }
};
}

REGISTER_FUNCTION(SparkTupleElement)
{
    factory.registerFunction<SparkFunctionTupleElement>();
}

}
