odps/df/expr/composites.py (205 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .expressions import Scalar, Expr, TableSchema, SequenceExpr, ListSequenceExpr, DictSequenceExpr, \
ListScalar, DictScalar, Column
from .element import AnyOp, ElementWise
from .collections import RowAppliedCollectionExpr
from .. import types as df_types
from ..utils import to_collection
from . import utils
def _scalar(val, tp=None):
if val is None:
return
if isinstance(val, Expr):
return val
if isinstance(val, (tuple, list)):
return type(val)(_scalar(it, tp=tp) for it in val)
else:
return Scalar(_value=val, _value_type=tp)
def explode(expr, *args, **kwargs):
"""
Expand list or dict data into multiple rows
:param expr: list / dict sequence / scalar
:return:
"""
if not isinstance(expr, Column):
expr = to_collection(expr)[expr.name]
if isinstance(expr, SequenceExpr):
dtype = expr.data_type
else:
dtype = expr.value_type
func_name = 'EXPLODE'
if args and isinstance(args[0], (list, tuple, set)):
names = list(args[0])
else:
names = args
pos = kwargs.get('pos', False)
if isinstance(expr, ListSequenceExpr):
if pos:
func_name = 'POSEXPLODE'
typos = [df_types.int64, dtype.value_type]
if not names:
names = [expr.name + '_pos', expr.name]
if len(names) == 1:
names = [names[0] + '_pos', names[0]]
if len(names) != 2:
raise ValueError("The length of parameter 'names' should be exactly 1.")
else:
typos = [dtype.value_type]
if not names:
names = [expr.name]
if len(names) != 1:
raise ValueError("The length of parameter 'names' should be exactly 1.")
elif isinstance(expr, DictSequenceExpr):
if pos:
raise ValueError('Cannot support explosion with pos on dicts.')
typos = [dtype.key_type, dtype.value_type]
if not names:
names = [expr.name + '_key', expr.name + '_value']
if len(names) != 2:
raise ValueError("The length of parameter 'names' should be exactly 2.")
else:
raise ValueError('Cannot explode expression with type %s' % type(expr).__name__)
schema = TableSchema.from_lists(names, typos)
return RowAppliedCollectionExpr(_input=expr.input, _func=func_name, _schema=schema,
_fields=[expr], _keep_nulls=kwargs.get('keep_nulls', False))
def composite_op(expr, output_expr_cls, output_type=None, **kwargs):
input_args = kwargs.copy()
input_args['_input'] = expr
def check_is_sequence(arg):
a = input_args.get(arg)
if isinstance(a, SequenceExpr):
return True
elif isinstance(a, (list, tuple)):
return any(isinstance(la, SequenceExpr) for la in a)
else:
return False
is_sequence = any(check_is_sequence(a) for a in output_expr_cls._args)
if is_sequence:
output_type = output_type or expr.data_type
return output_expr_cls(_data_type=output_type, _input=expr, **kwargs)
else:
output_type = output_type or expr.value_type
return output_expr_cls(_value_type=output_type, _input=expr, **kwargs)
class CompositeOp(ElementWise):
def accept(self, visitor):
visitor.visit_composite_op(self)
class CompositeBuilderOp(AnyOp):
@property
def node_name(self):
return self.__class__.__name__
def accept(self, visitor):
visitor.visit_composite_op(self)
class ListDictLength(CompositeOp):
__slots__ = ()
class ListDictGetItem(CompositeOp):
_args = '_input', '_key', '_negative_handled'
class ListContains(CompositeOp):
_args = '_input', '_value',
class ListSort(CompositeOp):
__slots__ = ()
class DictKeys(CompositeOp):
__slots__ = ()
class DictValues(CompositeOp):
__slots__ = ()
def _scan_inputs(seq, dtype=None):
if not seq:
raise TypeError('Inputs should not be empty')
seq = [_scalar(a) for a in seq]
arg_types = set()
arg_cat = set()
for a in seq:
if isinstance(a, SequenceExpr):
arg_types.add(a.data_type)
arg_cat.add(SequenceExpr)
else:
arg_types.add(a.value_type)
arg_cat.add(Scalar)
if dtype is not None:
if not all(dtype.can_implicit_cast(t) for t in arg_types):
raise TypeError('Not all given value can be implicitly casted')
else:
if len(arg_types) == 1:
dtype = arg_types.pop()
if isinstance(dtype, df_types.Integer):
if dtype != df_types.int64 and df_types.int32.can_implicit_cast(dtype):
dtype = df_types.int32
elif isinstance(dtype, df_types.Float):
dtype = df_types.float64
else:
if all(df_types.int32.can_implicit_cast(t) for t in arg_types):
dtype = df_types.int32
elif all(df_types.int64.can_implicit_cast(t) for t in arg_types):
dtype = df_types.int64
elif all(df_types.float64.can_implicit_cast(t) for t in arg_types):
dtype = df_types.float64
else:
raise TypeError('Types of inputs should be the same')
if SequenceExpr in arg_cat:
return seq, '_data_type', dtype
else:
return seq, '_value_type', dtype
class ListBuilder(CompositeBuilderOp):
_args = '_values',
class DictBuilder(CompositeBuilderOp):
_args = '_keys', '_values'
def _len(expr):
"""
Retrieve length of a list or dict sequence / scalar.
:param expr: list or dict sequence / scalar
:return:
"""
return composite_op(expr, ListDictLength, df_types.int64)
def _getitem(expr, key):
if isinstance(expr, SequenceExpr):
dtype = expr.data_type.value_type
else:
dtype = expr.value_type.value_type
return composite_op(expr, ListDictGetItem, dtype, _key=_scalar(key))
def _sort(expr):
"""
Retrieve sorted list
:param expr: list sequence / scalar
:return:
"""
return composite_op(expr, ListSort)
def _contains(expr, value):
"""
Check whether certain value is in the inspected list
:param expr: list sequence / scalar
:param value: value to inspect
:return:
"""
return composite_op(expr, ListContains, df_types.boolean, _value=_scalar(value))
def _keys(expr):
"""
Retrieve keys of a dict
:param expr: dict sequence / scalar
:return:
"""
if isinstance(expr, SequenceExpr):
dtype = expr.data_type
else:
dtype = expr.value_type
return composite_op(expr, DictKeys, df_types.List(dtype.key_type))
def _values(expr):
"""
Retrieve values of a dict
:param expr: dict sequence / scalar
:return:
"""
if isinstance(expr, SequenceExpr):
dtype = expr.data_type
else:
dtype = expr.value_type
return composite_op(expr, DictValues, df_types.List(dtype.value_type))
def make_list(*args, **kwargs):
dtype = kwargs.get('type')
if dtype is not None:
dtype = df_types.validate_data_type(dtype)
kwargs = dict()
kwargs['_values'], k, typ = _scan_inputs(args, dtype)
kwargs[k] = df_types.List(typ)
return ListBuilder(**kwargs)
def make_dict(*args, **kwargs):
if len(args) % 2 != 0:
raise ValueError('Num of inputs to build a dict should be even')
key_type = kwargs.get('key_type')
if key_type is not None:
key_type = df_types.validate_data_type(key_type)
value_type = kwargs.get('value_type')
if value_type is not None:
value_type = df_types.validate_data_type(value_type)
kwargs = dict()
keys = list(args[0::2])
values = list(args[1::2])
kwargs['_keys'], k1, key_type = _scan_inputs(keys, key_type)
kwargs['_values'], k2, value_type = _scan_inputs(values, value_type)
k = '_data_type' if '_data_type' in (k1, k2) else '_value_type'
kwargs[k] = df_types.Dict(key_type, value_type)
return DictBuilder(**kwargs)
_list_methods = dict(
__getitem__=_getitem,
len=_len,
sort=_sort,
contains=_contains,
explode=explode,
)
_dict_methods = dict(
__getitem__=_getitem,
len=_len,
keys=_keys,
values=_values,
explode=explode,
)
utils.add_method(ListSequenceExpr, _list_methods)
utils.add_method(ListScalar, _list_methods)
utils.add_method(DictSequenceExpr, _dict_methods)
utils.add_method(DictScalar, _dict_methods)