#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import builtins
import operator
from collections.abc import Iterable
from functools import partial, reduce

from .array import SparseNDArray, call_sparse
from .core import get_sparse_module, issparse
from .matrix import SparseMatrix
from .vector import SparseVector


def asarray(x, shape=None):
    from .core import issparse

    if issparse(x):
        return SparseNDArray(x, shape=shape)

    return x


def add(a, b, **_):
    try:
        return a + b
    except TypeError:
        if hasattr(b, "__radd__"):
            return b.__radd__(a)
        raise


def subtract(a, b, **_):
    try:
        return a - b
    except TypeError:
        if hasattr(b, "__rsub__"):
            return b.__rsub__(a)
        raise


def multiply(a, b, **_):
    try:
        return a * b
    except TypeError:
        if hasattr(b, "__rmul__"):
            return b.__rmul__(a)
        raise


def divide(a, b, **_):
    try:
        return a / b
    except TypeError:
        if hasattr(b, "__rdiv__"):
            return b.__rdiv__(a)
        raise


def true_divide(a, b, **_):
    try:
        return a / b
    except TypeError:
        if hasattr(b, "__rtruediv__"):
            return b.__rtruediv__(a)
        raise


def floor_divide(a, b, **_):
    try:
        return a // b
    except TypeError:
        if hasattr(b, "__rfloordiv__"):
            return b.__rfloordiv__(a)
        raise


def power(a, b, **_):
    try:
        return a**b
    except TypeError:
        if hasattr(b, "__rpow__"):
            return b.__rpow__(a)
        raise


def mod(a, b, **_):
    try:
        return a % b
    except TypeError:
        if hasattr(b, "__rmod__"):
            return b.__rmod__(a)
        raise


def _call_bin(method, a, b, **kwargs):
    from .core import cp, get_array_module, issparse

    # order does not take effect for sparse
    kwargs.pop("order", None)
    if hasattr(a, method):
        res = getattr(a, method)(b, **kwargs)
    elif get_array_module(a).isscalar(a):
        res = call_sparse(method, a, b, **kwargs)
    else:
        assert get_array_module(a) == get_array_module(b)
        xp = get_array_module(a)
        try:
            res = getattr(xp, method)(a, b, **kwargs)
        except TypeError:
            if xp is cp and issparse(b):
                res = getattr(xp, method)(a, b.toarray(), **kwargs)
            else:
                raise

    if res is NotImplemented:
        raise NotImplementedError

    return res


def _call_unary(method, x, *args, **kwargs):
    from .core import get_array_module

    # order does not take effect for sparse
    kwargs.pop("order", None)
    if hasattr(x, method):
        res = getattr(x, method)(*args, **kwargs)
    else:
        xp = get_array_module(x)
        res = getattr(xp, method)(x, *args, **kwargs)

    if res is NotImplemented:
        raise NotImplementedError

    return res


def float_power(a, b, **kw):
    return _call_bin("float_power", a, b, **kw)


def fmod(a, b, **kw):
    return _call_bin("fmod", a, b, **kw)


def logaddexp(a, b, **kw):
    return _call_bin("logaddexp", a, b, **kw)


def logaddexp2(a, b, **kw):
    return _call_bin("logaddexp2", a, b, **kw)


def negative(x, **_):
    return -x


def positive(x, **_):
    return operator.pos(x)


def absolute(x, **_):
    return builtins.abs(x)


abs = absolute


fabs = partial(_call_unary, "fabs")


def rint(x, **kw):
    return _call_unary("rint", x, **kw)


def sign(x, **kw):
    return _call_unary("sign", x, **kw)


def conj(x, **kw):
    return _call_unary("conj", x, **kw)


def exp(x, **kw):
    return _call_unary("exp", x, **kw)


def exp2(x, **kw):
    return _call_unary("exp2", x, **kw)


def log(x, **kw):
    return _call_unary("log", x, **kw)


def log2(x, **kw):
    return _call_unary("log2", x, **kw)


def log10(x, **kw):
    return _call_unary("log10", x, **kw)


