odps/df/backends/odpssql/rewriter.py (143 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..rewriter import BaseRewriter
from ...expr.reduction import Cat, GroupedCat
from ...expr.window import *
from ...expr.merge import *
from ...expr.expressions import LateralViewCollectionExpr
from ...expr.element import Func
from ...expr.utils import get_attrs
from ....errors import NoSuchObject
from ...utils import is_source_collection
from ....models import Table
class Rewriter(BaseRewriter):
def visit_project_collection(self, expr):
self._rewrite_reduction_in_projection(expr)
def visit_filter_collection(self, expr):
self._rewrite_reduction_in_filter(expr)
def visit_join(self, expr):
if expr._predicate and isinstance(expr._predicate, list):
expr._predicate = reduce(operator.and_, expr._predicate)
for node in (expr.rhs,):
parents = self._parents(node)
if isinstance(node, JoinCollectionExpr):
projection = JoinProjectCollectionExpr(
_input=node, _schema=node.schema,
_fields=node._fetch_fields())
self._sub(node, projection, parents)
elif isinstance(node, JoinProjectCollectionExpr):
self._sub(node.input, node, parents)
need_project = [False, ]
def walk(node):
if isinstance(node, JoinCollectionExpr) and \
node.column_conflict:
need_project[0] = True
return
if isinstance(node, JoinCollectionExpr):
walk(node.lhs)
walk(expr)
if need_project[0]:
parents = self._parents(expr)
if not parents or \
not any(isinstance(parent, (ProjectCollectionExpr, JoinCollectionExpr))
for parent in parents):
to_sub = expr[expr]
self._sub(expr, to_sub, parents)
def visit_lateral_view(self, expr):
parents = self._parents(expr)
if not parents \
or not any(isinstance(parent, ProjectCollectionExpr) \
and not isinstance(parent, LateralViewCollectionExpr) for parent in parents):
to_sub = ProjectCollectionExpr(
_input=expr, _schema=expr.schema, _fields=expr._fields
)
self._sub(expr, to_sub, parents)
def _handle_function(self, expr, raw_inputs):
# Since Python UDF cannot support decimal field,
# We will try to replace the decimal input with string.
# If the output is decimal, we will also try to replace it with string,
# and then cast back to decimal
def no_output_decimal():
if isinstance(expr, (SequenceExpr, Scalar)):
return expr.dtype != types.decimal
else:
return all(t != types.decimal for t in expr.schema.types)
if isinstance(expr, Func):
return
if all(input.dtype != types.decimal for input in raw_inputs) and \
no_output_decimal():
return
inputs = list(raw_inputs)
for input in raw_inputs:
if input.dtype == types.decimal:
self._sub(input, input.astype('string'), parents=[expr, ])
if hasattr(expr, '_raw_inputs'):
expr._raw_inputs = inputs
else:
assert len(inputs) == 1
expr._raw_input = inputs[0]
attrs = get_attrs(expr)
attr_values = dict((attr, getattr(expr, attr, None)) for attr in attrs)
if isinstance(expr, (SequenceExpr, Scalar)):
if expr.dtype == types.decimal:
if isinstance(expr, SequenceExpr):
attr_values['_data_type'] = types.string
attr_values['_source_data_type'] = types.string
else:
attr_values['_value_type'] = types.string
attr_values['_source_value_type'] = types.string
sub = type(expr)._new(**attr_values)
if expr.dtype == types.decimal:
sub = sub.astype('decimal')
else:
names = expr.schema.names
tps = expr.schema.types
cast_names = set()
if any(tp == types.decimal for tp in tps):
new_tps = []
for name, tp in zip(names, tps):
if tp != types.decimal:
new_tps.append(tp)
continue
new_tps.append(types.string)
cast_names.add(name)
if len(cast_names) > 0:
attr_values['_schema'] = TableSchema.from_lists(names, new_tps)
sub = type(expr)(**attr_values)
if len(cast_names) > 0:
fields = []
for name in names:
if name in cast_names:
fields.append(sub[name].astype('decimal'))
else:
fields.append(name)
sub = sub[fields]
self._sub(expr, sub)
def visit_function(self, expr):
self._handle_function(expr, expr._inputs)
def visit_reshuffle(self, expr):
if isinstance(expr.input, JoinCollectionExpr):
sub = JoinProjectCollectionExpr(
_input=expr.input, _schema=expr.input.schema,
_fields=expr.input._fetch_fields())
self._sub(expr.input, sub)
def visit_apply_collection(self, expr):
if (
isinstance(expr._input, JoinCollectionExpr)
and (expr._input._mapjoin or expr._input._skewjoin)
):
node = expr._input
projection = JoinProjectCollectionExpr(
_input=node, _schema=node.schema,
_fields=node._fetch_fields())
self._sub(node, projection)
self._handle_function(expr, expr._fields)
def visit_user_defined_aggregator(self, expr):
self._handle_function(expr, [expr.input, ])
def visit_column(self, expr):
if is_source_collection(expr.input) and isinstance(expr._input._source_data, Table):
try:
if expr.input._source_data.table_schema.is_partition(expr.source_name) and \
expr.dtype != types.string:
expr._source_data_type = types.string
except NoSuchObject:
return
def visit_reduction(self, expr):
if isinstance(expr, (Cat, GroupedCat)):
if expr._na_rep is not None:
input = expr.input.fillna(expr._na_rep)
self._sub(expr.input, input, parents=(expr, ))