odps/df/backends/rewriter.py (135 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools from .core import Backend from ..expr.expressions import Summary from ..expr.reduction import SequenceReduction from ..expr.window import * # don't remove from ..expr.merge import * from ..utils import traverse_until_source class BaseRewriter(Backend): def __init__(self, expr_dag, traversed=None): self._dag = expr_dag self._indexer = itertools.count(0) self._traversed = traversed or set() def rewrite(self): for node in self._iter(): self._traversed.add(id(node)) self._visit_node(node) return self._dag.root def _iter(self): for node in traverse_until_source(self._dag, top_down=True, traversed=self._traversed): yield node while True: all_traversed = True for node in traverse_until_source(self._dag, top_down=True): if id(node) not in self._traversed: all_traversed = False yield node if all_traversed: break def _visit_node(self, node): try: node.accept(self) except NotImplementedError: return def _sub(self, expr, to_sub, parents=None): self._dag.substitute(expr, to_sub, parents=parents) def _parents(self, expr): return self._dag.successors(expr) def _rewrite_reduction_in_projection(self, expr): # FIXME how to handle nested reduction? if isinstance(expr, Summary): return collection = expr.input sink_selects = [] columns = set() to_replace = [] windows_rewrite = False for field in expr.fields: has_window = False traversed = set() for path in field.all_path(collection, strict=True): for node in path: if id(node) in traversed: continue else: traversed.add(id(node)) if isinstance(node, SequenceReduction) and not node._need_cache: windows_rewrite = True has_window = True win = self._reduction_to_window(node) window_name = '%s_%s' % (win.name, next(self._indexer)) sink_selects.append(win.rename(window_name)) to_replace.append((node, window_name)) break elif isinstance(node, Column): if node.input is not collection: continue if node.name in columns: to_replace.append((node, node.name)) continue columns.add(node.name) select_field = collection[node.source_name] if node.is_renamed(): select_field = select_field.rename(node.name) sink_selects.append(select_field) to_replace.append((node, node.name)) if has_window: field._name = field.name if not windows_rewrite: return get = lambda x: x.name if not isinstance(x, six.string_types) else x projected = collection[sorted(sink_selects, key=get)] projected.optimize_banned = True # TO prevent from optimizing expr.substitute(collection, projected, dag=self._dag) for col, col_name in to_replace: self._sub(col, projected[col_name].rename(col.name)) def _rewrite_reduction_in_filter(self, expr): # FIXME how to handle nested reduction? collection = expr.input sink_selects = [] columns = set() to_replace = [] windows_rewrite = False traversed = set() for path in expr.predicate.all_path(collection, strict=True): for node in path: if id(node) in traversed: continue else: traversed.add(id(node)) if isinstance(node, SequenceReduction): windows_rewrite = True win = self._reduction_to_window(node) window_name = '%s_%s' % (win.name, next(self._indexer)) sink_selects.append(win.rename(window_name)) to_replace.append((node, window_name)) break elif isinstance(node, Column): if node.input is not collection: continue if node.name in columns: to_replace.append((node, node.name)) continue columns.add(node.name) select_field = collection[node.source_name] if node.is_renamed(): select_field = select_field.rename(node.name) sink_selects.append(select_field) to_replace.append((node, node.name)) for column_name in expr.schema.names: if column_name in columns: continue columns.add(column_name) sink_selects.append(column_name) if not windows_rewrite: return get = lambda x: x.name if not isinstance(x, six.string_types) else x projected = collection[sorted(sink_selects, key=get)] projected.optimize_banned = True # TO prevent from optimizing expr.substitute(collection, projected, dag=self._dag) for col, col_name in to_replace: self._sub(col, projected[col_name]) to_sub = expr[expr.schema.names] self._sub(expr, to_sub) def _reduction_to_window(self, expr): clazz = 'Cum' + expr.node_name return globals()[clazz](_input=expr.input, _data_type=expr.dtype)