def _compile_function()

in core/maxframe/dataframe/reduction/core.py [0:0]


    def _compile_function(self, func, func_name=None, ndim=1) -> ReductionSteps:
        from ...tensor.arithmetic.core import TensorBinOp, TensorUnaryOp
        from ...tensor.misc import TensorWhere
        from ..arithmetic.core import DataFrameBinOp, DataFrameUnaryOp
        from ..datasource.dataframe import DataFrameDataSource
        from ..datasource.series import SeriesDataSource
        from ..indexing.where import DataFrameWhere

        func_token = tokenize(func, self._axis, func_name, ndim)
        if func_token in _func_compile_cache:
            return _func_compile_cache[func_token]
        custom_reduction = func if isinstance(func, CustomReduction) else None

        self._check_function_valid(func)

        try:
            func_ret = self._build_mock_return_object(func, float, ndim=ndim)
        except (TypeError, AttributeError):
            # we may encounter lambda x: x.str.cat(...), use an object series to test
            func_ret = self._build_mock_return_object(func, object, ndim=1)
        output_limit = getattr(func, "output_limit", None) or 1

        if not isinstance(func_ret, ENTITY_TYPE):
            raise ValueError(
                f"Custom function should return a MaxFrame object, not {type(func_ret)}"
            )
        if func_ret.ndim >= ndim:
            raise ValueError("Function not a reduction")

        agg_graph = func_ret.build_graph()
        agg_tileables = set(t for t in agg_graph if getattr(t.op, "is_atomic", False))
        # check operators before aggregation
        for t in agg_graph.dfs(
            list(agg_tileables), visit_predicate="all", reverse=True
        ):
            if t not in agg_tileables and not isinstance(
                t.op,
                (
                    DataFrameUnaryOp,
                    DataFrameBinOp,
                    TensorUnaryOp,
                    TensorBinOp,
                    TensorWhere,
                    DataFrameWhere,
                    DataFrameDataSource,
                    SeriesDataSource,
                ),
            ):
                raise ValueError(f"Cannot support operator {type(t.op)} in aggregation")
        # check operators after aggregation
        for t in agg_graph.dfs(list(agg_tileables), visit_predicate="all"):
            if t not in agg_tileables and not isinstance(
                t.op,
                (
                    DataFrameUnaryOp,
                    DataFrameBinOp,
                    TensorWhere,
                    DataFrameWhere,
                    TensorUnaryOp,
                    TensorBinOp,
                ),
            ):
                raise ValueError(f"Cannot support operator {type(t.op)} in aggregation")

        pre_funcs, agg_funcs, post_funcs = [], [], []
        visited_inputs = set()
        # collect aggregations and their inputs
        for t in agg_tileables:
            agg_input_key = t.inputs[0].key

            # collect agg names
            step_func_name = getattr(t.op, "_func_name")
            if step_func_name in ("count", "size"):
                map_func_name, agg_func_name = step_func_name, "sum"
            else:
                map_func_name, agg_func_name = step_func_name, step_func_name

            # build agg description
            agg_funcs.append(
                ReductionAggStep(
                    agg_input_key,
                    func_name,
                    step_func_name,
                    map_func_name,
                    agg_func_name,
                    custom_reduction,
                    t.key,
                    output_limit,
                    t.op.get_reduction_args(axis=self._axis),
                )
            )
            # collect agg input and build function
            if agg_input_key not in visited_inputs:
                visited_inputs.add(agg_input_key)
                initial_inputs = list(t.inputs[0].build_graph().iter_indep())
                assert len(initial_inputs) == 1
                input_key = initial_inputs[0].key

                func_idl, _ = self._generate_function_idl(t.inputs[0])
                pre_funcs.append(
                    ReductionPreStep(
                        input_key, agg_input_key, None, msgpack.dumps(func_idl)
                    )
                )
        # collect function output after agg
        func_idl, input_keys = self._generate_function_idl(func_ret)
        post_funcs.append(
            ReductionPostStep(
                input_keys, func_ret.key, func_name, None, msgpack.dumps(func_idl)
            )
        )
        if len(_func_compile_cache) > 100:  # pragma: no cover
            _func_compile_cache.pop(next(iter(_func_compile_cache.keys())))
        result = _func_compile_cache[func_token] = ReductionSteps(
            pre_funcs, agg_funcs, post_funcs
        )
        return result