#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import, print_function

import bisect
import calendar
import codecs
import copy
import glob
import hmac
import logging
import math
import multiprocessing
import os
import random
import re
import shutil
import string
import struct
import sys
import threading
import time
import traceback
import types
import uuid
import warnings
import xml.dom.minidom
from base64 import b64encode
from datetime import date, datetime, timedelta
from email.utils import formatdate, parsedate_tz
from hashlib import md5, sha1

try:
    from collections.abc import Hashable, Iterable, Mapping
except ImportError:
    from collections import Hashable, Mapping, Iterable

from . import compat, options
from .compat import FixedOffset, getargspec, parsedate_to_datetime, six, utc
from .lib.monotonic import monotonic

try:
    import zoneinfo
except ImportError:
    zoneinfo = None
try:
    import pytz
except ImportError:
    pytz = None

try:
    from .src.utils_c import (
        CMillisecondsConverter,
        to_binary,
        to_lower_str,
        to_str,
        to_text,
    )
except ImportError:
    CMillisecondsConverter = to_str = to_text = to_binary = to_lower_str = None

TEMP_TABLE_PREFIX = "tmp_pyodps_"
if six.PY3:  # make flake8 happy
    unicode = str

_IS_WINDOWS = sys.platform.lower().startswith("win")

logger = logging.getLogger(__name__)

notset = object()


def deprecated(msg, cond=None):
    def _decorator(func):
        """This is a decorator which can be used to mark functions
        as deprecated. It will result in a warning being emmitted
        when the function is used."""

        @six.wraps(func)
        def _new_func(*args, **kwargs):
            warn_msg = "Call to deprecated function %s." % func.__name__
            if isinstance(msg, six.string_types):
                warn_msg += " " + msg
            if cond is None or cond():
                warnings.warn(warn_msg, category=DeprecationWarning, stacklevel=2)
            return func(*args, **kwargs)

        return _new_func

    if isinstance(msg, (types.FunctionType, types.MethodType)):
        return _decorator(msg)
    return _decorator


class ExperimentalNotAllowed(Exception):
    pass


def experimental(msg, cond=None):
    warn_cache = set()

    real_cond = cond
    if callable(cond):
        cond_spec = getargspec(cond)
        if not cond_spec.args and not cond_spec.varargs:
            real_cond = lambda *_, **__: cond()

    def _decorator(func):
        @six.wraps(func)
        def _new_func(*args, **kwargs):
            if real_cond is None or real_cond(*args, **kwargs):
                if not str_to_bool(os.environ.get("PYODPS_EXPERIMENTAL", "true")):
                    err_msg = (
                        "Calling to experimental method %s is denied." % func.__name__
                    )
                    if isinstance(msg, six.string_types):
                        err_msg += " " + msg
                    raise ExperimentalNotAllowed(err_msg)

                if func not in warn_cache:
                    warn_msg = "Call to experimental function %s." % func.__name__
                    if isinstance(msg, six.string_types):
                        warn_msg += " " + msg

                    warnings.warn(warn_msg, category=FutureWarning, stacklevel=2)
                    warn_cache.add(func)
            return func(*args, **kwargs)

        # intentionally eliminate __doc__ for Volume 2
        _new_func.__doc__ = None
        return _new_func

    if isinstance(msg, (types.FunctionType, types.MethodType)):
        return _decorator(msg)
    return _decorator


def fixed_writexml(self, writer, indent="", addindent="", newl=""):
    # indent = current indentation
    # addindent = indentation to add to higher levels
    # newl = newline string
    writer.write(indent + "<" + self.tagName)

    attrs = self._get_attributes()
    a_names = compat.lkeys(attrs)
    a_names.sort()

    for a_name in a_names:
        writer.write(" %s=\"" % a_name)
        xml.dom.minidom._write_data(writer, attrs[a_name].value)
        writer.write("\"")
    if self.childNodes:
        if (
            len(self.childNodes) == 1
            and self.childNodes[0].nodeType == xml.dom.minidom.Node.TEXT_NODE
        ):
            writer.write(">")
            self.childNodes[0].writexml(writer, "", "", "")
            writer.write("</%s>%s" % (self.tagName, newl))
            return
        writer.write(">%s" % (newl))
        for node in self.childNodes:
            node.writexml(writer, indent + addindent, addindent, newl)
        writer.write("%s</%s>%s" % (indent, self.tagName, newl))
    else:
        writer.write("/>%s" % (newl))


# replace minidom's function with ours
xml.dom.minidom.Element.writexml = fixed_writexml
xml_fixed = lambda: None


def hmac_sha1(secret, data):
    return b64encode(hmac.new(secret, data, sha1).digest())


def md5_hexdigest(data):
    return md5(to_binary(data)).hexdigest()


