# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio.events
import concurrent.futures
import contextvars
import dataclasses
import datetime
import enum
import functools
import importlib
import inspect
import io
import itertools
import numbers
import os
import pkgutil
import random
import struct
import sys
import threading
import time
import tokenize as pytokenize
import types
import weakref
import zlib
from collections.abc import Hashable, Mapping
from contextlib import contextmanager
from typing import (
    Any,
    Awaitable,
    Callable,
    Dict,
    Iterable,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    Union,
)

import msgpack
import numpy as np
import pandas as pd
import traitlets
from tornado import httpclient, web
from tornado.simple_httpclient import HTTPTimeoutError

from ._utils import (  # noqa: F401 # pylint: disable=unused-import
    NamedType,
    Timer,
    TypeDispatcher,
    ceildiv,
    get_user_call_point,
    new_random_id,
    register_tokenizer,
    reset_id_random_seed,
    to_binary,
    to_str,
    to_text,
    tokenize,
    tokenize_int,
)
from .lib.version import parse as parse_version
from .typing_ import TileableType, TimeoutType

# make flake8 happy by referencing these imports
NamedType = NamedType
TypeDispatcher = TypeDispatcher
tokenize = tokenize
register_tokenizer = register_tokenizer
ceildiv = ceildiv
reset_id_random_seed = reset_id_random_seed
new_random_id = new_random_id
get_user_call_point = get_user_call_point
_is_ci = (os.environ.get("CI") or "0").lower() in ("1", "true")
pd_release_version: Tuple[int] = parse_version(pd.__version__).release

try:
    from pandas._libs import lib as _pd__libs_lib
    from pandas._libs.lib import NoDefault, no_default

    _raw__reduce__ = type(NoDefault).__reduce__

    def _no_default__reduce__(self):
        if self is not NoDefault:
            return _raw__reduce__(self)
        else:  # pragma: no cover
            return getattr, (_pd__libs_lib, "NoDefault")

    if hasattr(_pd__libs_lib, "_NoDefault"):  # pragma: no cover
        # need to patch __reduce__ to make sure it can be properly unpickled
        type(NoDefault).__reduce__ = _no_default__reduce__
    else:
        # introduced in pandas 1.5.0 : register for pickle compatibility
        _pd__libs_lib._NoDefault = NoDefault
except ImportError:  # pragma: no cover

    class NoDefault(enum.Enum):
        no_default = "NO_DEFAULT"

        def __repr__(self) -> str:
            return "<no_default>"

    no_default = NoDefault.no_default

    try:
        # register for pickle compatibility
        from pandas._libs import lib as _pd__libs_lib

        _pd__libs_lib.NoDefault = NoDefault
    except (ImportError, AttributeError):
        pass

try:
    import pyarrow as pa
except ImportError:
    pa = None

try:
    from pandas import ArrowDtype as PandasArrowDtype  # noqa: F401

    ARROW_DTYPE_NOT_SUPPORTED = False
except ImportError:
    ARROW_DTYPE_NOT_SUPPORTED = True


class classproperty:
    def __init__(self, f):
        self.f = f

    def __get__(self, obj, owner):
        return self.f(owner)


def implements(f: Callable):
    def decorator(g):
        g.__doc__ = f.__doc__
        return g

    return decorator


class AttributeDict(dict):
    def __getattr__(self, item):
        try:
            return self[item]
        except KeyError:
            raise AttributeError(f"'AttributeDict' object has no attribute {item}")


def on_serialize_shape(shape: Tuple[int]):
    if shape:
        return tuple(s if not np.isnan(s) else -1 for s in shape)
    return shape


def on_deserialize_shape(shape: Tuple[int]):
    if shape:
        return tuple(s if s != -1 else np.nan for s in shape)
    return shape


def on_serialize_numpy_type(value: np.dtype):
    if value is pd.NaT:
        value = None
    return value.item() if isinstance(value, np.generic) else value


def on_serialize_nsplits(value: Tuple[Tuple[int]]):
    if value is None:
        return None
    new_nsplits = []
    for dim_splits in value:
        new_nsplits.append(tuple(None if pd.isna(v) else v for v in dim_splits))
    return tuple(new_nsplits)


