################################################################################
#  Licensed to the Apache Software Foundation (ASF) under one
#  or more contributor license agreements.  See the NOTICE file
#  distributed with this work for additional information
#  regarding copyright ownership.  The ASF licenses this file
#  to you under the Apache License, Version 2.0 (the
#  "License"); you may not use this file except in compliance
#  with the License.  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from abc import ABC
from typing import Tuple

from pyflink.ml.param import WithParams, Param, ParamValidators, StringParam, IntParam, \
    StringArrayParam, FloatParam, WindowsParam, BooleanParam
from pyflink.ml.common.window import Windows, GlobalWindows


class HasDistanceMeasure(WithParams, ABC):
    """
    Base class for the shared distance_measure param.
    """
    DISTANCE_MEASURE: Param[str] = StringParam(
        "distance_measure",
        "Distance measure. Supported options: 'euclidean', 'manhattan' and 'cosine'.",
        "euclidean",
        ParamValidators.in_array(['euclidean', 'manhattan', 'cosine']))

    def set_distance_measure(self, distance_measure: str):
        return self.set(self.DISTANCE_MEASURE, distance_measure)

    def get_distance_measure(self) -> str:
        return self.get(self.DISTANCE_MEASURE)

    @property
    def distance_measure(self) -> str:
        return self.get_distance_measure()


class HasFeaturesCol(WithParams, ABC):
    """
    Base class for the shared feature_col param.

    `HasFeaturesCol` is typically used for `Stage`s that implement `HasLabelCol`. It is preferred
    to use `HasInputCol` for other cases.
    """
    FEATURES_COL: Param[str] = StringParam(
        "features_col",
        "Features column name.",
        "features",
        ParamValidators.not_null())

    def set_features_col(self, col):
        return self.set(self.FEATURES_COL, col)

    def get_features_col(self) -> str:
        return self.get(self.FEATURES_COL)

    @property
    def features_col(self) -> str:
        return self.get_features_col()


class HasGlobalBatchSize(WithParams, ABC):
    """
    Base class for the shared global_batch_size param.
    """
    GLOBAL_BATCH_SIZE: Param[int] = IntParam(
        "global_batch_size",
        "Global batch size of training algorithms.",
        32,
        ParamValidators.gt(0))

    def set_global_batch_size(self, global_batch_size: int):
        return self.set(self.GLOBAL_BATCH_SIZE, global_batch_size)

    def get_global_batch_size(self) -> int:
        return self.get(self.GLOBAL_BATCH_SIZE)

    @property
    def global_batch_size(self) -> int:
        return self.get_global_batch_size()


class HasHandleInvalid(WithParams, ABC):
    """
    Base class for the shared handle_invalid param.

    Supported options and the corresponding behavior to handle invalid entries is listed as follows.

    <ul>
        <li>error: raise an exception.
        <li>skip: filter out rows with bad values.
    </ul>
    """
    HANDLE_INVALID: Param[str] = StringParam(
        "handle_invalid",
        "Strategy to handle invalid entries.",
        "error",
        ParamValidators.in_array(['error', 'skip']))

    def set_handle_invalid(self, value: str):
        return self.set(self.HANDLE_INVALID, value)

    def get_handle_invalid(self) -> str:
        return self.get(self.HANDLE_INVALID)

    @property
    def handle_invalid(self) -> str:
        return self.get_handle_invalid()


class HasInputCol(WithParams, ABC):
    """
    Base class for the shared input col param.
    """
    INPUT_COL: Param[str] = StringParam(
        "input_col",
        "Input column name.",
        "input",
        ParamValidators.not_null())

    def set_input_col(self, col: str):
        return self.set(self.INPUT_COL, col)

    def get_input_col(self) -> str:
        return self.get(self.INPUT_COL)

    @property
    def input_col(self) -> str:
        return self.get_input_col()


class HasInputCols(WithParams, ABC):
    """
    Base class for the shared input cols param.
    """
    INPUT_COLS: Param[Tuple[str, ...]] = StringArrayParam(
        "input_cols",
        "Input column names.",
        None,
        ParamValidators.non_empty_array())

    def set_input_cols(self, *cols: str):
        return self.set(self.INPUT_COLS, cols)

    def get_input_cols(self) -> Tuple[str, ...]:
        return self.get(self.INPUT_COLS)

    @property
    def input_cols(self) -> Tuple[str, ...]:
        return self.get_input_cols()