def expm1(x, **kw):
    return _call_unary("expm1", x, **kw)


def log1p(x, **kw):
    return _call_unary("log1p", x, **kw)


def sqrt(x, **kw):
    return _call_unary("sqrt", x, **kw)


def square(x, **kw):
    return _call_unary("square", x, **kw)


def cbrt(x, **kw):
    return _call_unary("cbrt", x, **kw)


def reciprocal(x, **kw):
    return _call_unary("reciprocal", x, **kw)


gamma = partial(_call_unary, "gamma")
gammaln = partial(_call_unary, "gammaln")
loggamma = partial(_call_unary, "loggamma")
gammasgn = partial(_call_unary, "gammasgn")
gammainc = partial(_call_bin, "gammainc")
gammaincinv = partial(_call_bin, "gammaincinv")
gammaincc = partial(_call_bin, "gammaincc")
gammainccinv = partial(_call_bin, "gammainccinv")
beta = partial(_call_bin, "beta")
betaln = partial(_call_bin, "betaln")
betainc = partial(call_sparse, "betainc")
betaincinv = partial(call_sparse, "betaincinv")
psi = partial(_call_unary, "psi")
rgamma = partial(_call_unary, "rgamma")
polygamma = partial(_call_bin, "polygamma")
multigammaln = partial(_call_bin, "multigammaln")
digamma = partial(_call_unary, "digamma")
poch = partial(_call_bin, "poch")

entr = partial(_call_unary, "entr")
rel_entr = partial(_call_bin, "rel_entr")
kl_div = partial(_call_bin, "kl_div")

xlogy = partial(_call_bin, "xlogy")

erf = partial(_call_unary, "erf")
erfc = partial(_call_unary, "erfc")
erfcx = partial(_call_unary, "erfcx")
erfi = partial(_call_unary, "erfi")
erfinv = partial(_call_unary, "erfinv")
erfcinv = partial(_call_unary, "erfcinv")
wofz = partial(_call_unary, "wofz")
dawsn = partial(_call_unary, "dawsn")
voigt_profile = partial(call_sparse, "voigt_profile")

jv = partial(_call_bin, "jv")
jve = partial(_call_bin, "jve")
yn = partial(_call_bin, "yn")
yv = partial(_call_bin, "yv")
yve = partial(_call_bin, "yve")
kn = partial(_call_bin, "kn")
kv = partial(_call_bin, "kv")
kve = partial(_call_bin, "kve")
iv = partial(_call_bin, "iv")
ive = partial(_call_bin, "ive")
hankel1 = partial(_call_bin, "hankel1")
hankel1e = partial(_call_bin, "hankel1e")
hankel2 = partial(_call_bin, "hankel2")
hankel2e = partial(_call_bin, "hankel2e")

hyp2f1 = partial(call_sparse, "hyp2f1")
hyp1f1 = partial(call_sparse, "hyp1f1")
hyperu = partial(call_sparse, "hyperu")
hyp0f1 = partial(_call_bin, "hyp0f1")

ellip_harm = partial(call_sparse, "ellip_harm")
ellip_harm_2 = partial(call_sparse, "ellip_harm_2")
ellip_normal = partial(call_sparse, "ellip_normal")

ellipk = partial(_call_unary, "ellipk")
ellipkm1 = partial(_call_unary, "ellipkm1")
ellipkinc = partial(_call_bin, "ellipkinc")
ellipe = partial(_call_unary, "ellipe")
ellipeinc = partial(_call_bin, "ellipeinc")
elliprc = partial(_call_bin, "elliprc")
elliprd = partial(call_sparse, "elliprd")
elliprf = partial(call_sparse, "elliprf")
elliprg = partial(call_sparse, "elliprg")
elliprj = partial(call_sparse, "elliprj")

airy = partial(_call_unary, "airy")
airye = partial(_call_unary, "airye")
itairy = partial(_call_unary, "itairy")


def equal(a, b, **_):
    try:
        return a == b
    except TypeError:
        return b == a


def not_equal(a, b, **_):
    try:
        return a != b
    except TypeError:
        return b != a


