odps/df/backends/engine.py (237 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import try: import pandas as pd except (ImportError, ValueError): pass from .core import Engine from .odpssql.engine import ODPSSQLEngine from .pd.engine import PandasEngine from .seahawks.engine import SeahawksEngine from .seahawks.models import SeahawksTable from .sqlalchemy.engine import SQLAlchemyEngine from .selecter import available_engines, Engines, EngineSelecter from .formatter import ExprExecutionGraphFormatter from .context import context from .utils import process_persist_kwargs from .. import Scalar from ..expr.core import ExprDAG from ..expr.expressions import CollectionExpr from ..expr.merge import JoinCollectionExpr, UnionCollectionExpr from ..expr.element import IsIn from ...models import Table from ... import options from ...utils import gen_repr_object def get_default_engine(*exprs): from ... import ODPS odps = None engines = set() for expr in exprs: if expr._engine: return expr._engine srcs = list(expr.data_source()) expr_engines = list(available_engines(srcs)) engines.update(set(expr_engines)) if len(expr_engines) == 1: engine = expr_engines[0] src = srcs[0] if engine in (Engines.ODPS, Engines.ALGO): expr_odps = src.project.odps elif engine in (Engines.PANDAS, Engines.SQLALCHEMY): expr_odps = None else: raise NotImplementedError else: table_src = next(it for it in srcs if hasattr(it, 'project')) expr_odps = table_src.project.odps if expr_odps is not None: odps = expr_odps if odps is None and options.account is not None and \ options.endpoint is not None and options.default_project is not None: odps = ODPS._from_account( options.account, options.default_project, endpoint=options.endpoint, tunnel_endpoint=options.tunnel.endpoint, overwrite_global=False, ) return MixedEngine(odps, list(engines)) class MixedEngine(Engine): def __init__(self, odps, engines=None): self._odps = odps self._engines = engines self._generated_table_names = [] self._selecter = EngineSelecter() self._pandas_engine = PandasEngine(self._odps) self._odpssql_engine = ODPSSQLEngine(self._odps) self._seahawks_engine = SeahawksEngine(self._odps) self._sqlalchemy_engine = SQLAlchemyEngine(self._odps) from ...ml.engine import OdpsAlgoEngine self._xflow_engine = OdpsAlgoEngine(self._odps) def stop(self): self._pandas_engine.stop() self._odpssql_engine.stop() self._seahawks_engine.stop() self._xflow_engine.stop() def _gen_table_name(self): table_name = self._odpssql_engine._gen_table_name() self._generated_table_names.append(table_name) return table_name def _get_backend(self, expr_dag): engine = self._selecter.select(expr_dag) if engine == Engines.ODPS: return self._odpssql_engine elif engine == Engines.PANDAS: return self._pandas_engine elif engine == Engines.SQLALCHEMY: return self._sqlalchemy_engine elif engine == Engines.ALGO: return self._xflow_engine else: assert engine == Engines.SEAHAWKS return self._seahawks_engine def _delegate(self, method, expr_dag, dag, expr, **kwargs): return getattr(self._get_backend(expr_dag), method)(expr_dag, dag, expr, **kwargs) def _cache(self, expr_dag, dag, expr, **kwargs): return self._delegate('_cache', expr_dag, dag, expr, **kwargs) def _handle_dep(self, expr_dag, dag, expr, **kwargs): return self._delegate('_handle_dep', expr_dag, dag, expr, **kwargs) def _execute(self, expr_dag, dag, expr, **kwargs): return self._delegate('_execute', expr_dag, dag, expr, **kwargs) def _persist(self, name, expr_dag, dag, expr, **kwargs): return self._get_backend(expr_dag)._persist(name, expr_dag, dag, expr, **kwargs) def _handle_join_or_union(self, expr_dag, dag, _, **kwargs): root = expr_dag.root if not self._selecter.has_diff_data_sources(root, no_cache=True): return to_execute = root._lhs if not self._selecter.has_odps_data_source(root._lhs) \ else root._rhs table_name = self._gen_table_name() odps_table = self._odps.get_table(table_name, schema=self._odpssql_engine._ctx.default_schema) sub = CollectionExpr(_source_data=odps_table, _schema=to_execute.schema) sub.add_deps(to_execute) expr_dag.substitute(to_execute, sub) # prevent the kwargs come from `persist` process_persist_kwargs(kwargs) execute_dag = ExprDAG(to_execute, dag=expr_dag) return self._get_backend(execute_dag)._persist( table_name, execute_dag, dag, to_execute, **kwargs) def _handle_isin(self, expr_dag, dag, expr, **kwargs): if not self._selecter.has_diff_data_sources(expr_dag.root, no_cache=True): return seq = expr._values[0] expr._values = None execute_dag = ExprDAG(seq, dag=expr_dag) execute_node = self._get_backend(execute_dag)._execute( execute_dag, dag, seq, **kwargs) def callback(res): vals = res[:, 0].tolist() expr._values = tuple(Scalar(val) for val in vals) execute_node.callback = callback return execute_node def _handle_function(self, expr_dag, dag, _, **kwargs): root = expr_dag.root # if expr input comes from an ODPS table is_root_input_from_odps = \ self._selecter.has_odps_data_source(root.children()[0]) for i, collection in enumerate(root._collection_resources): # if collection resource comes from an ODPS table is_source_from_odps = self._selecter.has_odps_data_source(collection) if is_root_input_from_odps and not is_source_from_odps: table_name = self._gen_table_name() odps_table = self._odps.get_table(table_name, schema=self._odpssql_engine._ctx.default_schema) sub = CollectionExpr(_source_data=odps_table, _schema=collection.schema) sub.add_deps(collection) expr_dag.substitute(collection, sub) # prevent the kwargs come from `persist` process_persist_kwargs(kwargs) execute_dag = ExprDAG(collection, dag=expr_dag) self._get_backend(execute_dag)._persist( table_name, execute_dag, dag, collection, **kwargs) elif not is_root_input_from_odps and is_source_from_odps: if not self._selecter.has_pandas_data_source(root.children()[0]): raise NotImplementedError sub = CollectionExpr(_source_data=pd.DataFrame(), _schema=collection.schema) sub.add_deps(collection) expr_dag.substitute(collection, sub) execute_node = self._odpssql_engine._execute( ExprDAG(collection, dag=expr_dag), dag, collection, **kwargs) def callback(res): sub._source_data = res.values execute_node.callback = callback else: continue def _dispatch(self, expr_dag, expr, ctx): from ...ml.expr import AlgoExprMixin funcs = [] if isinstance(expr, AlgoExprMixin): return self._xflow_engine._dispatch(expr_dag, expr, ctx) handle = None if isinstance(expr, (JoinCollectionExpr, UnionCollectionExpr)) and \ self._selecter.has_diff_data_sources(expr): handle = self._handle_join_or_union elif isinstance(expr, IsIn) and \ self._selecter.has_diff_data_sources(expr): handle = self._handle_isin elif hasattr(expr, '_func') and \ getattr(expr, '_collection_resources', None) is not None and \ self._selecter.has_diff_data_sources(expr): handle = self._handle_function if handle is not None: funcs.append(handle) h = super(MixedEngine, self)._dispatch(expr_dag, expr, ctx) if h is not None: if handle is None: return h funcs.append(h) if funcs: def f(*args, **kwargs): for func in funcs: func(*args, **kwargs) return f def _get_cached_sub_expr(self, cached_expr, ctx=None): ctx = ctx or context data = ctx.get_cached(cached_expr) if self._selecter.force_odps and isinstance(data, SeahawksTable): # skip seahawks heap table return return super(MixedEngine, self)._get_cached_sub_expr(cached_expr, ctx=ctx) def visualize(self, expr): dag = self.compile(expr) try: formatter = ExprExecutionGraphFormatter(dag) return gen_repr_object(svg=formatter._repr_svg_()) finally: self._generated_table_names = [] def compile(self, *expr_args_kwargs): src_expr_args_kwargs = tuple(expr_args_kwargs) exprs_dags, expr_args_kwargs = self._process(*expr_args_kwargs) if not self._selecter.force_odps and any(it[0] == '_persist' for it in expr_args_kwargs): self._selecter.force_odps = True return self.compile(*src_expr_args_kwargs) engine_types = [self._selecter.select(expr_dag) for expr_dag in exprs_dags] if any(engine_type == Engines.ODPS for engine_type in engine_types): self._selecter.force_odps = True if any(engine_type == Engines.SEAHAWKS for engine_type in engine_types): return self.compile(*src_expr_args_kwargs) try: dag = super(MixedEngine, self)._compile_dag(expr_args_kwargs, exprs_dags) except NotImplementedError: self._selecter.force_odps = True return self.compile(*src_expr_args_kwargs) if any(engine_type == Engines.SEAHAWKS for engine_type in engine_types): def fallback(): self._selecter.force_odps = True return self.compile(*src_expr_args_kwargs) def need_fallback(e): try: import sqlalchemy exceptions = (NotImplementedError, sqlalchemy.exc.DatabaseError) except ImportError: exceptions = (NotImplementedError,) if not isinstance(e, exceptions): return False if not isinstance(e, NotImplementedError) and \ 'AXF Exception' not in str(e): # the seahawks error return False return True if not all(not hasattr(n, 'verify') or n.verify() for n in dag.indep_nodes()): # we only verify the independent nodes dag = fallback() dag.fallback = fallback dag.need_fallback = need_fallback return dag