def rshift(val, n):
    return val >> n if val >= 0 else (val + 0x100000000) >> n


def long_bits_to_double(bits):
    """
    @type  bits: long
    @param bits: the bit pattern in IEEE 754 layout

    @rtype:  float
    @return: the double-precision floating-point value corresponding
             to the given bit pattern C{bits}.
    """
    return struct.unpack("d", struct.pack("Q", bits))[0]


def double_to_raw_long_bits(value):
    """
    @type  value: float
    @param value: a Python (double-precision) float value

    @rtype: long
    @return: the IEEE 754 bit representation (64 bits as a long integer)
             of the given double-precision floating-point value.
    """
    # pack double into 64 bits, then unpack as long int
    return struct.unpack("Q", struct.pack("d", float(value)))[0]


def camel_to_underline(name):
    s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
    return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()


def underline_to_capitalized(name):
    return "".join([s[0].upper() + s[1 : len(s)] for s in name.strip("_").split("_")])


def underline_to_camel(name):
    parts = name.split("_")
    return parts[0] + "".join(v.title() for v in parts[1:])


def long_to_int(value):
    if value & 0x80000000:
        return int(-((value ^ 0xFFFFFFFF) + 1))
    else:
        return int(value)


def int_to_uint(v):
    if v < 0:
        return int(v + 2**32)
    return v


def long_to_uint(value):
    v = long_to_int(value)
    return int_to_uint(v)


def stringify_expt():
    lines = traceback.format_exception(*sys.exc_info())
    return "\n".join(lines)


def str_to_printable(field_name, auto_quote=True):
    if not field_name:
        return field_name

    escapes = {
        "\\": "\\\\",
        '\'': '\\\'',
        '"': '\\"',
        "\a": "\\a",
        "\b": "\\b",
        "\f": "\\f",
        "\n": "\\n",
        "\r": "\\r",
        "\t": "\\t",
        "\v": "\\v",
        " ": " ",
    }

    def _escape_char(c):
        if c in escapes:
            return escapes[c]
        elif c < " ":
            return "\\x%02x" % ord(c)
        else:
            return c

    need_escape = lambda c: c <= " " or c in escapes
    if any(need_escape(c) for c in field_name):
        ret = "".join(_escape_char(ch) for ch in field_name)
        return '"' + ret + '"' if auto_quote else ret
    return field_name


def indent(text, n_spaces):
    if n_spaces <= 0:
        return text
    block = " " * n_spaces
    return "\n".join((block + it) if len(it) > 0 else it for it in text.split("\n"))


def parse_rfc822(s, use_legacy_parsedate=None):
    if s is None:
        return None

    use_legacy_parsedate = (
        use_legacy_parsedate
        if use_legacy_parsedate is not None
        else options.use_legacy_parsedate
    )
    if use_legacy_parsedate:
        date_tuple = parsedate_tz(s)
        return datetime(*date_tuple[:6])

    time_obj = parsedate_to_datetime(s)
    if time_obj.tzinfo is None:
        return time_obj
    time_obj = time_obj.astimezone(utc)
    gmt_ts = calendar.timegm(time_obj.timetuple())
    return datetime.fromtimestamp(gmt_ts)


def gen_rfc822(dt=None, localtime=False, usegmt=False):
    if dt is not None:
        t = time.mktime(dt.timetuple())
    else:
        t = None
    return formatdate(t, localtime=localtime, usegmt=usegmt)


try:
    _antique_mills = time.mktime(datetime(1928, 1, 1).timetuple()) * 1000
except OverflowError:
    _antique_mills = (
        int((datetime(1928, 1, 1) - datetime.utcfromtimestamp(0)).total_seconds())
        * 1000
    )
_min_datetime_mills = int(
    (datetime.min - datetime.utcfromtimestamp(0)).total_seconds() * 1000
)
_antique_errmsg = (
    "Date older than 1928-01-01 and may contain errors. "
    "Ignore this error by configuring `options.allow_antique_date` to True."
)
_min_datetime_errmsg = (
    "Date exceed range Python can handle. If you are reading data with tunnel, read "
    "the value as None by setting options.tunnel.overflow_date_as_none to True, "
    "or convert the value into strings with SQL before processing them with Python."
)


def to_timestamp(dt, local_tz=None, is_dst=False):
    return int(to_milliseconds(dt, local_tz=local_tz, is_dst=is_dst) / 1000.0)


