in core/maxframe/dataframe/reduction/core.py [0:0]
def _compile_function(self, func, func_name=None, ndim=1) -> ReductionSteps:
from ...tensor.arithmetic.core import TensorBinOp, TensorUnaryOp
from ...tensor.misc import TensorWhere
from ..arithmetic.core import DataFrameBinOp, DataFrameUnaryOp
from ..datasource.dataframe import DataFrameDataSource
from ..datasource.series import SeriesDataSource
from ..indexing.where import DataFrameWhere
func_token = tokenize(func, self._axis, func_name, ndim)
if func_token in _func_compile_cache:
return _func_compile_cache[func_token]
custom_reduction = func if isinstance(func, CustomReduction) else None
self._check_function_valid(func)
try:
func_ret = self._build_mock_return_object(func, float, ndim=ndim)
except (TypeError, AttributeError):
# we may encounter lambda x: x.str.cat(...), use an object series to test
func_ret = self._build_mock_return_object(func, object, ndim=1)
output_limit = getattr(func, "output_limit", None) or 1
if not isinstance(func_ret, ENTITY_TYPE):
raise ValueError(
f"Custom function should return a MaxFrame object, not {type(func_ret)}"
)
if func_ret.ndim >= ndim:
raise ValueError("Function not a reduction")
agg_graph = func_ret.build_graph()
agg_tileables = set(t for t in agg_graph if getattr(t.op, "is_atomic", False))
# check operators before aggregation
for t in agg_graph.dfs(
list(agg_tileables), visit_predicate="all", reverse=True
):
if t not in agg_tileables and not isinstance(
t.op,
(
DataFrameUnaryOp,
DataFrameBinOp,
TensorUnaryOp,
TensorBinOp,
TensorWhere,
DataFrameWhere,
DataFrameDataSource,
SeriesDataSource,
),
):
raise ValueError(f"Cannot support operator {type(t.op)} in aggregation")
# check operators after aggregation
for t in agg_graph.dfs(list(agg_tileables), visit_predicate="all"):
if t not in agg_tileables and not isinstance(
t.op,
(
DataFrameUnaryOp,
DataFrameBinOp,
TensorWhere,
DataFrameWhere,
TensorUnaryOp,
TensorBinOp,
),
):
raise ValueError(f"Cannot support operator {type(t.op)} in aggregation")
pre_funcs, agg_funcs, post_funcs = [], [], []
visited_inputs = set()
# collect aggregations and their inputs
for t in agg_tileables:
agg_input_key = t.inputs[0].key
# collect agg names
step_func_name = getattr(t.op, "_func_name")
if step_func_name in ("count", "size"):
map_func_name, agg_func_name = step_func_name, "sum"
else:
map_func_name, agg_func_name = step_func_name, step_func_name
# build agg description
agg_funcs.append(
ReductionAggStep(
agg_input_key,
func_name,
step_func_name,
map_func_name,
agg_func_name,
custom_reduction,
t.key,
output_limit,
t.op.get_reduction_args(axis=self._axis),
)
)
# collect agg input and build function
if agg_input_key not in visited_inputs:
visited_inputs.add(agg_input_key)
initial_inputs = list(t.inputs[0].build_graph().iter_indep())
assert len(initial_inputs) == 1
input_key = initial_inputs[0].key
func_idl, _ = self._generate_function_idl(t.inputs[0])
pre_funcs.append(
ReductionPreStep(
input_key, agg_input_key, None, msgpack.dumps(func_idl)
)
)
# collect function output after agg
func_idl, input_keys = self._generate_function_idl(func_ret)
post_funcs.append(
ReductionPostStep(
input_keys, func_ret.key, func_name, None, msgpack.dumps(func_idl)
)
)
if len(_func_compile_cache) > 100: # pragma: no cover
_func_compile_cache.pop(next(iter(_func_compile_cache.keys())))
result = _func_compile_cache[func_token] = ReductionSteps(
pre_funcs, agg_funcs, post_funcs
)
return result