odps/df/expr/utils.py (161 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import import itertools import inspect import traceback import threading from collections import OrderedDict from datetime import datetime from decimal import Decimal from ...compat import six from ...models import FileResource, TableResource from .. import types try: from collections.abc import Iterable except ImportError: from collections import Iterable def add_method(expr, methods): for k, v in six.iteritems(methods): setattr(expr, k, v) def same_data_source(*exprs): curr_data_source = None for expr in exprs: data_source = sorted(list(expr.data_source())) if curr_data_source is None: curr_data_source = data_source else: if curr_data_source != data_source: return False return True def highest_precedence_data_type(*data_types): data_types = set(data_types) if len(data_types) == 1: return data_types.pop() precedences = dict((t, idx) for idx, t in enumerate( [types.string, types.boolean, types.int8, types.int16, types.int32, types.int64, types.datetime, types.decimal, types.float32, types.float64])) type_precedences = [(precedences[data_type], data_type) for data_type in data_types] highest_data_type = sorted(type_precedences)[-1][1] for data_type in data_types: if data_type != highest_data_type and not highest_data_type.can_implicit_cast(data_type): raise TypeError( 'Type cast error: %s cannot implicitly cast to %s' % (data_type, highest_data_type)) return highest_data_type def get_attrs(node): from .core import Node tp = type(node) if not inspect.isclass(node) else node if inspect.getmro(tp) is None: tp = type(tp) return tuple(OrderedDict.fromkeys( it for it in itertools.chain(*(cls.__slots__ for cls in inspect.getmro(tp) if issubclass(cls, Node))) if not it.startswith('__'))) def get_collection_resources(resources): from .expressions import CollectionExpr if resources: for res in resources: if not isinstance(res, (TableResource, FileResource, CollectionExpr)): raise ValueError('resources must be ODPS file or table Resources or collections') if resources is not None and len(resources) > 0: ret = [res for res in resources if isinstance(res, CollectionExpr)] [r.cache() for r in ret] # we should execute the expressions by setting cache=True return ret def get_executed_collection_project_table_name(collection): from .expressions import CollectionExpr from ...models import Table from ..backends.context import context if not isinstance(collection, CollectionExpr): return if collection._source_data is not None and \ isinstance(collection._source_data, Table): source_data = collection._source_data return source_data.project.name + '.' + source_data.name if context.is_cached(collection) and \ isinstance(context.get_cached(collection), Table): source_data = context.get_cached(collection) return source_data.project.name + '.' + source_data.name def is_called_by_inspector(): return any(1 for v in traceback.extract_stack() if 'oinspect' in v[0].lower() and 'ipython' in v[0].lower()) def to_list(field): if isinstance(field, six.string_types): return [field, ] if isinstance(field, Iterable): return list(field) return [field, ] _lock = threading.Lock() _index = itertools.count(1) def new_id(): with _lock: return next(_index) def select_fields(collection): from .expressions import ProjectCollectionExpr, Summary from .collections import DistinctCollectionExpr, RowAppliedCollectionExpr from .groupby import GroupByCollectionExpr, MutateCollectionExpr if isinstance(collection, (ProjectCollectionExpr, Summary)): return collection.fields elif isinstance(collection, DistinctCollectionExpr): return collection.unique_fields elif isinstance(collection, (GroupByCollectionExpr, MutateCollectionExpr)): return collection.fields elif isinstance(collection, RowAppliedCollectionExpr): return collection.fields def is_changed(collection, column): # if the column is changed before the generated collection from .expressions import CollectionExpr, Column column_name = column.source_name src_collection = column.input if src_collection is collection: return False dag = collection.to_dag(copy=False, validate=False) coll = src_collection colls = [] while coll is not collection: try: parents = [p for p in dag.successors(coll) if isinstance(p, CollectionExpr)] except KeyError: return assert len(parents) == 1 coll = parents[0] colls.append(coll) name = column_name for coll in colls: fields = select_fields(coll) if fields: col_names = dict((field.source_name, field) for field in fields if isinstance(field, Column)) if name in col_names: name = col_names[name].name else: return True return False annotation_rtypes = { int: types.int64, str: types.string, float: types.float64, bool: types.boolean, datetime: types.datetime, Decimal: types.decimal, } def get_annotation_rtype(func): if hasattr(func, '__annotations__'): try: from typing import Union except ImportError: Union = None ret_type = func.__annotations__.get('return') if ret_type in annotation_rtypes: return annotation_rtypes.get(ret_type) elif hasattr(ret_type, '__origin__') and ret_type.__origin__ is Union: actual_types = [typo for typo in ret_type.__args__ if typo is not type(None)] if len(actual_types) == 1: return annotation_rtypes.get(actual_types[0]) elif Union is not None and type(ret_type) is type(Union): actual_types = [typo for typo in ret_type.__args__ if typo is not type(None)] if len(actual_types) == 1: return annotation_rtypes.get(actual_types[0]) return None def get_proxied_expr(expr): try: obj = object.__getattribute__(expr, '_proxy') return obj if obj is not None else expr except AttributeError: return expr