class MillisecondsConverter(object):
    _inst_cache = dict()

    @classmethod
    def _get_tz(cls, tz):
        if isinstance(tz, six.string_types):
            if pytz is None and zoneinfo is None:
                raise ImportError(
                    "Package `pytz` is needed when specifying string-format time zone."
                )
            else:
                return get_zone_from_name(tz)
        else:
            return tz

    def __new__(cls, local_tz=None, is_dst=False):
        cache_key = (cls, local_tz, is_dst)
        if cache_key in cls._inst_cache:
            return cls._inst_cache[cache_key]
        o = super(MillisecondsConverter, cls).__new__(cls)
        o.__init__(local_tz, is_dst)
        cls._inst_cache[cache_key] = o
        return o

    def _windows_mktime(self, timetuple):
        if self._local_tz:
            fromtimestamp = datetime.fromtimestamp
            mktime = time.mktime
        else:
            fromtimestamp = datetime.utcfromtimestamp
            mktime = calendar.timegm

        if timetuple[0] > 1970:
            return mktime(timetuple)
        dt = datetime(*timetuple[:6])
        epoch = fromtimestamp(0)
        return int((dt - epoch).total_seconds())

    def _windows_fromtimestamp(self, seconds):
        fromtimestamp = (
            datetime.fromtimestamp if self._local_tz else datetime.utcfromtimestamp
        )
        if seconds >= 0:
            return fromtimestamp(seconds)
        epoch = fromtimestamp(0)
        return epoch + timedelta(seconds=seconds)

    def __init__(self, local_tz=None, is_dst=False):
        self._local_tz = local_tz if local_tz is not None else options.local_timezone
        if self._local_tz is None:
            self._local_tz = True
        self._use_default_tz = type(self._local_tz) is bool

        self._allow_antique = options.allow_antique_date or _antique_mills is None
        self._is_dst = is_dst

        if self._local_tz:
            self._mktime = time.mktime
            self._fromtimestamp = datetime.fromtimestamp
        else:
            self._mktime = calendar.timegm
            self._fromtimestamp = datetime.utcfromtimestamp

        if _IS_WINDOWS:
            # special logic for negative timestamp under Windows
            self._mktime = self._windows_mktime
            self._fromtimestamp = self._windows_fromtimestamp

        self._tz = self._get_tz(self._local_tz) if not self._use_default_tz else None
        if hasattr(self._tz, "localize"):
            self._localize = lambda dt: self._tz.localize(dt, is_dst=is_dst)
        else:
            self._localize = lambda dt: dt.replace(tzinfo=self._tz)

    def to_milliseconds(self, dt):
        from .errors import DatetimeOverflowError

        if not self._use_default_tz and dt.tzinfo is None:
            dt = self._localize(dt)

        if dt.tzinfo is not None:
            mills = int(
                (
                    calendar.timegm(dt.astimezone(compat.utc).timetuple())
                    + dt.microsecond / 1000000.0
                )
                * 1000
            )
        else:
            mills = int(
                (self._mktime(dt.timetuple()) + dt.microsecond / 1000000.0) * 1000
            )

        if not self._allow_antique and mills < _antique_mills:
            raise DatetimeOverflowError(_antique_errmsg)
        return mills

    def from_milliseconds(self, milliseconds):
        from .errors import DatetimeOverflowError

        if not self._allow_antique and milliseconds < _antique_mills:
            raise DatetimeOverflowError(_antique_errmsg)
        if milliseconds < _min_datetime_mills:
            raise DatetimeOverflowError(_min_datetime_errmsg)

        seconds = compat.long_type(math.floor(milliseconds / 1000))
        microseconds = compat.long_type(milliseconds) % 1000 * 1000
        if self._use_default_tz:
            return self._fromtimestamp(seconds).replace(microsecond=microseconds)
        else:
            return (
                datetime.utcfromtimestamp(seconds)
                .replace(microsecond=microseconds, tzinfo=compat.utc)
                .astimezone(self._tz)
            )


def to_milliseconds(dt, local_tz=None, is_dst=False, force_py=False):
    cls = CMillisecondsConverter
    if force_py or cls is None:
        cls = MillisecondsConverter
    f = cls(local_tz, is_dst=is_dst)
    return f.to_milliseconds(dt)


def to_days(dt):
    start_day = date(1970, 1, 1)
    return (dt - start_day).days


def to_date(delta_day):
    start_day = date(1970, 1, 1)
    return start_day + timedelta(delta_day)


def to_datetime(milliseconds, local_tz=None, force_py=False):
    cls = CMillisecondsConverter
    if force_py or cls is None:
        cls = MillisecondsConverter
    f = cls(local_tz)
    return f.from_milliseconds(milliseconds)


def strptime_with_tz(dt, format="%Y-%m-%d %H:%M:%S"):
    try:
        return datetime.strptime(dt, format)
    except ValueError:
        naive_date_str, _, offset_str = dt.rpartition(" ")
        naive_dt = datetime.strptime(naive_date_str, format)
        offset = int(offset_str[-4:-2]) * 60 + int(offset_str[-2:])
        if offset_str[0] == "-":
            offset = -offset
        return naive_dt.replace(tzinfo=FixedOffset(offset))