def less(a, b, **_):
    try:
        return a < b
    except TypeError:
        return b > a


def less_equal(a, b, **_):
    try:
        return a <= b
    except TypeError:
        return b >= a


def greater(a, b, **_):
    try:
        return a > b
    except TypeError:
        return b < a


def greater_equal(a, b, **_):
    try:
        return a >= b
    except TypeError:
        return b <= a


def logical_and(a, b, **kw):
    return _call_bin("logical_and", a, b, **kw)


def logical_or(a, b, **kw):
    return _call_bin("logical_or", a, b, **kw)


def logical_xor(a, b, **kw):
    return _call_bin("logical_xor", a, b, **kw)


def logical_not(x, **kw):
    return _call_unary("logical_not", x, **kw)


def isclose(a, b, **kw):
    return _call_bin("isclose", a, b, **kw)


def bitwise_and(a, b, **_):
    try:
        return a & b
    except TypeError:
        return b & a


def bitwise_or(a, b, **_):
    try:
        return a | b
    except TypeError:
        return b | a


def bitwise_xor(a, b, **_):
    try:
        return operator.xor(a, b)
    except TypeError:
        return operator.xor(b, a)


def invert(x, **_):
    return ~x


def left_shift(a, b, **_):
    return a << b


def right_shift(a, b, **_):
    return a >> b


def sin(x, **kw):
    return _call_unary("sin", x, **kw)


def cos(x, **kw):
    return _call_unary("cos", x, **kw)


def tan(x, **kw):
    return _call_unary("tan", x, **kw)


def arcsin(x, **kw):
    return _call_unary("arcsin", x, **kw)


def arccos(x, **kw):
    return _call_unary("arccos", x, **kw)


def arctan(x, **kw):
    return _call_unary("arctan", x, **kw)


def arctan2(a, b, **kw):
    return _call_bin("arctan2", a, b, **kw)


def hypot(a, b, **kw):
    return _call_bin("hypot", a, b, **kw)


def sinh(x, **kw):
    return _call_unary("sinh", x, **kw)


def cosh(x, **kw):
    return _call_unary("cosh", x, **kw)


def tanh(x, **kw):
    return _call_unary("tanh", x, **kw)


def arcsinh(x, **kw):
    return _call_unary("arcsinh", x, **kw)


def arccosh(x, **kw):
    return _call_unary("arccosh", x, **kw)


def around(x, **kw):
    return _call_unary("around", x, **kw)


def arctanh(x, **kw):
    return _call_unary("arctanh", x, **kw)


def deg2rad(x, **kw):
    return _call_unary("deg2rad", x, **kw)


def rad2deg(x, **kw):
    return _call_unary("rad2deg", x, **kw)


def angle(x, **kw):
    return _call_unary("angle", x, **kw)


def isinf(x, **kw):
    return _call_unary("isinf", x, **kw)


def isnan(x, **kw):
    return _call_unary("isnan", x, **kw)


def signbit(x, **kw):
    return _call_unary("signbit", x, **kw)


def dot(a, b, sparse=True, **_):
    from .core import issparse

    if not issparse(a):
        ret = a.dot(b)
        if not sparse:
            return ret
        else:
            xps = get_sparse_module(ret)
            return SparseNDArray(xps.csr_matrix(ret), shape=ret.shape)

    return a.dot(b, sparse=sparse)


def tensordot(a, b, axes=2, sparse=True):
    if isinstance(axes, Iterable):
        a_axes, b_axes = axes
    else:
        a_axes = tuple(range(a.ndim - 1, a.ndim - axes - 1, -1))
        b_axes = tuple(range(0, axes))

    if isinstance(a_axes, Iterable):
        a_axes = tuple(a_axes)
    else:
        a_axes = (a_axes,)
    if isinstance(b_axes, Iterable):
        b_axes = tuple(b_axes)
    else:
        b_axes = (b_axes,)

    if a_axes == (a.ndim - 1,) and b_axes == (b.ndim - 2,):
        return dot(a, b, sparse=sparse)

    if a.ndim == b.ndim == 2:
        if a_axes == (a.ndim - 1,) and b_axes == (b.ndim - 1,):
            # inner product of multiple dims
            return dot(a, b.T, sparse=sparse)

    if a.ndim == 1 or b.ndim == 1:
        return dot(a, b, sparse=sparse)

    raise NotImplementedError


