################################################################################
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

from abc import ABC, abstractmethod
from typing import TypeVar, Generic, List, Dict, Any, Optional, Tuple, Union
from pyflink.ml.linalg import Vector
from pyflink.ml.common.window import Windows

import jsonpickle

T = TypeVar('T')
V = TypeVar('V')


class WithParams(Generic[T], ABC):
    """
    Interface for classes that take parameters. It provides APIs to set and get parameters.
    """

    def set(self, param: 'Param[V]', value: V) -> 'WithParams[T]':
        """
        Sets the value of the parameter.

        :param param: The parameter.
        :param value: The parameter value.
        :return: The WithParams instance itself.
        """
        if not self._is_compatible_type(param, value):
            raise TypeError(
                f"Parameter {param.name}'s type {param.type} is incompatible with the type of "
                f"{type(value)}")

        if not param.validator.validate(value):
            if value is None:
                raise ValueError(f'Parameter {param.name}\'s value should not be None.')
            else:
                raise ValueError(f'Parameter {param.name} is given an invalid value {value}.')

        self.get_param_map()[param] = value
        return self

    def get_param(self, name: str) -> Optional['Param[V]']:
        """
        Gets the parameter by its name.

        :param name: The parameter name.
        :return: The parameter.
        """
        for item in self.get_param_map():
            if item.name == name:
                return item
        return None

    def get(self, param: 'Param[V]') -> V:
        """
        Gets the value of the parameter.

        :param param: The parameter.
        :return: The parameter value.
        """
        value = self.get_param_map().get(param)
        if value is None and not param.validator.validate(None):
            raise ValueError(f'Parameter {param.name}\'s value should not be None')
        return value

    @abstractmethod
    def get_param_map(self) -> Dict['Param[Any]', Any]:
        """
        Returns a map which maps parameter definition to parameter value.
        """
        pass

    @staticmethod
    def _is_compatible_type(param: 'Param[V]', value: V) -> bool:
        if value is not None and param.type != type(value):
            return issubclass(type(value), param.type)
        if isinstance(value, list):
            for item in value:
                if param.type_name != f'list[{type(item).__name__}]':
                    return False
            return True
        return True


class ParamValidator(Generic[T], ABC):
    """
    An interface to validate that a parameter value is valid.
    """

    @abstractmethod
    def validate(self, value: T) -> bool:
        """
        Validates whether the parameter value is valid.

        :param value: The parameter value.
        :return: A boolean indicating whether the parameter value is valid.
        """
        pass


class ParamValidators(object):
    """
    Factory methods for common validation functions on numerical values.
    """

    @staticmethod
    def always_true() -> ParamValidator[T]:
        class AlwaysTrue(ParamValidator[T]):
            """
            Always return true.
            """

            def validate(self, value: T) -> bool:
                return True

        return AlwaysTrue()

    @staticmethod
    def gt(lower_bound: float) -> ParamValidator[T]:
        class GT(ParamValidator[T]):
            """
            Checks if the parameter value is greater than lower_bound.
            """

            def validate(self, value: T) -> bool:
                return value is not None and value > lower_bound  # type: ignore

        return GT()

    @staticmethod
    def gt_eq(lower_bound: float) -> ParamValidator[T]:
        class GtEq(ParamValidator[T]):
            """
            Checks if the parameter value is greater than or equal to lower_bound.
            """

            def validate(self, value: T) -> bool:
                return value is not None and value >= lower_bound  # type: ignore

        return GtEq()

    @staticmethod
    def lt(upper_bound: float) -> ParamValidator[T]:
        class LT(ParamValidator[T]):
            """
            Checks if the parameter value is less than upper_bound.
            """

            def validate(self, value: T) -> bool:
                return value is not None and value < upper_bound  # type: ignore

        return LT()

    @staticmethod
    def lt_eq(upper_bound: float) -> ParamValidator[T]:
        """
         Checks if the parameter value is less than or equal to upper_bound.
         """

        class LtEq(ParamValidator[T]):
            def validate(self, value: T) -> bool:
                return value is not None and value <= upper_bound  # type: ignore

        return LtEq()

    @staticmethod
    def in_range(lower_bound: float, upper_bound: float, lower_inclusive: bool = True,
                 upper_inclusive: bool = True) -> ParamValidator[T]:
        """
        Checks if the parameter value is in the range from lower_bound to upper_bound.
        """

        class InRange(ParamValidator[T]):
            def validate(self, value: T) -> bool:
                return (value is not None
                        and lower_bound <= value <= upper_bound  # type: ignore
                        and (lower_inclusive or value != lower_bound)
                        and (upper_inclusive or value != upper_bound))

        return InRange()

    @staticmethod
    def in_array(allowed: Union[Tuple[T], List[T]]) -> ParamValidator[T]:
        """
        Checks if the parameter value is in the array of allowed values.
        """

        class InArray(ParamValidator[T]):
            def validate(self, value: T) -> bool:
                return value is not None and value in allowed

        return InArray()

    @staticmethod
    def not_null() -> ParamValidator[T]:
        """
        Checks if the parameter value is not None.
        """

        class NotNull(ParamValidator[T]):
            def validate(self, value: T) -> bool:
                return value is not None

        return NotNull()

    @staticmethod
    def non_empty_array() -> ParamValidator[Tuple[T]]:
        """
        Checks if the parameter value is not empty array.
        """

        class NonEmptyArray(ParamValidator[Tuple[T]]):
            def validate(self, value: Tuple[T]) -> bool:
                return value is not None and len(value) > 0

        return NonEmptyArray()

    @staticmethod
    def is_sub_set(allowed: Union[Tuple[T], List[T]]) -> ParamValidator[List[T]]:
        """
        Checks if every element in the array-typed parameter value is in the array of allowed
        values.
        """

        class IsSubSet(ParamValidator[List[T]]):
            def validate(self, value: List[T]) -> bool:
                if value is None:
                    return False
                for t in value:
                    if t not in allowed:
                        return False
                return True

        return IsSubSet()


