odps/df/expr/utils.py (161 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import itertools
import inspect
import traceback
import threading
from collections import OrderedDict
from datetime import datetime
from decimal import Decimal
from ...compat import six
from ...models import FileResource, TableResource
from .. import types
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
def add_method(expr, methods):
for k, v in six.iteritems(methods):
setattr(expr, k, v)
def same_data_source(*exprs):
curr_data_source = None
for expr in exprs:
data_source = sorted(list(expr.data_source()))
if curr_data_source is None:
curr_data_source = data_source
else:
if curr_data_source != data_source:
return False
return True
def highest_precedence_data_type(*data_types):
data_types = set(data_types)
if len(data_types) == 1:
return data_types.pop()
precedences = dict((t, idx) for idx, t in enumerate(
[types.string, types.boolean, types.int8, types.int16, types.int32, types.int64,
types.datetime, types.decimal, types.float32, types.float64]))
type_precedences = [(precedences[data_type], data_type) for data_type in data_types]
highest_data_type = sorted(type_precedences)[-1][1]
for data_type in data_types:
if data_type != highest_data_type and not highest_data_type.can_implicit_cast(data_type):
raise TypeError(
'Type cast error: %s cannot implicitly cast to %s' % (data_type, highest_data_type))
return highest_data_type
def get_attrs(node):
from .core import Node
tp = type(node) if not inspect.isclass(node) else node
if inspect.getmro(tp) is None:
tp = type(tp)
return tuple(OrderedDict.fromkeys(
it for it in
itertools.chain(*(cls.__slots__ for cls in inspect.getmro(tp) if issubclass(cls, Node)))
if not it.startswith('__')))
def get_collection_resources(resources):
from .expressions import CollectionExpr
if resources:
for res in resources:
if not isinstance(res, (TableResource, FileResource, CollectionExpr)):
raise ValueError('resources must be ODPS file or table Resources or collections')
if resources is not None and len(resources) > 0:
ret = [res for res in resources if isinstance(res, CollectionExpr)]
[r.cache() for r in ret] # we should execute the expressions by setting cache=True
return ret
def get_executed_collection_project_table_name(collection):
from .expressions import CollectionExpr
from ...models import Table
from ..backends.context import context
if not isinstance(collection, CollectionExpr):
return
if collection._source_data is not None and \
isinstance(collection._source_data, Table):
source_data = collection._source_data
return source_data.project.name + '.' + source_data.name
if context.is_cached(collection) and \
isinstance(context.get_cached(collection), Table):
source_data = context.get_cached(collection)
return source_data.project.name + '.' + source_data.name
def is_called_by_inspector():
return any(1 for v in traceback.extract_stack() if 'oinspect' in v[0].lower() and 'ipython' in v[0].lower())
def to_list(field):
if isinstance(field, six.string_types):
return [field, ]
if isinstance(field, Iterable):
return list(field)
return [field, ]
_lock = threading.Lock()
_index = itertools.count(1)
def new_id():
with _lock:
return next(_index)
def select_fields(collection):
from .expressions import ProjectCollectionExpr, Summary
from .collections import DistinctCollectionExpr, RowAppliedCollectionExpr
from .groupby import GroupByCollectionExpr, MutateCollectionExpr
if isinstance(collection, (ProjectCollectionExpr, Summary)):
return collection.fields
elif isinstance(collection, DistinctCollectionExpr):
return collection.unique_fields
elif isinstance(collection, (GroupByCollectionExpr, MutateCollectionExpr)):
return collection.fields
elif isinstance(collection, RowAppliedCollectionExpr):
return collection.fields
def is_changed(collection, column):
# if the column is changed before the generated collection
from .expressions import CollectionExpr, Column
column_name = column.source_name
src_collection = column.input
if src_collection is collection:
return False
dag = collection.to_dag(copy=False, validate=False)
coll = src_collection
colls = []
while coll is not collection:
try:
parents = [p for p in dag.successors(coll) if isinstance(p, CollectionExpr)]
except KeyError:
return
assert len(parents) == 1
coll = parents[0]
colls.append(coll)
name = column_name
for coll in colls:
fields = select_fields(coll)
if fields:
col_names = dict((field.source_name, field) for field in fields if isinstance(field, Column))
if name in col_names:
name = col_names[name].name
else:
return True
return False
annotation_rtypes = {
int: types.int64,
str: types.string,
float: types.float64,
bool: types.boolean,
datetime: types.datetime,
Decimal: types.decimal,
}
def get_annotation_rtype(func):
if hasattr(func, '__annotations__'):
try:
from typing import Union
except ImportError:
Union = None
ret_type = func.__annotations__.get('return')
if ret_type in annotation_rtypes:
return annotation_rtypes.get(ret_type)
elif hasattr(ret_type, '__origin__') and ret_type.__origin__ is Union:
actual_types = [typo for typo in ret_type.__args__
if typo is not type(None)]
if len(actual_types) == 1:
return annotation_rtypes.get(actual_types[0])
elif Union is not None and type(ret_type) is type(Union):
actual_types = [typo for typo in ret_type.__args__
if typo is not type(None)]
if len(actual_types) == 1:
return annotation_rtypes.get(actual_types[0])
return None
def get_proxied_expr(expr):
try:
obj = object.__getattribute__(expr, '_proxy')
return obj if obj is not None else expr
except AttributeError:
return expr