def has_unknown_shape(*tiled_tileables: TileableType) -> bool:
    for tileable in tiled_tileables:
        if getattr(tileable, "shape", None) is None:
            continue
        if any(pd.isnull(s) for s in tileable.shape):
            return True
        if any(pd.isnull(s) for s in itertools.chain(*tileable.nsplits)):
            return True
    return False


def calc_nsplits(chunk_idx_to_shape: Dict[Tuple[int], Tuple[int]]) -> Tuple[Tuple[int]]:
    """
    Calculate a tiled entity's nsplits.

    Parameters
    ----------
    chunk_idx_to_shape : Dict type, {chunk_idx: chunk_shape}

    Returns
    -------
    nsplits
    """
    ndim = len(next(iter(chunk_idx_to_shape)))
    tileable_nsplits = []
    # for each dimension, record chunk shape whose index is zero on other dimensions
    for i in range(ndim):
        splits = []
        for index, shape in chunk_idx_to_shape.items():
            if all(idx == 0 for j, idx in enumerate(index) if j != i):
                splits.append(shape[i])
        tileable_nsplits.append(tuple(splits))
    return tuple(tileable_nsplits)


def copy_tileables(tileables: List[TileableType], **kwargs):
    inputs = kwargs.pop("inputs", None)
    copy_key = kwargs.pop("copy_key", True)
    copy_id = kwargs.pop("copy_id", True)
    if kwargs:
        raise TypeError(f"got un unexpected keyword argument '{next(iter(kwargs))}'")
    if len(tileables) > 1:
        # cannot handle tileables with different operators here
        # try to copy separately if so
        if len({t.op for t in tileables}) != 1:
            raise TypeError("All tileables' operators should be same.")

    op = tileables[0].op.copy().reset_key()
    if copy_key:
        op._key = tileables[0].op.key
    kws = []
    for t in tileables:
        params = t.params.copy()
        if copy_key:
            params["_key"] = t.key
        if copy_id:
            params["_id"] = t.id
        params.update(t.extra_params)
        kws.append(params)
    inputs = inputs or op.inputs
    return op.new_tileables(inputs, kws=kws, output_limit=len(kws))


def get_dtype(dtype: Union[np.dtype, pd.api.extensions.ExtensionDtype]):
    if pd.api.types.is_extension_array_dtype(dtype):
        return dtype
    elif dtype is pd.Timestamp or dtype is datetime.datetime:
        return np.dtype("datetime64[ns]")
    elif dtype is pd.Timedelta or dtype is datetime.timedelta:
        return np.dtype("timedelta64[ns]")
    else:
        return np.dtype(dtype)


def serialize_serializable(serializable, compress: bool = False):
    from .serialization import serialize

    bio = io.BytesIO()
    header, buffers = serialize(serializable)
    buf_sizes = [getattr(buf, "nbytes", len(buf)) for buf in buffers]
    header[0]["buf_sizes"] = buf_sizes
    s_header = msgpack.dumps(header)
    bio.write(struct.pack("<Q", len(s_header)))
    bio.write(s_header)
    for buf in buffers:
        bio.write(buf)
    ser_graph = bio.getvalue()

    if compress:
        ser_graph = zlib.compress(ser_graph)
    return ser_graph


def deserialize_serializable(ser_serializable: bytes):
    from .serialization import deserialize

    bio = io.BytesIO(ser_serializable)
    s_header_length = struct.unpack("Q", bio.read(8))[0]
    header2 = msgpack.loads(bio.read(s_header_length))
    buffers2 = [bio.read(s) for s in header2[0]["buf_sizes"]]
    return deserialize(header2, buffers2)


def skip_na_call(func: Callable):
    @functools.wraps(func)
    def new_func(x):
        return func(x) if x is not None else None

    return new_func


def url_path_join(*pieces):
    """Join components of url into a relative url

    Use to prevent double slash when joining subpath. This will leave the
    initial and final / in place
    """
    initial = pieces[0].startswith("/")
    final = pieces[-1].endswith("/")
    stripped = [s.strip("/") for s in pieces]
    result = "/".join(s for s in stripped if s)
    if initial:
        result = "/" + result
    if final:
        result = result + "/"
    if result == "//":
        result = "/"
    return result


