odps/df/backends/analyzer.py (252 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools from .core import Backend from ..utils import traverse_until_source from ..expr.expressions import Scalar, SequenceExpr, CollectionExpr from ..expr.reduction import GroupedSequenceReduction from ..expr.element import Switch from .. import output from ... import compat from ...models import TableSchema from .utils import refresh_dynamic from ..types import DynamicSchema from ...compat import six class BaseAnalyzer(Backend): """ Analyzer is used before optimzing, which analyze some operation that is not supported for this execution backend. """ def __init__(self, expr_dag, traversed=None, on_sub=None): self._dag = expr_dag self._indexer = itertools.count(0) self._traversed = traversed or set() self._on_sub = on_sub def analyze(self): for node in self._iter(): self._traversed.add(id(node)) self._visit_node(node) return self._dag.root def _iter(self): for node in traverse_until_source(self._dag, top_down=True, traversed=self._traversed): yield node while True: all_traversed = True for node in traverse_until_source(self._dag, top_down=True): if id(node) not in self._traversed: all_traversed = False yield node if all_traversed: break def _visit_node(self, node): try: node.accept(self) except NotImplementedError: return def _sub(self, expr, sub, parents=None): self._dag.substitute(expr, sub, parents=parents) if self._on_sub: self._on_sub(expr, sub) @staticmethod def _get_moment_sub_expr(expr, _input, order, center): def _group_mean(e): m = e.mean() if isinstance(expr, GroupedSequenceReduction): m = m.to_grouped_reduction(expr._grouped) return m def _order(e, o): if o == 1: return e else: return e ** o if not center: if order == 0: sub = Scalar(1) else: sub = _group_mean(_input ** order) else: if order == 0: sub = Scalar(1) elif order == 1: sub = Scalar(0) else: sub = _group_mean(_input ** order) divided = 1 divisor = 1 for o in compat.irange(1, order): divided *= order - o + 1 divisor *= o part_item = divided // divisor * _group_mean(_order(_input, order - o)) \ * (_order(_group_mean(_input), o)) if o & 1: sub -= part_item else: sub += part_item part_item = _group_mean(_input) ** order if order & 1: sub -= part_item else: sub += part_item return sub @classmethod def _get_cut_sub_expr(cls, expr): is_seq = isinstance(expr, SequenceExpr) kw = dict() if is_seq: kw['_data_type'] = expr.dtype else: kw['_value_type'] = expr.dtype conditions = [] thens = [] if expr.include_under: bin = expr.bins[0] if expr.right and not expr.include_lowest: conditions.append(expr.input <= bin) else: conditions.append(expr.input < bin) thens.append(expr.labels[0]) for i, bin in enumerate(expr.bins[1:]): lower_bin = expr.bins[i] if not expr.right or (i == 0 and expr.include_lowest): condition = lower_bin <= expr.input else: condition = lower_bin < expr.input if expr.right: condition = (condition & (expr.input <= bin)) else: condition = (condition & (expr.input < bin)) conditions.append(condition) if expr.include_under: thens.append(expr.labels[i + 1]) else: thens.append(expr.labels[i]) if expr.include_over: bin = expr.bins[-1] if expr.right: conditions.append(bin < expr.input) else: conditions.append(bin <= expr.input) thens.append(expr.labels[-1]) return Switch(_conditions=conditions, _thens=thens, _default=None, _input=None, **kw) @classmethod def _get_value_counts_sub_expr(cls, expr): collection = expr.input by = expr._by sort = expr._sort.value ascending = expr._ascending.value dropna = expr._dropna.value sub = collection.groupby(by).agg(count=collection.count()) if sort: sub = sub.sort('count', ascending=ascending) if dropna: sub = sub.filter(sub[by.name].notnull()) return sub def _get_pivot_sub_expr(self, expr): columns_expr = expr.input.distinct([c.copy() for c in expr._columns]) group_names = [g.name for g in expr._group] group_types = [g.dtype for g in expr._group] exprs = [expr] def callback(result, new_expr): expr = exprs[0] columns = [r[0] for r in result] if len(expr._values) > 1: names = group_names + \ ['{0}_{1}'.format(v.name, c) for v in expr._values for c in columns] types = group_types + \ list(itertools.chain(*[[n.dtype] * len(columns) for n in expr._values])) else: names = group_names + columns types = group_types + [expr._values[0].dtype] * len(columns) new_expr._schema = TableSchema.from_lists(names, types) column_name = expr._columns[0].name # column's size can only be 1 values_names = [v.name for v in expr._values] @output(names, types) def reducer(keys): values = [None] * len(columns) * len(values_names) def h(row, done): col = getattr(row, column_name) for val_idx, value_name in enumerate(values_names): val = getattr(row, value_name) idx = len(columns) * val_idx + columns.index(col) if values[idx] is not None: raise ValueError( 'Row contains duplicate entries, rows: {0}, column: {1}'.format(keys, col)) values[idx] = val if done: yield keys + tuple(values) return h fields = expr._group + expr._columns + expr._values pivoted = expr.input.select(fields).map_reduce(reducer=reducer, group=group_names) self._sub(new_expr, pivoted) # trigger refresh of dynamic operations refresh_dynamic(pivoted, self._dag) return CollectionExpr(_schema=DynamicSchema.from_lists(group_names, group_types), _deps=[(columns_expr, callback)]) def _get_pivot_table_sub_expr_without_columns(self, expr): def get_agg(field, agg_func, agg_func_name, fill_value): from ..expr.expressions import ReprWrapper if isinstance(agg_func, six.string_types): aggregated = field.eval(agg_func, rewrite=False) if isinstance(aggregated, ReprWrapper): aggregated = aggregated() else: aggregated = field.agg(agg_func) if fill_value is not None: aggregated.fillna(fill_value) return aggregated.rename('{0}_{1}'.format(field.name, agg_func_name)) grouped = expr.input.groupby(expr._group) aggs = [] for agg_func, agg_func_name in zip(expr._agg_func, expr._agg_func_names): for value in expr._values: agg = get_agg(value, agg_func, agg_func_name, expr.fill_value) aggs.append(agg) return grouped.aggregate(aggs, sort_by_name=False) def _get_pivot_table_sub_expr_with_columns(self, expr): columns_expr = expr.input.distinct([c.copy() for c in expr._columns]) group_names = [g.name for g in expr._group] group_types = [g.dtype for g in expr._group] exprs = [expr] def callback(result, new_expr): expr = exprs[0] columns = [r[0] for r in result] names = list(group_names) tps = list(group_types) aggs = [] for agg_func_name, agg_func in zip(expr._agg_func_names, expr._agg_func): for value_col in expr._values: for col in columns: base = '{0}_'.format(col) if col is not None else '' name = '{0}{1}_{2}'.format(base, value_col.name, agg_func_name) names.append(name) tps.append(value_col.dtype) col = col.item() if hasattr(col, 'item') else col field = (expr._columns[0] == col).ifelse( value_col, Scalar(_value_type=value_col.dtype)) if isinstance(agg_func, six.string_types): agg = getattr(field, agg_func)() else: func = agg_func() class ActualAgg(object): def buffer(self): return func.buffer() def __call__(self, buffer, value): if value is None: return func(buffer, value) def merge(self, buffer, pbuffer): func.merge(buffer, pbuffer) def getvalue(self, buffer): return func.getvalue(buffer) agg = field.agg(ActualAgg) if expr.fill_value is not None: agg = agg.fillna(expr.fill_value) agg = agg.rename(name) aggs.append(agg) new_expr._schema = TableSchema.from_lists(names, tps) pivoted = expr.input.groupby(expr._group).aggregate(aggs, sort_by_name=False) self._sub(new_expr, pivoted) # trigger refresh of dynamic operations refresh_dynamic(pivoted, self._dag) return CollectionExpr(_schema=DynamicSchema.from_lists(group_names, group_types), _deps=[(columns_expr, callback)]) def _get_pivot_table_sub_expr(self, expr): if expr._columns is None: return self._get_pivot_table_sub_expr_without_columns(expr) else: return self._get_pivot_table_sub_expr_with_columns(expr)