odps/df/expr/composites.py (205 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from .expressions import Scalar, Expr, TableSchema, SequenceExpr, ListSequenceExpr, DictSequenceExpr, \ ListScalar, DictScalar, Column from .element import AnyOp, ElementWise from .collections import RowAppliedCollectionExpr from .. import types as df_types from ..utils import to_collection from . import utils def _scalar(val, tp=None): if val is None: return if isinstance(val, Expr): return val if isinstance(val, (tuple, list)): return type(val)(_scalar(it, tp=tp) for it in val) else: return Scalar(_value=val, _value_type=tp) def explode(expr, *args, **kwargs): """ Expand list or dict data into multiple rows :param expr: list / dict sequence / scalar :return: """ if not isinstance(expr, Column): expr = to_collection(expr)[expr.name] if isinstance(expr, SequenceExpr): dtype = expr.data_type else: dtype = expr.value_type func_name = 'EXPLODE' if args and isinstance(args[0], (list, tuple, set)): names = list(args[0]) else: names = args pos = kwargs.get('pos', False) if isinstance(expr, ListSequenceExpr): if pos: func_name = 'POSEXPLODE' typos = [df_types.int64, dtype.value_type] if not names: names = [expr.name + '_pos', expr.name] if len(names) == 1: names = [names[0] + '_pos', names[0]] if len(names) != 2: raise ValueError("The length of parameter 'names' should be exactly 1.") else: typos = [dtype.value_type] if not names: names = [expr.name] if len(names) != 1: raise ValueError("The length of parameter 'names' should be exactly 1.") elif isinstance(expr, DictSequenceExpr): if pos: raise ValueError('Cannot support explosion with pos on dicts.') typos = [dtype.key_type, dtype.value_type] if not names: names = [expr.name + '_key', expr.name + '_value'] if len(names) != 2: raise ValueError("The length of parameter 'names' should be exactly 2.") else: raise ValueError('Cannot explode expression with type %s' % type(expr).__name__) schema = TableSchema.from_lists(names, typos) return RowAppliedCollectionExpr(_input=expr.input, _func=func_name, _schema=schema, _fields=[expr], _keep_nulls=kwargs.get('keep_nulls', False)) def composite_op(expr, output_expr_cls, output_type=None, **kwargs): input_args = kwargs.copy() input_args['_input'] = expr def check_is_sequence(arg): a = input_args.get(arg) if isinstance(a, SequenceExpr): return True elif isinstance(a, (list, tuple)): return any(isinstance(la, SequenceExpr) for la in a) else: return False is_sequence = any(check_is_sequence(a) for a in output_expr_cls._args) if is_sequence: output_type = output_type or expr.data_type return output_expr_cls(_data_type=output_type, _input=expr, **kwargs) else: output_type = output_type or expr.value_type return output_expr_cls(_value_type=output_type, _input=expr, **kwargs) class CompositeOp(ElementWise): def accept(self, visitor): visitor.visit_composite_op(self) class CompositeBuilderOp(AnyOp): @property def node_name(self): return self.__class__.__name__ def accept(self, visitor): visitor.visit_composite_op(self) class ListDictLength(CompositeOp): __slots__ = () class ListDictGetItem(CompositeOp): _args = '_input', '_key', '_negative_handled' class ListContains(CompositeOp): _args = '_input', '_value', class ListSort(CompositeOp): __slots__ = () class DictKeys(CompositeOp): __slots__ = () class DictValues(CompositeOp): __slots__ = () def _scan_inputs(seq, dtype=None): if not seq: raise TypeError('Inputs should not be empty') seq = [_scalar(a) for a in seq] arg_types = set() arg_cat = set() for a in seq: if isinstance(a, SequenceExpr): arg_types.add(a.data_type) arg_cat.add(SequenceExpr) else: arg_types.add(a.value_type) arg_cat.add(Scalar) if dtype is not None: if not all(dtype.can_implicit_cast(t) for t in arg_types): raise TypeError('Not all given value can be implicitly casted') else: if len(arg_types) == 1: dtype = arg_types.pop() if isinstance(dtype, df_types.Integer): if dtype != df_types.int64 and df_types.int32.can_implicit_cast(dtype): dtype = df_types.int32 elif isinstance(dtype, df_types.Float): dtype = df_types.float64 else: if all(df_types.int32.can_implicit_cast(t) for t in arg_types): dtype = df_types.int32 elif all(df_types.int64.can_implicit_cast(t) for t in arg_types): dtype = df_types.int64 elif all(df_types.float64.can_implicit_cast(t) for t in arg_types): dtype = df_types.float64 else: raise TypeError('Types of inputs should be the same') if SequenceExpr in arg_cat: return seq, '_data_type', dtype else: return seq, '_value_type', dtype class ListBuilder(CompositeBuilderOp): _args = '_values', class DictBuilder(CompositeBuilderOp): _args = '_keys', '_values' def _len(expr): """ Retrieve length of a list or dict sequence / scalar. :param expr: list or dict sequence / scalar :return: """ return composite_op(expr, ListDictLength, df_types.int64) def _getitem(expr, key): if isinstance(expr, SequenceExpr): dtype = expr.data_type.value_type else: dtype = expr.value_type.value_type return composite_op(expr, ListDictGetItem, dtype, _key=_scalar(key)) def _sort(expr): """ Retrieve sorted list :param expr: list sequence / scalar :return: """ return composite_op(expr, ListSort) def _contains(expr, value): """ Check whether certain value is in the inspected list :param expr: list sequence / scalar :param value: value to inspect :return: """ return composite_op(expr, ListContains, df_types.boolean, _value=_scalar(value)) def _keys(expr): """ Retrieve keys of a dict :param expr: dict sequence / scalar :return: """ if isinstance(expr, SequenceExpr): dtype = expr.data_type else: dtype = expr.value_type return composite_op(expr, DictKeys, df_types.List(dtype.key_type)) def _values(expr): """ Retrieve values of a dict :param expr: dict sequence / scalar :return: """ if isinstance(expr, SequenceExpr): dtype = expr.data_type else: dtype = expr.value_type return composite_op(expr, DictValues, df_types.List(dtype.value_type)) def make_list(*args, **kwargs): dtype = kwargs.get('type') if dtype is not None: dtype = df_types.validate_data_type(dtype) kwargs = dict() kwargs['_values'], k, typ = _scan_inputs(args, dtype) kwargs[k] = df_types.List(typ) return ListBuilder(**kwargs) def make_dict(*args, **kwargs): if len(args) % 2 != 0: raise ValueError('Num of inputs to build a dict should be even') key_type = kwargs.get('key_type') if key_type is not None: key_type = df_types.validate_data_type(key_type) value_type = kwargs.get('value_type') if value_type is not None: value_type = df_types.validate_data_type(value_type) kwargs = dict() keys = list(args[0::2]) values = list(args[1::2]) kwargs['_keys'], k1, key_type = _scan_inputs(keys, key_type) kwargs['_values'], k2, value_type = _scan_inputs(values, value_type) k = '_data_type' if '_data_type' in (k1, k2) else '_value_type' kwargs[k] = df_types.Dict(key_type, value_type) return DictBuilder(**kwargs) _list_methods = dict( __getitem__=_getitem, len=_len, sort=_sort, contains=_contains, explode=explode, ) _dict_methods = dict( __getitem__=_getitem, len=_len, keys=_keys, values=_values, explode=explode, ) utils.add_method(ListSequenceExpr, _list_methods) utils.add_method(ListScalar, _list_methods) utils.add_method(DictSequenceExpr, _dict_methods) utils.add_method(DictScalar, _dict_methods)