odps/ml/engine.py (477 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import json
import time
import uuid
from .runners import create_node_runner
from .expr import AlgoCollectionExpr, ODPSModelExpr, ModelDataCollectionExpr, MetricsResultExpr
from .utils import is_temp_table
from ..df.backends.context import context
from ..df.backends.analyzer import BaseAnalyzer
from ..df.backends.engine import Engine
from ..df.backends.odpssql import types
from ..df.backends.odpssql.types import df_schema_to_odps_schema
from ..df.backends.errors import CompileError
from ..df.backends.utils import refresh_dynamic
from ..df import DataFrame
from ..df.expr.collections import Node, CollectionExpr, Scalar
from ..df.expr.core import ExprDAG
from ..df.expr.dynamic import DynamicMixin
from ..df.utils import is_source_collection, is_constant_scalar
from .. import options, tempobj, utils
from ..compat import six, futures
from ..errors import ODPSError
from ..models import Partition, TableSchema
from ..ui import fetch_instance_group, reload_instance_status
class OdpsAlgoContext(object):
def __init__(self, odps):
self._odps = odps
self._node_caches = dict()
def register_exec(self, idx, parameters):
pass
class OdpsAlgoAnalyzer(BaseAnalyzer):
def visit_algo(self, expr):
pass
class OdpsAlgoEngine(Engine):
def __init__(self, odps):
self._odps = odps
self._ctx = OdpsAlgoContext(odps)
self._instances = []
def _dispatch(self, expr_dag, expr, ctx):
if expr._need_cache and not ctx.is_cached(expr):
# when the expr should be disk-persisted, skip
if expr is expr_dag.root and not expr._mem_cache:
return None
return super(OdpsAlgoEngine, self)._dispatch(expr_dag, expr, ctx)
def stop(self):
for inst in self._instances:
try:
self._odps.stop_instance(inst.id)
except ODPSError:
pass
def _gen_table_name(self, expr):
if options.ml.dry_run:
if isinstance(expr, Node):
node_name = expr.node_name
else:
node_name = str(expr)
return '%s_%s' % (utils.TEMP_TABLE_PREFIX, utils.camel_to_underline(node_name))
table_name = '%s%s_%s' % (utils.TEMP_TABLE_PREFIX, int(time.time()),
str(uuid.uuid4()).replace('-', '_'))
tempobj.register_temp_table(self._odps, table_name)
return table_name
def _gen_model_name(self, expr):
from .utils import TEMP_MODEL_PREFIX
if options.ml.dry_run:
if isinstance(expr, Node):
node_name = expr.node_name
else:
node_name = str(expr)
return '%s%s' % (utils.TEMP_TABLE_PREFIX, utils.camel_to_underline(node_name))
model_id_str = utils.to_binary(str(int(time.time())) + '_' + str(uuid.uuid4()).replace('-', '_'))
digest = hashlib.md5(model_id_str).hexdigest()
model_name = TEMP_MODEL_PREFIX + digest[-(32 - len(TEMP_MODEL_PREFIX)):]
tempobj.register_temp_model(self._odps, model_name)
return model_name
def _reload_ui(self, group, instance, ui):
if group:
reload_instance_status(self._odps, group, instance.id)
ui.update_group()
return fetch_instance_group(group).instances.get(instance.id)
def _run(self, algo_name, params, metas, engine_kw, ui, **kw):
runner = create_node_runner(self, algo_name, params, metas, engine_kw, ui, **kw)
runner.execute()
def _new_analyzer(self, expr_dag, on_sub=None):
return OdpsAlgoAnalyzer(expr_dag, on_sub=on_sub)
def _build_model(self, expr, model_name):
if expr._is_offline_model:
model = self._odps.get_offline_model(model_name)
return ODPSModelExpr(_source_data=model, _is_offline_model=True)
model_params = expr._model_params.copy()
for meta in ['predictor', 'recommender']:
meta_val = getattr(expr, '_' + meta, None)
if meta_val:
model_params[meta] = meta_val
model = self._odps.get_tables_model(model_name, tables=list(six.iterkeys(expr._model_collections)))
model._params = model_params
sub = ODPSModelExpr(_source_data=model, _is_offline_model=False,
_model_params=expr._model_params.copy(), _predictor=expr._predictor)
data_exprs = dict()
for k, v in six.iteritems(expr._model_collections):
data_exprs[k] = ModelDataCollectionExpr(_mlattr_model=sub, _data_item=k)
data_exprs[k]._source_data = self._odps.get_table(data_exprs[k].table_name())
sub._model_collections = data_exprs
return sub
def _cache(self, expr_dag, dag, expr, **kwargs):
is_source_model = isinstance(expr, ODPSModelExpr) and expr_dag.root._source_data is not None
# prevent the `partition` and `partitions` kwargs come from `persist`
kwargs.pop('partition', None)
kwargs.pop('partitions', None)
if is_source_collection(expr_dag.root) or \
is_constant_scalar(expr_dag.root) or \
is_source_model:
return
execute_dag = ExprDAG(expr_dag.root, dag=expr_dag)
if isinstance(expr, CollectionExpr):
table_name = self._gen_table_name(expr)
table = self._odps.get_table(table_name)
root = expr_dag.root
sub = CollectionExpr(_source_data=table, _schema=expr.schema)
sub.add_deps(root)
expr_dag.substitute(root, sub)
kw = dict(kwargs)
kw['lifecycle'] = options.temp_lifecycle
execute_node = self._persist(table_name, execute_dag, dag, expr, **kw)
def callback(result):
if getattr(expr, 'is_extra_expr', False):
sub._source_data = result._source_data
if isinstance(expr, DynamicMixin):
sub._schema = types.odps_schema_to_df_schema(table.table_schema)
refresh_dynamic(sub, expr_dag)
execute_node.callback = callback
elif isinstance(expr, ODPSModelExpr):
model_name = self._gen_model_name(expr)
sub = self._build_model(expr, model_name)
root = expr_dag.root
sub.add_deps(root)
expr_dag.substitute(root, sub)
kw = dict(kwargs)
if 'lifecycle' in kw:
del kw['lifecycle']
execute_node = self._persist(model_name, execute_dag, dag, expr, **kw)
else:
assert isinstance(expr, Scalar) # sequence is not cache-able
class ValueHolder(object): pass
sub = Scalar(_value_type=expr.dtype)
sub._value = ValueHolder()
execute_node = self._execute(execute_dag, dag, expr, **kwargs)
def callback(res):
sub._value = res
execute_node.callback = callback
return sub, execute_node
def _write_persist_kw(self, name, expr, **kwargs):
if isinstance(expr, CollectionExpr):
persist_kw = kwargs.copy()
persist_kw['_table'] = name
project = persist_kw.pop('project', None)
if self._odps.project != project:
persist_kw['_project'] = project
expr.persist_kw = persist_kw
elif isinstance(expr, ODPSModelExpr):
persist_kw = kwargs.copy()
persist_kw['_model'] = name
project = persist_kw.pop('project', None)
if project is not None and self._odps.project != project:
persist_kw['_project'] = project
expr.persist_kw = persist_kw
def _persist(self, name, expr_dag, dag, expr, **kwargs):
self._write_persist_kw(name, expr, **kwargs)
return super(OdpsAlgoEngine, self)._persist(name, expr_dag, dag, expr, **kwargs)
@staticmethod
def _is_output_model_only(src_expr):
if isinstance(src_expr, MetricsResultExpr):
return False
output_exprs = src_expr.outputs()
return not any(1 for out_expr in six.itervalues(output_exprs) if isinstance(out_expr, CollectionExpr))
def _build_output_tables(self, expr):
from .expr.exporters import get_output_table_name
if not utils.str_to_bool(expr.algo_meta.get('buildTables', False)):
return
def create_output_table(table_name, table_schema):
lifecycle = options.temp_lifecycle if is_temp_table(table_name) else options.lifecycle
self._odps.create_table(table_name, table_schema, lifecycle=lifecycle)
table_names, table_schemas = [], []
for out_name, out_expr in six.iteritems(expr.outputs()):
if getattr(out_expr, '_algo', None) is None:
continue
tn = get_output_table_name(expr, out_name)
if tn:
ts = getattr(out_expr, '_algo_schema', None) or out_expr._schema
table_names.append(tn)
table_schemas.append(df_schema_to_odps_schema(ts))
executor = futures.ThreadPoolExecutor(10)
list(executor.map(create_output_table, table_names, table_schemas))
def _do_execute(self, expr_dag, src_expr, **kwargs):
expr = expr_dag.root
kwargs['_output_models_only'] = self._is_output_model_only(src_expr)
kw = kwargs.copy()
if isinstance(src_expr, ODPSModelExpr):
ui = kw.pop('ui')
progress_proportion = kw.pop('progress_proportion', 1)
download_progress = progress_proportion
ui_group = kw.pop('group', None)
if hasattr(src_expr, '_source_data'):
result_expr = src_expr
else:
if not context.is_cached(src_expr):
temp_name = self._gen_model_name(src_expr)
download_progress = 0.1 * progress_proportion
self._do_persist(expr_dag, src_expr, temp_name, ui=ui,
progress_proportion=0.9 * progress_proportion, group=ui_group, **kw)
result_expr = src_expr.get_cached(context.get_cached(src_expr))
if result_expr._is_offline_model:
from .expr.models.pmml import PmmlResult
from .runners import XFlowNodeRunner
model = result_expr._source_data
if not options.ml.use_model_transfer:
pmml = model.get_model()
return PmmlResult(pmml)
else:
volume_name = options.ml.model_volume
if not self._odps.exist_volume(volume_name):
self._odps.create_parted_volume(volume_name)
vol_part = hashlib.md5(utils.to_binary(model.name)).hexdigest()
tempobj.register_temp_volume_partition(self._odps, (volume_name, vol_part))
algo_params = {
'modelName': model.name,
'volumeName': volume_name,
'partition': vol_part,
'format': 'pmml'
}
runner = XFlowNodeRunner(self, 'modeltransfer', algo_params, {}, {},
ui=ui, progress_proportion=download_progress, group=ui_group)
runner.execute()
pmml = self._odps.open_volume_reader(volume_name, vol_part, model.name + '.xml').read()
self._odps.delete_volume_partition(volume_name, vol_part)
return PmmlResult(utils.to_str(pmml))
else:
from .expr.models.base import TablesModelResult
results = dict()
frac = 1.0 / len(result_expr._model_collections)
for key, item in six.iteritems(result_expr._model_collections):
result = item.execute(ui=ui, progress_proportion=frac * 0.1 * progress_proportion,
group=ui_group)
results[key] = result
return TablesModelResult(result_expr._model_params, results)
elif isinstance(src_expr, MetricsResultExpr):
if not src_expr.executed:
expr.tables = dict((pt.name, self._gen_table_name(src_expr)) for pt in src_expr.output_ports)
gen_params = expr.convert_params(src_expr)
ui = kw.pop('ui')
progress_proportion = kw.pop('progress_proportion', 1)
ui_group = kw.pop('group', None)
engine_kw = getattr(src_expr, '_engine_kw', {})
engine_kw['lifecycle'] = options.temp_lifecycle
if hasattr(src_expr, '_cases'):
kw['_cases'] = src_expr._cases
self._run(src_expr._algo, gen_params, src_expr.algo_meta, engine_kw, ui,
progress_proportion=progress_proportion, group=ui_group, **kw)
src_expr.executed = True
if options.ml.dry_run:
return None
else:
if hasattr(src_expr, '_result_callback'):
callback = src_expr._result_callback
else:
callback = lambda v: v
return callback(expr.calculator(self._odps))
else:
temp_name = self._gen_table_name(src_expr)
persist_kw = kwargs.copy()
persist_kw['_table'] = temp_name
expr.persist_kw = persist_kw
ui = kw.pop('ui')
progress_proportion = kw.pop('progress_proportion', 1)
ui_group = kw.pop('group', None)
kw['lifecycle'] = options.temp_lifecycle
df = self._do_persist(expr_dag, src_expr, temp_name, ui=ui,
progress_proportion=0.9 * progress_proportion, group=ui_group, **kw)
return df.execute(ui=ui, progress_proportion=0.1 * progress_proportion, group=ui_group)
def _handle_expr_persist(self, out_expr):
from ..df.backends.engine import ODPSSQLEngine
class ODPSEngine(ODPSSQLEngine):
def compile(self, expr, prettify=True, libraries=None):
expr = self._convert_table(expr)
expr_dag = expr.to_dag()
self._analyze(expr_dag, expr)
new_expr = self._rewrite(expr_dag)
sql = self._compile(new_expr, prettify=prettify, libraries=libraries)
if isinstance(sql, list):
return '\n'.join(sql)
return sql
if isinstance(out_expr, CollectionExpr):
partition = out_expr.persist_kw.get('partition')
partitions = out_expr.persist_kw.get('partitions')
drop_table = out_expr.persist_kw.get('drop_table', False)
create_table = out_expr.persist_kw.get('create_table', True)
drop_partition = out_expr.persist_kw.get('drop_partition', False)
create_partition = out_expr.persist_kw.get('create_partition', False)
overwrite = out_expr.persist_kw.get('overwrite', True)
cast = out_expr.persist_kw.get('cast', False)
expr_table = out_expr.persist_kw['_table']
expr_project = out_expr.persist_kw.get('_project')
expr_table_path = expr_table if expr_project is None else expr_project + '.' + expr_table
if partitions is None and partition is None:
if drop_table:
self._odps.delete_table(expr_table, project=expr_project, if_exists=True)
if self._odps.exist_table(expr_table):
temp_table_name = self._gen_table_name(out_expr)
out_expr.persist_kw['_table'] = temp_table_name
out_expr.persist_kw['_project'] = None
def callback():
t = self._odps.get_table(expr_table)
if t.table_schema.partitions:
raise CompileError('Cannot insert into partition table %s without specifying '
'`partition` or `partitions`.')
expr = self._odps.get_table(temp_table_name).to_df()
expr = self._reorder(expr, t, cast=cast)
sql = ODPSEngine(self._odps).compile(expr, prettify=False)
action_str = 'OVERWRITE' if overwrite else 'INTO'
return 'INSERT {0} TABLE {1} \n{2}'.format(action_str, expr_table_path, sql)
return callback
else:
return None
elif partition is not None:
temp_table_name = self._gen_table_name(out_expr)
out_expr.persist_kw['_table'] = temp_table_name
out_expr.persist_kw['_project'] = None
def callback():
t = self._odps.get_table(temp_table_name)
for col in out_expr.schema.columns:
if col.name.lower() not in t.table_schema:
raise CompileError('Column(%s) does not exist in target table %s, '
'writing cannot be performed.' % (col.name, t.name))
if drop_partition:
t.delete_partition(partition, if_exists=True)
if create_partition:
t.create_partition(partition, if_not_exists=True)
expr = t.to_df()
expr = self._reorder(expr, t, cast=cast)
sql = ODPSEngine(self._odps).compile(expr, prettify=False)
action_str = 'OVERWRITE' if overwrite else 'INTO'
return 'INSERT {0} TABLE {1} PARTITION({2}) {3}'.format(
action_str, expr_table_path, partition, sql,
)
return callback
else:
temp_table_name = self._gen_table_name(out_expr)
out_expr.persist_kw['_table'] = temp_table_name
out_expr.persist_kw['_project'] = None
if isinstance(partitions, tuple):
partitions = list(partitions)
if not isinstance(partitions, list):
partitions = [partitions, ]
def callback():
t = self._odps.get_table(temp_table_name)
schema = t.table_schema
columns = [c for c in schema.columns if c.name not in partitions]
ps = [Partition(name=pt, type=schema.get_type(pt)) for pt in partitions]
if drop_table:
self._odps.delete_table(expr_table, project=expr_project, if_exists=True)
if create_table:
lifecycle = options.temp_lifecycle if is_temp_table(expr_table) else options.lifecycle
self._odps.create_table(expr_table, TableSchema(columns=columns, partitions=ps),
project=expr_project, lifecycle=lifecycle)
expr = t.to_df()
expr = self._reorder(expr, t, cast=cast, with_partitions=True)
sql = ODPSEngine(self._odps).compile(expr, prettify=False)
action_str = 'OVERWRITE' if overwrite else 'INTO'
return 'INSERT {0} TABLE {1} PARTITION({2}) {3}'.format(
action_str, expr_table_path, ', '.join(partitions), sql,
)
return callback
elif isinstance(out_expr, ODPSModelExpr):
drop_model = out_expr.persist_kw.get('drop_model', False)
expr_model = out_expr.persist_kw['_model']
if drop_model:
if out_expr._is_offline_model:
self._odps.delete_offline_model(expr_model, if_exists=True)
else:
self._odps.delete_tables_model(expr_model, if_exists=True)
def _do_persist(self, expr_dag, src_expr, name, partitions=None, partition=None, project=None,
drop_table=False, create_table=True, drop_partition=False, create_partition=False,
**kwargs):
from .runners import SQLNodeRunner
from .enums import PortType
expr = expr_dag.root
kwargs['_output_models_only'] = self._is_output_model_only(src_expr)
output_exprs = src_expr.outputs()
shared_kw = src_expr.shared_kw
shared_kw['required_outputs'] = dict()
if hasattr(src_expr, 'output_ports'):
for out_port in src_expr.output_ports:
if not out_port.required and out_port.name not in output_exprs:
continue
if out_port.name in output_exprs:
out_expr = output_exprs[out_port.name]
if not getattr(out_expr, 'persist_kw', None):
expr_name = self._gen_table_name(out_expr) if isinstance(out_expr, CollectionExpr) \
else self._gen_model_name(expr)
self._write_persist_kw(expr_name, out_expr, **kwargs)
else:
expr_name = self._gen_table_name(src_expr.node_name) if out_port.type == PortType.DATA \
else self._gen_model_name(src_expr.node_name)
shared_kw['required_outputs'][out_port.name] = expr_name
src_expr.shared_kw = shared_kw
kw = kwargs.copy()
ui = kw.pop('ui')
progress_proportion = kw.pop('progress_proportion', 1)
ui_group = kw.pop('group', None)
engine_kw = getattr(src_expr, '_engine_kw', None)
if kw.get('lifecycle'):
engine_kw['lifecycle'] = kw['lifecycle']
elif options.lifecycle:
engine_kw['lifecycle'] = options.lifecycle
if hasattr(src_expr, '_cases'):
kw['_cases'] = src_expr._cases
if not options.ml.dry_run:
self._build_output_tables(src_expr)
sql_callbacks = []
expr.wait_execution()
if not src_expr.executed:
for out_expr in six.itervalues(output_exprs):
callback = self._handle_expr_persist(out_expr)
if callback is not None:
sql_callbacks.append(callback)
gen_params = expr.convert_params(src_expr)
if not src_expr.executed:
prog_ratio = 1
sub_ratio = 0
if sql_callbacks:
prog_ratio = 0.8
sub_ratio = (1 - prog_ratio) * progress_proportion / len(sql_callbacks)
try:
self._run(src_expr._algo, gen_params, src_expr.algo_meta, engine_kw, ui,
progress_proportion=prog_ratio * progress_proportion, group=ui_group, **kw)
for cb in sql_callbacks:
sql = cb()
runner = SQLNodeRunner(self, 'SQL', dict(sql=sql), dict(), dict(), ui,
progress_proportion=sub_ratio, group=ui_group)
runner.execute()
finally:
src_expr.executed = True
if getattr(src_expr, 'is_extra_expr', False):
t = src_expr._table_callback(self._odps, src_expr)
context.cache(src_expr, t)
if options.ml.dry_run:
df = CollectionExpr(_source_data=t, _schema=src_expr._schema)
else:
df = DataFrame(t)
df._ml_fields = src_expr._ml_fields
return df
ret = None
for out_name, out_expr in six.iteritems(output_exprs):
r = self._cache_expr_result(out_expr)
if out_name == src_expr._output_name:
ret = r
return ret
def _cache_expr_result(self, src_expr):
if isinstance(src_expr, ODPSModelExpr):
name = src_expr.persist_kw['_model']
model_expr = self._build_model(src_expr, name)
context.cache(src_expr, model_expr._source_data)
if not model_expr._is_offline_model:
params_str = utils.escape_odps_string(json.dumps(model_expr._source_data.params))
for k, v in six.iteritems(model_expr._model_collections):
if not options.ml.dry_run:
self._odps.run_sql("alter table %s set comment '%s'" % (v._source_data.name, params_str))
context.cache(src_expr._model_collections[k], v._source_data)
return model_expr
else:
name = src_expr.persist_kw['_table']
project = src_expr.persist_kw.get('_project')
t = self._odps.get_table(name, project=project)
context.cache(src_expr, t)
if options.ml.dry_run:
df = CollectionExpr(_source_data=t, _schema=src_expr._schema)
else:
df = DataFrame(t)
df._ml_fields = src_expr._ml_fields
return df