class HasCategoricalCols(WithParams, ABC):
    """
    Base class for the shared categorical cols param.
    """
    CATEGORICAL_COLS: Param[Tuple[str, ...]] = StringArrayParam(
        "categorical_cols",
        "Categorical column names.",
        [],
        ParamValidators.not_null())

    def set_categorical_cols(self, *cols: str):
        return self.set(self.CATEGORICAL_COLS, cols)

    def get_categorical_cols(self) -> Tuple[str, ...]:
        return self.get(self.CATEGORICAL_COLS)

    @property
    def categorical_cols(self) -> Tuple[str, ...]:
        return self.get_categorical_cols()


class HasNumFeatures(WithParams, ABC):
    """
    Base class for the shared numFeatures param.
    """
    NUM_FEATURES: Param[int] = IntParam(
        "num_features",
        "Number of features.",
        262144,
        ParamValidators.gt(0))

    def set_num_features(self, num_features: int):
        return self.set(self.NUM_FEATURES, num_features)

    def get_num_features(self) -> int:
        return self.get(self.NUM_FEATURES)

    @property
    def num_features(self) -> int:
        return self.get_num_features()


class HasLabelCol(WithParams, ABC):
    """
    Base class for the shared label column param.
    """
    LABEL_COL: Param[str] = StringParam(
        "label_col",
        "Label column name.",
        "label",
        ParamValidators.not_null())

    def set_label_col(self, col: str):
        return self.set(self.LABEL_COL, col)

    def get_label_col(self) -> str:
        return self.get(self.LABEL_COL)

    @property
    def label_col(self) -> str:
        return self.get_label_col()


class HasLearningRate(WithParams, ABC):
    """
    Base class for the shared learning rate param.
    """

    LEARNING_RATE: Param[float] = FloatParam(
        "learning_rate",
        "Learning rate of optimization method.",
        0.1,
        ParamValidators.gt(0))

    def set_learning_rate(self, learning_rate: float):
        return self.set(self.LEARNING_RATE, learning_rate)

    def get_learning_rate(self) -> float:
        return self.get(self.LEARNING_RATE)

    @property
    def learning_rate(self) -> float:
        return self.get_learning_rate()


class HasMaxIter(WithParams, ABC):
    """
    Base class for the shared maxIter param.
    """
    MAX_ITER: Param[int] = IntParam(
        "max_iter",
        "Maximum number of iterations.",
        20,
        ParamValidators.gt(0))

    def set_max_iter(self, max_iter: int):
        return self.set(self.MAX_ITER, max_iter)

    def get_max_iter(self) -> int:
        return self.get(self.MAX_ITER)

    @property
    def max_iter(self) -> int:
        return self.get_max_iter()


class HasMultiClass(WithParams, ABC):
    """
    Base class for the shared multi class param.

    Supported options:
        <li>auto: selects the classification type based on the number of classes:
            If the number of unique label values from the input data is one or two,
            set to "binomial". Otherwise, set to "multinomial".
        <li>binomial: binary logistic regression.
        <li>multinomial: multinomial logistic regression.
    """
    MULTI_CLASS: Param[str] = StringParam(
        "multi_class",
        "Classification type. Supported options: 'auto', 'binomial' and 'multinomial'.",
        'auto',
        ParamValidators.in_array(['auto', 'binomial', 'multinomial']))

    def set_multi_class(self, class_type: str):
        return self.set(self.MULTI_CLASS, class_type)

    def get_multi_class(self) -> str:
        return self.get(self.MULTI_CLASS)

    @property
    def multi_class(self) -> str:
        return self.get_multi_class()


class HasOutputCol(WithParams, ABC):
    """
    Base class for the shared output_col param.
    """
    OUTPUT_COL: Param[str] = StringParam(
        "output_col",
        "Output column name.",
        "output",
        ParamValidators.not_null())

    def set_output_col(self, col: str):
        return self.set(self.OUTPUT_COL, col)

    def get_output_col(self) -> str:
        return self.get(self.OUTPUT_COL)

    @property
    def output_col(self) -> str:
        return self.get_output_col()