def random_ports(port: int, n: int):
    """Generate a list of n random ports near the given port.

    The first 5 ports will be sequential, and the remaining n-5 will be
    randomly selected in the range [port-2*n, port+2*n].
    """
    for i in range(min(5, n)):
        yield port + i
    for i in range(n - 5):
        yield max(1, port + random.randint(-2 * n, 2 * n))


def build_temp_table_name(session_id: str, tileable_key: str) -> str:
    return f"tmp_mf_{session_id}_{tileable_key}"


def build_temp_intermediate_table_name(session_id: str, tileable_key: str) -> str:
    temp_table = build_temp_table_name(session_id, tileable_key)
    return f"{temp_table}_intermediate"


def build_session_volume_name(session_id: str) -> str:
    return f"mf_vol_{session_id.replace('-', '_')}"


async def wait_http_response(
    url: str, *, request_timeout: TimeoutType = None, **kwargs
) -> httpclient.HTTPResponse:
    start_time = time.time()
    while request_timeout is None or time.time() - start_time < request_timeout:
        timeout_left = min(10.0, time.time() - start_time) if request_timeout else None
        try:
            return await httpclient.AsyncHTTPClient().fetch(
                url, request_timeout=timeout_left, **kwargs
            )
        except HTTPTimeoutError:
            pass
    raise TimeoutError


def get_handler_timeout_value(handler: web.RequestHandler) -> TimeoutType:
    wait = bool(int(handler.get_argument("wait", "0")))
    timeout = float(handler.get_argument("timeout", "0"))
    if wait and abs(timeout) < 1e-6:
        timeout = None
    elif not wait:
        timeout = 0
    return timeout


def format_timeout_params(timeout: TimeoutType) -> str:
    if timeout is None:
        return "?wait=1"
    elif abs(timeout) < 1e-6:
        return "?wait=0"
    else:
        return f"?wait=1&timeout={timeout}"


_PrimitiveType = TypeVar("_PrimitiveType")


def create_sync_primitive(
    cls: Type[_PrimitiveType], loop: asyncio.AbstractEventLoop
) -> _PrimitiveType:
    """
    Create an asyncio sync primitive (locks, events, etc.)
    in a certain event loop.
    """
    if sys.version_info[1] < 10:
        return cls(loop=loop)

    # From Python3.10 the loop parameter has been removed. We should work around here.
    old_loop = asyncio.get_event_loop()
    try:
        asyncio.set_event_loop(loop)
        primitive = cls()
    finally:
        asyncio.set_event_loop(old_loop)
    return primitive


class ToThreadCancelledError(asyncio.CancelledError):
    def __init__(self, *args, result=None):
        super().__init__(*args)
        self._result = result

    @property
    def result(self):
        return self._result


_ToThreadRetType = TypeVar("_ToThreadRetType")


class ToThreadMixin:
    _thread_pool_size = 1
    _counter = itertools.count().__next__

    def __del__(self):
        if hasattr(self, "_pool"):
            kw = {"wait": False}
            if sys.version_info[:2] >= (3, 9):
                kw["cancel_futures"] = True
            self._pool.shutdown(**kw)

    async def to_thread(
        self,
        func: Callable[..., _ToThreadRetType],
        *args,
        wait_on_cancel: bool = False,
        timeout: float = None,
        **kwargs,
    ) -> _ToThreadRetType:
        if not hasattr(self, "_pool"):
            self._pool = concurrent.futures.ThreadPoolExecutor(
                self._thread_pool_size,
                thread_name_prefix=f"{type(self).__name__}Pool-{self._counter()}",
            )

        loop = asyncio.events.get_running_loop()
        ctx = contextvars.copy_context()
        func_call = functools.partial(ctx.run, func, *args, **kwargs)
        fut = loop.run_in_executor(self._pool, func_call)

        try:
            coro = fut
            if wait_on_cancel:
                coro = asyncio.shield(coro)
            if timeout is not None:
                coro = asyncio.wait_for(coro, timeout)
            return await coro
        except (asyncio.CancelledError, asyncio.TimeoutError) as ex:
            if not wait_on_cancel:
                raise
            result = await fut
            raise ToThreadCancelledError(*ex.args, result=result)

    def ensure_async_call(
        self,
        func: Callable[..., _ToThreadRetType],
        *args,
        wait_on_cancel: bool = False,
        **kwargs,
    ) -> Awaitable[_ToThreadRetType]:
        if asyncio.iscoroutinefunction(func):
            return func(*args, **kwargs)
        return self.to_thread(func, *args, wait_on_cancel=wait_on_cancel, **kwargs)