def matmul(a, b, sparse=True, **_):
    return dot(a, b, sparse=sparse)


def concatenate(tensors, axis=0):
    return reduce(lambda a, b: _call_bin("concatenate", a, b, axis=axis), tensors)


def transpose(tensor, axes=None):
    return _call_unary("transpose", tensor, axes=axes)


def swapaxes(tensor, axis1, axis2):
    return _call_unary("swapaxes", tensor, axis1, axis2)


def sum(tensor, axis=None, **kw):
    return _call_unary("sum", tensor, axis=axis, **kw)


def prod(tensor, axis=None, **kw):
    return _call_unary("prod", tensor, axis=axis, **kw)


def amax(tensor, axis=None, **kw):
    return _call_unary("amax", tensor, axis=axis, **kw)


max = amax


def amin(tensor, axis=None, **kw):
    return _call_unary("amin", tensor, axis=axis, **kw)


min = amin


def all(tensor, axis=None, **kw):
    return _call_unary("all", tensor, axis=axis, **kw)


def any(tensor, axis=None, **kw):
    return _call_unary("any", tensor, axis=axis, **kw)


def mean(tensor, axis=None, **kw):
    return _call_unary("mean", tensor, axis=axis, **kw)


def nansum(tensor, axis=None, **kw):
    return _call_unary("nansum", tensor, axis=axis, **kw)


def nanprod(tensor, axis=None, **kw):
    return _call_unary("nanprod", tensor, axis=axis, **kw)


def nanmax(tensor, axis=None, **kw):
    return _call_unary("nanmax", tensor, axis=axis, **kw)


def nanmin(tensor, axis=None, **kw):
    return _call_unary("nanmin", tensor, axis=axis, **kw)


def argmax(tensor, axis=None, **kw):
    return _call_unary("argmax", tensor, axis=axis, **kw)


def nanargmax(tensor, axis=None, **kw):
    return _call_unary("nanargmax", tensor, axis=axis, **kw)


def argmin(tensor, axis=None, **kw):
    return _call_unary("argmin", tensor, axis=axis, **kw)


def nanargmin(tensor, axis=None, **kw):
    return _call_unary("nanargmin", tensor, axis=axis, **kw)


def var(tensor, axis=None, **kw):
    return _call_unary("var", tensor, axis=axis, **kw)


def cumsum(tensor, axis=None, **kw):
    return _call_unary("cumsum", tensor, axis=axis, **kw)


def cumprod(tensor, axis=None, **kw):
    return _call_unary("cumprod", tensor, axis=axis, **kw)


def nancumsum(tensor, axis=None, **kw):
    return _call_unary("nancumsum", tensor, axis=axis, **kw)


def nancumprod(tensor, axis=None, **kw):
    return _call_unary("nancumprod", tensor, axis=axis, **kw)


def count_nonzero(tensor, axis=None, **kw):
    return _call_unary("count_nonzero", tensor, axis=axis, **kw)


def maximum(a, b, **kw):
    return _call_bin("maximum", a, b, **kw)


def minimum(a, b, **kw):
    return _call_bin("minimum", a, b, **kw)


def fmax(a, b, **kw):
    return _call_bin("fmax", a, b, **kw)


def fmin(a, b, **kw):
    return _call_bin("fmin", a, b, **kw)


def floor(x, **kw):
    return _call_unary("floor", x, **kw)


def ceil(x, **kw):
    return _call_unary("ceil", x, **kw)


def trunc(x, **kw):
    return _call_unary("trunc", x, **kw)


def degrees(x, **kw):
    return _call_unary("degrees", x, **kw)


def radians(x, **kw):
    return _call_unary("radians", x, **kw)