if to_binary is None or to_text is None or to_str is None or to_lower_str is None:

    def to_binary(text, encoding="utf-8"):
        if text is None:
            return text
        if isinstance(text, six.text_type):
            return text.encode(encoding)
        elif isinstance(text, (six.binary_type, bytearray)):
            return bytes(text)
        else:
            return str(text).encode(encoding) if six.PY3 else str(text)

    def to_text(binary, encoding="utf-8"):
        if binary is None:
            return binary
        if isinstance(binary, (six.binary_type, bytearray)):
            return binary.decode(encoding)
        elif isinstance(binary, six.text_type):
            return binary
        else:
            return str(binary) if six.PY3 else str(binary).decode(encoding)

    def to_str(text, encoding="utf-8"):
        return (
            to_text(text, encoding=encoding)
            if six.PY3
            else to_binary(text, encoding=encoding)
        )

    def to_lower_str(s, encoding="utf-8"):
        if s is None:
            return None
        return to_str(s, encoding).lower()


def get_zone_from_name(tzname):
    return zoneinfo.ZoneInfo(tzname) if zoneinfo else pytz.timezone(tzname)


def get_zone_name(tz):
    return getattr(tz, "key", None) or getattr(tz, "zone", None)


# fix encoding conversion problem under windows
if sys.platform == "win32":

    def _replace_default_encoding(func):
        def _fun(s, encoding=None):
            return func(s, encoding=encoding or options.display.encoding)

        _fun.__name__ = func.__name__
        _fun.__doc__ = func.__doc__
        return _fun

    to_binary = _replace_default_encoding(to_binary)
    to_text = _replace_default_encoding(to_text)
    to_str = _replace_default_encoding(to_str)


def is_lambda(f):
    lam = lambda: 0
    return isinstance(f, type(lam)) and f.__name__ == lam.__name__


def str_to_kv(string, typ=None):
    d = dict()
    for pair in string.split(","):
        k, v = pair.split(":", 1)
        if typ:
            v = typ(v)
        d[k] = v
    return d


def interval_select(val, intervals, targets):
    return targets[bisect.bisect_left(intervals, val)]


def is_namedtuple(obj):
    return isinstance(obj, tuple) and hasattr(obj, "_fields")


def str_to_bool(s):
    if isinstance(s, bool) or s is None:
        return s
    s = s.lower().strip()
    if s == "true":
        return True
    elif s == "false":
        return False
    else:
        raise ValueError(s)


def bool_to_str(s):
    return str(s).lower()


def get_root_dir():
    return os.path.dirname(sys.modules[__name__].__file__)


def load_text_file(path):
    file_path = get_root_dir() + path
    if not os.path.exists(file_path):
        return None
    with codecs.open(file_path, encoding="utf-8") as f:
        inp_file = f.read()
        f.close()
    return inp_file


def load_file_paths(pattern):
    file_path = os.path.normpath(
        os.path.dirname(sys.modules[__name__].__file__) + pattern
    )
    return glob.glob(file_path)


def load_static_file_paths(path):
    return load_file_paths("/static/" + path)


def load_text_files(pattern, func=None):
    file_path = os.path.normpath(
        os.path.dirname(sys.modules[__name__].__file__) + pattern
    )
    content_dict = dict()
    for file_path in glob.glob(file_path):
        _, fn = os.path.split(file_path)
        if func and not func(fn):
            continue
        with codecs.open(file_path, encoding="utf-8") as f:
            content_dict[fn] = f.read()
            f.close()
    return content_dict


def load_static_text_file(path):
    return load_text_file("/static/" + path)


def load_internal_static_text_file(path):
    return load_text_file("/internal/static/" + path)


def load_static_text_files(pattern, func=None):
    return load_text_files("/static/" + pattern, func)


def init_progress_bar(val=1, use_console=True):
    try:
        from traitlets import TraitError

        ipython = True
    except ImportError:
        try:
            from IPython.utils.traitlets import TraitError

            ipython = True
        except ImportError:
            ipython = False

    from .console import ProgressBar, is_widgets_available

    if not ipython:
        bar = ProgressBar(val) if use_console else None
    else:
        try:
            if is_widgets_available():
                bar = ProgressBar(val, True)
            else:
                bar = ProgressBar(val) if use_console else None
        except TraitError:
            bar = ProgressBar(val) if use_console else None

    return bar


