odps/df/expr/arithmetic.py (273 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from datetime import datetime from .expressions import * from .element import AnyOp, ElementWise from . import utils from .. import types class BinOp(AnyOp): __slots__ = () _args = '_lhs', '_rhs' @property def node_name(self): return self.__class__.__name__ @property def name(self): if self._name: return self._name seq_hs = [hs for hs in self.args if isinstance(hs, SequenceExpr)] if len(seq_hs) == 1: return seq_hs[0].name def accept(self, visitor): visitor.visit_binary_op(self) class UnaryOp(ElementWise): __slots__ = () def accept(self, visitor): visitor.visit_unary_op(self) class Arithmetic(BinOp): __slots__ = () class Comparison(BinOp): __slots__ = () class LogicalBinOp(BinOp): __slots__ = () class Negate(UnaryOp): __slots__ = () class Invert(UnaryOp): __slots__ = () class Abs(UnaryOp): __slots__ = () class Add(Arithmetic): __slots__ = () class Substract(Arithmetic): __slots__ = () class Multiply(Arithmetic): __slots__ = () class Divide(Arithmetic): __slots__ = () class FloorDivide(Arithmetic): __slots__ = () class Mod(Arithmetic): __slots__ = () class Power(Arithmetic): __slots__ = () class Greater(Comparison): __slots__ = () class Less(Comparison): __slots__ = () class Equal(Comparison): __slots__ = () class NotEqual(Comparison): __slots__ = () class GreaterEqual(Comparison): __slots__ = () class LessEqual(Comparison): __slots__ = () class Or(LogicalBinOp): __slots__ = () class And(LogicalBinOp): __slots__ = () def _get_type(other): if isinstance(other, SequenceExpr): other_type = other._data_type elif isinstance(other, Scalar): other_type = other._value_type else: other = Scalar(_value=other) other_type = other._value_type return other_type, other def _arithmetic(expr, other, output_expr_cls, reverse=False, output_type=None): if isinstance(expr, (SequenceExpr, Scalar)): other_type, other = _get_type(other) is_sequence = isinstance(expr, SequenceExpr) or isinstance(other, SequenceExpr) if output_type is None: output_type = utils.highest_precedence_data_type(expr.dtype, other_type) if reverse: expr, other = other, expr if is_sequence: return output_expr_cls(_data_type=output_type, _lhs=expr, _rhs=other) else: return output_expr_cls(_value_type=output_type, _lhs=expr, _rhs=other) def _reversed_arithmetic(expr, other, output_expr_cls, output_type=None): return _arithmetic(expr, other, output_expr_cls, reverse=True, output_type=output_type) def _cmp(expr, other, output_expr_cls): if isinstance(expr, (SequenceExpr, Scalar)): other_type, other = _get_type(other) is_sequence = isinstance(expr, SequenceExpr) or isinstance(other, SequenceExpr) utils.highest_precedence_data_type(expr.dtype, other_type) # operand cast to data_type to compare output_type = types.boolean if is_sequence: return output_expr_cls(_data_type=output_type, _lhs=expr, _rhs=other) else: return output_expr_cls(_value_type=output_type, _lhs=expr, _rhs=other) def _unary(expr, output_expr_cls): if isinstance(expr, (SequenceExpr, Scalar)): is_sequence = isinstance(expr, SequenceExpr) if is_sequence: return output_expr_cls(_data_type=expr.dtype, _input=expr) else: return output_expr_cls(_value_type=expr.dtype, _input=expr) def _logic(expr, other, output_expr_cls): if isinstance(expr, (SequenceExpr, Scalar)): other_type, other = _get_type(other) is_sequence = isinstance(expr, SequenceExpr) or isinstance(other, SequenceExpr) if expr.dtype == types.boolean and other.dtype == types.boolean: output_type = types.boolean if is_sequence: return output_expr_cls(_data_type=output_type, _lhs=expr, _rhs=other) else: return output_expr_cls(_value_type=output_type, _lhs=expr, _rhs=other) raise TypeError('Logic operation needs boolean operand') def _is_datetime(expr): if isinstance(expr, Expr): return expr.dtype == types.datetime else: return isinstance(expr, datetime) def _add(expr, other): if _is_datetime(expr) and _is_datetime(other): raise ExpressionError('Cannot add two datetimes') return _arithmetic(expr, other, Add) def _radd(expr, other): if _is_datetime(expr) and _is_datetime(other): raise ExpressionError('Cannot add two datetimes') return _reversed_arithmetic(expr, other, Add) def _sub(expr, other): rtype = None if _is_datetime(expr) and _is_datetime(other): rtype = types.int64 return _arithmetic(expr, other, Substract, output_type=rtype) def _rsub(expr, other): rtype = None if _is_datetime(expr) and _is_datetime(other): rtype = types.int64 return _reversed_arithmetic(expr, other, Substract, output_type=rtype) def _eq(expr, other): return _cmp(expr, other, Equal) def _ne(expr, other): return _cmp(expr, other, NotEqual) def _gt(expr, other): return _cmp(expr, other, Greater) def _lt(expr, other): return _cmp(expr, other, Less) def _le(expr, other): return _cmp(expr, other, LessEqual) def _ge(expr, other): return _cmp(expr, other, GreaterEqual) def _mul(expr, other): return _arithmetic(expr, other, Multiply) def _rmul(expr, other): return _reversed_arithmetic(expr, other, Multiply) def _div(expr, other): if isinstance(expr.dtype, types.Integer) and isinstance(_get_type(other)[0], types.Integer): output_type = types.float64 else: output_type = None return _arithmetic(expr, other, Divide, output_type=output_type) def _rdiv(expr, other): if isinstance(expr.dtype, types.Integer) and isinstance(_get_type(other)[0], types.Integer): output_type = types.float64 else: output_type = None return _reversed_arithmetic(expr, other, Divide, output_type=output_type) def _mod(expr, other): return _arithmetic(expr, other, Mod) def _rmod(expr, other): return _reversed_arithmetic(expr, other, Mod) def _floordiv(expr, other): return _arithmetic(expr, other, FloorDivide) def _rfloordiv(expr, other): return _reversed_arithmetic(expr, other, FloorDivide) def _pow(expr, other): return _arithmetic(expr, other, Power) def _rpow(expr, other): return _reversed_arithmetic(expr, other, Power) def _or(expr, other): return _logic(expr, other, Or) def _ror(expr, other): return _or(expr, other) def _and(expr, other): return _logic(expr, other, And) def _rand(expr, other): return _and(expr, other) def _neg(expr): if isinstance(expr, Negate): return expr.input return _unary(expr, Negate) def _invert(expr): if isinstance(expr, Invert): return expr.input return _unary(expr, Invert) def _abs(expr): if isinstance(expr, Abs): return expr return _unary(expr, Abs) _number_methods = dict( _add=_add, _radd=_radd, _sub=_sub, _rsub=_rsub, _mul=_mul, _rmul=_rmul, _div=_div, _rdiv=_rdiv, _floordiv=_floordiv, _rfloordiv=_rfloordiv, _mod=_mod, _rmod=_rmod, _pow=_pow, _rpow=_rpow, _neg=_neg, _abs=_abs, _eq=_eq, _ne=_ne, _gt=_gt, _lt=_lt, _le=_le, _ge=_ge, ) _int_number_methods = dict( _invert=_invert ) _string_methods = dict( _add=_add, _radd=_radd, _eq=_eq, _ne=_ne, _gt=_gt, _lt=_lt, _le=_le, _ge=_ge, ) _boolean_methods = dict( _or=_or, _ror=_ror, _and=_and, _rand=_rand, _eq=_eq, _ne=_ne, _invert=_invert, _neg=_neg ) _datetime_methods = dict( _add=_add, # TODO, to check _radd=_radd, _sub=_sub, _rsub=_rsub, _eq=_eq, _ne=_ne, _gt=_gt, _lt=_lt, _le=_le, _ge=_ge ) utils.add_method(StringSequenceExpr, _string_methods) utils.add_method(StringScalar, _string_methods) utils.add_method(BooleanSequenceExpr, _boolean_methods) utils.add_method(BooleanScalar, _boolean_methods) utils.add_method(DatetimeSequenceExpr, _datetime_methods) utils.add_method(DatetimeScalar, _datetime_methods) utils.add_method(DateSequenceExpr, _datetime_methods) utils.add_method(DateScalar, _datetime_methods) utils.add_method(TimestampSequenceExpr, _datetime_methods) utils.add_method(TimestampScalar, _datetime_methods) for number_sequence in number_sequences + number_scalars: utils.add_method(number_sequence, _number_methods) for int_number_sequence in int_number_sequences + int_number_scalars: utils.add_method(int_number_sequence, _int_number_methods)