odps/df/backends/optimize/core.py (196 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools from ..core import Backend from ...expr.core import ExprDictionary from ...expr.expressions import * from ...expr.groupby import GroupByCollectionExpr from ...expr.reduction import SequenceReduction, GroupedSequenceReduction from ...expr.merge import JoinCollectionExpr from ...expr.window import Window from ...expr.utils import select_fields from ...utils import traverse_until_source from .... import utils from .columnpruning import ColumnPruning from .predicatepushdown import PredicatePushdown from .utils import change_input class Optimizer(Backend): def __init__(self, dag): self._dag = dag def optimize(self): if options.df.optimize: if options.df.optimizes.cp: ColumnPruning(self._dag).prune() if options.df.optimizes.pp: PredicatePushdown(self._dag).pushdown() for node in traverse_until_source(self._dag, top_down=True): try: node.accept(self) except NotImplementedError: continue # from down up do again for node in traverse_until_source(self._dag): try: node.accept(self) except NotImplementedError: continue return self._dag.root def _sub(self, expr, to_sub, parents=None): self._dag.substitute(expr, to_sub, parents=parents) def visit_filter_collection(self, expr): if not options.df.optimize: return if isinstance(expr.input, GroupByCollectionExpr) and \ not expr.input.optimize_banned: # move filter on GroupBy to GroupBy's having grouped = expr.input predicate = self._do_compact(expr.predicate, expr.input) if predicate is None: predicate = expr.predicate having = grouped.having if having is not None: predicates = having & predicate else: predicates = predicate grouped._having = predicates self._sub(expr, grouped) elif isinstance(expr.input, FilterCollectionExpr): filters = [expr] node = expr.input while isinstance(node, FilterCollectionExpr): filters.append(node) node = node.input self._compact_filters(*filters) raise NotImplementedError @classmethod def get_compact_filters(cls, dag, *filters): input = filters[-1].input get_field = lambda n, col: input[col] for filter in filters: change_input(filter, filter.input, input, get_field, dag) predicate = reduce(operator.and_, [f.predicate for f in filters[::-1]]) return FilterCollectionExpr(input, predicate, _schema=input.schema) def _compact_filters(self, *filters): new_filter = self.get_compact_filters(self._dag, *filters) self._sub(filters[0], new_filter) def visit_project_collection(self, expr): # Summary does not attend here if not options.df.optimize: return compacted = self._visit_need_compact_collection(expr) if compacted: expr = compacted if isinstance(expr, ProjectCollectionExpr) and \ isinstance(expr.input, GroupByCollectionExpr) and \ not expr.input.optimize_banned: # compact projection into Groupby grouped = expr.input selects = [] for n in expr.traverse(top_down=True, unique=True, stop_cond=lambda x: x is grouped): # stop compact if isinstance(n, (Window, SequenceReduction)): return for field in expr._fields: selects.append(self._do_compact(field, grouped) or field) grouped._aggregations = grouped._fields = selects grouped._schema = TableSchema.from_lists( [f.name for f in selects], [f.dtype for f in selects] ) self._sub(expr, grouped) return def visit_groupby(self, expr): if not options.df.optimize: return # we do not do compact on the projections from Join input = expr.input while isinstance(input, ProjectCollectionExpr): input = input._input if isinstance(input, JoinCollectionExpr): return if len(expr._aggregations) == 1 and \ isinstance(expr._aggregations[0], GroupedSequenceReduction) and \ isinstance(expr._aggregations[0]._input, CollectionExpr): # we just skip the case: df.groupby(***).count() return self._visit_need_compact_collection(expr) def visit_distinct(self, expr): if not options.df.optimize: return self._visit_need_compact_collection(expr) def visit_apply_collection(self, expr): if not options.df.optimize: return if expr._lateral_view: return if ( isinstance(expr.input, JoinCollectionExpr) and (expr.input._mapjoin or expr.input._skewjoin) ): return self._visit_need_compact_collection(expr) def _visit_need_compact_collection(self, expr): compacted = self._compact(expr) if compacted is None: return if expr is not compacted: self._sub(expr, compacted) return compacted def _compact(self, expr): to_compact = [expr, ] for node in traverse_until_source(expr, top_down=True, unique=True): if node is expr: continue if not isinstance(node, CollectionExpr): continue # We do not handle collection with Scalar column or window function here # TODO think way to compact in this situation if isinstance(node, ProjectCollectionExpr) and \ not node.optimize_banned and \ not any(isinstance(n, Window) for n in node._fields): valid = True for it in itertools.chain(*(node.all_path(to_compact[-1]))): if isinstance(it, CollectionExpr) and \ any(isinstance(n.input, LateralViewCollectionExpr) for n in it.columns): valid = False break if isinstance(it, SequenceReduction): valid = False break if not valid: break to_compact.append(node) else: break if len(to_compact) <= 1: return return self._do_compact(expr, *to_compact[1:][::-1]) def _do_compact(self, expr, *to_compact): retval = expr collection_dict = ExprDictionary() for coll in to_compact: collection_dict[coll] = True for node in expr.traverse(top_down=True, unique=True, stop_cond=lambda x: x is to_compact[-1]): if node in collection_dict: parents = self._dag.successors(node) for parent in parents: if isinstance(parent, Column): col = parent col_name = col.source_name or col.name field = self._get_field(node, col_name) if col.is_renamed(): field = field.rename(col.name) else: field = field.copy() self._sub(col, field) if col is retval: retval = field else: parent.substitute(node, node.input, dag=self._dag) return retval def _get_fields(self, collection): fields = select_fields(collection) if isinstance(collection, GroupByCollectionExpr) and \ collection._having is not None: # add GroupbyCollectionExpr.having to broadcast fields fields.append(collection._having) return fields def _get_field(self, collection, name): # FIXME: consider name with upper letters name = utils.to_str(name) if isinstance(collection, GroupByCollectionExpr): return collection._name_to_exprs()[name] name_idxes = collection.schema._name_indexes if name.lower() in name_idxes: idx = name_idxes[name.lower()] else: idx = name_idxes[name] return self._get_fields(collection)[idx]