class HasOutputCols(WithParams, ABC):
    """
    Base class for the shared output_cols param.
    """
    OUTPUT_COLS: Param[Tuple[str, ...]] = StringArrayParam(
        "output_cols",
        "Output column names.",
        None,
        ParamValidators.non_empty_array())

    def set_output_cols(self, *cols: str):
        return self.set(self.OUTPUT_COLS, cols)

    def get_output_cols(self) -> Tuple[str, ...]:
        return self.get(self.OUTPUT_COLS)

    @property
    def output_cols(self) -> Tuple[str, ...]:
        return self.get_output_cols()


class HasPredictionCol(WithParams, ABC):
    """
    Base class for the shared prediction column param.
    """
    PREDICTION_COL: Param[str] = StringParam(
        "prediction_col",
        "Prediction column name.",
        "prediction",
        ParamValidators.not_null())

    def set_prediction_col(self, col: str):
        return self.set(self.PREDICTION_COL, col)

    def get_prediction_col(self) -> str:
        return self.get(self.PREDICTION_COL)

    @property
    def prediction_col(self) -> str:
        return self.get_prediction_col()


class HasRawPredictionCol(WithParams, ABC):
    """
    Base class for the shared raw prediction column param.
    """
    RAW_PREDICTION_COL: Param[str] = StringParam(
        "raw_prediction_col",
        "Raw prediction column name.",
        "raw_prediction")

    def set_raw_prediction_col(self, col: str):
        return self.set(self.RAW_PREDICTION_COL, col)

    def get_raw_prediction_col(self):
        return self.get(self.RAW_PREDICTION_COL)

    @property
    def raw_prediction_col(self) -> str:
        return self.get_raw_prediction_col()


class HasReg(WithParams, ABC):
    """
    Base class for the shared regularization param.
    """
    REG: Param[float] = FloatParam(
        "reg",
        "Regularization parameter.",
        0.,
        ParamValidators.gt_eq(0.))

    def set_reg(self, value: float):
        return self.set(self.REG, value)

    def get_reg(self) -> float:
        return self.get(self.REG)

    @property
    def reg(self) -> float:
        return self.get_reg()


class HasSeed(WithParams, ABC):
    """
    Base class for the shared seed param.
    """
    SEED: Param[int] = IntParam(
        "seed",
        "The random seed.",
        None)

    def set_seed(self, seed: int):
        return self.set(self.SEED, seed) if seed is not None else hash(self.__class__.__name__)

    def get_seed(self) -> int:
        return self.get(self.SEED)

    @property
    def seed(self) -> int:
        return self.get_seed()


class HasTol(WithParams, ABC):
    """
    Base class for the shared tolerance param.
    """
    TOL: Param[float] = FloatParam(
        "tol",
        "Convergence tolerance for iterative algorithms.",
        1e-6,
        ParamValidators.gt_eq(0))

    def set_tol(self, value: float):
        return self.set(self.TOL, value)

    def get_tol(self) -> float:
        return self.get(self.TOL)

    @property
    def tol(self) -> float:
        return self.get_tol()


class HasWeightCol(WithParams, ABC):
    """
    Base class for the shared weight column param. If this is not set, we treat all instance weights
    as 1.0.
    """
    WEIGHT_COL: Param[str] = StringParam(
        "weight_col",
        "Weight column name.",
        None)

    def set_weight_col(self, col: str):
        return self.set(self.WEIGHT_COL, col)

    def get_weight_col(self) -> str:
        return self.get(self.WEIGHT_COL)

    @property
    def weight_col(self):
        return self.get_weight_col()


class HasBatchStrategy(WithParams, ABC):
    """
    Base class for the shared batch strategy param.
    """
    BATCH_STRATEGY: Param[str] = StringParam(
        "batch_strategy",
        "Strategy to create mini batch from online train data.",
        "count",
        ParamValidators.in_array(["count"]))

    def get_batch_strategy(self) -> str:
        return self.get(self.BATCH_STRATEGY)

    @property
    def batch_strategy(self):
        return self.get_batch_strategy()