def config_odps_default_options():
    from odps import options as odps_options

    odps_options.sql.settings = {
        "odps.longtime.instance": "false",
        "odps.sql.session.select.only": "false",
        "metaservice.client.cache.enable": "false",
        "odps.sql.session.result.cache.enable": "false",
        "odps.sql.submit.mode": "script",
        "odps.sql.job.max.time.hours": 72,
    }


def to_hashable(obj: Any) -> Hashable:
    if isinstance(obj, Mapping):
        items = type(obj)((k, to_hashable(v)) for k, v in obj.items())
    elif not isinstance(obj, str) and isinstance(obj, Iterable):
        items = tuple(to_hashable(item) for item in obj)
    elif isinstance(obj, Hashable):
        items = obj
    else:
        raise TypeError(type(obj))
    return items


def estimate_pandas_size(
    pd_obj, max_samples: int = 10, min_sample_rows: int = 100
) -> int:
    if len(pd_obj) <= min_sample_rows or isinstance(pd_obj, pd.RangeIndex):
        return sys.getsizeof(pd_obj)
    if isinstance(pd_obj, pd.MultiIndex):
        # MultiIndex's sample size can't be used to estimate
        return sys.getsizeof(pd_obj)

    from .dataframe.arrays import ArrowDtype

    def _is_fast_dtype(dtype):
        if isinstance(dtype, np.dtype):
            return np.issubdtype(dtype, np.number)
        else:
            return isinstance(dtype, ArrowDtype)

    dtypes = []
    is_series = False
    if isinstance(pd_obj, pd.DataFrame):
        dtypes.extend(pd_obj.dtypes)
        index_obj = pd_obj.index
    elif isinstance(pd_obj, pd.Series):
        dtypes.append(pd_obj.dtype)
        index_obj = pd_obj.index
        is_series = True
    else:
        index_obj = pd_obj

    # handling possible MultiIndex
    if hasattr(index_obj, "dtypes"):
        dtypes.extend(index_obj.dtypes)
    else:
        dtypes.append(index_obj.dtype)

    if all(_is_fast_dtype(dtype) for dtype in dtypes):
        return sys.getsizeof(pd_obj)

    indices = np.sort(np.random.choice(len(pd_obj), size=max_samples, replace=False))
    iloc = pd_obj if isinstance(pd_obj, pd.Index) else pd_obj.iloc
    if isinstance(index_obj, pd.MultiIndex):
        # MultiIndex's sample size is much greater than expected, thus we calculate
        # the size separately.
        index_size = sys.getsizeof(pd_obj.index)
        if is_series:
            sample_frame_size = iloc[indices].memory_usage(deep=True, index=False)
        else:
            sample_frame_size = iloc[indices].memory_usage(deep=True, index=False).sum()
        return index_size + sample_frame_size * len(pd_obj) // max_samples
    else:
        sample_size = sys.getsizeof(iloc[indices])
        return sample_size * len(pd_obj) // max_samples


class ModulePlaceholder:
    def __init__(self, mod_name: str):
        self._mod_name = mod_name

    def _raises(self):
        raise AttributeError(f"{self._mod_name} is required but not installed.")

    def __getattr__(self, key):
        self._raises()

    def __call__(self, *_args, **_kwargs):
        self._raises()


