core/maxframe/codegen.py (352 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import base64 import dataclasses import logging from collections import defaultdict from enum import Enum from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union from odps.types import OdpsSchema from odps.utils import camel_to_underline from .core import OperatorType, Tileable, TileableGraph from .core.operator import Fetch, Operator from .extension import iter_extensions from .io.odpsio import build_dataframe_table_meta from .io.odpsio.schema import pandas_to_odps_schema from .lib import wrapped_pickle as pickle from .protocol import DataFrameTableMeta, ResultInfo from .serialization import PickleContainer from .serialization.serializables import Serializable, StringField from .typing_ import PandasObjectTypes from .udf import MarkedFunction, PythonPackOptions if TYPE_CHECKING: from odpsctx import ODPSSessionContext logger = logging.getLogger(__name__) @dataclasses.dataclass class CodeGenResult: code: str input_key_to_variables: Dict[str, str] output_key_to_variables: Dict[str, str] output_key_to_result_infos: Dict[str, ResultInfo] constants: Dict[str, Any] class AbstractUDF(Serializable): _session_id: str = StringField("session_id") def __init__(self, session_id: Optional[str] = None, **kw): super().__init__(_session_id=session_id, **kw) @property def name(self) -> str: return camel_to_underline(type(self).__name__) @property def session_id(self): return getattr(self, "_session_id", None) @session_id.setter def session_id(self, value: str): self._session_id = value @abc.abstractmethod def register(self, odps: "ODPSSessionContext", overwrite: bool = False): raise NotImplementedError @abc.abstractmethod def unregister(self, odps: "ODPSSessionContext"): raise NotImplementedError @abc.abstractmethod def collect_pythonpack(self) -> List[PythonPackOptions]: raise NotImplementedError @abc.abstractmethod def load_pythonpack_resources(self, odps_ctx: "ODPSSessionContext") -> None: raise NotImplementedError class UserCodeMixin: __slots__ = () @classmethod def obj_to_python_expr(cls, obj: Any = None) -> str: """ Parameters ---------- obj The object to convert to python expr. Returns ------- str : The str type content equals to the object when use in the python code directly. """ if obj is None: return "None" if isinstance(obj, (int, float)): return repr(obj) if isinstance(obj, bool): return "True" if obj else "False" if isinstance(obj, bytes): base64_bytes = base64.b64encode(obj) return f"base64.b64decode({base64_bytes})" if isinstance(obj, str): return repr(obj) if isinstance(obj, list): return ( f"[{', '.join([cls.obj_to_python_expr(element) for element in obj])}]" ) if isinstance(obj, dict): items = ( f"{repr(key)}: {cls.obj_to_python_expr(value)}" for key, value in obj.items() ) return f"{{{', '.join(items)}}}" if isinstance(obj, tuple): return f"({', '.join([cls.obj_to_python_expr(sub_obj) for sub_obj in obj])}{',' if len(obj) == 1 else ''})" if isinstance(obj, set): return ( f"{{{', '.join([cls.obj_to_python_expr(sub_obj) for sub_obj in obj])}}}" if obj else "set()" ) if isinstance(obj, PickleContainer): return UserCodeMixin.generate_pickled_codes(obj, None) raise ValueError(f"not support arg type {type(obj)}") @classmethod def generate_pickled_codes( cls, code_to_pickle: Any, unpicked_data_var_name: Union[str, None] = "pickled_data", ) -> str: """ Generate pickled codes. The final pickled variable is called 'pickled_data'. Parameters ---------- code_to_pickle: Any The code to be pickled. unpicked_data_var_name: str The variables in code used to hold the loads object from the cloudpickle Returns ------- str : The code snippets of pickling, the final variable is called 'pickled_data' by default. """ pickled, buffers = cls.dump_pickled_data(code_to_pickle) pickle_loads_expr = f"cloudpickle.loads({cls.obj_to_python_expr(pickled)}, buffers={cls.obj_to_python_expr(buffers)})" if unpicked_data_var_name: return f"{unpicked_data_var_name} = {pickle_loads_expr}" return pickle_loads_expr @staticmethod def dump_pickled_data( code_to_pickle: Any, ) -> Tuple[List[bytes], List[bytes]]: if isinstance(code_to_pickle, MarkedFunction): code_to_pickle = code_to_pickle.func if isinstance(code_to_pickle, PickleContainer): buffers = code_to_pickle.get_buffers() pickled = buffers[0] buffers = buffers[1:] else: pickled = pickle.dumps(code_to_pickle, protocol=pickle.DEFAULT_PROTOCOL) buffers = [] return pickled, buffers class BigDagCodeContext(metaclass=abc.ABCMeta): def __init__(self, session_id: str = None, subdag_id: str = None): self._session_id = session_id self._subdag_id = subdag_id self._tileable_key_to_variables = dict() self.constants = dict() self._data_table_meta_cache = dict() self._odps_schema_cache = dict() self._udfs = dict() self._tileable_key_to_result_infos = dict() self._next_var_id = 0 self._next_const_id = 0 @property def session_id(self) -> str: return self._session_id def register_udf(self, udf: AbstractUDF): udf.session_id = self._session_id self._udfs[udf.name] = udf def get_udfs(self) -> List[AbstractUDF]: return list(self._udfs.values()) def get_input_tileable_variable(self, tileable: Tileable) -> str: """ Get or create the variable name for an input tileable. It should be used on the RIGHT side of the assignment. """ return self._get_tileable_variable(tileable) def get_output_tileable_variable(self, tileable: Tileable) -> str: """ Get or create the variable name for an output tileable. It should be used on the LEFT side of the assignment. """ return self._get_tileable_variable(tileable) def _get_tileable_variable(self, tileable: Tileable) -> str: try: return self._tileable_key_to_variables[tileable.key] except KeyError: var_name = self.next_var_name() self._tileable_key_to_variables[tileable.key] = var_name return var_name def next_var_name(self) -> str: var_name = f"var_{self._next_var_id}" self._next_var_id += 1 return var_name def get_odps_schema( self, data: PandasObjectTypes, unknown_as_string: bool = False ) -> OdpsSchema: """ Get the corresponding ODPS schema of the input df_obj. Parameters ---------- data : The pandas data object. unknown_as_string : Whether mapping the unknown data type to a temp string value. Returns ------- OdpsSchema : The OdpsSchema of df_obj. """ if data.key not in self._odps_schema_cache: odps_schema, table_meta = pandas_to_odps_schema(data, unknown_as_string) self._data_table_meta_cache[data.key] = table_meta self._odps_schema_cache[data.key] = odps_schema return self._odps_schema_cache[data.key] def get_pandas_data_table_meta(self, data: PandasObjectTypes) -> DataFrameTableMeta: if data.key not in self._data_table_meta_cache: self._data_table_meta_cache[data.key] = build_dataframe_table_meta(data) return self._data_table_meta_cache[data.key] def register_operator_constants(self, const_val, var_name: str = None) -> str: if var_name is None: if ( isinstance(const_val, (int, str, bytes, bool, float)) or const_val is None ): return repr(const_val) var_name = f"const_{self._next_const_id}" self._next_const_id += 1 self.constants[var_name] = const_val return var_name def put_tileable_result_info( self, tileable: Tileable, result_info: ResultInfo ) -> None: self._tileable_key_to_result_infos[tileable.key] = result_info def get_tileable_result_infos(self) -> Dict[str, ResultInfo]: return self._tileable_key_to_result_infos class EngineAcceptance(Enum): """ DENY: The operator is not accepted by the current engine. ACCEPT: The operator is accepted by the current engine, and doesn't break from here. BREAK: The operator is accepted by the current engine, but should break from here. PREDECESSOR: The acceptance of the operator is decided by engines of its predecessors. If acceptance of all predecessors are SUCCESSOR, the acceptance of current operator is SUCCESSOR. Otherwise the engine selected in predecessors with highest priority is used. SUCCESSOR: The acceptance of the operator is decided by engines of its successors. If the operator has no successors, the acceptance will be treated as ACCEPT. Otherwise the engine selected in successors with highest priority is used. """ DENY = 0 ACCEPT = 1 BREAK = 2 PREDECESSOR = 3 SUCCESSOR = 4 @classmethod def _missing_(cls, pred: bool) -> "EngineAcceptance": """ A convenience method to get ACCEPT or DENY result via the input predicate. Parameters ---------- pred : bool The predicate variable. Returns ------- EngineAcceptance : Returns ACCEPT if the predicate is true, otherwise returns DENY. """ return cls.ACCEPT if pred else cls.DENY class BigDagOperatorAdapter(metaclass=abc.ABCMeta): # todo handle refcount issue when generated code is being executed def accepts(self, op: Operator) -> EngineAcceptance: return EngineAcceptance.ACCEPT @abc.abstractmethod def generate_code(self, op: OperatorType, context: BigDagCodeContext) -> List[str]: raise NotImplementedError def generate_comment( self, op: OperatorType, context: BigDagCodeContext ) -> List[str]: """ Generate the comment codes before actual ones. Parameters ---------- op : Operator The operator instance. context : BigDagCodeContext The BigDagCodeContext instance. Returns ------- result: List[str] The comment codes, one per line. """ return list() def generate_pre_op_code( self, op: Operator, context: BigDagCodeContext ) -> List[str]: """ Generate the codes before actually handling the operator. This method is usually implemented in the base class of each engine. Parameters ---------- op : Operator The operator instance. context : BigDagCodeContext The BigDagCodeContext instance. Returns ------- result: List[str] The codes generated before one operator actually handled, one per line. """ return list() def generate_post_op_code( self, op: Operator, context: BigDagCodeContext ) -> List[str]: """ Generate the codes after actually handling the operator. This method is usually implemented in the base class of each engine. Parameters ---------- op : Operator The operator instance. context : BigDagCodeContext The BigDagCodeContext instance. Returns ------- result: List[str] The codes generated after one operator actually handled, one per line. """ return list() _engine_to_codegen: Dict[str, Type["BigDagCodeGenerator"]] = dict() def register_engine_codegen(type_: Type["BigDagCodeGenerator"]): _engine_to_codegen[type_.engine_type] = type_ return type_ BUILTIN_ENGINE_SPE = "SPE" BUILTIN_ENGINE_MCSQL = "MCSQL" class BigDagCodeGenerator(metaclass=abc.ABCMeta): _context: BigDagCodeContext engine_type: Optional[str] = None engine_priority: int = 0 _extension_loaded = False _generate_comments_enabled: bool = True def __init__(self, session_id: str, subdag_id: str = None): self._session_id = session_id self._subdag_id = subdag_id self._context = self._init_context(session_id, subdag_id) self._generate_comments_enabled = True @classmethod def _load_engine_extensions(cls): if cls._extension_loaded: return for name, ep in iter_extensions(): _engine_to_codegen[name.upper()] = ep.get_codegen() cls._extension_loaded = True @classmethod def get_engine_types(cls) -> List[str]: cls._load_engine_extensions() engines = sorted( _engine_to_codegen.values(), key=lambda x: x.engine_priority, reverse=True ) return [e.engine_type for e in engines] @classmethod def get_by_engine_type(cls, engine_type: str) -> Type["BigDagCodeGenerator"]: cls._load_engine_extensions() return _engine_to_codegen[engine_type] @abc.abstractmethod def get_op_adapter( self, op_type: Type[OperatorType] ) -> Type[BigDagOperatorAdapter]: raise NotImplementedError @abc.abstractmethod def _init_context(self, session_id: str, subdag_id: str) -> BigDagCodeContext: raise NotImplementedError def _generate_delete_code(self, var_name: str) -> List[str]: return [] def generate_code(self, dag: TileableGraph) -> List[str]: """ Generate the code of the input dag. Parameters ---------- dag : TileableGraph The input DAG instance. Returns ------- List[str] : The code lines. """ code_lines = [] visited_op_key = set() result_key_set = set(t.key for t in dag.result_tileables) out_refcounts = dict() for tileable in dag.topological_iter(): op: OperatorType = tileable.op if op.key in visited_op_key or isinstance(op, Fetch): continue visited_op_key.add(op.key) adapter = self.get_op_adapter(type(op))() code_lines.extend(adapter.generate_pre_op_code(op, self._context)) if self._generate_comments_enabled: code_lines.extend(adapter.generate_comment(op, self._context)) code_lines.extend(adapter.generate_code(op, self._context)) code_lines.extend(adapter.generate_post_op_code(op, self._context)) code_lines.append("") # Append an empty line to separate operators # record refcounts for out_t in op.outputs: if out_t.key in result_key_set: continue if dag.count_successors(out_t) == 0: delete_code = self._generate_delete_code( self._context.get_input_tileable_variable(out_t) ) code_lines.extend(delete_code) else: out_refcounts[out_t.key] = dag.count_successors(out_t) # check if refs of inputs are no longer needed for inp_t in op.inputs: if inp_t.key not in out_refcounts: continue out_refcounts[inp_t.key] -= 1 if out_refcounts[inp_t.key] == 0: delete_code = self._generate_delete_code( self._context.get_input_tileable_variable(inp_t) ) code_lines.extend(delete_code) out_refcounts.pop(inp_t.key) return code_lines def generate(self, dag: TileableGraph) -> CodeGenResult: code_lines = self.generate_code(dag) input_key_to_vars = dict() for tileable in dag.topological_iter(): op: OperatorType = tileable.op if isinstance(op, Fetch): fetch_tileable = self._context.get_input_tileable_variable(tileable) input_key_to_vars[op.outputs[0].key] = fetch_tileable result_variables = { t.key: self._context.get_input_tileable_variable(t) for t in dag.results } return CodeGenResult( code="\n".join(code_lines), input_key_to_variables=input_key_to_vars, output_key_to_variables=result_variables, constants=self._context.constants, output_key_to_result_infos=self._context.get_tileable_result_infos(), ) def run_pythonpacks( self, odps_ctx: "ODPSSessionContext", python_tag: str, is_production: bool = False, schedule_id: Optional[str] = None, hints: Optional[dict] = None, priority: Optional[int] = None, ) -> Dict[str, PythonPackOptions]: key_to_packs = defaultdict(list) for udf in self._context.get_udfs(): for pack in udf.collect_pythonpack(): key_to_packs[pack.key].append(pack) distinct_packs = [] for packs in key_to_packs.values(): distinct_packs.append(packs[0]) inst_id_to_req = {} for pack in distinct_packs: inst = odps_ctx.run_pythonpack( requirements=pack.requirements, prefer_binary=pack.prefer_binary, pre_release=pack.pre_release, force_rebuild=pack.force_rebuild, no_audit_wheel=pack.no_audit_wheel, python_tag=python_tag, is_production=is_production, schedule_id=schedule_id, hints=hints, priority=priority, ) # fulfill instance id of pythonpacks with same keys for same_pack in key_to_packs[pack.key]: same_pack.pack_instance_id = inst.id inst_id_to_req[inst.id] = pack return inst_id_to_req def register_udfs(self, odps_ctx: "ODPSSessionContext"): for udf in self._context.get_udfs(): logger.info("[Session=%s] Registering UDF %s", self._session_id, udf.name) udf.register(odps_ctx, True) def unregister_udfs(self, odps_ctx: "ODPSSessionContext"): for udf in self._context.get_udfs(): logger.info("[Session=%s] Unregistering UDF %s", self._session_id, udf.name) udf.unregister(odps_ctx) def get_udfs(self) -> List[AbstractUDF]: return self._context.get_udfs()