def init_progress_ui(val=1, lock=False, use_console=True, mock=False):
    from .ui import ProgressGroupUI, html_notify

    progress_group = None
    bar = None
    if not mock and is_main_thread():
        bar = init_progress_bar(val=val, use_console=use_console)
        if bar and bar._ipython_widget:
            try:
                progress_group = ProgressGroupUI(bar._ipython_widget)
            except:
                pass

    _lock = threading.Lock() if lock else None

    def ui_method(func):
        def inner(*args, **kwargs):
            if mock:
                return
            if _lock:
                with _lock:
                    return func(*args, **kwargs)
            else:
                return func(*args, **kwargs)

        return inner

    class ProgressUI(object):
        @ui_method
        def update(self, value=None):
            if bar:
                bar.update(value=value)

        @ui_method
        def current_progress(self):
            if bar and hasattr(bar, "_current_value"):
                return bar._current_value

        @ui_method
        def inc(self, value):
            if bar and hasattr(bar, "_current_value"):
                current_val = bar._current_value
                bar.update(current_val + value)

        @ui_method
        def status(self, prefix, suffix="", clear_keys=False):
            if progress_group:
                if clear_keys:
                    progress_group.clear_keys()
                progress_group.prefix = prefix
                progress_group.suffix = suffix

        @ui_method
        def add_keys(self, keys):
            if progress_group:
                progress_group.add_keys(keys)

        @ui_method
        def remove_keys(self, keys):
            if progress_group:
                progress_group.remove_keys(keys)

        @ui_method
        def update_group(self):
            if progress_group:
                progress_group.update()

        @ui_method
        def notify(self, msg):
            html_notify(msg)

        @ui_method
        def close(self):
            if bar:
                bar.close()
            if progress_group:
                progress_group.close()

    return ProgressUI()


def escape_odps_string(src):
    trans_dict = {
        "\b": r"\b",
        "\t": r"\t",
        "\n": r"\n",
        "\r": r"\r",
        "'": r"\'",
        '"': r"\"",
        "\\": r"\\",
        "\0": r"\0",
    }
    return "".join(trans_dict[ch] if ch in trans_dict else ch for ch in src)


def to_odps_scalar(s):
    try:
        from pandas import Timestamp as pd_Timestamp
    except ImportError:
        pd_Timestamp = type("DummyType", (object,), {})

    if s is None or (isinstance(s, float) and math.isnan(s)):
        return "NULL"
    if isinstance(s, six.string_types):
        return "'%s'" % escape_odps_string(s)
    elif isinstance(s, (datetime, pd_Timestamp)):
        microsec = s.microsecond
        nanosec = getattr(s, "nanosecond", 0)
        if microsec or nanosec:
            s = s.strftime("%Y-%m-%d %H:%M:%S.%f") + ("%03d" % nanosec)
            out_type = "TIMESTAMP"
        else:
            s = s.strftime("%Y-%m-%d %H:%M:%S")
            out_type = "DATETIME"
        return "CAST('%s' AS %s)" % (escape_odps_string(s), out_type)
    return str(s)


def replace_sql_parameters(sql, ns):
    param_re = re.compile(r":([a-zA-Z_][a-zA-Z0-9_]*)")

    def is_numeric(val):
        return isinstance(val, (six.integer_types, float))

    def is_sequence(val):
        return isinstance(val, (tuple, set, list))

    def format_string(val):
        return "'{0}'".format(escape_odps_string(str(val)))

    def format_numeric(val):
        return repr(val)

    def format_sequence(val):
        escaped = [
            format_numeric(v) if is_numeric(v) else format_string(v) for v in val
        ]
        return "({0})".format(", ".join(escaped))

    def replace(matched):
        name = matched.group(1)
        val = ns.get(name)
        if val is None:
            return matched.group(0)
        elif is_numeric(val):
            return format_numeric(val)
        elif is_sequence(val):
            return format_sequence(val)
        else:
            return format_string(val)

    return param_re.sub(replace, sql)


def is_main_process():
    return "main" in multiprocessing.current_process().name.lower()


survey_calls = dict()


def survey(func):
    @six.wraps(func)
    def wrapped(*args, **kwargs):
        arg_spec = getargspec(func)

        if "self" in arg_spec.args:
            func_cls = args[0].__class__
        else:
            func_cls = None

        if func_cls:
            func_sig = ".".join([func_cls.__module__, func_cls.__name__, func.__name__])
        else:
            func_sig = ".".join([func.__module__, func.__name__])

        add_survey_call(func_sig)
        return func(*args, **kwargs)

    return wrapped


def add_survey_call(group):
    if any(r.search(group) is not None for r in options.skipped_survey_regexes):
        return
    if group not in survey_calls:
        survey_calls[group] = 1
    else:
        survey_calls[group] += 1


def get_survey_calls():
    return copy.copy(survey_calls)


def clear_survey_calls():
    survey_calls.clear()


def require_package(pack_name):
    def _decorator(func):
        try:
            __import__(pack_name, fromlist=[""])
            return func
        except ImportError:
            return None

    return _decorator


