odps/df/backends/core.py (646 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import itertools
import time
import types
import os
import sys
import threading
from operator import itemgetter
from ...compat import six, Enum, Iterable
from ...models import Resource
from ...config import options
from ...dag import DAG
from ...utils import init_progress_ui
from ...ui.progress import create_instance_group
from ...compat import futures
from ...types import PartitionSpec
from .. import utils
from ..expr.expressions import Expr, CollectionExpr, Scalar
from ..expr.core import ExprDictionary, ExprDAG
from .context import context, ExecuteContext
from .errors import DagDependencyError, CompileError
from .formatter import ExprExecutionGraphFormatter
class EngineTypes(Enum):
ODPS = 'ODPS'
PANDAS = 'PANDAS'
SEAHAWKS = 'SEAHAWKS'
SQLALCHEMY = 'SQLALCHEMY'
ALGO = 'ALGO'
class Backend(object):
def visit_source_collection(self, expr):
raise NotImplementedError
def visit_project_collection(self, expr):
raise NotImplementedError
def visit_apply_collection(self, expr):
raise NotImplementedError
def visit_lateral_view(self, expr):
raise NotImplementedError
def visit_filter_collection(self, expr):
raise NotImplementedError
def visit_filter_partition_collection(self, expr):
raise NotImplementedError
def visit_algo(self, expr):
raise NotImplementedError
def visit_slice_collection(self, expr):
raise NotImplementedError
def visit_element_op(self, expr):
raise NotImplementedError
def visit_binary_op(self, expr):
raise NotImplementedError
def visit_unary_op(self, expr):
raise NotImplementedError
def visit_math(self, expr):
raise NotImplementedError
def visit_string_op(self, expr):
raise NotImplementedError
def visit_datetime_op(self, expr):
raise NotImplementedError
def visit_composite_op(self, expr):
raise NotImplementedError
def visit_groupby(self, expr):
raise NotImplementedError
def visit_mutate(self, expr):
raise NotImplementedError
def visit_reshuffle(self, expr):
raise NotImplementedError
def visit_value_counts(self, expr):
raise NotImplementedError
def visit_sort(self, expr):
raise NotImplementedError
def visit_sort_column(self, expr):
raise NotImplementedError
def visit_distinct(self, expr):
raise NotImplementedError
def visit_sample(self, expr):
raise NotImplementedError
def visit_pivot(self, expr):
raise NotImplementedError
def visit_reduction(self, expr):
raise NotImplementedError
def visit_user_defined_aggregator(self, expr):
raise NotImplementedError
def visit_column(self, expr):
raise NotImplementedError
def visit_function(self, expr):
raise NotImplementedError
def visit_builtin_function(self, expr):
raise NotImplementedError
def visit_sequence(self, expr):
raise NotImplementedError
def visit_cum_window(self, expr):
raise NotImplementedError
def visit_rank_window(self, expr):
raise NotImplementedError
def visit_shift_window(self, expr):
raise NotImplementedError
def visit_scalar(self, expr):
raise NotImplementedError
def visit_join(self, expr):
raise NotImplementedError
def visit_cast(self, expr):
raise NotImplementedError
def visit_union(self, expr):
raise NotImplementedError
def visit_concat(self, expr):
raise NotImplementedError
def visit_append_id(self, expr):
raise NotImplementedError
def visit_split(self, expr):
raise NotImplementedError
def visit_extract_kv(self, expr):
raise NotImplementedError
class ExecuteNode(object):
def __init__(self, expr_dag, result_index=None, callback=None):
self.expr_dag = expr_dag
self.result_index = result_index
self.callback = callback
@property
def expr(self):
return self.expr_dag.root
def run(self, **execute_kw):
raise NotImplementedError
def __call__(self, ui=None, progress_proportion=None):
res = self.run(ui=ui, progress_proportion=progress_proportion)
if self.callback:
self.callback(res)
return res
def __repr__(self):
raise NotImplementedError
def _repr_html_(self):
raise NotImplementedError
class ExecuteDAG(DAG):
def _run(self, ui, progress_proportion=1.0):
curr_progress = ui.current_progress() or 0
try:
calls = self.topological_sort()
results = [None] * len(calls)
result_idx = dict()
for i, call in enumerate(calls):
res = call(ui=ui, progress_proportion=progress_proportion / len(calls))
results[i] = res
if call.result_index is not None:
result_idx[call.result_index] = i
return [results[result_idx[idx]] for idx in sorted(result_idx)]
except Exception as e:
if self._can_fallback() and self._need_fallback(e):
ui.update(curr_progress)
return self.fallback()._run(ui, progress_proportion)
raise
def _run_in_parallel(self, ui, n_parallel, wait=True, timeout=None, progress_proportion=1.0):
submits_lock = threading.RLock()
submits = dict()
user_wait = dict()
result_wait = dict()
results = dict()
curr_progress = ui.current_progress() or 0
def actual_run(dag=None, is_fallback=False):
dag = dag or self
calls = dag.topological_sort()
result_calls = sorted([c for c in calls if c.result_index is not None],
key=lambda x: x.result_index)
fallback = threading.Event()
if is_fallback:
ui.update(curr_progress)
def close_ui(*_):
with submits_lock:
if all(call in submits and call in results for call in result_calls):
ui.close()
executor = futures.ThreadPoolExecutor(max_workers=n_parallel)
for call in calls:
if call.result_index is not None and is_fallback:
# if is fallback, we do not create new future
# cause the future objects have been passed to user
future = result_wait[call.result_index]
else:
future = futures.Future()
user_wait[call] = future
if call.result_index is not None:
future.add_done_callback(close_ui)
if not is_fallback:
result_wait[call.result_index] = future
for call in calls:
def run(func):
try:
if fallback.is_set():
raise DagDependencyError('Node execution failed due to callback')
if call.result_index is None or not is_fallback:
user_wait[func].set_running_or_notify_cancel()
prevs = dag.predecessors(func)
if prevs:
fs = [user_wait[p] for p in prevs]
for f in fs:
if f.exception():
raise DagDependencyError('Node execution failed due to failure of '
'previous node, exception: %s' % f.exception())
res = func(ui=ui, progress_proportion=progress_proportion / len(calls))
results[func] = res
user_wait[func].set_result(res)
return res
except:
e, tb = sys.exc_info()[1:]
if not is_fallback and self._can_fallback() and self._need_fallback(e):
if not fallback.is_set():
fallback.set()
new_dag = dag.fallback()
actual_run(new_dag, True)
if not fallback.is_set():
results[func] = (e, tb)
if six.PY2:
user_wait[func].set_exception_info(e, tb)
else:
user_wait[func].set_exception(e)
raise
finally:
with submits_lock:
for f in dag.successors(func):
if f in submits:
continue
prevs = dag.predecessors(f)
if all(p in submits and user_wait[p].done() for p in prevs):
submits[f] = executor.submit(run, f)
if not dag.predecessors(call):
with submits_lock:
submits[call] = executor.submit(run, call)
if wait:
dones, _ = futures.wait(user_wait.values())
for done in dones:
done.result()
return [results[c] for c in
sorted([c for c in calls if c.result_index is not None],
key=lambda x: x.result_index)]
if timeout:
futures.wait(user_wait.values(), timeout=timeout)
actual_run()
if wait:
return [it[1].result() for it in sorted(result_wait.items(), key=itemgetter(0))]
else:
return [it[1] for it in sorted(result_wait.items(), key=itemgetter(0))]
def execute(self, ui=None, async_=False, n_parallel=1, timeout=None,
close_and_notify=True, progress_proportion=1.0, **kw):
async_ = kw.get('async', async_)
ui = ui or init_progress_ui(lock=async_)
succeeded = False
if not async_:
try:
if n_parallel <= 1:
results = self._run(ui, progress_proportion)
else:
results = self._run_in_parallel(ui, n_parallel, progress_proportion=progress_proportion)
succeeded = True
return results
finally:
if close_and_notify or succeeded:
ui.close()
if succeeded:
ui.notify('DataFrame execution succeeded')
else:
ui.notify('DataFrame execution failed')
else:
try:
fs = self._run_in_parallel(ui, n_parallel, wait=not async_, timeout=timeout,
progress_proportion=progress_proportion)
succeeded = True
return fs
finally:
if succeeded:
ui.notify('DataFrame execution submitted')
else:
ui.notify('DataFrame execution failed to summit')
def _can_fallback(self):
return hasattr(self, 'fallback') and self.fallback is not None
def _need_fallback(self, e):
return hasattr(self, 'need_fallback') and self.need_fallback(e)
def __repr__(self):
return ExprExecutionGraphFormatter(self)._to_str()
def _repr_html_(self):
return ExprExecutionGraphFormatter(self)._to_html()
class Engine(object):
def stop(self):
pass
@classmethod
def _convert_table(cls, expr):
if isinstance(expr, Expr):
return utils.to_collection(expr)
expr_dag = expr
root = utils.to_collection(expr_dag.root)
if root is not expr_dag.root:
new_expr_dag = ExprDAG(root, dag=expr_dag)
new_expr_dag.ensure_all_nodes_in_dag()
return new_expr_dag
return expr_dag
def _cache(self, expr_dag, dag, expr, **kwargs):
# should return the data
raise NotImplementedError
def _dispatch(self, expr_dag, expr, ctx):
if expr._need_cache:
if not ctx.is_cached(expr):
def h():
def inner(*args, **kwargs):
ret = self._cache(*args, **kwargs)
if ret:
data, node = ret
ctx.cache(expr, data)
return node
return inner
return h()
else:
cached = ctx.get_cached(expr)
if isinstance(expr, CollectionExpr):
cached = cached.copy()
expr_dag.substitute(expr, cached)
elif expr._deps:
return self._handle_dep
def _new_analyzer(self, expr_dag, on_sub=None):
raise NotImplementedError
def _new_rewriter(self, expr_dag):
return
def _analyze(self, expr_dag, dag, **kwargs):
from .optimize import Optimizer
def sub_has_dep(_, sub):
if sub._deps is not None:
kw = dict(kwargs)
kw['finish'] = False
kw.pop('head', None)
kw.pop('tail', None)
self._handle_dep(ExprDAG(sub, dag=expr_dag), dag, sub, **kw)
# analyze first
self._new_analyzer(expr_dag, on_sub=sub_has_dep).analyze()
# optimize
return Optimizer(expr_dag).optimize()
def _rewrite(self, expr_dag):
# rewrite if exist
rewriter = self._new_rewriter(expr_dag)
if rewriter:
return rewriter.rewrite()
return expr_dag.root
def _new_execute_node(self, expr_dag):
return ExecuteNode(expr_dag)
def _handle_dep(self, expr_dag, dag, expr, **kwargs):
root = expr_dag.root
execute_nodes = []
for dep in root._deps:
if isinstance(dep, tuple):
if len(dep) == 3:
node, action, callback = dep
else:
node, callback = dep
action = '_execute'
else:
node, action, callback = dep, '_execute', None
if callback:
def dep_callback(res):
callback(res, expr)
else:
dep_callback = None
execute_node = getattr(self, action)(ExprDAG(node, dag=expr_dag), dag, node,
analyze=False, **kwargs)
execute_node.callback = dep_callback
execute_nodes.append(execute_node)
return execute_nodes
def _handle_expr_args_kwargs(self, expr_args_kwargs):
if len(expr_args_kwargs) == 1 and not isinstance(expr_args_kwargs[0], Expr) and \
all(isinstance(it, Expr) for it in expr_args_kwargs[0]):
expr_args_kwargs = expr_args_kwargs[0]
if all(isinstance(it, Expr) for it in expr_args_kwargs):
expr_args_kwargs = [('_execute', it, (), {}) for it in expr_args_kwargs]
return expr_args_kwargs
def _process(self, *expr_args_kwargs):
expr_args_kwargs = self._handle_expr_args_kwargs(expr_args_kwargs)
def h(e):
if isinstance(e, Scalar) and e.name is None:
return e.rename('__rand_%s' % int(time.time()))
if isinstance(e, CollectionExpr) and hasattr(e, '_proxy') and \
e._proxy is not None:
return e._proxy
return e
src_exprs = [h(it[1]) for it in expr_args_kwargs]
exprs_dags = self._build_expr_dag([self._convert_table(e) for e in src_exprs])
return exprs_dags, expr_args_kwargs
def _compile_dag(self, expr_args_kwargs, exprs_dags):
ctx = ExecuteContext() # expr -> new_expr
dag = ExecuteDAG()
for idx, it, expr_dag in zip(itertools.count(0), expr_args_kwargs, exprs_dags):
action, src_expr, args, kwargs = it
for node in expr_dag.traverse():
if hasattr(self, '_selecter') and not self._selecter.force_odps and hasattr(node, '_algo'):
raise NotImplementedError
h = self._dispatch(expr_dag, node, ctx)
if h:
kw = dict(kwargs)
kw['finish'] = False
if node is expr_dag.root:
node_dag = expr_dag
else:
node_dag = ExprDAG(node, dag=expr_dag)
h(node_dag, dag, node, **kw)
args = args + (expr_dag, dag, src_expr)
n = getattr(self, action)(*args, **kwargs)
n.result_index = idx
return dag
def compile(self, *expr_args_kwargs):
exprs_dags, expr_args_kwargs = self._process(*expr_args_kwargs)
return self._compile_dag(expr_args_kwargs, exprs_dags)
def _action(self, *exprs_args_kwargs, **kwargs):
ui = kwargs.pop('ui', None)
progress_proportion = kwargs.pop('progress_proportion', 1.0)
async_ = kwargs.pop('async_', kwargs.pop('async', False))
n_parallel = kwargs.pop('n_parallel', 1)
timeout = kwargs.pop('timeout', None)
batch = kwargs.pop('batch', False)
action = kwargs.pop('action', None)
def transform(*exprs_args_kw):
for expr_args_kwargs in exprs_args_kw:
if len(expr_args_kwargs) == 3:
expr, args, kw = expr_args_kwargs
act = action
else:
act, expr, args, kw = expr_args_kwargs
kw = kw.copy()
kw.update(kwargs)
yield act, expr, args, kw
dag = self.compile(*transform(*exprs_args_kwargs))
try:
res = self._execute_dag(dag, ui=ui, async_=async_, n_parallel=n_parallel,
timeout=timeout, progress_proportion=progress_proportion)
except KeyboardInterrupt:
self.stop()
sys.exit(1)
if not batch:
return res[0]
return res
def _do_execute(self, expr_dag, expr, **kwargs):
raise NotImplementedError
def _execute(self, expr_dag, dag, expr, **kwargs):
# analyze first
analyze = kwargs.pop('analyze', True)
if analyze:
kw = dict(kwargs)
kw.pop('execute_kw', None)
self._analyze(expr_dag, dag, **kw)
engine = self
execute_node = self._new_execute_node(expr_dag)
group_key = kwargs.get('group') or self._create_progress_group(expr)
def run(s, **execute_kw):
kw = dict(kwargs)
kw.update(kw.pop('execute_kw', dict()))
kw.update(execute_kw)
kw['group'] = group_key
if 'ui' in kw:
kw['ui'].add_keys(group_key)
result = engine._do_execute(expr_dag, expr, **kw)
if 'ui' in kw:
kw['ui'].remove_keys(group_key)
return result
execute_node.run = types.MethodType(run, execute_node)
self._add_node(execute_node, dag)
return execute_node
@classmethod
def _reorder(cls, expr, table, cast=False, with_partitions=False):
from .odpssql.engine import types as odps_engine_types
from .. import NullScalar
df_schema = odps_engine_types.odps_schema_to_df_schema(table.table_schema)
expr_schema = expr.schema.to_ignorecase_schema()
expr_table_schema = odps_engine_types.df_schema_to_odps_schema(expr_schema)
case_dict = dict((n.lower(), n) for n in expr.schema.names)
for col in expr_table_schema.columns:
if col.name.lower() not in table.table_schema:
raise CompileError('Column(%s) does not exist in target table %s, '
'writing cannot be performed.' % (col.name, table.name))
t_col = table.table_schema[col.name.lower()]
if not cast and not t_col.type.can_implicit_cast(col.type):
raise CompileError('Cannot implicitly cast column %s from %s to %s.' % (
col.name, col.type, t_col.type))
if (
table.table_schema.names == expr_schema.names
and df_schema.types[:len(table.table_schema.names)] == expr_schema.types
):
return expr
def field(name):
expr_name = case_dict[name]
if expr[expr_name].dtype == df_schema[name].type:
return expr[expr_name]
elif df_schema[name].type.can_implicit_cast(expr[expr_name].dtype) or cast:
return expr[expr_name].astype(df_schema[name].type)
else:
raise CompileError('Column %s\'s type does not match, expect %s, got %s' % (
expr_name, expr[expr_name].dtype, df_schema[name].type))
names = [c.name for c in table.table_schema.columns] if with_partitions else table.table_schema.names
return expr[[field(name) if name in expr_schema else NullScalar(df_schema[name].type).rename(name)
for name in names]]
@classmethod
def _get_partition(cls, partition, table=None):
if isinstance(partition, dict):
p_spec = PartitionSpec()
for name, val in six.iteritems(partition):
p_spec[name] = val
elif isinstance(partition, PartitionSpec):
p_spec = partition
else:
if not isinstance(partition, six.string_types):
raise TypeError('`partition` should be a str or dict, '
'got {0} instead'.format(type(partition)))
p_spec = PartitionSpec(partition)
if table is not None:
part_names = [p.name for p in table.table_schema.partitions]
for name in part_names:
if name not in p_spec:
raise ValueError('Table has partition column {0} '
'which not specified by `partition`'.format(name))
for name in p_spec.keys:
if name not in table.table_schema._partition_schema:
raise ValueError('Table does not have partition({0}) '
'which specified in `partition`'.format(name))
if p_spec.keys != part_names:
old_p_spec = p_spec
p_spec = PartitionSpec()
for n in part_names:
p_spec[n] = old_p_spec[n]
return p_spec
def _do_persist(self, expr_dag, expr, name, **kwargs):
raise NotImplementedError
def _persist(self, name, expr_dag, dag, expr, **kwargs):
# analyze first
analyze = kwargs.pop('analyze', True)
if analyze:
self._analyze(expr_dag, dag, **kwargs)
engine = self
execute_node = self._new_execute_node(expr_dag)
group_key = self._create_progress_group(expr)
def run(s, **execute_kw):
kw = dict(kwargs)
kw.update(execute_kw)
kw['group'] = group_key
if 'ui' in kw:
kw['ui'].add_keys(group_key)
result = engine._do_persist(expr_dag, expr, name, **kw)
if 'ui' in kw:
kw['ui'].remove_keys(group_key)
return result
execute_node.run = types.MethodType(run, execute_node)
self._add_node(execute_node, dag)
return execute_node
@classmethod
def _handle_params(cls, *expr_args_kwargs, **kwargs):
if isinstance(expr_args_kwargs[0], Expr):
return [(expr_args_kwargs[0], expr_args_kwargs[1:], {})], kwargs
elif isinstance(expr_args_kwargs[0], Iterable) and \
all(isinstance(e, Expr) for e in expr_args_kwargs[0]):
args = expr_args_kwargs[1:]
kwargs['batch'] = True
return [(e, args, {}) for e in expr_args_kwargs[0]], kwargs
else:
kwargs['batch'] = True
return expr_args_kwargs, kwargs
@staticmethod
def _create_ui(**kwargs):
existing_ui = kwargs.get('ui')
if existing_ui:
return existing_ui
async_ = kwargs.get('async_', kwargs.get('async', False))
ui = init_progress_ui(lock=async_, use_console=not async_)
ui.status('Preparing')
return ui
@staticmethod
def _create_progress_group(expr):
node_name = getattr(expr, 'node_name', expr.__class__.__name__)
return create_instance_group(node_name)
def execute(self, *exprs_args_kwargs, **kwargs):
exprs_args_kwargs, kwargs = self._handle_params(*exprs_args_kwargs, **kwargs)
kwargs['ui'] = self._create_ui(**kwargs)
kwargs['action'] = '_execute'
return self._action(*exprs_args_kwargs, **kwargs)
def persist(self, *exprs_args_kwargs, **kwargs):
exprs_args_kwargs, kwargs = self._handle_params(*exprs_args_kwargs, **kwargs)
kwargs['ui'] = self._create_ui(**kwargs)
kwargs['action'] = '_persist'
return self._action(*exprs_args_kwargs, **kwargs)
def batch(self, *action_exprs_args_kwargs, **kwargs):
args = []
for action_expr_args_kwargs in action_exprs_args_kwargs:
action, others = action_expr_args_kwargs[0], action_expr_args_kwargs[1:]
action = '_%s' % action if not action.startswith('_') else action
args.append((action, ) + tuple(others))
kwargs = kwargs.copy()
kwargs['batch'] = True
return self._action(*args, **kwargs)
def _get_cached_sub_expr(self, cached_expr, ctx=None):
ctx = ctx or context
data = ctx.get_cached(cached_expr)
return cached_expr.get_cached(data)
def _build_expr_dag(self, exprs, on_copy=None):
cached_exprs = ExprDictionary()
def find_cached(_, n):
if context.is_cached(n) and hasattr(n, 'get_cached'):
cached_exprs[n] = True
if on_copy is not None:
if not isinstance(on_copy, Iterable):
on_copy = (on_copy, )
else:
on_copy = tuple(on_copy)
on_copy = on_copy + (find_cached, )
else:
on_copy = find_cached
res = tuple(expr.to_dag(copy=True, on_copy=on_copy, validate=False)
for expr in exprs)
for cached in cached_exprs:
sub = self._get_cached_sub_expr(cached)
if sub is not None:
for dag in res:
if dag.contains_node(cached):
dag.substitute(cached, sub)
return res
def _add_node(self, dag_node, dag):
nodes = dag.nodes()
dag.add_node(dag_node)
for node in nodes:
node_expr = node.expr
if dag_node.expr.is_ancestor(node_expr):
dag.add_edge(node, dag_node)
elif node_expr.is_ancestor(dag_node.expr):
dag.add_edge(dag_node, node)
@classmethod
def _execute_dag(cls, dag, ui=None, async_=False, n_parallel=1, timeout=None, close_and_notify=True,
progress_proportion=1.0, **kw):
async_ = kw.pop('async', async_)
return dag.execute(ui=ui, async_=async_, n_parallel=n_parallel, timeout=timeout,
close_and_notify=close_and_notify, progress_proportion=progress_proportion)
def _get_libraries(self, libraries):
def conv(libs):
if libs is None:
return None
if isinstance(libs, (six.binary_type, six.text_type, Resource)):
return conv([libs, ])
new_libs = []
for lib in libs:
if not isinstance(lib, (Resource, six.string_types)):
raise ValueError('Resource %s not acceptable: illegal input type %s.'
% (repr(lib), type(lib).__name__))
if isinstance(lib, Resource):
new_libs.append(lib)
elif '/' not in lib and self._odps.exist_resource(lib):
new_libs.append(self._odps.get_resource(lib))
elif os.path.isfile(lib) and lib.endswith('.py'):
new_libs.append(lib)
elif os.path.isdir(lib):
new_libs.append(lib)
else:
raise ValueError('Resource %s not found.' % repr(lib))
return new_libs
libraries = conv(libraries) or []
if options.df.libraries is not None:
libraries.extend(conv(options.df.libraries))
if len(libraries) == 0:
return
return list(set(libraries))