def lazy_import(
    name: str,
    package: str = None,
    globals: Dict = None,  # pylint: disable=redefined-builtin
    locals: Dict = None,  # pylint: disable=redefined-builtin
    rename: str = None,
    placeholder: bool = False,
):
    rename = rename or name
    prefix_name = name.split(".", 1)[0]
    globals = globals or inspect.currentframe().f_back.f_globals

    class LazyModule(object):
        def __init__(self):
            self._on_loads = []

        def __getattr__(self, item):
            if item.startswith("_pytest") or item in ("__bases__", "__test__"):
                raise AttributeError(item)

            real_mod = importlib.import_module(name, package=package)
            if rename in globals:
                globals[rename] = real_mod
            elif locals is not None:
                locals[rename] = real_mod
            ret = getattr(real_mod, item)
            for on_load_func in self._on_loads:
                on_load_func()
            # make sure on_load hooks only executed once
            self._on_loads = []
            return ret

        def add_load_handler(self, func: Callable):
            self._on_loads.append(func)
            return func

    if pkgutil.find_loader(prefix_name) is not None:
        return LazyModule()
    elif placeholder:
        return ModulePlaceholder(prefix_name)
    else:
        return None


def sbytes(x: Any) -> bytes:
    # NB: bytes() in Python 3 has different semantic with Python 2, see: help(bytes)
    from numbers import Number

    if x is None or isinstance(x, Number):
        return bytes(str(x), encoding="ascii")
    elif isinstance(x, list):
        return bytes("[" + ", ".join([str(k) for k in x]) + "]", encoding="utf-8")
    elif isinstance(x, tuple):
        return bytes("(" + ", ".join([str(k) for k in x]) + ")", encoding="utf-8")
    elif isinstance(x, str):
        return bytes(x, encoding="utf-8")
    else:
        return bytes(x)


def is_full_slice(slc: Any) -> bool:
    """Check if the input is a full slice ((:) or (0:))"""
    return (
        isinstance(slc, slice)
        and (slc.start == 0 or slc.start is None)
        and slc.stop is None
        and slc.step is None
    )


_enter_counter = 0
_initial_session = None


def enter_current_session(func: Callable):
    @functools.wraps(func)
    def wrapped(cls, ctx, op):
        from .session import AbstractSession, get_default_session

        global _enter_counter, _initial_session
        # skip in some test cases
        if not hasattr(ctx, "get_current_session"):
            return func(cls, ctx, op)

        with AbstractSession._lock:
            if _enter_counter == 0:
                # to handle nested call, only set initial session
                # in first call
                session = ctx.get_current_session()
                _initial_session = get_default_session()
                session.as_default()
            _enter_counter += 1

        try:
            result = func(cls, ctx, op)
        finally:
            with AbstractSession._lock:
                _enter_counter -= 1
                if _enter_counter == 0:
                    # set previous session when counter is 0
                    if _initial_session:
                        _initial_session.as_default()
                    else:
                        AbstractSession.reset_default()
        return result

    return wrapped


_func_token_cache = weakref.WeakKeyDictionary()


def _get_func_token_values(func):
    if hasattr(func, "__code__"):
        tokens = [func.__code__.co_code]
        if func.__closure__ is not None:
            cvars = tuple([x.cell_contents for x in func.__closure__])
            tokens.append(cvars)
        return tokens
    else:
        tokens = []
        while isinstance(func, functools.partial):
            tokens.extend([func.args, func.keywords])
            func = func.func
        if hasattr(func, "__code__"):
            tokens.extend(_get_func_token_values(func))
        elif isinstance(func, types.BuiltinFunctionType):
            tokens.extend([func.__module__, func.__qualname__])
        else:
            tokens.append(func)
        return tokens


def get_func_token(func):
    try:
        token = _func_token_cache.get(func)
        if token is None:
            fields = _get_func_token_values(func)
            token = tokenize(*fields)
            _func_token_cache[func] = token
        return token
    except TypeError:  # cannot create weak reference to func like 'numpy.ufunc'
        return tokenize(*_get_func_token_values(func))


_io_quiet_local = threading.local()
_io_quiet_lock = threading.Lock()


class _QuietIOWrapper:
    def __init__(self, wrapped):
        self.wrapped = wrapped

    def __getattr__(self, item):
        return getattr(self.wrapped, item)

    def write(self, d):
        if getattr(_io_quiet_local, "is_wrapped", False):
            return 0
        return self.wrapped.write(d)