def gen_repr_object(**kwargs):
    obj = type("ReprObject", (), {})
    text = kwargs.pop("text", None)
    if six.PY2 and isinstance(text, unicode):
        text = text.encode("utf-8")
    if text:
        setattr(obj, "text", text)
        setattr(obj, "__repr__", lambda self: text)
    for k, v in six.iteritems(kwargs):
        setattr(obj, k, v)
        setattr(obj, "_repr_{0}_".format(k), lambda self: v)
    if "gv" in kwargs:
        try:
            from graphviz import Source

            setattr(
                obj,
                "_repr_svg_",
                lambda self: Source(self._repr_gv_(), encoding="utf-8")._repr_svg_(),
            )
        except ImportError:
            pass
    return obj()


def build_pyodps_dir(*args):
    default_dir = os.path.join(os.path.expanduser("~"), ".pyodps")
    if sys.platform == "win32" and "APPDATA" in os.environ:
        win_default_dir = os.path.join(os.environ["APPDATA"], "pyodps")
        if os.path.exists(default_dir):
            shutil.move(default_dir, win_default_dir)
        default_dir = win_default_dir
    home_dir = os.environ.get("PYODPS_DIR") or default_dir
    return os.path.join(home_dir, *args)


def object_getattr(obj, attr, default=None):
    try:
        return object.__getattribute__(obj, attr)
    except AttributeError:
        return default


def attach_internal(cls):
    cls_path = cls.__module__ + "." + cls.__name__
    try:
        from .internal.core import MIXIN_TARGETS

        mixin_cls = MIXIN_TARGETS[cls_path]
        for method_name in dir(mixin_cls):
            if method_name.startswith("_"):
                continue
            att = getattr(mixin_cls, method_name)
            if six.PY2 and type(att).__name__ in ("instancemethod", "method"):
                att = att.__func__
            setattr(cls, method_name, att)
        return cls
    except ImportError:
        return cls


def is_main_thread():
    if hasattr(threading, "main_thread"):
        return threading.current_thread() is threading.main_thread()
    return threading.current_thread().__class__.__name__ == "_MainThread"


def write_log(msg):
    """Legacy method to keep compatibility"""
    logger.info(msg)


def split_quoted(s, delimiter=",", maxsplit=0):
    pattern = r"""((?:[^""" + delimiter + r""""']|"[^"]*"|'[^']*')+)"""
    return re.split(pattern, s, maxsplit=maxsplit)[1::2]


def gen_temp_table():
    return "%s%s" % (TEMP_TABLE_PREFIX, str(uuid.uuid4()).replace("-", "_"))


def hashable(obj):
    if isinstance(obj, Hashable):
        items = obj
    elif isinstance(obj, Mapping):
        items = type(obj)((k, hashable(v)) for k, v in six.iteritems(obj))
    elif isinstance(obj, Iterable):
        items = tuple(hashable(item) for item in obj)
    else:
        raise TypeError(type(obj))

    return items


def thread_local_attribute(thread_local_name, default_value=None):
    attr_name = "_local_attr_%d" % random.randint(0, 99999999)

    def _get_thread_local(self):
        thread_local = getattr(self, thread_local_name, None)
        if thread_local is None:
            setattr(self, thread_local_name, threading.local())
            thread_local = getattr(self, thread_local_name)
        return thread_local

    def _getter(self):
        thread_local = _get_thread_local(self)
        if not hasattr(thread_local, attr_name) and callable(default_value):
            setattr(thread_local, attr_name, default_value())
        return getattr(thread_local, attr_name)

    def _setter(self, value):
        thread_local = _get_thread_local(self)
        setattr(thread_local, attr_name, value)

    return property(fget=_getter, fset=_setter)


def call_with_retry(func, *args, **kwargs):
    retry_num = 0
    retry_times = kwargs.pop("retry_times", options.retry_times)
    retry_timeout = kwargs.pop("retry_timeout", None)
    delay = kwargs.pop("delay", options.retry_delay)
    reset_func = kwargs.pop("reset_func", None)
    exc_type = kwargs.pop("exc_type", BaseException)
    allow_interrupt = kwargs.pop("allow_interrupt", True)
    no_raise = kwargs.pop("no_raise", False)

    start_time = monotonic() if retry_timeout is not None else None
    while True:
        try:
            return func(*args, **kwargs)
        except exc_type as ex:
            retry_num += 1
            time.sleep(delay)
            if allow_interrupt and isinstance(ex, KeyboardInterrupt):
                raise
            if (retry_times is not None and retry_num > retry_times) or (
                retry_timeout is not None
                and start_time is not None
                and monotonic() - start_time > retry_timeout
            ):
                if no_raise:
                    return sys.exc_info()
                raise
            if callable(reset_func):
                reset_func()


def get_id(n):
    if hasattr(n, "_node_id"):
        return n._node_id

    return id(n)