class HasDecayFactor(WithParams, ABC):
    """
    Base class for the shared decay factor param.
    """
    DECAY_FACTOR: Param[float] = FloatParam(
        "decay_factor",
        "The forgetfulness of the previous centroids.",
        0.,
        ParamValidators.in_range(0, 1))

    def set_decay_factor(self, value: float):
        return self.set(self.DECAY_FACTOR, value)

    def get_decay_factor(self) -> float:
        return self.get(self.DECAY_FACTOR)

    @property
    def decay_factor(self):
        return self.get(self.DECAY_FACTOR)


class HasElasticNet(WithParams, ABC):
    """
    Base class for the shared decay factor param.
    """
    ELASTIC_NET: Param[float] = FloatParam(
        "elastic_net",
        "ElasticNet parameter.",
        0.,
        ParamValidators.in_range(0.0, 1.0))

    def set_elastic_net(self, value: float):
        return self.set(self.ELASTIC_NET, value)

    def get_elastic_net(self) -> float:
        return self.get(self.ELASTIC_NET)

    @property
    def elastic_net(self):
        return self.get(self.ELASTIC_NET)


class HasWindows(WithParams, ABC):
    """
    Base class for the shared windows param.
    """
    WINDOWS: Param[Windows] = WindowsParam(
        "windows",
        "Windowing strategy that determines how to create mini-batches from input data.",
        GlobalWindows(),
        ParamValidators.not_null())

    def set_windows(self, value: Windows):
        self.set(self.WINDOWS, value)
        return self

    def get_windows(self) -> Windows:
        return self.get(self.WINDOWS)

    @property
    def windows(self):
        return self.get(self.WINDOWS)


class HasRelativeError(WithParams, ABC):
    """
    Interface for shared param relativeError.
    """
    RELATIVE_ERROR: Param[float] = FloatParam(
        "relative_error",
        "The relative target precision for the approximate quantile algorithm.",
        0.001,
        ParamValidators.in_range(0.0, 1.0))

    def set_relative_error(self, value: float):
        return self.set(self.RELATIVE_ERROR, value)

    def get_relative_error(self) -> float:
        return self.get(self.RELATIVE_ERROR)

    @property
    def relative_error(self):
        return self.get(self.RELATIVE_ERROR)


class HasFlatten(WithParams, ABC):
    """
    Interface for shared flatten param.
    """
    FLATTEN: Param[bool] = BooleanParam(
        "flatten",
        "If false, the returned table contains only a single row, otherwise, one row per feature.",
        False
    )

    def set_flatten(self, value: bool):
        return self.set(self.FLATTEN, value)

    def get_flatten(self) -> bool:
        return self.get(self.FLATTEN)

    @property
    def flatten(self):
        return self.get(self.FLATTEN)


class HasModelVersionCol(WithParams, ABC):
    """
    Interface for the shared model version column param.
    """
    MODEL_VERSION_COL: Param[str] = StringParam(
        "model_version_col",
        "The name of the column which contains the version of the model data that "
        "the input data is predicted with. The version should be a 64-bit integer.",
        "version"
    )

    def set_model_version_col(self, value: str):
        return self.set(self.MODEL_VERSION_COL, value)

    def get_model_version_col(self) -> str:
        return self.get(self.MODEL_VERSION_COL)

    @property
    def model_version_col(self):
        return self.get_model_version_col()


class HasMaxAllowedModelDelayMs(WithParams, ABC):
    """
    Interface for the shared max allowed model delay in milliseconds param.
    """
    MAX_ALLOWED_MODEL_DELAY_MS: Param[int] = IntParam(
        "max_allowed_model_delay_ms",
        "The maximum difference allowed between the timestamps of the input record "
        "and the model data that is used to predict that input record. "
        "This param only works when the input contains event time.",
        0,
        ParamValidators.gt_eq(0)
    )

    def set_max_allowed_model_delay_ms(self, value: int):
        return self.set(self.MAX_ALLOWED_MODEL_DELAY_MS, value)

    def get_max_allowed_model_delay_ms(self) -> int:
        return self.get(self.MAX_ALLOWED_MODEL_DELAY_MS)

    @property
    def max_allowed_model_delay_ms(self):
        return self.get_max_allowed_model_delay_ms()