@contextmanager
def quiet_stdio():
    """Quiets standard outputs when inferring types of functions"""
    with _io_quiet_lock:
        _io_quiet_local.is_wrapped = True
        sys.stdout = _QuietIOWrapper(sys.stdout)
        sys.stderr = _QuietIOWrapper(sys.stderr)

    try:
        yield
    finally:
        with _io_quiet_lock:
            sys.stdout = sys.stdout.wrapped
            sys.stderr = sys.stderr.wrapped
            if not isinstance(sys.stdout, _QuietIOWrapper):
                _io_quiet_local.is_wrapped = False


# from https://github.com/ericvsmith/dataclasses/blob/master/dataclass_tools.py
# released under Apache License 2.0
def dataslots(cls):
    # Need to create a new class, since we can't set __slots__
    #  after a class has been created.

    # Make sure __slots__ isn't already set.
    if "__slots__" in cls.__dict__:  # pragma: no cover
        raise TypeError(f"{cls.__name__} already specifies __slots__")

    # Create a new dict for our new class.
    cls_dict = dict(cls.__dict__)
    field_names = tuple(f.name for f in dataclasses.fields(cls))
    cls_dict["__slots__"] = field_names
    for field_name in field_names:
        # Remove our attributes, if present. They'll still be
        #  available in _MARKER.
        cls_dict.pop(field_name, None)
    # Remove __dict__ itself.
    cls_dict.pop("__dict__", None)
    # And finally create the class.
    qualname = getattr(cls, "__qualname__", None)
    cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
    if qualname is not None:
        cls.__qualname__ = qualname
    return cls


def adapt_docstring(doc: str) -> str:
    """
    Adapt numpy-style docstrings to MaxFrame docstring.

    This util function will add MaxFrame imports, replace object references
    and add execute calls. Note that check is needed after replacement.
    """
    if doc is None:
        return None

    lines = []
    first_prompt = True
    prev_prompt = False
    has_numpy = "np." in doc
    has_pandas = "pd." in doc

    for line in doc.splitlines():
        sp = line.strip()
        if sp.startswith(">>>") or sp.startswith("..."):
            prev_prompt = True
            if first_prompt:
                first_prompt = False
                indent = "".join(itertools.takewhile(lambda x: x in (" ", "\t"), line))
                if has_numpy:
                    lines.extend([indent + ">>> import maxframe.tensor as mt"])
                if has_pandas:
                    lines.extend([indent + ">>> import maxframe.dataframe as md"])
            line = line.replace("np.", "mt.").replace("pd.", "md.")
        elif prev_prompt:
            prev_prompt = False
            if sp:
                lines[-1] += ".execute()"
        lines.append(line)
    return "\n".join(lines)


def stringify_path(path: Union[str, os.PathLike]) -> str:
    """
    Convert *path* to a string or unicode path if possible.
    """
    if isinstance(path, str):
        return path

    # checking whether path implements the filesystem protocol
    try:
        return path.__fspath__()
    except AttributeError:
        raise TypeError("not a path-like object")


_memory_size_indices = {"": 0, "k": 1, "m": 2, "g": 3, "t": 4}


def parse_readable_size(value: Union[str, int, float]) -> Tuple[float, bool]:
    if isinstance(value, numbers.Number):
        return float(value), False

    value = value.strip().lower()
    num_pos = 0
    while num_pos < len(value) and value[num_pos] in "0123456789.-":
        num_pos += 1

    value, suffix = value[:num_pos], value[num_pos:]
    suffix = suffix.strip()
    if suffix.endswith("%"):
        return float(value) / 100, True

    try:
        return float(value) * (1024 ** _memory_size_indices[suffix[:1]]), False
    except (ValueError, KeyError):
        raise ValueError(f"Unknown limitation value: {value}")


def remove_suffix(value: str, suffix: str) -> Tuple[str, bool]:
    """
    Remove a suffix from a given string if it exists.

    Parameters
    ----------
    value : str
        The original string.
    suffix : str
        The suffix to be removed.

    Returns
    -------
    Tuple[str, bool]
        A tuple containing the modified string and a boolean indicating whether the suffix was found.
    """

    # Check if the suffix is an empty string
    if len(suffix) == 0:
        # If the suffix is empty, return the original string with True
        return value, True

    # Check if the length of the value is less than the length of the suffix
    if len(value) < len(suffix):
        # If the value is shorter than the suffix, it cannot have the suffix
        return value, False

    # Check if the suffix matches the end of the value
    match = value.endswith(suffix)

    # If the suffix is found, remove it; otherwise, return the original string
    if match:
        return value[: -len(suffix)], match
    else:
        return value, match


