#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime

from .expressions import *
from .element import AnyOp, ElementWise
from . import utils
from .. import types


class BinOp(AnyOp):
    __slots__ = ()
    _args = '_lhs', '_rhs'

    @property
    def node_name(self):
        return self.__class__.__name__

    @property
    def name(self):
        if self._name:
            return self._name

        seq_hs = [hs for hs in self.args if isinstance(hs, SequenceExpr)]
        if len(seq_hs) == 1:
            return seq_hs[0].name

    def accept(self, visitor):
        visitor.visit_binary_op(self)


class UnaryOp(ElementWise):
    __slots__ = ()

    def accept(self, visitor):
        visitor.visit_unary_op(self)


class Arithmetic(BinOp):
    __slots__ = ()


class Comparison(BinOp):
    __slots__ = ()


class LogicalBinOp(BinOp):
    __slots__ = ()


class Negate(UnaryOp):
    __slots__ = ()


class Invert(UnaryOp):
    __slots__ = ()


class Abs(UnaryOp):
    __slots__ = ()


class Add(Arithmetic):
    __slots__ = ()


class Substract(Arithmetic):
    __slots__ = ()


class Multiply(Arithmetic):
    __slots__ = ()


class Divide(Arithmetic):
    __slots__ = ()


class FloorDivide(Arithmetic):
    __slots__ = ()


class Mod(Arithmetic):
    __slots__ = ()


class Power(Arithmetic):
    __slots__ = ()


class Greater(Comparison):
    __slots__ = ()


class Less(Comparison):
    __slots__ = ()


class Equal(Comparison):
    __slots__ = ()


class NotEqual(Comparison):
    __slots__ = ()


class GreaterEqual(Comparison):
    __slots__ = ()


class LessEqual(Comparison):
    __slots__ = ()


class Or(LogicalBinOp):
    __slots__ = ()


class And(LogicalBinOp):
    __slots__ = ()


def _get_type(other):
    if isinstance(other, SequenceExpr):
        other_type = other._data_type
    elif isinstance(other, Scalar):
        other_type = other._value_type
    else:
        other = Scalar(_value=other)
        other_type = other._value_type

    return other_type, other


def _arithmetic(expr, other, output_expr_cls, reverse=False, output_type=None):
    if isinstance(expr, (SequenceExpr, Scalar)):
        other_type, other = _get_type(other)
        is_sequence = isinstance(expr, SequenceExpr) or isinstance(other, SequenceExpr)
        if output_type is None:
            output_type = utils.highest_precedence_data_type(expr.dtype, other_type)

        if reverse:
            expr, other = other, expr
        if is_sequence:
            return output_expr_cls(_data_type=output_type, _lhs=expr, _rhs=other)
        else:
            return output_expr_cls(_value_type=output_type, _lhs=expr, _rhs=other)


def _reversed_arithmetic(expr, other, output_expr_cls, output_type=None):
    return _arithmetic(expr, other, output_expr_cls, reverse=True, output_type=output_type)


def _cmp(expr, other, output_expr_cls):
    if isinstance(expr, (SequenceExpr, Scalar)):
        other_type, other = _get_type(other)
        is_sequence = isinstance(expr, SequenceExpr) or isinstance(other, SequenceExpr)
        utils.highest_precedence_data_type(expr.dtype, other_type)
        # operand cast to data_type to compare
        output_type = types.boolean

        if is_sequence:
            return output_expr_cls(_data_type=output_type, _lhs=expr, _rhs=other)
        else:
            return output_expr_cls(_value_type=output_type, _lhs=expr, _rhs=other)


def _unary(expr, output_expr_cls):
    if isinstance(expr, (SequenceExpr, Scalar)):
        is_sequence = isinstance(expr, SequenceExpr)

        if is_sequence:
            return output_expr_cls(_data_type=expr.dtype, _input=expr)
        else:
            return output_expr_cls(_value_type=expr.dtype, _input=expr)


def _logic(expr, other, output_expr_cls):
    if isinstance(expr, (SequenceExpr, Scalar)):
        other_type, other = _get_type(other)
        is_sequence = isinstance(expr, SequenceExpr) or isinstance(other, SequenceExpr)

        if expr.dtype == types.boolean and other.dtype == types.boolean:
            output_type = types.boolean
            if is_sequence:
                return output_expr_cls(_data_type=output_type, _lhs=expr, _rhs=other)
            else:
                return output_expr_cls(_value_type=output_type, _lhs=expr, _rhs=other)

        raise TypeError('Logic operation needs boolean operand')


def _is_datetime(expr):
    if isinstance(expr, Expr):
        return expr.dtype == types.datetime
    else:
        return isinstance(expr, datetime)


def _add(expr, other):
    if _is_datetime(expr) and _is_datetime(other):
        raise ExpressionError('Cannot add two datetimes')
    return _arithmetic(expr, other, Add)


def _radd(expr, other):
    if _is_datetime(expr) and _is_datetime(other):
        raise ExpressionError('Cannot add two datetimes')
    return _reversed_arithmetic(expr, other, Add)


