def _get_pivot_table_handler()

in odps/df/backends/pd/compiler.py [0:0]


    def _get_pivot_table_handler(self, expr):
        from ...expr.query import ExprVisitor

        class WrappedNumpyFunction(object):
            def __init__(self, fun):
                self._fun = fun

            def __call__(self, *args, **kwargs):
                return self._fun(*args, **kwargs)

        class AggFuncVisitor(ExprVisitor):
            def __init__(self, np_object, env):
                super(AggFuncVisitor, self).__init__(env)
                self.np_object = np_object

            def get_named_object(self, obj_name):
                if obj_name == 'count':
                    return WrappedNumpyFunction(np.size)
                elif obj_name == 'nunique':
                    return WrappedNumpyFunction(lambda x: np.size(np.unique(x)))
                elif obj_name == 'quantile':
                    return WrappedNumpyFunction(lambda x, prob: np.percentile(x, prob * 100))
                else:
                    return WrappedNumpyFunction(getattr(np, obj_name))

            def visit_Call(self, node):
                func = self.visit(node.func)
                args = [self.visit(n) for n in node.args]
                if isinstance(func, WrappedNumpyFunction):
                    args = [self.np_object] + args
                kwargs = OrderedDict([(kw.arg, self.visit(kw.value)) for kw in node.keywords])
                return func(*args, **kwargs)

        def get_real_aggfunc(aggfunc):
            if isinstance(aggfunc, six.string_types):
                if aggfunc == 'count':
                    return getattr(np, 'size')
                if aggfunc == 'nunique':
                    return lambda x: np.size(np.unique(x))
                if hasattr(np, aggfunc):
                    return getattr(np, aggfunc)

                def agg_eval(x):
                    visitor = AggFuncVisitor(x, {})
                    return visitor.eval(aggfunc, rewrite=False)

                return agg_eval

            if inspect.isclass(aggfunc):
                aggfunc = aggfunc()

                def func(x):
                    buffer = aggfunc.buffer()
                    for it in x:
                        aggfunc(buffer, it)
                    return aggfunc.getvalue(buffer)

                return func
            return aggfunc

        def handle(kw):
            columns = expr._columns if expr._columns else []
            df = self._merge_values(expr._group + columns + expr._values, kw)
            pivoted = df.pivot_table(index=self._get_names(expr._group),
                                     columns=self._get_names(expr._columns),
                                     values=self._get_names(expr._values),
                                     aggfunc=[get_real_aggfunc(f) for f in expr._agg_func],
                                     fill_value=expr.fill_value)
            levels = pivoted.columns.levels if isinstance(pivoted.columns, pd.MultiIndex) \
                else [pivoted.columns]
            pivoted.reset_index(inplace=True)

            names = self._get_names(expr._group, True)
            tps = [g.dtype for g in expr._group]
            columns_values = levels[-1] if expr._columns else [None, ]
            for agg_func_name in expr._agg_func_names:
                for value_col in expr._values:
                    for col in columns_values:
                        base = '{0}_'.format(col) if col is not None else ''
                        name = '{0}{1}_{2}'.format(base, value_col.name, agg_func_name)
                        names.append(name)
                        tps.append(value_col.dtype)
            if expr._columns:
                expr._schema = TableSchema.from_lists(names, tps)

            res = pd.DataFrame(pivoted.values, columns=names)
            to_sub = CollectionExpr(_source_data=res, _schema=expr._schema)
            self._expr_dag.substitute(expr, to_sub)

            # trigger refresh of dynamic operations
            def func(expr):
                for c in traverse_until_source(expr, unique=True):
                    if c not in self._expr_to_dag_node:
                        c.accept(self)

            refresh_dynamic(to_sub, self._expr_dag, func=func)

            return to_sub, res

        return handle