def find_objects(nested: Union[List, Dict], types: Union[Type, Tuple[Type]]) -> List:
    found = []
    stack = [nested]

    while len(stack) > 0:
        it = stack.pop()
        if isinstance(it, types):
            found.append(it)
            continue

        if isinstance(it, (list, tuple, set)):
            stack.extend(list(it)[::-1])
        elif isinstance(it, dict):
            stack.extend(list(it.values())[::-1])

    return found


def replace_objects(nested: Union[List, Dict], mapping: Mapping) -> Union[List, Dict]:
    if not mapping:
        return nested

    if isinstance(nested, dict):
        vals = list(nested.values())
    else:
        vals = list(nested)

    new_vals = []
    for val in vals:
        if isinstance(val, (dict, list, tuple, set)):
            new_val = replace_objects(val, mapping)
        else:
            try:
                new_val = mapping.get(val, val)
            except TypeError:
                new_val = val
        new_vals.append(new_val)

    if isinstance(nested, dict):
        return type(nested)((k, v) for k, v in zip(nested.keys(), new_vals))
    else:
        return type(nested)(new_vals)


def trait_from_env(
    trait_name: str, env: str, trait: Optional[traitlets.TraitType] = None
):
    if trait is None:
        prev_locals = inspect.stack()[1].frame.f_locals
        trait = prev_locals[trait_name]

    default_value = trait.default_value
    sub_trait: traitlets.TraitType = getattr(trait, "_trait", None)

    def default_value_simple(self):
        env_val = os.getenv(env, default_value)
        if isinstance(env_val, (str, bytes)):
            return trait.from_string(env_val)
        return env_val

    def default_value_list(self):
        env_val = os.getenv(env, default_value)
        if env_val is None or isinstance(env_val, traitlets.Sentinel):
            return env_val

        parts = env_val.split(",") if env_val else []
        if sub_trait:
            return [sub_trait.from_string(s) for s in parts]
        else:
            return parts

    if isinstance(trait, traitlets.List):
        default_value_fun = default_value_list
    else:  # pragma: no cover
        default_value_fun = default_value_simple

    default_value_fun.__name__ = trait_name + "_default"
    return traitlets.default(trait_name)(default_value_fun)


def relay_future(
    dest: Union[asyncio.Future, concurrent.futures.Future],
    src: Union[asyncio.Future, concurrent.futures.Future],
) -> None:
    def cb(fut: Union[asyncio.Future, concurrent.futures.Future]):
        try:
            dest.set_result(fut.result())
        except BaseException as ex:
            dest.set_exception(ex)

    src.add_done_callback(cb)


_arrow_type_constructors = {}
if pa:
    _arrow_type_constructors = {
        "bool": pa.bool_,
        "list": lambda x: pa.list_(dict(x)["item"]),
        "map": lambda x: pa.map_(*x),
        "struct": pa.struct,
        "fixed_size_binary": pa.binary,
        "halffloat": pa.float16,
        "float": pa.float32,
        "double": pa.float64,
        "decimal": pa.decimal128,
    }
    _plain_arrow_types = """
    null
    int8 int16 int32 int64
    uint8 uint16 uint32 uint64
    float16 float32 float64
    date32 date64
    decimal128 decimal256
    string utf8 binary
    time32 time64 duration timestamp
    month_day_nano_interval
    """
    for _type_name in _plain_arrow_types.split():
        try:
            _arrow_type_constructors[_type_name] = getattr(pa, _type_name)
        except AttributeError:  # pragma: no cover
            pass


