odps/df/backends/analyzer.py (252 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from .core import Backend
from ..utils import traverse_until_source
from ..expr.expressions import Scalar, SequenceExpr, CollectionExpr
from ..expr.reduction import GroupedSequenceReduction
from ..expr.element import Switch
from .. import output
from ... import compat
from ...models import TableSchema
from .utils import refresh_dynamic
from ..types import DynamicSchema
from ...compat import six
class BaseAnalyzer(Backend):
"""
Analyzer is used before optimzing,
which analyze some operation that is not supported for this execution backend.
"""
def __init__(self, expr_dag, traversed=None, on_sub=None):
self._dag = expr_dag
self._indexer = itertools.count(0)
self._traversed = traversed or set()
self._on_sub = on_sub
def analyze(self):
for node in self._iter():
self._traversed.add(id(node))
self._visit_node(node)
return self._dag.root
def _iter(self):
for node in traverse_until_source(self._dag, top_down=True,
traversed=self._traversed):
yield node
while True:
all_traversed = True
for node in traverse_until_source(self._dag, top_down=True):
if id(node) not in self._traversed:
all_traversed = False
yield node
if all_traversed:
break
def _visit_node(self, node):
try:
node.accept(self)
except NotImplementedError:
return
def _sub(self, expr, sub, parents=None):
self._dag.substitute(expr, sub, parents=parents)
if self._on_sub:
self._on_sub(expr, sub)
@staticmethod
def _get_moment_sub_expr(expr, _input, order, center):
def _group_mean(e):
m = e.mean()
if isinstance(expr, GroupedSequenceReduction):
m = m.to_grouped_reduction(expr._grouped)
return m
def _order(e, o):
if o == 1:
return e
else:
return e ** o
if not center:
if order == 0:
sub = Scalar(1)
else:
sub = _group_mean(_input ** order)
else:
if order == 0:
sub = Scalar(1)
elif order == 1:
sub = Scalar(0)
else:
sub = _group_mean(_input ** order)
divided = 1
divisor = 1
for o in compat.irange(1, order):
divided *= order - o + 1
divisor *= o
part_item = divided // divisor * _group_mean(_order(_input, order - o)) \
* (_order(_group_mean(_input), o))
if o & 1:
sub -= part_item
else:
sub += part_item
part_item = _group_mean(_input) ** order
if order & 1:
sub -= part_item
else:
sub += part_item
return sub
@classmethod
def _get_cut_sub_expr(cls, expr):
is_seq = isinstance(expr, SequenceExpr)
kw = dict()
if is_seq:
kw['_data_type'] = expr.dtype
else:
kw['_value_type'] = expr.dtype
conditions = []
thens = []
if expr.include_under:
bin = expr.bins[0]
if expr.right and not expr.include_lowest:
conditions.append(expr.input <= bin)
else:
conditions.append(expr.input < bin)
thens.append(expr.labels[0])
for i, bin in enumerate(expr.bins[1:]):
lower_bin = expr.bins[i]
if not expr.right or (i == 0 and expr.include_lowest):
condition = lower_bin <= expr.input
else:
condition = lower_bin < expr.input
if expr.right:
condition = (condition & (expr.input <= bin))
else:
condition = (condition & (expr.input < bin))
conditions.append(condition)
if expr.include_under:
thens.append(expr.labels[i + 1])
else:
thens.append(expr.labels[i])
if expr.include_over:
bin = expr.bins[-1]
if expr.right:
conditions.append(bin < expr.input)
else:
conditions.append(bin <= expr.input)
thens.append(expr.labels[-1])
return Switch(_conditions=conditions, _thens=thens,
_default=None, _input=None, **kw)
@classmethod
def _get_value_counts_sub_expr(cls, expr):
collection = expr.input
by = expr._by
sort = expr._sort.value
ascending = expr._ascending.value
dropna = expr._dropna.value
sub = collection.groupby(by).agg(count=collection.count())
if sort:
sub = sub.sort('count', ascending=ascending)
if dropna:
sub = sub.filter(sub[by.name].notnull())
return sub
def _get_pivot_sub_expr(self, expr):
columns_expr = expr.input.distinct([c.copy() for c in expr._columns])
group_names = [g.name for g in expr._group]
group_types = [g.dtype for g in expr._group]
exprs = [expr]
def callback(result, new_expr):
expr = exprs[0]
columns = [r[0] for r in result]
if len(expr._values) > 1:
names = group_names + \
['{0}_{1}'.format(v.name, c)
for v in expr._values for c in columns]
types = group_types + \
list(itertools.chain(*[[n.dtype] * len(columns)
for n in expr._values]))
else:
names = group_names + columns
types = group_types + [expr._values[0].dtype] * len(columns)
new_expr._schema = TableSchema.from_lists(names, types)
column_name = expr._columns[0].name # column's size can only be 1
values_names = [v.name for v in expr._values]
@output(names, types)
def reducer(keys):
values = [None] * len(columns) * len(values_names)
def h(row, done):
col = getattr(row, column_name)
for val_idx, value_name in enumerate(values_names):
val = getattr(row, value_name)
idx = len(columns) * val_idx + columns.index(col)
if values[idx] is not None:
raise ValueError(
'Row contains duplicate entries, rows: {0}, column: {1}'.format(keys, col))
values[idx] = val
if done:
yield keys + tuple(values)
return h
fields = expr._group + expr._columns + expr._values
pivoted = expr.input.select(fields).map_reduce(reducer=reducer, group=group_names)
self._sub(new_expr, pivoted)
# trigger refresh of dynamic operations
refresh_dynamic(pivoted, self._dag)
return CollectionExpr(_schema=DynamicSchema.from_lists(group_names, group_types),
_deps=[(columns_expr, callback)])
def _get_pivot_table_sub_expr_without_columns(self, expr):
def get_agg(field, agg_func, agg_func_name, fill_value):
from ..expr.expressions import ReprWrapper
if isinstance(agg_func, six.string_types):
aggregated = field.eval(agg_func, rewrite=False)
if isinstance(aggregated, ReprWrapper):
aggregated = aggregated()
else:
aggregated = field.agg(agg_func)
if fill_value is not None:
aggregated.fillna(fill_value)
return aggregated.rename('{0}_{1}'.format(field.name, agg_func_name))
grouped = expr.input.groupby(expr._group)
aggs = []
for agg_func, agg_func_name in zip(expr._agg_func, expr._agg_func_names):
for value in expr._values:
agg = get_agg(value, agg_func, agg_func_name, expr.fill_value)
aggs.append(agg)
return grouped.aggregate(aggs, sort_by_name=False)
def _get_pivot_table_sub_expr_with_columns(self, expr):
columns_expr = expr.input.distinct([c.copy() for c in expr._columns])
group_names = [g.name for g in expr._group]
group_types = [g.dtype for g in expr._group]
exprs = [expr]
def callback(result, new_expr):
expr = exprs[0]
columns = [r[0] for r in result]
names = list(group_names)
tps = list(group_types)
aggs = []
for agg_func_name, agg_func in zip(expr._agg_func_names, expr._agg_func):
for value_col in expr._values:
for col in columns:
base = '{0}_'.format(col) if col is not None else ''
name = '{0}{1}_{2}'.format(base, value_col.name, agg_func_name)
names.append(name)
tps.append(value_col.dtype)
col = col.item() if hasattr(col, 'item') else col
field = (expr._columns[0] == col).ifelse(
value_col, Scalar(_value_type=value_col.dtype))
if isinstance(agg_func, six.string_types):
agg = getattr(field, agg_func)()
else:
func = agg_func()
class ActualAgg(object):
def buffer(self):
return func.buffer()
def __call__(self, buffer, value):
if value is None:
return
func(buffer, value)
def merge(self, buffer, pbuffer):
func.merge(buffer, pbuffer)
def getvalue(self, buffer):
return func.getvalue(buffer)
agg = field.agg(ActualAgg)
if expr.fill_value is not None:
agg = agg.fillna(expr.fill_value)
agg = agg.rename(name)
aggs.append(agg)
new_expr._schema = TableSchema.from_lists(names, tps)
pivoted = expr.input.groupby(expr._group).aggregate(aggs, sort_by_name=False)
self._sub(new_expr, pivoted)
# trigger refresh of dynamic operations
refresh_dynamic(pivoted, self._dag)
return CollectionExpr(_schema=DynamicSchema.from_lists(group_names, group_types),
_deps=[(columns_expr, callback)])
def _get_pivot_table_sub_expr(self, expr):
if expr._columns is None:
return self._get_pivot_table_sub_expr_without_columns(expr)
else:
return self._get_pivot_table_sub_expr_with_columns(expr)