def _sub(expr, other):
    rtype = None
    if _is_datetime(expr) and _is_datetime(other):
        rtype = types.int64
    return _arithmetic(expr, other, Substract, output_type=rtype)


def _rsub(expr, other):
    rtype = None
    if _is_datetime(expr) and _is_datetime(other):
        rtype = types.int64
    return _reversed_arithmetic(expr, other, Substract, output_type=rtype)


def _eq(expr, other):
    return _cmp(expr, other, Equal)


def _ne(expr, other):
    return _cmp(expr, other, NotEqual)


def _gt(expr, other):
    return _cmp(expr, other, Greater)


def _lt(expr, other):
    return _cmp(expr, other, Less)


def _le(expr, other):
    return _cmp(expr, other, LessEqual)


def _ge(expr, other):
    return _cmp(expr, other, GreaterEqual)


def _mul(expr, other):
    return _arithmetic(expr, other, Multiply)


def _rmul(expr, other):
    return _reversed_arithmetic(expr, other, Multiply)


def _div(expr, other):
    if isinstance(expr.dtype, types.Integer) and isinstance(_get_type(other)[0], types.Integer):
        output_type = types.float64
    else:
        output_type = None
    return _arithmetic(expr, other, Divide, output_type=output_type)


def _rdiv(expr, other):
    if isinstance(expr.dtype, types.Integer) and isinstance(_get_type(other)[0], types.Integer):
        output_type = types.float64
    else:
        output_type = None
    return _reversed_arithmetic(expr, other, Divide, output_type=output_type)


def _mod(expr, other):
    return _arithmetic(expr, other, Mod)


def _rmod(expr, other):
    return _reversed_arithmetic(expr, other, Mod)


def _floordiv(expr, other):
    return _arithmetic(expr, other, FloorDivide)


def _rfloordiv(expr, other):
    return _reversed_arithmetic(expr, other, FloorDivide)


def _pow(expr, other):
    return _arithmetic(expr, other, Power)


def _rpow(expr, other):
    return _reversed_arithmetic(expr, other, Power)


def _or(expr, other):
    return _logic(expr, other, Or)


def _ror(expr, other):
    return _or(expr, other)


def _and(expr, other):
    return _logic(expr, other, And)


def _rand(expr, other):
    return _and(expr, other)


def _neg(expr):
    if isinstance(expr, Negate):
        return expr.input

    return _unary(expr, Negate)


def _invert(expr):
    if isinstance(expr, Invert):
        return expr.input

    return _unary(expr, Invert)


def _abs(expr):
    if isinstance(expr, Abs):
        return expr

    return _unary(expr, Abs)


_number_methods = dict(
    _add=_add,
    _radd=_radd,
    _sub=_sub,
    _rsub=_rsub,
    _mul=_mul,
    _rmul=_rmul,
    _div=_div,
    _rdiv=_rdiv,
    _floordiv=_floordiv,
    _rfloordiv=_rfloordiv,
    _mod=_mod,
    _rmod=_rmod,
    _pow=_pow,
    _rpow=_rpow,
    _neg=_neg,
    _abs=_abs,
    _eq=_eq,
    _ne=_ne,
    _gt=_gt,
    _lt=_lt,
    _le=_le,
    _ge=_ge,
)

_int_number_methods = dict(
    _invert=_invert
)

_string_methods = dict(
    _add=_add,
    _radd=_radd,
    _eq=_eq,
    _ne=_ne,
    _gt=_gt,
    _lt=_lt,
    _le=_le,
    _ge=_ge,
)

_boolean_methods = dict(
    _or=_or,
    _ror=_ror,
    _and=_and,
    _rand=_rand,
    _eq=_eq,
    _ne=_ne,
    _invert=_invert,
    _neg=_neg
)

_datetime_methods = dict(
    _add=_add,  # TODO, to check
    _radd=_radd,
    _sub=_sub,
    _rsub=_rsub,
    _eq=_eq,
    _ne=_ne,
    _gt=_gt,
    _lt=_lt,
    _le=_le,
    _ge=_ge
)

utils.add_method(StringSequenceExpr, _string_methods)
utils.add_method(StringScalar, _string_methods)

utils.add_method(BooleanSequenceExpr, _boolean_methods)
utils.add_method(BooleanScalar, _boolean_methods)

utils.add_method(DatetimeSequenceExpr, _datetime_methods)
utils.add_method(DatetimeScalar, _datetime_methods)

utils.add_method(DateSequenceExpr, _datetime_methods)
utils.add_method(DateScalar, _datetime_methods)

utils.add_method(TimestampSequenceExpr, _datetime_methods)
utils.add_method(TimestampScalar, _datetime_methods)

for number_sequence in number_sequences + number_scalars:
    utils.add_method(number_sequence, _number_methods)

for int_number_sequence in int_number_sequences + int_number_scalars:
    utils.add_method(int_number_sequence, _int_number_methods)