def arrow_type_from_str(type_str: str) -> pa.DataType:
    """
    Convert arrow type representations (for inst., list<item: int64>)
    into arrow DataType instances
    """
    # enable consecutive brackets to be tokenized
    type_str = type_str.replace("<", "< ").replace(">", " >")
    token_iter = pytokenize.tokenize(io.BytesIO(type_str.encode()).readline)
    value_stack, op_stack = [], []

    def _pop_make_type(with_args: bool = False, combined: bool = True) -> None:
        """
        Pops tops of value stacks, creates a DataType instance and push back

        Parameters
        ----------
            with_args: bool
                if True, will contain next item (parameter list) in
                the value stack as parameters
            combined: bool
                if True, will use first element of the top of the value stack
                in DataType constructors
        """
        args = () if not with_args else (value_stack.pop(-1),)
        if not combined:
            args = args[0]
        type_name = value_stack.pop(-1)
        if isinstance(type_name, pa.DataType):
            value_stack.append(type_name)
        elif type_name in _arrow_type_constructors:
            value_stack.append(_arrow_type_constructors[type_name](*args))
        else:  # pragma: no cover
            value_stack.append(type_name)

    for token in token_iter:
        if token.type == pytokenize.OP:
            if token.string == ":":
                op_stack.append(token.string)
            elif token.string == ",":
                # gather previous sub-types
                if op_stack[-1] in ("<", ":"):
                    _pop_make_type()

                if op_stack[-1] == ":":
                    # parameterized sub-types need to be represented as tuples
                    op_stack.pop(-1)
                    values = value_stack[-2:]
                    value_stack = value_stack[:-2]
                    value_stack.append(tuple(values))
                # put generated item into the parameter list
                val = value_stack.pop(-1)
                value_stack[-1].append(val)
            elif token.string in ("<", "[", "("):
                # pushes an empty parameter list for future use
                value_stack.append([])
                op_stack.append(token.string)
            elif token.string in (")", "]"):
                # put generated item into the parameter list
                val = value_stack.pop(-1)
                value_stack[-1].append(val)
                # make DataType (i.e., fixed_size_binary / decimal) given args
                _pop_make_type(with_args=True, combined=False)
                op_stack.pop(-1)
            elif token.string == ">":
                _pop_make_type()

                if op_stack[-1] == ":":
                    # parameterized sub-types need to be represented as tuples
                    op_stack.pop(-1)
                    values = value_stack[-2:]
                    value_stack = value_stack[:-2]
                    value_stack.append(tuple(values))

                # put generated item into the parameter list
                val = value_stack.pop(-1)
                value_stack[-1].append(val)
                # make DataType (i.e., list / map / struct) given args
                _pop_make_type(True)
                op_stack.pop(-1)
        elif token.type == pytokenize.NAME:
            value_stack.append(token.string)
        elif token.type == pytokenize.NUMBER:
            value_stack.append(int(token.string))
        elif token.type == pytokenize.ENDMARKER:
            # make final type
            _pop_make_type()
    if len(value_stack) > 1:
        raise ValueError(f"Cannot parse type {type_str}")
    return value_stack[-1]


def get_python_tag():
    # todo add implementation suffix for non-GIL tags when PEP703 is ready
    version_info = sys.version_info
    return f"cp{version_info[0]}{version_info[1]}"


def get_item_if_scalar(val: Any) -> Any:
    if isinstance(val, np.ndarray) and val.shape == ():
        return val.item()
    return val


def collect_leaf_operators(root) -> List[Type]:
    result = []

    def _collect(op_type):
        if len(op_type.__subclasses__()) == 0:
            result.append(op_type)
        for subclass in op_type.__subclasses__():
            _collect(subclass)

    _collect(root)
    return result


@contextmanager
def sync_pyodps_options():
    from odps.config import option_context as pyodps_option_context

    from .config import options

    with pyodps_option_context() as cfg:
        cfg.local_timezone = options.local_timezone
        if options.session.enable_schema:
            cfg.enable_schema = options.session.enable_schema
        yield


def str_to_bool(s: Optional[str]) -> Optional[bool]:
    return s.lower().strip() in ("true", "1") if s is not None else None


def is_empty(val):
    if isinstance(val, (pd.DataFrame, pd.Series, pd.Index)):
        return val.empty
    return not bool(val)


def get_default_table_properties():
    return {"storagestrategy": "archive"}