def clip(a, a_max, a_min, **kw):
    from .core import get_array_module

    if hasattr(a, "clip"):
        res = getattr(a, "clip")(a_max, a_min, **kw)
    else:
        xp = get_array_module(a)
        res = getattr(xp, "clip")(a, a_max, a_min, **kw)

    if res is NotImplemented:
        raise NotImplementedError

    return res


def iscomplex(x, **kw):
    return _call_unary("iscomplex", x, **kw)


def real(x, **_):
    return x.real


def imag(x, **_):
    return x.imag


def fix(x, **kw):
    return _call_unary("fix", x, **kw)


def i0(x, **kw):
    return _call_unary("i0", x, **kw)


def nan_to_num(x, **kw):
    return _call_unary("nan_to_num", x, **kw)


def copysign(a, b, **kw):
    return _call_bin("copysign", a, b, **kw)


def nextafter(a, b, **kw):
    return _call_bin("nextafter", a, b, **kw)


def spacing(x, **kw):
    return _call_unary("spacing", x, **kw)


def ldexp(a, b, **kw):
    return _call_bin("ldexp", a, b, **kw)


def frexp(x, **kw):
    return _call_unary("frexp", x, **kw)


def modf(x, **kw):
    return _call_unary("modf", x, **kw)


def sinc(x, **kw):
    return _call_unary("sinc", x, **kw)


def isfinite(x, **kw):
    return _call_unary("isfinite", x, **kw)


def isreal(x, **kw):
    return _call_unary("isreal", x, **kw)


def isfortran(x, **kw):
    return call_sparse("isfortran", x, **kw)


def where(cond, x, y):
    if any([i.ndim not in (0, 2) for i in (cond, x, y)]):
        raise NotImplementedError

    from .matrix import where as matrix_where

    return matrix_where(cond, x, y)


def digitize(x, bins, right=False):
    return _call_unary("digitize", x, bins, right)


def repeat(a, repeats, axis=None):
    return _call_unary("repeat", a, repeats, axis=axis)


def fill_diagonal(a, val, wrap=False):
    return _call_unary("fill_diagonal", a, val, wrap=wrap)


def unique(a, return_index=False, return_inverse=False, return_counts=False, axis=None):
    return _call_unary(
        "unique",
        a,
        return_index=return_index,
        return_inverse=return_inverse,
        return_counts=return_counts,
        axis=axis,
    )


def zeros(shape, dtype=float, gpu=False):
    if len(shape) == 2:
        from .matrix import zeros_sparse_matrix

        return zeros_sparse_matrix(shape, dtype=dtype, gpu=gpu)

    raise NotImplementedError


def ones_like(x):
    from .core import get_array_module

    return get_array_module(x).ones(x.shape)


def diag(v, k=0, gpu=False):
    assert v.ndim in {1, 2}

    from .matrix import diag_sparse_matrix

    return diag_sparse_matrix(v, k=k, gpu=gpu)


def eye(N, M=None, k=0, dtype=float, gpu=False):
    from .matrix import eye_sparse_matrix

    return eye_sparse_matrix(N, M=M, k=k, dtype=dtype, gpu=gpu)


def triu(m, k=0, gpu=False):
    if m.ndim == 2:
        from .matrix import triu_sparse_matrix

        return triu_sparse_matrix(m, k=k, gpu=gpu)

    raise NotImplementedError


def tril(m, k=0, gpu=False):
    if m.ndim == 2:
        from .matrix import tril_sparse_matrix

        return tril_sparse_matrix(m, k=k, gpu=gpu)

    raise NotImplementedError


def lu(m):
    from .matrix import lu_sparse_matrix

    return lu_sparse_matrix(m)


def solve_triangular(a, b, lower=False, sparse=True):
    from .matrix import solve_triangular_sparse_matrix

    return solve_triangular_sparse_matrix(a, b, lower=lower, sparse=sparse)


def block(arrs):
    arr = arrs[0]
    while isinstance(arr, list):
        arr = arr[0]
    if arr.ndim != 2:  # pragma: no cover
        raise NotImplementedError

    from .matrix import block

    return block(arrs)