def strip_if_str(s):
    if isinstance(s, six.binary_type):
        s = to_str(s)
    if isinstance(s, six.string_types):
        return s.strip()
    return s


def with_wait_argument(func):
    func_spec = (
        compat.getfullargspec(func)
        if compat.getfullargspec
        else compat.getargspec(func)
    )
    args_set = set(func_spec.args)
    if hasattr(func_spec, "kwonlyargs"):
        args_set |= set(func_spec.kwonlyargs or [])
    has_varkw = (
        getattr(func_spec, "varkw", None) or getattr(func_spec, "keywords", None)
    ) is not None

    try:
        async_index = func_spec.args.index("async_")
    except ValueError:
        async_index = None

    @six.wraps(func)
    def wrapped(*args, **kwargs):
        if async_index is not None and len(args) >= async_index + 1:
            warnings.warn(
                "Please use async_ as a keyword argument, like obj.func(async_=True)",
                DeprecationWarning,
                stacklevel=2,
            )
            add_survey_call(".".join([func.__module__, func.__name__, "async_"]))
        elif "wait" in kwargs:
            kwargs["async_"] = not kwargs.pop("wait")
        elif "async" in kwargs:
            kwargs["async_"] = kwargs.pop("async")

        if not has_varkw and kwargs:
            no_args_match = [key for key in kwargs if key not in args_set]
            if no_args_match:
                warnings.warn(
                    "Arguments %s not supported, ignored by default. "
                    "Please check argument spellings." % (", ".join(no_args_match)),
                    stacklevel=2,
                )
            for arg in no_args_match:
                kwargs.pop(arg, None)
        return func(*args, **kwargs)

    return wrapped


def split_sql_by_semicolon(sql_statement):
    sql_statement = sql_statement.replace("\r\n", "\n").replace("\r", "\n")
    left_brackets = {"}": "{", "]": "[", ")": "("}

    def cut_statement(stmt_start, stmt_end=None):
        stmt_end = stmt_end or len(sql_statement)
        parts = []
        left = stmt_start
        for comm_start, comm_end in comment_blocks:
            if comm_end <= stmt_start:
                continue
            if comm_start > stmt_end:
                break
            parts.append(sql_statement[left:comm_start])
            left = comm_end
        parts.append(sql_statement[left:stmt_end])
        combined_lines = "".join(parts).splitlines()
        return "\n".join(line.rstrip() for line in combined_lines).strip()

    start, pos = 0, 0
    statements = []
    comment_sign = None
    comment_pos = None
    comment_blocks = []
    quote_sign = None
    bracket_stack = []
    while pos < len(sql_statement):
        ch = sql_statement[pos]
        dch = sql_statement[pos : pos + 2] if pos + 1 < len(sql_statement) else None
        if quote_sign is None and comment_sign is None:
            if ch in ("{", "[", "("):
                # start of brackets
                bracket_stack.append(ch)
                pos += 1
            elif ch in ("}", "]", ")"):
                # end of brackets
                assert bracket_stack[-1] == left_brackets[ch]
                bracket_stack.pop()
                pos += 1
            elif ch in ('"', "'", "`"):
                # start of quote
                quote_sign = ch
                pos += 1
            elif dch in ("--", "/*"):
                # start of line or block comments
                comment_sign = dch
                comment_pos = pos
                pos += 2
            elif ch == ";" and not bracket_stack:
                # semicolon without brackets, quotes and comments
                part_statement = cut_statement(start, pos + 1)
                if part_statement and part_statement != ";":
                    statements.append(part_statement)
                pos += 1
                start = pos
            else:
                pos += 1
        elif quote_sign is not None and ch == quote_sign:
            quote_sign = None
            pos += 1
        elif quote_sign is not None and ch == "\\":
            # skip escape char
            pos += 2
        elif comment_sign == "--" and ch == "\n":
            # line comment ends
            comment_sign = None
            comment_blocks.append((comment_pos, pos))
            pos += 1
        elif comment_sign == "/*" and dch == "*/":
            # block comment ends
            comment_sign = None
            comment_blocks.append((comment_pos, pos + 2))
            pos += 2
        else:
            pos += 1
    part_statement = cut_statement(start)
    if part_statement and part_statement != ";":
        statements.append(part_statement)
    return statements


def strip_backquotes(name_str):
    name_str = name_str.strip()
    if not name_str.startswith("`") or not name_str.endswith("`"):
        return name_str
    return name_str[1:-1].replace("``", "`")


def split_backquoted(s, sep=" ", maxsplit=-1):
    if "`" in sep:
        raise ValueError("Separator cannot contain backquote")
    results = []
    pos = 0
    cur_token = ""
    quoted = False
    while pos < len(s):
        ch = s[pos]
        if ch == "`":
            quoted = not quoted
            cur_token += ch
            pos += 1
        elif (
            not quoted
            and (maxsplit < 0 or len(results) < maxsplit)
            and s[pos : pos + len(sep)] == sep
        ):
            results.append(cur_token)
            cur_token = ""
            pos += len(sep)
        else:
            cur_token += ch
            pos += 1
    results.append(cur_token)
    return results


