odps/df/backends/engine.py (237 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
try:
import pandas as pd
except (ImportError, ValueError):
pass
from .core import Engine
from .odpssql.engine import ODPSSQLEngine
from .pd.engine import PandasEngine
from .seahawks.engine import SeahawksEngine
from .seahawks.models import SeahawksTable
from .sqlalchemy.engine import SQLAlchemyEngine
from .selecter import available_engines, Engines, EngineSelecter
from .formatter import ExprExecutionGraphFormatter
from .context import context
from .utils import process_persist_kwargs
from .. import Scalar
from ..expr.core import ExprDAG
from ..expr.expressions import CollectionExpr
from ..expr.merge import JoinCollectionExpr, UnionCollectionExpr
from ..expr.element import IsIn
from ...models import Table
from ... import options
from ...utils import gen_repr_object
def get_default_engine(*exprs):
from ... import ODPS
odps = None
engines = set()
for expr in exprs:
if expr._engine:
return expr._engine
srcs = list(expr.data_source())
expr_engines = list(available_engines(srcs))
engines.update(set(expr_engines))
if len(expr_engines) == 1:
engine = expr_engines[0]
src = srcs[0]
if engine in (Engines.ODPS, Engines.ALGO):
expr_odps = src.project.odps
elif engine in (Engines.PANDAS, Engines.SQLALCHEMY):
expr_odps = None
else:
raise NotImplementedError
else:
table_src = next(it for it in srcs if hasattr(it, 'project'))
expr_odps = table_src.project.odps
if expr_odps is not None:
odps = expr_odps
if odps is None and options.account is not None and \
options.endpoint is not None and options.default_project is not None:
odps = ODPS._from_account(
options.account, options.default_project,
endpoint=options.endpoint,
tunnel_endpoint=options.tunnel.endpoint,
overwrite_global=False,
)
return MixedEngine(odps, list(engines))
class MixedEngine(Engine):
def __init__(self, odps, engines=None):
self._odps = odps
self._engines = engines
self._generated_table_names = []
self._selecter = EngineSelecter()
self._pandas_engine = PandasEngine(self._odps)
self._odpssql_engine = ODPSSQLEngine(self._odps)
self._seahawks_engine = SeahawksEngine(self._odps)
self._sqlalchemy_engine = SQLAlchemyEngine(self._odps)
from ...ml.engine import OdpsAlgoEngine
self._xflow_engine = OdpsAlgoEngine(self._odps)
def stop(self):
self._pandas_engine.stop()
self._odpssql_engine.stop()
self._seahawks_engine.stop()
self._xflow_engine.stop()
def _gen_table_name(self):
table_name = self._odpssql_engine._gen_table_name()
self._generated_table_names.append(table_name)
return table_name
def _get_backend(self, expr_dag):
engine = self._selecter.select(expr_dag)
if engine == Engines.ODPS:
return self._odpssql_engine
elif engine == Engines.PANDAS:
return self._pandas_engine
elif engine == Engines.SQLALCHEMY:
return self._sqlalchemy_engine
elif engine == Engines.ALGO:
return self._xflow_engine
else:
assert engine == Engines.SEAHAWKS
return self._seahawks_engine
def _delegate(self, method, expr_dag, dag, expr, **kwargs):
return getattr(self._get_backend(expr_dag), method)(expr_dag, dag, expr, **kwargs)
def _cache(self, expr_dag, dag, expr, **kwargs):
return self._delegate('_cache', expr_dag, dag, expr, **kwargs)
def _handle_dep(self, expr_dag, dag, expr, **kwargs):
return self._delegate('_handle_dep', expr_dag, dag, expr, **kwargs)
def _execute(self, expr_dag, dag, expr, **kwargs):
return self._delegate('_execute', expr_dag, dag, expr, **kwargs)
def _persist(self, name, expr_dag, dag, expr, **kwargs):
return self._get_backend(expr_dag)._persist(name, expr_dag, dag, expr, **kwargs)
def _handle_join_or_union(self, expr_dag, dag, _, **kwargs):
root = expr_dag.root
if not self._selecter.has_diff_data_sources(root, no_cache=True):
return
to_execute = root._lhs if not self._selecter.has_odps_data_source(root._lhs) \
else root._rhs
table_name = self._gen_table_name()
odps_table = self._odps.get_table(table_name, schema=self._odpssql_engine._ctx.default_schema)
sub = CollectionExpr(_source_data=odps_table, _schema=to_execute.schema)
sub.add_deps(to_execute)
expr_dag.substitute(to_execute, sub)
# prevent the kwargs come from `persist`
process_persist_kwargs(kwargs)
execute_dag = ExprDAG(to_execute, dag=expr_dag)
return self._get_backend(execute_dag)._persist(
table_name, execute_dag, dag, to_execute, **kwargs)
def _handle_isin(self, expr_dag, dag, expr, **kwargs):
if not self._selecter.has_diff_data_sources(expr_dag.root, no_cache=True):
return
seq = expr._values[0]
expr._values = None
execute_dag = ExprDAG(seq, dag=expr_dag)
execute_node = self._get_backend(execute_dag)._execute(
execute_dag, dag, seq, **kwargs)
def callback(res):
vals = res[:, 0].tolist()
expr._values = tuple(Scalar(val) for val in vals)
execute_node.callback = callback
return execute_node
def _handle_function(self, expr_dag, dag, _, **kwargs):
root = expr_dag.root
# if expr input comes from an ODPS table
is_root_input_from_odps = \
self._selecter.has_odps_data_source(root.children()[0])
for i, collection in enumerate(root._collection_resources):
# if collection resource comes from an ODPS table
is_source_from_odps = self._selecter.has_odps_data_source(collection)
if is_root_input_from_odps and not is_source_from_odps:
table_name = self._gen_table_name()
odps_table = self._odps.get_table(table_name, schema=self._odpssql_engine._ctx.default_schema)
sub = CollectionExpr(_source_data=odps_table, _schema=collection.schema)
sub.add_deps(collection)
expr_dag.substitute(collection, sub)
# prevent the kwargs come from `persist`
process_persist_kwargs(kwargs)
execute_dag = ExprDAG(collection, dag=expr_dag)
self._get_backend(execute_dag)._persist(
table_name, execute_dag, dag, collection, **kwargs)
elif not is_root_input_from_odps and is_source_from_odps:
if not self._selecter.has_pandas_data_source(root.children()[0]):
raise NotImplementedError
sub = CollectionExpr(_source_data=pd.DataFrame(),
_schema=collection.schema)
sub.add_deps(collection)
expr_dag.substitute(collection, sub)
execute_node = self._odpssql_engine._execute(
ExprDAG(collection, dag=expr_dag), dag, collection, **kwargs)
def callback(res):
sub._source_data = res.values
execute_node.callback = callback
else:
continue
def _dispatch(self, expr_dag, expr, ctx):
from ...ml.expr import AlgoExprMixin
funcs = []
if isinstance(expr, AlgoExprMixin):
return self._xflow_engine._dispatch(expr_dag, expr, ctx)
handle = None
if isinstance(expr, (JoinCollectionExpr, UnionCollectionExpr)) and \
self._selecter.has_diff_data_sources(expr):
handle = self._handle_join_or_union
elif isinstance(expr, IsIn) and \
self._selecter.has_diff_data_sources(expr):
handle = self._handle_isin
elif hasattr(expr, '_func') and \
getattr(expr, '_collection_resources', None) is not None and \
self._selecter.has_diff_data_sources(expr):
handle = self._handle_function
if handle is not None:
funcs.append(handle)
h = super(MixedEngine, self)._dispatch(expr_dag, expr, ctx)
if h is not None:
if handle is None:
return h
funcs.append(h)
if funcs:
def f(*args, **kwargs):
for func in funcs:
func(*args, **kwargs)
return f
def _get_cached_sub_expr(self, cached_expr, ctx=None):
ctx = ctx or context
data = ctx.get_cached(cached_expr)
if self._selecter.force_odps and isinstance(data, SeahawksTable):
# skip seahawks heap table
return
return super(MixedEngine, self)._get_cached_sub_expr(cached_expr, ctx=ctx)
def visualize(self, expr):
dag = self.compile(expr)
try:
formatter = ExprExecutionGraphFormatter(dag)
return gen_repr_object(svg=formatter._repr_svg_())
finally:
self._generated_table_names = []
def compile(self, *expr_args_kwargs):
src_expr_args_kwargs = tuple(expr_args_kwargs)
exprs_dags, expr_args_kwargs = self._process(*expr_args_kwargs)
if not self._selecter.force_odps and any(it[0] == '_persist' for it in expr_args_kwargs):
self._selecter.force_odps = True
return self.compile(*src_expr_args_kwargs)
engine_types = [self._selecter.select(expr_dag) for expr_dag in exprs_dags]
if any(engine_type == Engines.ODPS for engine_type in engine_types):
self._selecter.force_odps = True
if any(engine_type == Engines.SEAHAWKS for engine_type in engine_types):
return self.compile(*src_expr_args_kwargs)
try:
dag = super(MixedEngine, self)._compile_dag(expr_args_kwargs, exprs_dags)
except NotImplementedError:
self._selecter.force_odps = True
return self.compile(*src_expr_args_kwargs)
if any(engine_type == Engines.SEAHAWKS for engine_type in engine_types):
def fallback():
self._selecter.force_odps = True
return self.compile(*src_expr_args_kwargs)
def need_fallback(e):
try:
import sqlalchemy
exceptions = (NotImplementedError, sqlalchemy.exc.DatabaseError)
except ImportError:
exceptions = (NotImplementedError,)
if not isinstance(e, exceptions):
return False
if not isinstance(e, NotImplementedError) and \
'AXF Exception' not in str(e):
# the seahawks error
return False
return True
if not all(not hasattr(n, 'verify') or n.verify()
for n in dag.indep_nodes()):
# we only verify the independent nodes
dag = fallback()
dag.fallback = fallback
dag.need_fallback = need_fallback
return dag