core/maxframe/lib/sparse/__init__.py (515 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import builtins import operator from collections.abc import Iterable from functools import partial, reduce from .array import SparseNDArray, call_sparse from .core import get_sparse_module, issparse from .matrix import SparseMatrix from .vector import SparseVector def asarray(x, shape=None): from .core import issparse if issparse(x): return SparseNDArray(x, shape=shape) return x def add(a, b, **_): try: return a + b except TypeError: if hasattr(b, "__radd__"): return b.__radd__(a) raise def subtract(a, b, **_): try: return a - b except TypeError: if hasattr(b, "__rsub__"): return b.__rsub__(a) raise def multiply(a, b, **_): try: return a * b except TypeError: if hasattr(b, "__rmul__"): return b.__rmul__(a) raise def divide(a, b, **_): try: return a / b except TypeError: if hasattr(b, "__rdiv__"): return b.__rdiv__(a) raise def true_divide(a, b, **_): try: return a / b except TypeError: if hasattr(b, "__rtruediv__"): return b.__rtruediv__(a) raise def floor_divide(a, b, **_): try: return a // b except TypeError: if hasattr(b, "__rfloordiv__"): return b.__rfloordiv__(a) raise def power(a, b, **_): try: return a**b except TypeError: if hasattr(b, "__rpow__"): return b.__rpow__(a) raise def mod(a, b, **_): try: return a % b except TypeError: if hasattr(b, "__rmod__"): return b.__rmod__(a) raise def _call_bin(method, a, b, **kwargs): from .core import cp, get_array_module, issparse # order does not take effect for sparse kwargs.pop("order", None) if hasattr(a, method): res = getattr(a, method)(b, **kwargs) elif get_array_module(a).isscalar(a): res = call_sparse(method, a, b, **kwargs) else: assert get_array_module(a) == get_array_module(b) xp = get_array_module(a) try: res = getattr(xp, method)(a, b, **kwargs) except TypeError: if xp is cp and issparse(b): res = getattr(xp, method)(a, b.toarray(), **kwargs) else: raise if res is NotImplemented: raise NotImplementedError return res def _call_unary(method, x, *args, **kwargs): from .core import get_array_module # order does not take effect for sparse kwargs.pop("order", None) if hasattr(x, method): res = getattr(x, method)(*args, **kwargs) else: xp = get_array_module(x) res = getattr(xp, method)(x, *args, **kwargs) if res is NotImplemented: raise NotImplementedError return res def float_power(a, b, **kw): return _call_bin("float_power", a, b, **kw) def fmod(a, b, **kw): return _call_bin("fmod", a, b, **kw) def logaddexp(a, b, **kw): return _call_bin("logaddexp", a, b, **kw) def logaddexp2(a, b, **kw): return _call_bin("logaddexp2", a, b, **kw) def negative(x, **_): return -x def positive(x, **_): return operator.pos(x) def absolute(x, **_): return builtins.abs(x) abs = absolute fabs = partial(_call_unary, "fabs") def rint(x, **kw): return _call_unary("rint", x, **kw) def sign(x, **kw): return _call_unary("sign", x, **kw) def conj(x, **kw): return _call_unary("conj", x, **kw) def exp(x, **kw): return _call_unary("exp", x, **kw) def exp2(x, **kw): return _call_unary("exp2", x, **kw) def log(x, **kw): return _call_unary("log", x, **kw) def log2(x, **kw): return _call_unary("log2", x, **kw) def log10(x, **kw): return _call_unary("log10", x, **kw) def expm1(x, **kw): return _call_unary("expm1", x, **kw) def log1p(x, **kw): return _call_unary("log1p", x, **kw) def sqrt(x, **kw): return _call_unary("sqrt", x, **kw) def square(x, **kw): return _call_unary("square", x, **kw) def cbrt(x, **kw): return _call_unary("cbrt", x, **kw) def reciprocal(x, **kw): return _call_unary("reciprocal", x, **kw) gamma = partial(_call_unary, "gamma") gammaln = partial(_call_unary, "gammaln") loggamma = partial(_call_unary, "loggamma") gammasgn = partial(_call_unary, "gammasgn") gammainc = partial(_call_bin, "gammainc") gammaincinv = partial(_call_bin, "gammaincinv") gammaincc = partial(_call_bin, "gammaincc") gammainccinv = partial(_call_bin, "gammainccinv") beta = partial(_call_bin, "beta") betaln = partial(_call_bin, "betaln") betainc = partial(call_sparse, "betainc") betaincinv = partial(call_sparse, "betaincinv") psi = partial(_call_unary, "psi") rgamma = partial(_call_unary, "rgamma") polygamma = partial(_call_bin, "polygamma") multigammaln = partial(_call_bin, "multigammaln") digamma = partial(_call_unary, "digamma") poch = partial(_call_bin, "poch") entr = partial(_call_unary, "entr") rel_entr = partial(_call_bin, "rel_entr") kl_div = partial(_call_bin, "kl_div") xlogy = partial(_call_bin, "xlogy") erf = partial(_call_unary, "erf") erfc = partial(_call_unary, "erfc") erfcx = partial(_call_unary, "erfcx") erfi = partial(_call_unary, "erfi") erfinv = partial(_call_unary, "erfinv") erfcinv = partial(_call_unary, "erfcinv") wofz = partial(_call_unary, "wofz") dawsn = partial(_call_unary, "dawsn") voigt_profile = partial(call_sparse, "voigt_profile") jv = partial(_call_bin, "jv") jve = partial(_call_bin, "jve") yn = partial(_call_bin, "yn") yv = partial(_call_bin, "yv") yve = partial(_call_bin, "yve") kn = partial(_call_bin, "kn") kv = partial(_call_bin, "kv") kve = partial(_call_bin, "kve") iv = partial(_call_bin, "iv") ive = partial(_call_bin, "ive") hankel1 = partial(_call_bin, "hankel1") hankel1e = partial(_call_bin, "hankel1e") hankel2 = partial(_call_bin, "hankel2") hankel2e = partial(_call_bin, "hankel2e") hyp2f1 = partial(call_sparse, "hyp2f1") hyp1f1 = partial(call_sparse, "hyp1f1") hyperu = partial(call_sparse, "hyperu") hyp0f1 = partial(_call_bin, "hyp0f1") ellip_harm = partial(call_sparse, "ellip_harm") ellip_harm_2 = partial(call_sparse, "ellip_harm_2") ellip_normal = partial(call_sparse, "ellip_normal") ellipk = partial(_call_unary, "ellipk") ellipkm1 = partial(_call_unary, "ellipkm1") ellipkinc = partial(_call_bin, "ellipkinc") ellipe = partial(_call_unary, "ellipe") ellipeinc = partial(_call_bin, "ellipeinc") elliprc = partial(_call_bin, "elliprc") elliprd = partial(call_sparse, "elliprd") elliprf = partial(call_sparse, "elliprf") elliprg = partial(call_sparse, "elliprg") elliprj = partial(call_sparse, "elliprj") airy = partial(_call_unary, "airy") airye = partial(_call_unary, "airye") itairy = partial(_call_unary, "itairy") def equal(a, b, **_): try: return a == b except TypeError: return b == a def not_equal(a, b, **_): try: return a != b except TypeError: return b != a def less(a, b, **_): try: return a < b except TypeError: return b > a def less_equal(a, b, **_): try: return a <= b except TypeError: return b >= a def greater(a, b, **_): try: return a > b except TypeError: return b < a def greater_equal(a, b, **_): try: return a >= b except TypeError: return b <= a def logical_and(a, b, **kw): return _call_bin("logical_and", a, b, **kw) def logical_or(a, b, **kw): return _call_bin("logical_or", a, b, **kw) def logical_xor(a, b, **kw): return _call_bin("logical_xor", a, b, **kw) def logical_not(x, **kw): return _call_unary("logical_not", x, **kw) def isclose(a, b, **kw): return _call_bin("isclose", a, b, **kw) def bitwise_and(a, b, **_): try: return a & b except TypeError: return b & a def bitwise_or(a, b, **_): try: return a | b except TypeError: return b | a def bitwise_xor(a, b, **_): try: return operator.xor(a, b) except TypeError: return operator.xor(b, a) def invert(x, **_): return ~x def left_shift(a, b, **_): return a << b def right_shift(a, b, **_): return a >> b def sin(x, **kw): return _call_unary("sin", x, **kw) def cos(x, **kw): return _call_unary("cos", x, **kw) def tan(x, **kw): return _call_unary("tan", x, **kw) def arcsin(x, **kw): return _call_unary("arcsin", x, **kw) def arccos(x, **kw): return _call_unary("arccos", x, **kw) def arctan(x, **kw): return _call_unary("arctan", x, **kw) def arctan2(a, b, **kw): return _call_bin("arctan2", a, b, **kw) def hypot(a, b, **kw): return _call_bin("hypot", a, b, **kw) def sinh(x, **kw): return _call_unary("sinh", x, **kw) def cosh(x, **kw): return _call_unary("cosh", x, **kw) def tanh(x, **kw): return _call_unary("tanh", x, **kw) def arcsinh(x, **kw): return _call_unary("arcsinh", x, **kw) def arccosh(x, **kw): return _call_unary("arccosh", x, **kw) def around(x, **kw): return _call_unary("around", x, **kw) def arctanh(x, **kw): return _call_unary("arctanh", x, **kw) def deg2rad(x, **kw): return _call_unary("deg2rad", x, **kw) def rad2deg(x, **kw): return _call_unary("rad2deg", x, **kw) def angle(x, **kw): return _call_unary("angle", x, **kw) def isinf(x, **kw): return _call_unary("isinf", x, **kw) def isnan(x, **kw): return _call_unary("isnan", x, **kw) def signbit(x, **kw): return _call_unary("signbit", x, **kw) def dot(a, b, sparse=True, **_): from .core import issparse if not issparse(a): ret = a.dot(b) if not sparse: return ret else: xps = get_sparse_module(ret) return SparseNDArray(xps.csr_matrix(ret), shape=ret.shape) return a.dot(b, sparse=sparse) def tensordot(a, b, axes=2, sparse=True): if isinstance(axes, Iterable): a_axes, b_axes = axes else: a_axes = tuple(range(a.ndim - 1, a.ndim - axes - 1, -1)) b_axes = tuple(range(0, axes)) if isinstance(a_axes, Iterable): a_axes = tuple(a_axes) else: a_axes = (a_axes,) if isinstance(b_axes, Iterable): b_axes = tuple(b_axes) else: b_axes = (b_axes,) if a_axes == (a.ndim - 1,) and b_axes == (b.ndim - 2,): return dot(a, b, sparse=sparse) if a.ndim == b.ndim == 2: if a_axes == (a.ndim - 1,) and b_axes == (b.ndim - 1,): # inner product of multiple dims return dot(a, b.T, sparse=sparse) if a.ndim == 1 or b.ndim == 1: return dot(a, b, sparse=sparse) raise NotImplementedError def matmul(a, b, sparse=True, **_): return dot(a, b, sparse=sparse) def concatenate(tensors, axis=0): return reduce(lambda a, b: _call_bin("concatenate", a, b, axis=axis), tensors) def transpose(tensor, axes=None): return _call_unary("transpose", tensor, axes=axes) def swapaxes(tensor, axis1, axis2): return _call_unary("swapaxes", tensor, axis1, axis2) def sum(tensor, axis=None, **kw): return _call_unary("sum", tensor, axis=axis, **kw) def prod(tensor, axis=None, **kw): return _call_unary("prod", tensor, axis=axis, **kw) def amax(tensor, axis=None, **kw): return _call_unary("amax", tensor, axis=axis, **kw) max = amax def amin(tensor, axis=None, **kw): return _call_unary("amin", tensor, axis=axis, **kw) min = amin def all(tensor, axis=None, **kw): return _call_unary("all", tensor, axis=axis, **kw) def any(tensor, axis=None, **kw): return _call_unary("any", tensor, axis=axis, **kw) def mean(tensor, axis=None, **kw): return _call_unary("mean", tensor, axis=axis, **kw) def nansum(tensor, axis=None, **kw): return _call_unary("nansum", tensor, axis=axis, **kw) def nanprod(tensor, axis=None, **kw): return _call_unary("nanprod", tensor, axis=axis, **kw) def nanmax(tensor, axis=None, **kw): return _call_unary("nanmax", tensor, axis=axis, **kw) def nanmin(tensor, axis=None, **kw): return _call_unary("nanmin", tensor, axis=axis, **kw) def argmax(tensor, axis=None, **kw): return _call_unary("argmax", tensor, axis=axis, **kw) def nanargmax(tensor, axis=None, **kw): return _call_unary("nanargmax", tensor, axis=axis, **kw) def argmin(tensor, axis=None, **kw): return _call_unary("argmin", tensor, axis=axis, **kw) def nanargmin(tensor, axis=None, **kw): return _call_unary("nanargmin", tensor, axis=axis, **kw) def var(tensor, axis=None, **kw): return _call_unary("var", tensor, axis=axis, **kw) def cumsum(tensor, axis=None, **kw): return _call_unary("cumsum", tensor, axis=axis, **kw) def cumprod(tensor, axis=None, **kw): return _call_unary("cumprod", tensor, axis=axis, **kw) def nancumsum(tensor, axis=None, **kw): return _call_unary("nancumsum", tensor, axis=axis, **kw) def nancumprod(tensor, axis=None, **kw): return _call_unary("nancumprod", tensor, axis=axis, **kw) def count_nonzero(tensor, axis=None, **kw): return _call_unary("count_nonzero", tensor, axis=axis, **kw) def maximum(a, b, **kw): return _call_bin("maximum", a, b, **kw) def minimum(a, b, **kw): return _call_bin("minimum", a, b, **kw) def fmax(a, b, **kw): return _call_bin("fmax", a, b, **kw) def fmin(a, b, **kw): return _call_bin("fmin", a, b, **kw) def floor(x, **kw): return _call_unary("floor", x, **kw) def ceil(x, **kw): return _call_unary("ceil", x, **kw) def trunc(x, **kw): return _call_unary("trunc", x, **kw) def degrees(x, **kw): return _call_unary("degrees", x, **kw) def radians(x, **kw): return _call_unary("radians", x, **kw) def clip(a, a_max, a_min, **kw): from .core import get_array_module if hasattr(a, "clip"): res = getattr(a, "clip")(a_max, a_min, **kw) else: xp = get_array_module(a) res = getattr(xp, "clip")(a, a_max, a_min, **kw) if res is NotImplemented: raise NotImplementedError return res def iscomplex(x, **kw): return _call_unary("iscomplex", x, **kw) def real(x, **_): return x.real def imag(x, **_): return x.imag def fix(x, **kw): return _call_unary("fix", x, **kw) def i0(x, **kw): return _call_unary("i0", x, **kw) def nan_to_num(x, **kw): return _call_unary("nan_to_num", x, **kw) def copysign(a, b, **kw): return _call_bin("copysign", a, b, **kw) def nextafter(a, b, **kw): return _call_bin("nextafter", a, b, **kw) def spacing(x, **kw): return _call_unary("spacing", x, **kw) def ldexp(a, b, **kw): return _call_bin("ldexp", a, b, **kw) def frexp(x, **kw): return _call_unary("frexp", x, **kw) def modf(x, **kw): return _call_unary("modf", x, **kw) def sinc(x, **kw): return _call_unary("sinc", x, **kw) def isfinite(x, **kw): return _call_unary("isfinite", x, **kw) def isreal(x, **kw): return _call_unary("isreal", x, **kw) def isfortran(x, **kw): return call_sparse("isfortran", x, **kw) def where(cond, x, y): if any([i.ndim not in (0, 2) for i in (cond, x, y)]): raise NotImplementedError from .matrix import where as matrix_where return matrix_where(cond, x, y) def digitize(x, bins, right=False): return _call_unary("digitize", x, bins, right) def repeat(a, repeats, axis=None): return _call_unary("repeat", a, repeats, axis=axis) def fill_diagonal(a, val, wrap=False): return _call_unary("fill_diagonal", a, val, wrap=wrap) def unique(a, return_index=False, return_inverse=False, return_counts=False, axis=None): return _call_unary( "unique", a, return_index=return_index, return_inverse=return_inverse, return_counts=return_counts, axis=axis, ) def zeros(shape, dtype=float, gpu=False): if len(shape) == 2: from .matrix import zeros_sparse_matrix return zeros_sparse_matrix(shape, dtype=dtype, gpu=gpu) raise NotImplementedError def ones_like(x): from .core import get_array_module return get_array_module(x).ones(x.shape) def diag(v, k=0, gpu=False): assert v.ndim in {1, 2} from .matrix import diag_sparse_matrix return diag_sparse_matrix(v, k=k, gpu=gpu) def eye(N, M=None, k=0, dtype=float, gpu=False): from .matrix import eye_sparse_matrix return eye_sparse_matrix(N, M=M, k=k, dtype=dtype, gpu=gpu) def triu(m, k=0, gpu=False): if m.ndim == 2: from .matrix import triu_sparse_matrix return triu_sparse_matrix(m, k=k, gpu=gpu) raise NotImplementedError def tril(m, k=0, gpu=False): if m.ndim == 2: from .matrix import tril_sparse_matrix return tril_sparse_matrix(m, k=k, gpu=gpu) raise NotImplementedError def lu(m): from .matrix import lu_sparse_matrix return lu_sparse_matrix(m) def solve_triangular(a, b, lower=False, sparse=True): from .matrix import solve_triangular_sparse_matrix return solve_triangular_sparse_matrix(a, b, lower=lower, sparse=sparse) def block(arrs): arr = arrs[0] while isinstance(arr, list): arr = arr[0] if arr.ndim != 2: # pragma: no cover raise NotImplementedError from .matrix import block return block(arrs)