class Param(Generic[T]):
    """
    Definition of a parameter, including name, description, default value and the validator.
    """

    def __init__(self, name: str, type_type: type, type_name: str, description: str,
                 default_value: T, validator: ParamValidator[T]):
        self.name = name
        self.type = type_type
        self.type_name = type_name
        self.description = description
        self.default_value = default_value
        self.validator = validator
        if default_value is not None and not validator.validate(default_value):
            raise ValueError(f"Parameter {name} is given an invalid value {default_value}")

    def json_encode(self, value: T) -> str:
        """
        Encodes the given object into a json-formatted string.

        :param value: An object of class type T.
        :return: A json-formatted string.
        """
        return str(jsonpickle.encode(value, keys=True))

    def json_decode(self, json: str) -> T:
        """
        Decodes the given string into an object of class type T.

        :param json: A json-formatted string.
        :return: An object of class type T.
        """
        return jsonpickle.decode(json, keys=True)

    def __eq__(self, other):
        return isinstance(other, Param) and self.name == other.name

    def __hash__(self):
        return hash(self.name)

    def __str__(self):
        return self.name


class BooleanParam(Param[bool]):
    """
    Class for the boolean parameter.
    """

    def __init__(self, name: str, description: str, default_value: Optional[bool],
                 validator: ParamValidator[bool] = ParamValidators.always_true()):
        super(BooleanParam, self).__init__(name, bool, "bool", description, default_value,
                                           validator)


class IntParam(Param[int]):
    """
    Class for the int parameter.
    """

    def __init__(self, name: str, description: str, default_value: Optional[int],
                 validator: ParamValidator[int] = ParamValidators.always_true()):
        super(IntParam, self).__init__(name, int, "int", description, default_value, validator)


class FloatParam(Param[float]):
    """
    Class for the float parameter.
    """

    def __init__(self, name: str, description: str, default_value: Optional[float],
                 validator: ParamValidator[float] = ParamValidators.always_true()):
        super(FloatParam, self).__init__(name, float, "float", description, default_value,
                                         validator)


class StringParam(Param[str]):
    """
    Class for the string parameter.
    """

    def __init__(self, name: str, description: str, default_value: Optional[str],
                 validator: ParamValidator[str] = ParamValidators.always_true()):
        super(StringParam, self).__init__(name, str, "str", description, default_value, validator)


class IntArrayParam(Param[Tuple[int, ...]]):
    """
    Class for the int array parameter.
    """

    def __init__(self, name: str, description: str, default_value: Optional[Tuple[int, ...]],
                 validator: ParamValidator[Tuple[int, ...]] = ParamValidators.always_true()):
        super(IntArrayParam, self).__init__(name, tuple, "Tuple[int]", description, default_value,
                                            validator)


class FloatArrayParam(Param[Tuple[float, ...]]):
    """
    Class for the float array parameter.
    """

    def __init__(self, name: str, description: str, default_value: Optional[Tuple[float, ...]],
                 validator: ParamValidator[Tuple[float, ...]] = ParamValidators.always_true()):
        super(FloatArrayParam, self).__init__(name, tuple, "Tuple[float]", description,
                                              default_value, validator)


class FloatArrayArrayParam(Param[Tuple[Tuple[float, ...]]]):
    """
    Class for the array of float array parameter.
    """

    def __init__(self, name: str,
                 description: str,
                 default_value: Optional[Tuple[Tuple[float, ...]]],
                 validator: ParamValidator[
                     Tuple[Tuple[float, ...]]] = ParamValidators.always_true()):
        super(FloatArrayArrayParam, self).__init__(name, tuple, "Tuple[Tuple[float]]", description,
                                                   default_value, validator)


class StringArrayParam(Param[Tuple[str, ...]]):
    """
    Class for the string array parameter.
    """

    def __init__(self, name: str, description: str, default_value: Optional[Tuple[str, ...]],
                 validator: ParamValidator[Tuple[str, ...]] = ParamValidators.always_true()):
        super(StringArrayParam, self).__init__(name, tuple, "Tuple[str]", description,
                                               default_value, validator)


class VectorParam(Param[Vector]):
    """
    Class for the vector parameter.
    """

    def __init__(self, name: str, description: str, default_value: Optional[Vector],
                 validator: ParamValidator[Vector] = ParamValidators.always_true()):
        super(VectorParam, self).__init__(name, Vector, "str", description, default_value,
                                          validator)


class WindowsParam(Param[Windows]):
    """
    Class for the Windows parameter.
    """

    def __init__(self, name: str, description: str, default_value: Optional[Windows],
                 validator: ParamValidator[Windows] = ParamValidators.always_true()):
        super(WindowsParam, self).__init__(name, Windows, "str", description, default_value,
                                           validator)
