core/maxframe/utils.py (800 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio.events
import concurrent.futures
import contextvars
import dataclasses
import datetime
import enum
import functools
import importlib
import inspect
import io
import itertools
import numbers
import os
import pkgutil
import random
import struct
import sys
import threading
import time
import tokenize as pytokenize
import types
import weakref
import zlib
from collections.abc import Hashable, Mapping
from contextlib import contextmanager
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import msgpack
import numpy as np
import pandas as pd
import traitlets
from tornado import httpclient, web
from tornado.simple_httpclient import HTTPTimeoutError
from ._utils import ( # noqa: F401 # pylint: disable=unused-import
NamedType,
Timer,
TypeDispatcher,
ceildiv,
get_user_call_point,
new_random_id,
register_tokenizer,
reset_id_random_seed,
to_binary,
to_str,
to_text,
tokenize,
tokenize_int,
)
from .lib.version import parse as parse_version
from .typing_ import TileableType, TimeoutType
# make flake8 happy by referencing these imports
NamedType = NamedType
TypeDispatcher = TypeDispatcher
tokenize = tokenize
register_tokenizer = register_tokenizer
ceildiv = ceildiv
reset_id_random_seed = reset_id_random_seed
new_random_id = new_random_id
get_user_call_point = get_user_call_point
_is_ci = (os.environ.get("CI") or "0").lower() in ("1", "true")
pd_release_version: Tuple[int] = parse_version(pd.__version__).release
try:
from pandas._libs import lib as _pd__libs_lib
from pandas._libs.lib import NoDefault, no_default
_raw__reduce__ = type(NoDefault).__reduce__
def _no_default__reduce__(self):
if self is not NoDefault:
return _raw__reduce__(self)
else: # pragma: no cover
return getattr, (_pd__libs_lib, "NoDefault")
if hasattr(_pd__libs_lib, "_NoDefault"): # pragma: no cover
# need to patch __reduce__ to make sure it can be properly unpickled
type(NoDefault).__reduce__ = _no_default__reduce__
else:
# introduced in pandas 1.5.0 : register for pickle compatibility
_pd__libs_lib._NoDefault = NoDefault
except ImportError: # pragma: no cover
class NoDefault(enum.Enum):
no_default = "NO_DEFAULT"
def __repr__(self) -> str:
return "<no_default>"
no_default = NoDefault.no_default
try:
# register for pickle compatibility
from pandas._libs import lib as _pd__libs_lib
_pd__libs_lib.NoDefault = NoDefault
except (ImportError, AttributeError):
pass
try:
import pyarrow as pa
except ImportError:
pa = None
try:
from pandas import ArrowDtype as PandasArrowDtype # noqa: F401
ARROW_DTYPE_NOT_SUPPORTED = False
except ImportError:
ARROW_DTYPE_NOT_SUPPORTED = True
class classproperty:
def __init__(self, f):
self.f = f
def __get__(self, obj, owner):
return self.f(owner)
def implements(f: Callable):
def decorator(g):
g.__doc__ = f.__doc__
return g
return decorator
class AttributeDict(dict):
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(f"'AttributeDict' object has no attribute {item}")
def on_serialize_shape(shape: Tuple[int]):
if shape:
return tuple(s if not np.isnan(s) else -1 for s in shape)
return shape
def on_deserialize_shape(shape: Tuple[int]):
if shape:
return tuple(s if s != -1 else np.nan for s in shape)
return shape
def on_serialize_numpy_type(value: np.dtype):
if value is pd.NaT:
value = None
return value.item() if isinstance(value, np.generic) else value
def on_serialize_nsplits(value: Tuple[Tuple[int]]):
if value is None:
return None
new_nsplits = []
for dim_splits in value:
new_nsplits.append(tuple(None if pd.isna(v) else v for v in dim_splits))
return tuple(new_nsplits)
def has_unknown_shape(*tiled_tileables: TileableType) -> bool:
for tileable in tiled_tileables:
if getattr(tileable, "shape", None) is None:
continue
if any(pd.isnull(s) for s in tileable.shape):
return True
if any(pd.isnull(s) for s in itertools.chain(*tileable.nsplits)):
return True
return False
def calc_nsplits(chunk_idx_to_shape: Dict[Tuple[int], Tuple[int]]) -> Tuple[Tuple[int]]:
"""
Calculate a tiled entity's nsplits.
Parameters
----------
chunk_idx_to_shape : Dict type, {chunk_idx: chunk_shape}
Returns
-------
nsplits
"""
ndim = len(next(iter(chunk_idx_to_shape)))
tileable_nsplits = []
# for each dimension, record chunk shape whose index is zero on other dimensions
for i in range(ndim):
splits = []
for index, shape in chunk_idx_to_shape.items():
if all(idx == 0 for j, idx in enumerate(index) if j != i):
splits.append(shape[i])
tileable_nsplits.append(tuple(splits))
return tuple(tileable_nsplits)
def copy_tileables(tileables: List[TileableType], **kwargs):
inputs = kwargs.pop("inputs", None)
copy_key = kwargs.pop("copy_key", True)
copy_id = kwargs.pop("copy_id", True)
if kwargs:
raise TypeError(f"got un unexpected keyword argument '{next(iter(kwargs))}'")
if len(tileables) > 1:
# cannot handle tileables with different operators here
# try to copy separately if so
if len({t.op for t in tileables}) != 1:
raise TypeError("All tileables' operators should be same.")
op = tileables[0].op.copy().reset_key()
if copy_key:
op._key = tileables[0].op.key
kws = []
for t in tileables:
params = t.params.copy()
if copy_key:
params["_key"] = t.key
if copy_id:
params["_id"] = t.id
params.update(t.extra_params)
kws.append(params)
inputs = inputs or op.inputs
return op.new_tileables(inputs, kws=kws, output_limit=len(kws))
def get_dtype(dtype: Union[np.dtype, pd.api.extensions.ExtensionDtype]):
if pd.api.types.is_extension_array_dtype(dtype):
return dtype
elif dtype is pd.Timestamp or dtype is datetime.datetime:
return np.dtype("datetime64[ns]")
elif dtype is pd.Timedelta or dtype is datetime.timedelta:
return np.dtype("timedelta64[ns]")
else:
return np.dtype(dtype)
def serialize_serializable(serializable, compress: bool = False):
from .serialization import serialize
bio = io.BytesIO()
header, buffers = serialize(serializable)
buf_sizes = [getattr(buf, "nbytes", len(buf)) for buf in buffers]
header[0]["buf_sizes"] = buf_sizes
s_header = msgpack.dumps(header)
bio.write(struct.pack("<Q", len(s_header)))
bio.write(s_header)
for buf in buffers:
bio.write(buf)
ser_graph = bio.getvalue()
if compress:
ser_graph = zlib.compress(ser_graph)
return ser_graph
def deserialize_serializable(ser_serializable: bytes):
from .serialization import deserialize
bio = io.BytesIO(ser_serializable)
s_header_length = struct.unpack("Q", bio.read(8))[0]
header2 = msgpack.loads(bio.read(s_header_length))
buffers2 = [bio.read(s) for s in header2[0]["buf_sizes"]]
return deserialize(header2, buffers2)
def skip_na_call(func: Callable):
@functools.wraps(func)
def new_func(x):
return func(x) if x is not None else None
return new_func
def url_path_join(*pieces):
"""Join components of url into a relative url
Use to prevent double slash when joining subpath. This will leave the
initial and final / in place
"""
initial = pieces[0].startswith("/")
final = pieces[-1].endswith("/")
stripped = [s.strip("/") for s in pieces]
result = "/".join(s for s in stripped if s)
if initial:
result = "/" + result
if final:
result = result + "/"
if result == "//":
result = "/"
return result
def random_ports(port: int, n: int):
"""Generate a list of n random ports near the given port.
The first 5 ports will be sequential, and the remaining n-5 will be
randomly selected in the range [port-2*n, port+2*n].
"""
for i in range(min(5, n)):
yield port + i
for i in range(n - 5):
yield max(1, port + random.randint(-2 * n, 2 * n))
def build_temp_table_name(session_id: str, tileable_key: str) -> str:
return f"tmp_mf_{session_id}_{tileable_key}"
def build_temp_intermediate_table_name(session_id: str, tileable_key: str) -> str:
temp_table = build_temp_table_name(session_id, tileable_key)
return f"{temp_table}_intermediate"
def build_session_volume_name(session_id: str) -> str:
return f"mf_vol_{session_id.replace('-', '_')}"
async def wait_http_response(
url: str, *, request_timeout: TimeoutType = None, **kwargs
) -> httpclient.HTTPResponse:
start_time = time.time()
while request_timeout is None or time.time() - start_time < request_timeout:
timeout_left = min(10.0, time.time() - start_time) if request_timeout else None
try:
return await httpclient.AsyncHTTPClient().fetch(
url, request_timeout=timeout_left, **kwargs
)
except HTTPTimeoutError:
pass
raise TimeoutError
def get_handler_timeout_value(handler: web.RequestHandler) -> TimeoutType:
wait = bool(int(handler.get_argument("wait", "0")))
timeout = float(handler.get_argument("timeout", "0"))
if wait and abs(timeout) < 1e-6:
timeout = None
elif not wait:
timeout = 0
return timeout
def format_timeout_params(timeout: TimeoutType) -> str:
if timeout is None:
return "?wait=1"
elif abs(timeout) < 1e-6:
return "?wait=0"
else:
return f"?wait=1&timeout={timeout}"
_PrimitiveType = TypeVar("_PrimitiveType")
def create_sync_primitive(
cls: Type[_PrimitiveType], loop: asyncio.AbstractEventLoop
) -> _PrimitiveType:
"""
Create an asyncio sync primitive (locks, events, etc.)
in a certain event loop.
"""
if sys.version_info[1] < 10:
return cls(loop=loop)
# From Python3.10 the loop parameter has been removed. We should work around here.
old_loop = asyncio.get_event_loop()
try:
asyncio.set_event_loop(loop)
primitive = cls()
finally:
asyncio.set_event_loop(old_loop)
return primitive
class ToThreadCancelledError(asyncio.CancelledError):
def __init__(self, *args, result=None):
super().__init__(*args)
self._result = result
@property
def result(self):
return self._result
_ToThreadRetType = TypeVar("_ToThreadRetType")
class ToThreadMixin:
_thread_pool_size = 1
_counter = itertools.count().__next__
def __del__(self):
if hasattr(self, "_pool"):
kw = {"wait": False}
if sys.version_info[:2] >= (3, 9):
kw["cancel_futures"] = True
self._pool.shutdown(**kw)
async def to_thread(
self,
func: Callable[..., _ToThreadRetType],
*args,
wait_on_cancel: bool = False,
timeout: float = None,
**kwargs,
) -> _ToThreadRetType:
if not hasattr(self, "_pool"):
self._pool = concurrent.futures.ThreadPoolExecutor(
self._thread_pool_size,
thread_name_prefix=f"{type(self).__name__}Pool-{self._counter()}",
)
loop = asyncio.events.get_running_loop()
ctx = contextvars.copy_context()
func_call = functools.partial(ctx.run, func, *args, **kwargs)
fut = loop.run_in_executor(self._pool, func_call)
try:
coro = fut
if wait_on_cancel:
coro = asyncio.shield(coro)
if timeout is not None:
coro = asyncio.wait_for(coro, timeout)
return await coro
except (asyncio.CancelledError, asyncio.TimeoutError) as ex:
if not wait_on_cancel:
raise
result = await fut
raise ToThreadCancelledError(*ex.args, result=result)
def ensure_async_call(
self,
func: Callable[..., _ToThreadRetType],
*args,
wait_on_cancel: bool = False,
**kwargs,
) -> Awaitable[_ToThreadRetType]:
if asyncio.iscoroutinefunction(func):
return func(*args, **kwargs)
return self.to_thread(func, *args, wait_on_cancel=wait_on_cancel, **kwargs)
def config_odps_default_options():
from odps import options as odps_options
odps_options.sql.settings = {
"odps.longtime.instance": "false",
"odps.sql.session.select.only": "false",
"metaservice.client.cache.enable": "false",
"odps.sql.session.result.cache.enable": "false",
"odps.sql.submit.mode": "script",
"odps.sql.job.max.time.hours": 72,
}
def to_hashable(obj: Any) -> Hashable:
if isinstance(obj, Mapping):
items = type(obj)((k, to_hashable(v)) for k, v in obj.items())
elif not isinstance(obj, str) and isinstance(obj, Iterable):
items = tuple(to_hashable(item) for item in obj)
elif isinstance(obj, Hashable):
items = obj
else:
raise TypeError(type(obj))
return items
def estimate_pandas_size(
pd_obj, max_samples: int = 10, min_sample_rows: int = 100
) -> int:
if len(pd_obj) <= min_sample_rows or isinstance(pd_obj, pd.RangeIndex):
return sys.getsizeof(pd_obj)
if isinstance(pd_obj, pd.MultiIndex):
# MultiIndex's sample size can't be used to estimate
return sys.getsizeof(pd_obj)
from .dataframe.arrays import ArrowDtype
def _is_fast_dtype(dtype):
if isinstance(dtype, np.dtype):
return np.issubdtype(dtype, np.number)
else:
return isinstance(dtype, ArrowDtype)
dtypes = []
is_series = False
if isinstance(pd_obj, pd.DataFrame):
dtypes.extend(pd_obj.dtypes)
index_obj = pd_obj.index
elif isinstance(pd_obj, pd.Series):
dtypes.append(pd_obj.dtype)
index_obj = pd_obj.index
is_series = True
else:
index_obj = pd_obj
# handling possible MultiIndex
if hasattr(index_obj, "dtypes"):
dtypes.extend(index_obj.dtypes)
else:
dtypes.append(index_obj.dtype)
if all(_is_fast_dtype(dtype) for dtype in dtypes):
return sys.getsizeof(pd_obj)
indices = np.sort(np.random.choice(len(pd_obj), size=max_samples, replace=False))
iloc = pd_obj if isinstance(pd_obj, pd.Index) else pd_obj.iloc
if isinstance(index_obj, pd.MultiIndex):
# MultiIndex's sample size is much greater than expected, thus we calculate
# the size separately.
index_size = sys.getsizeof(pd_obj.index)
if is_series:
sample_frame_size = iloc[indices].memory_usage(deep=True, index=False)
else:
sample_frame_size = iloc[indices].memory_usage(deep=True, index=False).sum()
return index_size + sample_frame_size * len(pd_obj) // max_samples
else:
sample_size = sys.getsizeof(iloc[indices])
return sample_size * len(pd_obj) // max_samples
class ModulePlaceholder:
def __init__(self, mod_name: str):
self._mod_name = mod_name
def _raises(self):
raise AttributeError(f"{self._mod_name} is required but not installed.")
def __getattr__(self, key):
self._raises()
def __call__(self, *_args, **_kwargs):
self._raises()
def lazy_import(
name: str,
package: str = None,
globals: Dict = None, # pylint: disable=redefined-builtin
locals: Dict = None, # pylint: disable=redefined-builtin
rename: str = None,
placeholder: bool = False,
):
rename = rename or name
prefix_name = name.split(".", 1)[0]
globals = globals or inspect.currentframe().f_back.f_globals
class LazyModule(object):
def __init__(self):
self._on_loads = []
def __getattr__(self, item):
if item.startswith("_pytest") or item in ("__bases__", "__test__"):
raise AttributeError(item)
real_mod = importlib.import_module(name, package=package)
if rename in globals:
globals[rename] = real_mod
elif locals is not None:
locals[rename] = real_mod
ret = getattr(real_mod, item)
for on_load_func in self._on_loads:
on_load_func()
# make sure on_load hooks only executed once
self._on_loads = []
return ret
def add_load_handler(self, func: Callable):
self._on_loads.append(func)
return func
if pkgutil.find_loader(prefix_name) is not None:
return LazyModule()
elif placeholder:
return ModulePlaceholder(prefix_name)
else:
return None
def sbytes(x: Any) -> bytes:
# NB: bytes() in Python 3 has different semantic with Python 2, see: help(bytes)
from numbers import Number
if x is None or isinstance(x, Number):
return bytes(str(x), encoding="ascii")
elif isinstance(x, list):
return bytes("[" + ", ".join([str(k) for k in x]) + "]", encoding="utf-8")
elif isinstance(x, tuple):
return bytes("(" + ", ".join([str(k) for k in x]) + ")", encoding="utf-8")
elif isinstance(x, str):
return bytes(x, encoding="utf-8")
else:
return bytes(x)
def is_full_slice(slc: Any) -> bool:
"""Check if the input is a full slice ((:) or (0:))"""
return (
isinstance(slc, slice)
and (slc.start == 0 or slc.start is None)
and slc.stop is None
and slc.step is None
)
_enter_counter = 0
_initial_session = None
def enter_current_session(func: Callable):
@functools.wraps(func)
def wrapped(cls, ctx, op):
from .session import AbstractSession, get_default_session
global _enter_counter, _initial_session
# skip in some test cases
if not hasattr(ctx, "get_current_session"):
return func(cls, ctx, op)
with AbstractSession._lock:
if _enter_counter == 0:
# to handle nested call, only set initial session
# in first call
session = ctx.get_current_session()
_initial_session = get_default_session()
session.as_default()
_enter_counter += 1
try:
result = func(cls, ctx, op)
finally:
with AbstractSession._lock:
_enter_counter -= 1
if _enter_counter == 0:
# set previous session when counter is 0
if _initial_session:
_initial_session.as_default()
else:
AbstractSession.reset_default()
return result
return wrapped
_func_token_cache = weakref.WeakKeyDictionary()
def _get_func_token_values(func):
if hasattr(func, "__code__"):
tokens = [func.__code__.co_code]
if func.__closure__ is not None:
cvars = tuple([x.cell_contents for x in func.__closure__])
tokens.append(cvars)
return tokens
else:
tokens = []
while isinstance(func, functools.partial):
tokens.extend([func.args, func.keywords])
func = func.func
if hasattr(func, "__code__"):
tokens.extend(_get_func_token_values(func))
elif isinstance(func, types.BuiltinFunctionType):
tokens.extend([func.__module__, func.__qualname__])
else:
tokens.append(func)
return tokens
def get_func_token(func):
try:
token = _func_token_cache.get(func)
if token is None:
fields = _get_func_token_values(func)
token = tokenize(*fields)
_func_token_cache[func] = token
return token
except TypeError: # cannot create weak reference to func like 'numpy.ufunc'
return tokenize(*_get_func_token_values(func))
_io_quiet_local = threading.local()
_io_quiet_lock = threading.Lock()
class _QuietIOWrapper:
def __init__(self, wrapped):
self.wrapped = wrapped
def __getattr__(self, item):
return getattr(self.wrapped, item)
def write(self, d):
if getattr(_io_quiet_local, "is_wrapped", False):
return 0
return self.wrapped.write(d)
@contextmanager
def quiet_stdio():
"""Quiets standard outputs when inferring types of functions"""
with _io_quiet_lock:
_io_quiet_local.is_wrapped = True
sys.stdout = _QuietIOWrapper(sys.stdout)
sys.stderr = _QuietIOWrapper(sys.stderr)
try:
yield
finally:
with _io_quiet_lock:
sys.stdout = sys.stdout.wrapped
sys.stderr = sys.stderr.wrapped
if not isinstance(sys.stdout, _QuietIOWrapper):
_io_quiet_local.is_wrapped = False
# from https://github.com/ericvsmith/dataclasses/blob/master/dataclass_tools.py
# released under Apache License 2.0
def dataslots(cls):
# Need to create a new class, since we can't set __slots__
# after a class has been created.
# Make sure __slots__ isn't already set.
if "__slots__" in cls.__dict__: # pragma: no cover
raise TypeError(f"{cls.__name__} already specifies __slots__")
# Create a new dict for our new class.
cls_dict = dict(cls.__dict__)
field_names = tuple(f.name for f in dataclasses.fields(cls))
cls_dict["__slots__"] = field_names
for field_name in field_names:
# Remove our attributes, if present. They'll still be
# available in _MARKER.
cls_dict.pop(field_name, None)
# Remove __dict__ itself.
cls_dict.pop("__dict__", None)
# And finally create the class.
qualname = getattr(cls, "__qualname__", None)
cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
if qualname is not None:
cls.__qualname__ = qualname
return cls
def adapt_docstring(doc: str) -> str:
"""
Adapt numpy-style docstrings to MaxFrame docstring.
This util function will add MaxFrame imports, replace object references
and add execute calls. Note that check is needed after replacement.
"""
if doc is None:
return None
lines = []
first_prompt = True
prev_prompt = False
has_numpy = "np." in doc
has_pandas = "pd." in doc
for line in doc.splitlines():
sp = line.strip()
if sp.startswith(">>>") or sp.startswith("..."):
prev_prompt = True
if first_prompt:
first_prompt = False
indent = "".join(itertools.takewhile(lambda x: x in (" ", "\t"), line))
if has_numpy:
lines.extend([indent + ">>> import maxframe.tensor as mt"])
if has_pandas:
lines.extend([indent + ">>> import maxframe.dataframe as md"])
line = line.replace("np.", "mt.").replace("pd.", "md.")
elif prev_prompt:
prev_prompt = False
if sp:
lines[-1] += ".execute()"
lines.append(line)
return "\n".join(lines)
def stringify_path(path: Union[str, os.PathLike]) -> str:
"""
Convert *path* to a string or unicode path if possible.
"""
if isinstance(path, str):
return path
# checking whether path implements the filesystem protocol
try:
return path.__fspath__()
except AttributeError:
raise TypeError("not a path-like object")
_memory_size_indices = {"": 0, "k": 1, "m": 2, "g": 3, "t": 4}
def parse_readable_size(value: Union[str, int, float]) -> Tuple[float, bool]:
if isinstance(value, numbers.Number):
return float(value), False
value = value.strip().lower()
num_pos = 0
while num_pos < len(value) and value[num_pos] in "0123456789.-":
num_pos += 1
value, suffix = value[:num_pos], value[num_pos:]
suffix = suffix.strip()
if suffix.endswith("%"):
return float(value) / 100, True
try:
return float(value) * (1024 ** _memory_size_indices[suffix[:1]]), False
except (ValueError, KeyError):
raise ValueError(f"Unknown limitation value: {value}")
def remove_suffix(value: str, suffix: str) -> Tuple[str, bool]:
"""
Remove a suffix from a given string if it exists.
Parameters
----------
value : str
The original string.
suffix : str
The suffix to be removed.
Returns
-------
Tuple[str, bool]
A tuple containing the modified string and a boolean indicating whether the suffix was found.
"""
# Check if the suffix is an empty string
if len(suffix) == 0:
# If the suffix is empty, return the original string with True
return value, True
# Check if the length of the value is less than the length of the suffix
if len(value) < len(suffix):
# If the value is shorter than the suffix, it cannot have the suffix
return value, False
# Check if the suffix matches the end of the value
match = value.endswith(suffix)
# If the suffix is found, remove it; otherwise, return the original string
if match:
return value[: -len(suffix)], match
else:
return value, match
def find_objects(nested: Union[List, Dict], types: Union[Type, Tuple[Type]]) -> List:
found = []
stack = [nested]
while len(stack) > 0:
it = stack.pop()
if isinstance(it, types):
found.append(it)
continue
if isinstance(it, (list, tuple, set)):
stack.extend(list(it)[::-1])
elif isinstance(it, dict):
stack.extend(list(it.values())[::-1])
return found
def replace_objects(nested: Union[List, Dict], mapping: Mapping) -> Union[List, Dict]:
if not mapping:
return nested
if isinstance(nested, dict):
vals = list(nested.values())
else:
vals = list(nested)
new_vals = []
for val in vals:
if isinstance(val, (dict, list, tuple, set)):
new_val = replace_objects(val, mapping)
else:
try:
new_val = mapping.get(val, val)
except TypeError:
new_val = val
new_vals.append(new_val)
if isinstance(nested, dict):
return type(nested)((k, v) for k, v in zip(nested.keys(), new_vals))
else:
return type(nested)(new_vals)
def trait_from_env(
trait_name: str, env: str, trait: Optional[traitlets.TraitType] = None
):
if trait is None:
prev_locals = inspect.stack()[1].frame.f_locals
trait = prev_locals[trait_name]
default_value = trait.default_value
sub_trait: traitlets.TraitType = getattr(trait, "_trait", None)
def default_value_simple(self):
env_val = os.getenv(env, default_value)
if isinstance(env_val, (str, bytes)):
return trait.from_string(env_val)
return env_val
def default_value_list(self):
env_val = os.getenv(env, default_value)
if env_val is None or isinstance(env_val, traitlets.Sentinel):
return env_val
parts = env_val.split(",") if env_val else []
if sub_trait:
return [sub_trait.from_string(s) for s in parts]
else:
return parts
if isinstance(trait, traitlets.List):
default_value_fun = default_value_list
else: # pragma: no cover
default_value_fun = default_value_simple
default_value_fun.__name__ = trait_name + "_default"
return traitlets.default(trait_name)(default_value_fun)
def relay_future(
dest: Union[asyncio.Future, concurrent.futures.Future],
src: Union[asyncio.Future, concurrent.futures.Future],
) -> None:
def cb(fut: Union[asyncio.Future, concurrent.futures.Future]):
try:
dest.set_result(fut.result())
except BaseException as ex:
dest.set_exception(ex)
src.add_done_callback(cb)
_arrow_type_constructors = {}
if pa:
_arrow_type_constructors = {
"bool": pa.bool_,
"list": lambda x: pa.list_(dict(x)["item"]),
"map": lambda x: pa.map_(*x),
"struct": pa.struct,
"fixed_size_binary": pa.binary,
"halffloat": pa.float16,
"float": pa.float32,
"double": pa.float64,
"decimal": pa.decimal128,
}
_plain_arrow_types = """
null
int8 int16 int32 int64
uint8 uint16 uint32 uint64
float16 float32 float64
date32 date64
decimal128 decimal256
string utf8 binary
time32 time64 duration timestamp
month_day_nano_interval
"""
for _type_name in _plain_arrow_types.split():
try:
_arrow_type_constructors[_type_name] = getattr(pa, _type_name)
except AttributeError: # pragma: no cover
pass
def arrow_type_from_str(type_str: str) -> pa.DataType:
"""
Convert arrow type representations (for inst., list<item: int64>)
into arrow DataType instances
"""
# enable consecutive brackets to be tokenized
type_str = type_str.replace("<", "< ").replace(">", " >")
token_iter = pytokenize.tokenize(io.BytesIO(type_str.encode()).readline)
value_stack, op_stack = [], []
def _pop_make_type(with_args: bool = False, combined: bool = True) -> None:
"""
Pops tops of value stacks, creates a DataType instance and push back
Parameters
----------
with_args: bool
if True, will contain next item (parameter list) in
the value stack as parameters
combined: bool
if True, will use first element of the top of the value stack
in DataType constructors
"""
args = () if not with_args else (value_stack.pop(-1),)
if not combined:
args = args[0]
type_name = value_stack.pop(-1)
if isinstance(type_name, pa.DataType):
value_stack.append(type_name)
elif type_name in _arrow_type_constructors:
value_stack.append(_arrow_type_constructors[type_name](*args))
else: # pragma: no cover
value_stack.append(type_name)
for token in token_iter:
if token.type == pytokenize.OP:
if token.string == ":":
op_stack.append(token.string)
elif token.string == ",":
# gather previous sub-types
if op_stack[-1] in ("<", ":"):
_pop_make_type()
if op_stack[-1] == ":":
# parameterized sub-types need to be represented as tuples
op_stack.pop(-1)
values = value_stack[-2:]
value_stack = value_stack[:-2]
value_stack.append(tuple(values))
# put generated item into the parameter list
val = value_stack.pop(-1)
value_stack[-1].append(val)
elif token.string in ("<", "[", "("):
# pushes an empty parameter list for future use
value_stack.append([])
op_stack.append(token.string)
elif token.string in (")", "]"):
# put generated item into the parameter list
val = value_stack.pop(-1)
value_stack[-1].append(val)
# make DataType (i.e., fixed_size_binary / decimal) given args
_pop_make_type(with_args=True, combined=False)
op_stack.pop(-1)
elif token.string == ">":
_pop_make_type()
if op_stack[-1] == ":":
# parameterized sub-types need to be represented as tuples
op_stack.pop(-1)
values = value_stack[-2:]
value_stack = value_stack[:-2]
value_stack.append(tuple(values))
# put generated item into the parameter list
val = value_stack.pop(-1)
value_stack[-1].append(val)
# make DataType (i.e., list / map / struct) given args
_pop_make_type(True)
op_stack.pop(-1)
elif token.type == pytokenize.NAME:
value_stack.append(token.string)
elif token.type == pytokenize.NUMBER:
value_stack.append(int(token.string))
elif token.type == pytokenize.ENDMARKER:
# make final type
_pop_make_type()
if len(value_stack) > 1:
raise ValueError(f"Cannot parse type {type_str}")
return value_stack[-1]
def get_python_tag():
# todo add implementation suffix for non-GIL tags when PEP703 is ready
version_info = sys.version_info
return f"cp{version_info[0]}{version_info[1]}"
def get_item_if_scalar(val: Any) -> Any:
if isinstance(val, np.ndarray) and val.shape == ():
return val.item()
return val
def collect_leaf_operators(root) -> List[Type]:
result = []
def _collect(op_type):
if len(op_type.__subclasses__()) == 0:
result.append(op_type)
for subclass in op_type.__subclasses__():
_collect(subclass)
_collect(root)
return result
@contextmanager
def sync_pyodps_options():
from odps.config import option_context as pyodps_option_context
from .config import options
with pyodps_option_context() as cfg:
cfg.local_timezone = options.local_timezone
if options.session.enable_schema:
cfg.enable_schema = options.session.enable_schema
yield
def str_to_bool(s: Optional[str]) -> Optional[bool]:
return s.lower().strip() in ("true", "1") if s is not None else None
def is_empty(val):
if isinstance(val, (pd.DataFrame, pd.Series, pd.Index)):
return val.empty
return not bool(val)
def get_default_table_properties():
return {"storagestrategy": "archive"}