odps/df/backends/odpssql/rewriter.py (143 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..rewriter import BaseRewriter from ...expr.reduction import Cat, GroupedCat from ...expr.window import * from ...expr.merge import * from ...expr.expressions import LateralViewCollectionExpr from ...expr.element import Func from ...expr.utils import get_attrs from ....errors import NoSuchObject from ...utils import is_source_collection from ....models import Table class Rewriter(BaseRewriter): def visit_project_collection(self, expr): self._rewrite_reduction_in_projection(expr) def visit_filter_collection(self, expr): self._rewrite_reduction_in_filter(expr) def visit_join(self, expr): if expr._predicate and isinstance(expr._predicate, list): expr._predicate = reduce(operator.and_, expr._predicate) for node in (expr.rhs,): parents = self._parents(node) if isinstance(node, JoinCollectionExpr): projection = JoinProjectCollectionExpr( _input=node, _schema=node.schema, _fields=node._fetch_fields()) self._sub(node, projection, parents) elif isinstance(node, JoinProjectCollectionExpr): self._sub(node.input, node, parents) need_project = [False, ] def walk(node): if isinstance(node, JoinCollectionExpr) and \ node.column_conflict: need_project[0] = True return if isinstance(node, JoinCollectionExpr): walk(node.lhs) walk(expr) if need_project[0]: parents = self._parents(expr) if not parents or \ not any(isinstance(parent, (ProjectCollectionExpr, JoinCollectionExpr)) for parent in parents): to_sub = expr[expr] self._sub(expr, to_sub, parents) def visit_lateral_view(self, expr): parents = self._parents(expr) if not parents \ or not any(isinstance(parent, ProjectCollectionExpr) \ and not isinstance(parent, LateralViewCollectionExpr) for parent in parents): to_sub = ProjectCollectionExpr( _input=expr, _schema=expr.schema, _fields=expr._fields ) self._sub(expr, to_sub, parents) def _handle_function(self, expr, raw_inputs): # Since Python UDF cannot support decimal field, # We will try to replace the decimal input with string. # If the output is decimal, we will also try to replace it with string, # and then cast back to decimal def no_output_decimal(): if isinstance(expr, (SequenceExpr, Scalar)): return expr.dtype != types.decimal else: return all(t != types.decimal for t in expr.schema.types) if isinstance(expr, Func): return if all(input.dtype != types.decimal for input in raw_inputs) and \ no_output_decimal(): return inputs = list(raw_inputs) for input in raw_inputs: if input.dtype == types.decimal: self._sub(input, input.astype('string'), parents=[expr, ]) if hasattr(expr, '_raw_inputs'): expr._raw_inputs = inputs else: assert len(inputs) == 1 expr._raw_input = inputs[0] attrs = get_attrs(expr) attr_values = dict((attr, getattr(expr, attr, None)) for attr in attrs) if isinstance(expr, (SequenceExpr, Scalar)): if expr.dtype == types.decimal: if isinstance(expr, SequenceExpr): attr_values['_data_type'] = types.string attr_values['_source_data_type'] = types.string else: attr_values['_value_type'] = types.string attr_values['_source_value_type'] = types.string sub = type(expr)._new(**attr_values) if expr.dtype == types.decimal: sub = sub.astype('decimal') else: names = expr.schema.names tps = expr.schema.types cast_names = set() if any(tp == types.decimal for tp in tps): new_tps = [] for name, tp in zip(names, tps): if tp != types.decimal: new_tps.append(tp) continue new_tps.append(types.string) cast_names.add(name) if len(cast_names) > 0: attr_values['_schema'] = TableSchema.from_lists(names, new_tps) sub = type(expr)(**attr_values) if len(cast_names) > 0: fields = [] for name in names: if name in cast_names: fields.append(sub[name].astype('decimal')) else: fields.append(name) sub = sub[fields] self._sub(expr, sub) def visit_function(self, expr): self._handle_function(expr, expr._inputs) def visit_reshuffle(self, expr): if isinstance(expr.input, JoinCollectionExpr): sub = JoinProjectCollectionExpr( _input=expr.input, _schema=expr.input.schema, _fields=expr.input._fetch_fields()) self._sub(expr.input, sub) def visit_apply_collection(self, expr): if ( isinstance(expr._input, JoinCollectionExpr) and (expr._input._mapjoin or expr._input._skewjoin) ): node = expr._input projection = JoinProjectCollectionExpr( _input=node, _schema=node.schema, _fields=node._fetch_fields()) self._sub(node, projection) self._handle_function(expr, expr._fields) def visit_user_defined_aggregator(self, expr): self._handle_function(expr, [expr.input, ]) def visit_column(self, expr): if is_source_collection(expr.input) and isinstance(expr._input._source_data, Table): try: if expr.input._source_data.table_schema.is_partition(expr.source_name) and \ expr.dtype != types.string: expr._source_data_type = types.string except NoSuchObject: return def visit_reduction(self, expr): if isinstance(expr, (Cat, GroupedCat)): if expr._na_rep is not None: input = expr.input.fillna(expr._na_rep) self._sub(expr.input, input, parents=(expr, ))