odps/expressions/core.py (153 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import requests
try:
import pyarrow as pa
except ImportError:
pa = None
try:
import pandas as pd
except ImportError:
pd = None
from ..compat import six
from ..serializers import (
JSONNodeField,
JSONNodeReferenceField,
JSONNodesReferencesField,
JSONSerializableModel,
)
from ..types import is_record, validate_data_type, validate_value
from ..utils import to_odps_scalar
from .functions import ExprFunction
_name_to_expr_clses = {}
class Expression(JSONSerializableModel):
__slots__ = ("final",)
_type = JSONNodeField("type")
@classmethod
def _load_expr_classes(cls):
if not _name_to_expr_clses:
for val in globals().values():
if (
not isinstance(val, type)
or not issubclass(val, Expression)
or val is Expression
):
continue
cls_name = val.__name__[0].lower() + val.__name__[1:]
_name_to_expr_clses[cls_name] = val
return _name_to_expr_clses
@classmethod
def _get_expr_class(cls, expr):
expr_name = next(iter(expr.keys()))
return cls._load_expr_classes()[expr_name]
@classmethod
def deserial(cls, content, obj=None, **kw):
if obj is None:
inst_cls = cls._get_expr_class(content)
obj = inst_cls(_parent=kw.get("_parent"))
content = next(iter(content.values()))
return super(Expression, cls).deserial(content, obj=obj, **kw)
@property
def type(self):
return validate_data_type(self._type)
def eval(self, data):
raise NotImplementedError
def to_str(self, ref_to_str=None):
raise NotImplementedError
def __str__(self):
return self.to_str()
def __repr__(self):
return "%s(%s)" % (type(self).__name__, str(self))
def _make_final_result(self, data, res):
from ..tunnel.io.types import odps_type_to_arrow_type
if not self.final:
return res
if is_record(data):
return res
elif pd and isinstance(data, pd.DataFrame):
if not isinstance(res, pd.Series):
return pd.Series([res] * len(data))
return res
elif pa and isinstance(data, (pa.RecordBatch, pa.Table)):
if isinstance(res, (pa.Array, pa.ChunkedArray)):
return res
return pa.array(
[res] * data.num_rows, type=odps_type_to_arrow_type(self.type)
)
class FunctionCall(Expression):
__slots__ = ("args", "references")
name = JSONNodeField("name")
def to_str(self, ref_to_str=None):
ref_to_str = ref_to_str or {}
if not hasattr(self, "args"):
return super(FunctionCall, self).to_str(ref_to_str)
func_cls = self.get_function_cls()
return func_cls.to_str([s.to_str(ref_to_str) for s in self.args])
def get_function_cls(self):
return ExprFunction.get_cls(self.name)
def eval(self, data):
func_cls = self.get_function_cls()
args = [a.eval(data) for a in self.args]
res = func_cls.call(*args)
return self._make_final_result(data, res)
class LeafExprDesc(Expression):
class Reference(JSONSerializableModel):
name = JSONNodeField("name")
constant = JSONNodeField("constant", default=None)
reference = JSONNodeReferenceField(Reference, "reference", default=None)
def eval(self, data):
if self.constant:
val = validate_value(self.constant, self.type)
elif self.reference:
if pa and isinstance(data, (pa.RecordBatch, pa.Table)):
name_to_idx = {
c.lower(): idx for idx, c in enumerate(data.schema.names)
}
val = data.column(name_to_idx[self.reference.name.lower()])
elif pd and isinstance(data, pd.DataFrame):
lower_to_name = {c.lower(): c for c in data.columns}
val = data[lower_to_name[self.reference.name.lower()]]
else:
val = data[self.reference.name]
else:
raise NotImplementedError("Expression cannot be called")
return self._make_final_result(data, val)
def to_str(self, ref_to_str=None):
ref_to_str = ref_to_str or {}
if self.constant:
val = validate_value(self.constant, self.type)
return to_odps_scalar(val)
elif self.reference:
default_str = "`%s`" % self.reference.name
return ref_to_str.get(self.reference.name, default_str)
else:
raise NotImplementedError("Expression cannot be accepted")
class VisitedExpressions(JSONSerializableModel):
expressions = JSONNodesReferencesField(Expression, "expressions")
@classmethod
def parse(cls, response, obj=None, **kw):
if isinstance(response, Expression):
return response
if isinstance(response, requests.Response):
# PY2 prefer bytes, while PY3 prefer str
response = response.content.decode() if six.PY3 else response.content
if isinstance(response, six.string_types):
response = json.loads(response)
if isinstance(response, list):
response = {"expressions": response}
parsed = super(VisitedExpressions, cls).parse(response, obj=obj, **kw)
res_stack = []
for expr in parsed.expressions:
expr.final = False
if isinstance(expr, FunctionCall):
arg_count = expr.get_function_cls().arg_count
expr.args = res_stack[-arg_count:]
res_stack = res_stack[:-arg_count]
res_stack.append(expr)
assert len(res_stack) == 1
res_stack[0].final = True
res_stack[0].references = [
t.reference.name
for t in parsed.expressions
if isinstance(t, LeafExprDesc) and t.reference is not None
]
return res_stack[0]
parse = VisitedExpressions.parse