odps/utils.py (716 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import, print_function import bisect import calendar import codecs import copy import glob import hmac import logging import math import multiprocessing import os import random import re import shutil import string import struct import sys import threading import time import traceback import types import uuid import warnings import xml.dom.minidom from base64 import b64encode from datetime import date, datetime, timedelta from email.utils import formatdate, parsedate_tz from hashlib import md5, sha1 try: from collections.abc import Hashable, Iterable, Mapping except ImportError: from collections import Hashable, Mapping, Iterable from . import compat, options from .compat import FixedOffset, getargspec, parsedate_to_datetime, six, utc from .lib.monotonic import monotonic try: import zoneinfo except ImportError: zoneinfo = None try: import pytz except ImportError: pytz = None try: from .src.utils_c import ( CMillisecondsConverter, to_binary, to_lower_str, to_str, to_text, ) except ImportError: CMillisecondsConverter = to_str = to_text = to_binary = to_lower_str = None TEMP_TABLE_PREFIX = "tmp_pyodps_" if six.PY3: # make flake8 happy unicode = str _IS_WINDOWS = sys.platform.lower().startswith("win") logger = logging.getLogger(__name__) notset = object() def deprecated(msg, cond=None): def _decorator(func): """This is a decorator which can be used to mark functions as deprecated. It will result in a warning being emmitted when the function is used.""" @six.wraps(func) def _new_func(*args, **kwargs): warn_msg = "Call to deprecated function %s." % func.__name__ if isinstance(msg, six.string_types): warn_msg += " " + msg if cond is None or cond(): warnings.warn(warn_msg, category=DeprecationWarning, stacklevel=2) return func(*args, **kwargs) return _new_func if isinstance(msg, (types.FunctionType, types.MethodType)): return _decorator(msg) return _decorator class ExperimentalNotAllowed(Exception): pass def experimental(msg, cond=None): warn_cache = set() real_cond = cond if callable(cond): cond_spec = getargspec(cond) if not cond_spec.args and not cond_spec.varargs: real_cond = lambda *_, **__: cond() def _decorator(func): @six.wraps(func) def _new_func(*args, **kwargs): if real_cond is None or real_cond(*args, **kwargs): if not str_to_bool(os.environ.get("PYODPS_EXPERIMENTAL", "true")): err_msg = ( "Calling to experimental method %s is denied." % func.__name__ ) if isinstance(msg, six.string_types): err_msg += " " + msg raise ExperimentalNotAllowed(err_msg) if func not in warn_cache: warn_msg = "Call to experimental function %s." % func.__name__ if isinstance(msg, six.string_types): warn_msg += " " + msg warnings.warn(warn_msg, category=FutureWarning, stacklevel=2) warn_cache.add(func) return func(*args, **kwargs) # intentionally eliminate __doc__ for Volume 2 _new_func.__doc__ = None return _new_func if isinstance(msg, (types.FunctionType, types.MethodType)): return _decorator(msg) return _decorator def fixed_writexml(self, writer, indent="", addindent="", newl=""): # indent = current indentation # addindent = indentation to add to higher levels # newl = newline string writer.write(indent + "<" + self.tagName) attrs = self._get_attributes() a_names = compat.lkeys(attrs) a_names.sort() for a_name in a_names: writer.write(" %s=\"" % a_name) xml.dom.minidom._write_data(writer, attrs[a_name].value) writer.write("\"") if self.childNodes: if ( len(self.childNodes) == 1 and self.childNodes[0].nodeType == xml.dom.minidom.Node.TEXT_NODE ): writer.write(">") self.childNodes[0].writexml(writer, "", "", "") writer.write("</%s>%s" % (self.tagName, newl)) return writer.write(">%s" % (newl)) for node in self.childNodes: node.writexml(writer, indent + addindent, addindent, newl) writer.write("%s</%s>%s" % (indent, self.tagName, newl)) else: writer.write("/>%s" % (newl)) # replace minidom's function with ours xml.dom.minidom.Element.writexml = fixed_writexml xml_fixed = lambda: None def hmac_sha1(secret, data): return b64encode(hmac.new(secret, data, sha1).digest()) def md5_hexdigest(data): return md5(to_binary(data)).hexdigest() def rshift(val, n): return val >> n if val >= 0 else (val + 0x100000000) >> n def long_bits_to_double(bits): """ @type bits: long @param bits: the bit pattern in IEEE 754 layout @rtype: float @return: the double-precision floating-point value corresponding to the given bit pattern C{bits}. """ return struct.unpack("d", struct.pack("Q", bits))[0] def double_to_raw_long_bits(value): """ @type value: float @param value: a Python (double-precision) float value @rtype: long @return: the IEEE 754 bit representation (64 bits as a long integer) of the given double-precision floating-point value. """ # pack double into 64 bits, then unpack as long int return struct.unpack("Q", struct.pack("d", float(value)))[0] def camel_to_underline(name): s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() def underline_to_capitalized(name): return "".join([s[0].upper() + s[1 : len(s)] for s in name.strip("_").split("_")]) def underline_to_camel(name): parts = name.split("_") return parts[0] + "".join(v.title() for v in parts[1:]) def long_to_int(value): if value & 0x80000000: return int(-((value ^ 0xFFFFFFFF) + 1)) else: return int(value) def int_to_uint(v): if v < 0: return int(v + 2**32) return v def long_to_uint(value): v = long_to_int(value) return int_to_uint(v) def stringify_expt(): lines = traceback.format_exception(*sys.exc_info()) return "\n".join(lines) def str_to_printable(field_name, auto_quote=True): if not field_name: return field_name escapes = { "\\": "\\\\", '\'': '\\\'', '"': '\\"', "\a": "\\a", "\b": "\\b", "\f": "\\f", "\n": "\\n", "\r": "\\r", "\t": "\\t", "\v": "\\v", " ": " ", } def _escape_char(c): if c in escapes: return escapes[c] elif c < " ": return "\\x%02x" % ord(c) else: return c need_escape = lambda c: c <= " " or c in escapes if any(need_escape(c) for c in field_name): ret = "".join(_escape_char(ch) for ch in field_name) return '"' + ret + '"' if auto_quote else ret return field_name def indent(text, n_spaces): if n_spaces <= 0: return text block = " " * n_spaces return "\n".join((block + it) if len(it) > 0 else it for it in text.split("\n")) def parse_rfc822(s, use_legacy_parsedate=None): if s is None: return None use_legacy_parsedate = ( use_legacy_parsedate if use_legacy_parsedate is not None else options.use_legacy_parsedate ) if use_legacy_parsedate: date_tuple = parsedate_tz(s) return datetime(*date_tuple[:6]) time_obj = parsedate_to_datetime(s) if time_obj.tzinfo is None: return time_obj time_obj = time_obj.astimezone(utc) gmt_ts = calendar.timegm(time_obj.timetuple()) return datetime.fromtimestamp(gmt_ts) def gen_rfc822(dt=None, localtime=False, usegmt=False): if dt is not None: t = time.mktime(dt.timetuple()) else: t = None return formatdate(t, localtime=localtime, usegmt=usegmt) try: _antique_mills = time.mktime(datetime(1928, 1, 1).timetuple()) * 1000 except OverflowError: _antique_mills = ( int((datetime(1928, 1, 1) - datetime.utcfromtimestamp(0)).total_seconds()) * 1000 ) _min_datetime_mills = int( (datetime.min - datetime.utcfromtimestamp(0)).total_seconds() * 1000 ) _antique_errmsg = ( "Date older than 1928-01-01 and may contain errors. " "Ignore this error by configuring `options.allow_antique_date` to True." ) _min_datetime_errmsg = ( "Date exceed range Python can handle. If you are reading data with tunnel, read " "the value as None by setting options.tunnel.overflow_date_as_none to True, " "or convert the value into strings with SQL before processing them with Python." ) def to_timestamp(dt, local_tz=None, is_dst=False): return int(to_milliseconds(dt, local_tz=local_tz, is_dst=is_dst) / 1000.0) class MillisecondsConverter(object): _inst_cache = dict() @classmethod def _get_tz(cls, tz): if isinstance(tz, six.string_types): if pytz is None and zoneinfo is None: raise ImportError( "Package `pytz` is needed when specifying string-format time zone." ) else: return get_zone_from_name(tz) else: return tz def __new__(cls, local_tz=None, is_dst=False): cache_key = (cls, local_tz, is_dst) if cache_key in cls._inst_cache: return cls._inst_cache[cache_key] o = super(MillisecondsConverter, cls).__new__(cls) o.__init__(local_tz, is_dst) cls._inst_cache[cache_key] = o return o def _windows_mktime(self, timetuple): if self._local_tz: fromtimestamp = datetime.fromtimestamp mktime = time.mktime else: fromtimestamp = datetime.utcfromtimestamp mktime = calendar.timegm if timetuple[0] > 1970: return mktime(timetuple) dt = datetime(*timetuple[:6]) epoch = fromtimestamp(0) return int((dt - epoch).total_seconds()) def _windows_fromtimestamp(self, seconds): fromtimestamp = ( datetime.fromtimestamp if self._local_tz else datetime.utcfromtimestamp ) if seconds >= 0: return fromtimestamp(seconds) epoch = fromtimestamp(0) return epoch + timedelta(seconds=seconds) def __init__(self, local_tz=None, is_dst=False): self._local_tz = local_tz if local_tz is not None else options.local_timezone if self._local_tz is None: self._local_tz = True self._use_default_tz = type(self._local_tz) is bool self._allow_antique = options.allow_antique_date or _antique_mills is None self._is_dst = is_dst if self._local_tz: self._mktime = time.mktime self._fromtimestamp = datetime.fromtimestamp else: self._mktime = calendar.timegm self._fromtimestamp = datetime.utcfromtimestamp if _IS_WINDOWS: # special logic for negative timestamp under Windows self._mktime = self._windows_mktime self._fromtimestamp = self._windows_fromtimestamp self._tz = self._get_tz(self._local_tz) if not self._use_default_tz else None if hasattr(self._tz, "localize"): self._localize = lambda dt: self._tz.localize(dt, is_dst=is_dst) else: self._localize = lambda dt: dt.replace(tzinfo=self._tz) def to_milliseconds(self, dt): from .errors import DatetimeOverflowError if not self._use_default_tz and dt.tzinfo is None: dt = self._localize(dt) if dt.tzinfo is not None: mills = int( ( calendar.timegm(dt.astimezone(compat.utc).timetuple()) + dt.microsecond / 1000000.0 ) * 1000 ) else: mills = int( (self._mktime(dt.timetuple()) + dt.microsecond / 1000000.0) * 1000 ) if not self._allow_antique and mills < _antique_mills: raise DatetimeOverflowError(_antique_errmsg) return mills def from_milliseconds(self, milliseconds): from .errors import DatetimeOverflowError if not self._allow_antique and milliseconds < _antique_mills: raise DatetimeOverflowError(_antique_errmsg) if milliseconds < _min_datetime_mills: raise DatetimeOverflowError(_min_datetime_errmsg) seconds = compat.long_type(math.floor(milliseconds / 1000)) microseconds = compat.long_type(milliseconds) % 1000 * 1000 if self._use_default_tz: return self._fromtimestamp(seconds).replace(microsecond=microseconds) else: return ( datetime.utcfromtimestamp(seconds) .replace(microsecond=microseconds, tzinfo=compat.utc) .astimezone(self._tz) ) def to_milliseconds(dt, local_tz=None, is_dst=False, force_py=False): cls = CMillisecondsConverter if force_py or cls is None: cls = MillisecondsConverter f = cls(local_tz, is_dst=is_dst) return f.to_milliseconds(dt) def to_days(dt): start_day = date(1970, 1, 1) return (dt - start_day).days def to_date(delta_day): start_day = date(1970, 1, 1) return start_day + timedelta(delta_day) def to_datetime(milliseconds, local_tz=None, force_py=False): cls = CMillisecondsConverter if force_py or cls is None: cls = MillisecondsConverter f = cls(local_tz) return f.from_milliseconds(milliseconds) def strptime_with_tz(dt, format="%Y-%m-%d %H:%M:%S"): try: return datetime.strptime(dt, format) except ValueError: naive_date_str, _, offset_str = dt.rpartition(" ") naive_dt = datetime.strptime(naive_date_str, format) offset = int(offset_str[-4:-2]) * 60 + int(offset_str[-2:]) if offset_str[0] == "-": offset = -offset return naive_dt.replace(tzinfo=FixedOffset(offset)) if to_binary is None or to_text is None or to_str is None or to_lower_str is None: def to_binary(text, encoding="utf-8"): if text is None: return text if isinstance(text, six.text_type): return text.encode(encoding) elif isinstance(text, (six.binary_type, bytearray)): return bytes(text) else: return str(text).encode(encoding) if six.PY3 else str(text) def to_text(binary, encoding="utf-8"): if binary is None: return binary if isinstance(binary, (six.binary_type, bytearray)): return binary.decode(encoding) elif isinstance(binary, six.text_type): return binary else: return str(binary) if six.PY3 else str(binary).decode(encoding) def to_str(text, encoding="utf-8"): return ( to_text(text, encoding=encoding) if six.PY3 else to_binary(text, encoding=encoding) ) def to_lower_str(s, encoding="utf-8"): if s is None: return None return to_str(s, encoding).lower() def get_zone_from_name(tzname): return zoneinfo.ZoneInfo(tzname) if zoneinfo else pytz.timezone(tzname) def get_zone_name(tz): return getattr(tz, "key", None) or getattr(tz, "zone", None) # fix encoding conversion problem under windows if sys.platform == "win32": def _replace_default_encoding(func): def _fun(s, encoding=None): return func(s, encoding=encoding or options.display.encoding) _fun.__name__ = func.__name__ _fun.__doc__ = func.__doc__ return _fun to_binary = _replace_default_encoding(to_binary) to_text = _replace_default_encoding(to_text) to_str = _replace_default_encoding(to_str) def is_lambda(f): lam = lambda: 0 return isinstance(f, type(lam)) and f.__name__ == lam.__name__ def str_to_kv(string, typ=None): d = dict() for pair in string.split(","): k, v = pair.split(":", 1) if typ: v = typ(v) d[k] = v return d def interval_select(val, intervals, targets): return targets[bisect.bisect_left(intervals, val)] def is_namedtuple(obj): return isinstance(obj, tuple) and hasattr(obj, "_fields") def str_to_bool(s): if isinstance(s, bool) or s is None: return s s = s.lower().strip() if s == "true": return True elif s == "false": return False else: raise ValueError(s) def bool_to_str(s): return str(s).lower() def get_root_dir(): return os.path.dirname(sys.modules[__name__].__file__) def load_text_file(path): file_path = get_root_dir() + path if not os.path.exists(file_path): return None with codecs.open(file_path, encoding="utf-8") as f: inp_file = f.read() f.close() return inp_file def load_file_paths(pattern): file_path = os.path.normpath( os.path.dirname(sys.modules[__name__].__file__) + pattern ) return glob.glob(file_path) def load_static_file_paths(path): return load_file_paths("/static/" + path) def load_text_files(pattern, func=None): file_path = os.path.normpath( os.path.dirname(sys.modules[__name__].__file__) + pattern ) content_dict = dict() for file_path in glob.glob(file_path): _, fn = os.path.split(file_path) if func and not func(fn): continue with codecs.open(file_path, encoding="utf-8") as f: content_dict[fn] = f.read() f.close() return content_dict def load_static_text_file(path): return load_text_file("/static/" + path) def load_internal_static_text_file(path): return load_text_file("/internal/static/" + path) def load_static_text_files(pattern, func=None): return load_text_files("/static/" + pattern, func) def init_progress_bar(val=1, use_console=True): try: from traitlets import TraitError ipython = True except ImportError: try: from IPython.utils.traitlets import TraitError ipython = True except ImportError: ipython = False from .console import ProgressBar, is_widgets_available if not ipython: bar = ProgressBar(val) if use_console else None else: try: if is_widgets_available(): bar = ProgressBar(val, True) else: bar = ProgressBar(val) if use_console else None except TraitError: bar = ProgressBar(val) if use_console else None return bar def init_progress_ui(val=1, lock=False, use_console=True, mock=False): from .ui import ProgressGroupUI, html_notify progress_group = None bar = None if not mock and is_main_thread(): bar = init_progress_bar(val=val, use_console=use_console) if bar and bar._ipython_widget: try: progress_group = ProgressGroupUI(bar._ipython_widget) except: pass _lock = threading.Lock() if lock else None def ui_method(func): def inner(*args, **kwargs): if mock: return if _lock: with _lock: return func(*args, **kwargs) else: return func(*args, **kwargs) return inner class ProgressUI(object): @ui_method def update(self, value=None): if bar: bar.update(value=value) @ui_method def current_progress(self): if bar and hasattr(bar, "_current_value"): return bar._current_value @ui_method def inc(self, value): if bar and hasattr(bar, "_current_value"): current_val = bar._current_value bar.update(current_val + value) @ui_method def status(self, prefix, suffix="", clear_keys=False): if progress_group: if clear_keys: progress_group.clear_keys() progress_group.prefix = prefix progress_group.suffix = suffix @ui_method def add_keys(self, keys): if progress_group: progress_group.add_keys(keys) @ui_method def remove_keys(self, keys): if progress_group: progress_group.remove_keys(keys) @ui_method def update_group(self): if progress_group: progress_group.update() @ui_method def notify(self, msg): html_notify(msg) @ui_method def close(self): if bar: bar.close() if progress_group: progress_group.close() return ProgressUI() def escape_odps_string(src): trans_dict = { "\b": r"\b", "\t": r"\t", "\n": r"\n", "\r": r"\r", "'": r"\'", '"': r"\"", "\\": r"\\", "\0": r"\0", } return "".join(trans_dict[ch] if ch in trans_dict else ch for ch in src) def to_odps_scalar(s): try: from pandas import Timestamp as pd_Timestamp except ImportError: pd_Timestamp = type("DummyType", (object,), {}) if s is None or (isinstance(s, float) and math.isnan(s)): return "NULL" if isinstance(s, six.string_types): return "'%s'" % escape_odps_string(s) elif isinstance(s, (datetime, pd_Timestamp)): microsec = s.microsecond nanosec = getattr(s, "nanosecond", 0) if microsec or nanosec: s = s.strftime("%Y-%m-%d %H:%M:%S.%f") + ("%03d" % nanosec) out_type = "TIMESTAMP" else: s = s.strftime("%Y-%m-%d %H:%M:%S") out_type = "DATETIME" return "CAST('%s' AS %s)" % (escape_odps_string(s), out_type) return str(s) def replace_sql_parameters(sql, ns): param_re = re.compile(r":([a-zA-Z_][a-zA-Z0-9_]*)") def is_numeric(val): return isinstance(val, (six.integer_types, float)) def is_sequence(val): return isinstance(val, (tuple, set, list)) def format_string(val): return "'{0}'".format(escape_odps_string(str(val))) def format_numeric(val): return repr(val) def format_sequence(val): escaped = [ format_numeric(v) if is_numeric(v) else format_string(v) for v in val ] return "({0})".format(", ".join(escaped)) def replace(matched): name = matched.group(1) val = ns.get(name) if val is None: return matched.group(0) elif is_numeric(val): return format_numeric(val) elif is_sequence(val): return format_sequence(val) else: return format_string(val) return param_re.sub(replace, sql) def is_main_process(): return "main" in multiprocessing.current_process().name.lower() survey_calls = dict() def survey(func): @six.wraps(func) def wrapped(*args, **kwargs): arg_spec = getargspec(func) if "self" in arg_spec.args: func_cls = args[0].__class__ else: func_cls = None if func_cls: func_sig = ".".join([func_cls.__module__, func_cls.__name__, func.__name__]) else: func_sig = ".".join([func.__module__, func.__name__]) add_survey_call(func_sig) return func(*args, **kwargs) return wrapped def add_survey_call(group): if any(r.search(group) is not None for r in options.skipped_survey_regexes): return if group not in survey_calls: survey_calls[group] = 1 else: survey_calls[group] += 1 def get_survey_calls(): return copy.copy(survey_calls) def clear_survey_calls(): survey_calls.clear() def require_package(pack_name): def _decorator(func): try: __import__(pack_name, fromlist=[""]) return func except ImportError: return None return _decorator def gen_repr_object(**kwargs): obj = type("ReprObject", (), {}) text = kwargs.pop("text", None) if six.PY2 and isinstance(text, unicode): text = text.encode("utf-8") if text: setattr(obj, "text", text) setattr(obj, "__repr__", lambda self: text) for k, v in six.iteritems(kwargs): setattr(obj, k, v) setattr(obj, "_repr_{0}_".format(k), lambda self: v) if "gv" in kwargs: try: from graphviz import Source setattr( obj, "_repr_svg_", lambda self: Source(self._repr_gv_(), encoding="utf-8")._repr_svg_(), ) except ImportError: pass return obj() def build_pyodps_dir(*args): default_dir = os.path.join(os.path.expanduser("~"), ".pyodps") if sys.platform == "win32" and "APPDATA" in os.environ: win_default_dir = os.path.join(os.environ["APPDATA"], "pyodps") if os.path.exists(default_dir): shutil.move(default_dir, win_default_dir) default_dir = win_default_dir home_dir = os.environ.get("PYODPS_DIR") or default_dir return os.path.join(home_dir, *args) def object_getattr(obj, attr, default=None): try: return object.__getattribute__(obj, attr) except AttributeError: return default def attach_internal(cls): cls_path = cls.__module__ + "." + cls.__name__ try: from .internal.core import MIXIN_TARGETS mixin_cls = MIXIN_TARGETS[cls_path] for method_name in dir(mixin_cls): if method_name.startswith("_"): continue att = getattr(mixin_cls, method_name) if six.PY2 and type(att).__name__ in ("instancemethod", "method"): att = att.__func__ setattr(cls, method_name, att) return cls except ImportError: return cls def is_main_thread(): if hasattr(threading, "main_thread"): return threading.current_thread() is threading.main_thread() return threading.current_thread().__class__.__name__ == "_MainThread" def write_log(msg): """Legacy method to keep compatibility""" logger.info(msg) def split_quoted(s, delimiter=",", maxsplit=0): pattern = r"""((?:[^""" + delimiter + r""""']|"[^"]*"|'[^']*')+)""" return re.split(pattern, s, maxsplit=maxsplit)[1::2] def gen_temp_table(): return "%s%s" % (TEMP_TABLE_PREFIX, str(uuid.uuid4()).replace("-", "_")) def hashable(obj): if isinstance(obj, Hashable): items = obj elif isinstance(obj, Mapping): items = type(obj)((k, hashable(v)) for k, v in six.iteritems(obj)) elif isinstance(obj, Iterable): items = tuple(hashable(item) for item in obj) else: raise TypeError(type(obj)) return items def thread_local_attribute(thread_local_name, default_value=None): attr_name = "_local_attr_%d" % random.randint(0, 99999999) def _get_thread_local(self): thread_local = getattr(self, thread_local_name, None) if thread_local is None: setattr(self, thread_local_name, threading.local()) thread_local = getattr(self, thread_local_name) return thread_local def _getter(self): thread_local = _get_thread_local(self) if not hasattr(thread_local, attr_name) and callable(default_value): setattr(thread_local, attr_name, default_value()) return getattr(thread_local, attr_name) def _setter(self, value): thread_local = _get_thread_local(self) setattr(thread_local, attr_name, value) return property(fget=_getter, fset=_setter) def call_with_retry(func, *args, **kwargs): retry_num = 0 retry_times = kwargs.pop("retry_times", options.retry_times) retry_timeout = kwargs.pop("retry_timeout", None) delay = kwargs.pop("delay", options.retry_delay) reset_func = kwargs.pop("reset_func", None) exc_type = kwargs.pop("exc_type", BaseException) allow_interrupt = kwargs.pop("allow_interrupt", True) no_raise = kwargs.pop("no_raise", False) start_time = monotonic() if retry_timeout is not None else None while True: try: return func(*args, **kwargs) except exc_type as ex: retry_num += 1 time.sleep(delay) if allow_interrupt and isinstance(ex, KeyboardInterrupt): raise if (retry_times is not None and retry_num > retry_times) or ( retry_timeout is not None and start_time is not None and monotonic() - start_time > retry_timeout ): if no_raise: return sys.exc_info() raise if callable(reset_func): reset_func() def get_id(n): if hasattr(n, "_node_id"): return n._node_id return id(n) def strip_if_str(s): if isinstance(s, six.binary_type): s = to_str(s) if isinstance(s, six.string_types): return s.strip() return s def with_wait_argument(func): func_spec = ( compat.getfullargspec(func) if compat.getfullargspec else compat.getargspec(func) ) args_set = set(func_spec.args) if hasattr(func_spec, "kwonlyargs"): args_set |= set(func_spec.kwonlyargs or []) has_varkw = ( getattr(func_spec, "varkw", None) or getattr(func_spec, "keywords", None) ) is not None try: async_index = func_spec.args.index("async_") except ValueError: async_index = None @six.wraps(func) def wrapped(*args, **kwargs): if async_index is not None and len(args) >= async_index + 1: warnings.warn( "Please use async_ as a keyword argument, like obj.func(async_=True)", DeprecationWarning, stacklevel=2, ) add_survey_call(".".join([func.__module__, func.__name__, "async_"])) elif "wait" in kwargs: kwargs["async_"] = not kwargs.pop("wait") elif "async" in kwargs: kwargs["async_"] = kwargs.pop("async") if not has_varkw and kwargs: no_args_match = [key for key in kwargs if key not in args_set] if no_args_match: warnings.warn( "Arguments %s not supported, ignored by default. " "Please check argument spellings." % (", ".join(no_args_match)), stacklevel=2, ) for arg in no_args_match: kwargs.pop(arg, None) return func(*args, **kwargs) return wrapped def split_sql_by_semicolon(sql_statement): sql_statement = sql_statement.replace("\r\n", "\n").replace("\r", "\n") left_brackets = {"}": "{", "]": "[", ")": "("} def cut_statement(stmt_start, stmt_end=None): stmt_end = stmt_end or len(sql_statement) parts = [] left = stmt_start for comm_start, comm_end in comment_blocks: if comm_end <= stmt_start: continue if comm_start > stmt_end: break parts.append(sql_statement[left:comm_start]) left = comm_end parts.append(sql_statement[left:stmt_end]) combined_lines = "".join(parts).splitlines() return "\n".join(line.rstrip() for line in combined_lines).strip() start, pos = 0, 0 statements = [] comment_sign = None comment_pos = None comment_blocks = [] quote_sign = None bracket_stack = [] while pos < len(sql_statement): ch = sql_statement[pos] dch = sql_statement[pos : pos + 2] if pos + 1 < len(sql_statement) else None if quote_sign is None and comment_sign is None: if ch in ("{", "[", "("): # start of brackets bracket_stack.append(ch) pos += 1 elif ch in ("}", "]", ")"): # end of brackets assert bracket_stack[-1] == left_brackets[ch] bracket_stack.pop() pos += 1 elif ch in ('"', "'", "`"): # start of quote quote_sign = ch pos += 1 elif dch in ("--", "/*"): # start of line or block comments comment_sign = dch comment_pos = pos pos += 2 elif ch == ";" and not bracket_stack: # semicolon without brackets, quotes and comments part_statement = cut_statement(start, pos + 1) if part_statement and part_statement != ";": statements.append(part_statement) pos += 1 start = pos else: pos += 1 elif quote_sign is not None and ch == quote_sign: quote_sign = None pos += 1 elif quote_sign is not None and ch == "\\": # skip escape char pos += 2 elif comment_sign == "--" and ch == "\n": # line comment ends comment_sign = None comment_blocks.append((comment_pos, pos)) pos += 1 elif comment_sign == "/*" and dch == "*/": # block comment ends comment_sign = None comment_blocks.append((comment_pos, pos + 2)) pos += 2 else: pos += 1 part_statement = cut_statement(start) if part_statement and part_statement != ";": statements.append(part_statement) return statements def strip_backquotes(name_str): name_str = name_str.strip() if not name_str.startswith("`") or not name_str.endswith("`"): return name_str return name_str[1:-1].replace("``", "`") def split_backquoted(s, sep=" ", maxsplit=-1): if "`" in sep: raise ValueError("Separator cannot contain backquote") results = [] pos = 0 cur_token = "" quoted = False while pos < len(s): ch = s[pos] if ch == "`": quoted = not quoted cur_token += ch pos += 1 elif ( not quoted and (maxsplit < 0 or len(results) < maxsplit) and s[pos : pos + len(sep)] == sep ): results.append(cur_token) cur_token = "" pos += len(sep) else: cur_token += ch pos += 1 results.append(cur_token) return results _convert_host_hash = ( "6fe9b6c02efc24159c09863c1aadffb5", "9b75728355160c10b5eb75e4bf105a76", ) _default_host_hash = "c7f4116fbf820f99284dcc89f340b372" def get_default_logview_endpoint(default_endpoint, odps_endpoint): try: if odps_endpoint is None: return default_endpoint parsed_host = compat.urlparse(odps_endpoint).hostname if parsed_host is None: return default_endpoint default_host = compat.urlparse(default_endpoint).hostname hashed_host = md5_hexdigest(to_binary(parsed_host)) hashed_default_host = md5_hexdigest(to_binary(default_host)) if ( hashed_default_host == _default_host_hash and hashed_host in _convert_host_hash ): suffix = r"\1" + string.ascii_letters[1::-1] * 2 + parsed_host[-8:] return re.sub(r"([a-z])[a-z]{3}\.[a-z]{3}$", suffix, default_endpoint) return default_endpoint except: return default_endpoint def is_job_insight_available(odps_endpoint): try: if odps_endpoint is None: return False parsed_host = compat.urlparse(odps_endpoint).hostname if parsed_host is None: return False hashed_host = md5_hexdigest(to_binary(parsed_host)) return hashed_host not in _convert_host_hash and ( parsed_host.endswith(".aliyun-inc.com") or parsed_host.endswith(".aliyun.com") ) except: return False def _import_version_function(): try: try: from importlib.metadata import version except ImportError: from importlib_metadata import version return version except ImportError: try: import pkg_resources except ImportError: return None return lambda x: pkg_resources.get_distribution(x).version def get_package_version(pack_name, import_name=None): version_func = _import_version_function() if version_func is not None: try: return version_func(pack_name) except: pass import_name = pack_name if import_name is True else import_name if import_name: mod = __import__(import_name) return mod.__version__ return None def show_versions(): # pragma: no cover import locale import platform uname_result = platform.uname() language_code, encoding = locale.getlocale() results = { "python": ".".join([str(i) for i in sys.version_info]), "python-bits": struct.calcsize("P") * 8, "OS": uname_result.system, "OS-release": uname_result.release, "Version": uname_result.version, "machine": uname_result.machine, "processor": uname_result.processor, "byteorder": sys.byteorder, "LC_ALL": os.environ.get("LC_ALL"), "LANG": os.environ.get("LANG"), "LOCALE": {"language-code": language_code, "encoding": encoding}, } try: from .src import crc32c_c # noqa: F401 results["USE_CLIB"] = True except ImportError: results["USE_CLIB"] = False try: from . import internal # noqa: F401 results["HAS_INTERNAL"] = True except ImportError: results["HAS_INTERNAL"] = False packages = { "pyodps": "odps", "requests": "requests", "urllib3": "urllib3", "charset_normalizer": "charset_normalizer", "chardet": "chardet", "idna": "idna", "certifi": "certifi", "numpy": "numpy", "scipy": "scipy", "pandas": "pandas", "matplotlib": "matplotlib", "sqlalchemy": "sqlalchemy", "pytz": "pytz", "python-dateutil": "dateutil", "IPython": "IPython", "maxframe": "maxframe", } for pack_name, pack_imp in packages.items(): try: results[pack_name] = get_package_version(pack_name, pack_imp) except (ImportError, AttributeError): pass key_size = 1 + max(len(key) for key in results.keys()) for key, val in results.items(): print(key + " " * (key_size - len(key)) + ": " + str(val)) def get_supported_python_tag(align=None): if align is None: align = options.align_supported_python_tag if align: if sys.version_info[:2] >= (3, 11): return "cp311" elif sys.version_info[:2] >= (3, 6): return "cp37" else: return "cp27" else: return "cp" + str(sys.version_info[0]) + str(sys.version_info[1]) def chain_generator(*iterable): for itr in iterable: for item in itr: yield item