odps/expressions/core.py (153 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import requests try: import pyarrow as pa except ImportError: pa = None try: import pandas as pd except ImportError: pd = None from ..compat import six from ..serializers import ( JSONNodeField, JSONNodeReferenceField, JSONNodesReferencesField, JSONSerializableModel, ) from ..types import is_record, validate_data_type, validate_value from ..utils import to_odps_scalar from .functions import ExprFunction _name_to_expr_clses = {} class Expression(JSONSerializableModel): __slots__ = ("final",) _type = JSONNodeField("type") @classmethod def _load_expr_classes(cls): if not _name_to_expr_clses: for val in globals().values(): if ( not isinstance(val, type) or not issubclass(val, Expression) or val is Expression ): continue cls_name = val.__name__[0].lower() + val.__name__[1:] _name_to_expr_clses[cls_name] = val return _name_to_expr_clses @classmethod def _get_expr_class(cls, expr): expr_name = next(iter(expr.keys())) return cls._load_expr_classes()[expr_name] @classmethod def deserial(cls, content, obj=None, **kw): if obj is None: inst_cls = cls._get_expr_class(content) obj = inst_cls(_parent=kw.get("_parent")) content = next(iter(content.values())) return super(Expression, cls).deserial(content, obj=obj, **kw) @property def type(self): return validate_data_type(self._type) def eval(self, data): raise NotImplementedError def to_str(self, ref_to_str=None): raise NotImplementedError def __str__(self): return self.to_str() def __repr__(self): return "%s(%s)" % (type(self).__name__, str(self)) def _make_final_result(self, data, res): from ..tunnel.io.types import odps_type_to_arrow_type if not self.final: return res if is_record(data): return res elif pd and isinstance(data, pd.DataFrame): if not isinstance(res, pd.Series): return pd.Series([res] * len(data)) return res elif pa and isinstance(data, (pa.RecordBatch, pa.Table)): if isinstance(res, (pa.Array, pa.ChunkedArray)): return res return pa.array( [res] * data.num_rows, type=odps_type_to_arrow_type(self.type) ) class FunctionCall(Expression): __slots__ = ("args", "references") name = JSONNodeField("name") def to_str(self, ref_to_str=None): ref_to_str = ref_to_str or {} if not hasattr(self, "args"): return super(FunctionCall, self).to_str(ref_to_str) func_cls = self.get_function_cls() return func_cls.to_str([s.to_str(ref_to_str) for s in self.args]) def get_function_cls(self): return ExprFunction.get_cls(self.name) def eval(self, data): func_cls = self.get_function_cls() args = [a.eval(data) for a in self.args] res = func_cls.call(*args) return self._make_final_result(data, res) class LeafExprDesc(Expression): class Reference(JSONSerializableModel): name = JSONNodeField("name") constant = JSONNodeField("constant", default=None) reference = JSONNodeReferenceField(Reference, "reference", default=None) def eval(self, data): if self.constant: val = validate_value(self.constant, self.type) elif self.reference: if pa and isinstance(data, (pa.RecordBatch, pa.Table)): name_to_idx = { c.lower(): idx for idx, c in enumerate(data.schema.names) } val = data.column(name_to_idx[self.reference.name.lower()]) elif pd and isinstance(data, pd.DataFrame): lower_to_name = {c.lower(): c for c in data.columns} val = data[lower_to_name[self.reference.name.lower()]] else: val = data[self.reference.name] else: raise NotImplementedError("Expression cannot be called") return self._make_final_result(data, val) def to_str(self, ref_to_str=None): ref_to_str = ref_to_str or {} if self.constant: val = validate_value(self.constant, self.type) return to_odps_scalar(val) elif self.reference: default_str = "`%s`" % self.reference.name return ref_to_str.get(self.reference.name, default_str) else: raise NotImplementedError("Expression cannot be accepted") class VisitedExpressions(JSONSerializableModel): expressions = JSONNodesReferencesField(Expression, "expressions") @classmethod def parse(cls, response, obj=None, **kw): if isinstance(response, Expression): return response if isinstance(response, requests.Response): # PY2 prefer bytes, while PY3 prefer str response = response.content.decode() if six.PY3 else response.content if isinstance(response, six.string_types): response = json.loads(response) if isinstance(response, list): response = {"expressions": response} parsed = super(VisitedExpressions, cls).parse(response, obj=obj, **kw) res_stack = [] for expr in parsed.expressions: expr.final = False if isinstance(expr, FunctionCall): arg_count = expr.get_function_cls().arg_count expr.args = res_stack[-arg_count:] res_stack = res_stack[:-arg_count] res_stack.append(expr) assert len(res_stack) == 1 res_stack[0].final = True res_stack[0].references = [ t.reference.name for t in parsed.expressions if isinstance(t, LeafExprDesc) and t.reference is not None ] return res_stack[0] parse = VisitedExpressions.parse