_convert_host_hash = (
    "6fe9b6c02efc24159c09863c1aadffb5",
    "9b75728355160c10b5eb75e4bf105a76",
)
_default_host_hash = "c7f4116fbf820f99284dcc89f340b372"


def get_default_logview_endpoint(default_endpoint, odps_endpoint):
    try:
        if odps_endpoint is None:
            return default_endpoint

        parsed_host = compat.urlparse(odps_endpoint).hostname
        if parsed_host is None:
            return default_endpoint

        default_host = compat.urlparse(default_endpoint).hostname
        hashed_host = md5_hexdigest(to_binary(parsed_host))
        hashed_default_host = md5_hexdigest(to_binary(default_host))

        if (
            hashed_default_host == _default_host_hash
            and hashed_host in _convert_host_hash
        ):
            suffix = r"\1" + string.ascii_letters[1::-1] * 2 + parsed_host[-8:]
            return re.sub(r"([a-z])[a-z]{3}\.[a-z]{3}$", suffix, default_endpoint)
        return default_endpoint
    except:
        return default_endpoint


def is_job_insight_available(odps_endpoint):
    try:
        if odps_endpoint is None:
            return False

        parsed_host = compat.urlparse(odps_endpoint).hostname
        if parsed_host is None:
            return False
        hashed_host = md5_hexdigest(to_binary(parsed_host))
        return hashed_host not in _convert_host_hash and (
            parsed_host.endswith(".aliyun-inc.com")
            or parsed_host.endswith(".aliyun.com")
        )
    except:
        return False


def _import_version_function():
    try:
        try:
            from importlib.metadata import version
        except ImportError:
            from importlib_metadata import version
        return version
    except ImportError:
        try:
            import pkg_resources
        except ImportError:
            return None
        return lambda x: pkg_resources.get_distribution(x).version


def get_package_version(pack_name, import_name=None):
    version_func = _import_version_function()
    if version_func is not None:
        try:
            return version_func(pack_name)
        except:
            pass
    import_name = pack_name if import_name is True else import_name
    if import_name:
        mod = __import__(import_name)
        return mod.__version__
    return None


def show_versions():  # pragma: no cover
    import locale
    import platform

    uname_result = platform.uname()
    language_code, encoding = locale.getlocale()

    results = {
        "python": ".".join([str(i) for i in sys.version_info]),
        "python-bits": struct.calcsize("P") * 8,
        "OS": uname_result.system,
        "OS-release": uname_result.release,
        "Version": uname_result.version,
        "machine": uname_result.machine,
        "processor": uname_result.processor,
        "byteorder": sys.byteorder,
        "LC_ALL": os.environ.get("LC_ALL"),
        "LANG": os.environ.get("LANG"),
        "LOCALE": {"language-code": language_code, "encoding": encoding},
    }

    try:
        from .src import crc32c_c  # noqa: F401

        results["USE_CLIB"] = True
    except ImportError:
        results["USE_CLIB"] = False

    try:
        from . import internal  # noqa: F401

        results["HAS_INTERNAL"] = True
    except ImportError:
        results["HAS_INTERNAL"] = False

    packages = {
        "pyodps": "odps",
        "requests": "requests",
        "urllib3": "urllib3",
        "charset_normalizer": "charset_normalizer",
        "chardet": "chardet",
        "idna": "idna",
        "certifi": "certifi",
        "numpy": "numpy",
        "scipy": "scipy",
        "pandas": "pandas",
        "matplotlib": "matplotlib",
        "sqlalchemy": "sqlalchemy",
        "pytz": "pytz",
        "python-dateutil": "dateutil",
        "IPython": "IPython",
        "maxframe": "maxframe",
    }
    for pack_name, pack_imp in packages.items():
        try:
            results[pack_name] = get_package_version(pack_name, pack_imp)
        except (ImportError, AttributeError):
            pass
    key_size = 1 + max(len(key) for key in results.keys())
    for key, val in results.items():
        print(key + " " * (key_size - len(key)) + ": " + str(val))


def get_supported_python_tag(align=None):
    if align is None:
        align = options.align_supported_python_tag
    if align:
        if sys.version_info[:2] >= (3, 11):
            return "cp311"
        elif sys.version_info[:2] >= (3, 6):
            return "cp37"
        else:
            return "cp27"
    else:
        return "cp" + str(sys.version_info[0]) + str(sys.version_info[1])


def chain_generator(*iterable):
    for itr in iterable:
        for item in itr:
            yield item
