#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import json
import time
import uuid

from .runners import create_node_runner
from .expr import AlgoCollectionExpr, ODPSModelExpr, ModelDataCollectionExpr, MetricsResultExpr
from .utils import is_temp_table
from ..df.backends.context import context
from ..df.backends.analyzer import BaseAnalyzer
from ..df.backends.engine import Engine
from ..df.backends.odpssql import types
from ..df.backends.odpssql.types import df_schema_to_odps_schema
from ..df.backends.errors import CompileError
from ..df.backends.utils import refresh_dynamic
from ..df import DataFrame
from ..df.expr.collections import Node, CollectionExpr, Scalar
from ..df.expr.core import ExprDAG
from ..df.expr.dynamic import DynamicMixin
from ..df.utils import is_source_collection, is_constant_scalar
from .. import options, tempobj, utils
from ..compat import six, futures
from ..errors import ODPSError
from ..models import Partition, TableSchema
from ..ui import fetch_instance_group, reload_instance_status


class OdpsAlgoContext(object):
    def __init__(self, odps):
        self._odps = odps
        self._node_caches = dict()

    def register_exec(self, idx, parameters):
        pass


class OdpsAlgoAnalyzer(BaseAnalyzer):
    def visit_algo(self, expr):
        pass


class OdpsAlgoEngine(Engine):
    def __init__(self, odps):
        self._odps = odps
        self._ctx = OdpsAlgoContext(odps)
        self._instances = []

    def _dispatch(self, expr_dag, expr, ctx):
        if expr._need_cache and not ctx.is_cached(expr):
            # when the expr should be disk-persisted, skip
            if expr is expr_dag.root and not expr._mem_cache:
                return None
        return super(OdpsAlgoEngine, self)._dispatch(expr_dag, expr, ctx)

    def stop(self):
        for inst in self._instances:
            try:
                self._odps.stop_instance(inst.id)
            except ODPSError:
                pass

    def _gen_table_name(self, expr):
        if options.ml.dry_run:
            if isinstance(expr, Node):
                node_name = expr.node_name
            else:
                node_name = str(expr)
            return '%s_%s' % (utils.TEMP_TABLE_PREFIX, utils.camel_to_underline(node_name))

        table_name = '%s%s_%s' % (utils.TEMP_TABLE_PREFIX, int(time.time()),
                                  str(uuid.uuid4()).replace('-', '_'))
        tempobj.register_temp_table(self._odps, table_name)
        return table_name

    def _gen_model_name(self, expr):
        from .utils import TEMP_MODEL_PREFIX
        if options.ml.dry_run:
            if isinstance(expr, Node):
                node_name = expr.node_name
            else:
                node_name = str(expr)
            return '%s%s' % (utils.TEMP_TABLE_PREFIX, utils.camel_to_underline(node_name))

        model_id_str = utils.to_binary(str(int(time.time())) + '_' + str(uuid.uuid4()).replace('-', '_'))
        digest = hashlib.md5(model_id_str).hexdigest()

        model_name = TEMP_MODEL_PREFIX + digest[-(32 - len(TEMP_MODEL_PREFIX)):]
        tempobj.register_temp_model(self._odps, model_name)
        return model_name

    def _reload_ui(self, group, instance, ui):
        if group:
            reload_instance_status(self._odps, group, instance.id)
            ui.update_group()
            return fetch_instance_group(group).instances.get(instance.id)

    def _run(self, algo_name, params, metas, engine_kw, ui, **kw):
        runner = create_node_runner(self, algo_name, params, metas, engine_kw, ui, **kw)
        runner.execute()

    def _new_analyzer(self, expr_dag, on_sub=None):
        return OdpsAlgoAnalyzer(expr_dag, on_sub=on_sub)

    def _build_model(self, expr, model_name):
        if expr._is_offline_model:
            model = self._odps.get_offline_model(model_name)
            return ODPSModelExpr(_source_data=model, _is_offline_model=True)

        model_params = expr._model_params.copy()
        for meta in ['predictor', 'recommender']:
            meta_val = getattr(expr, '_' + meta, None)
            if meta_val:
                model_params[meta] = meta_val

        model = self._odps.get_tables_model(model_name, tables=list(six.iterkeys(expr._model_collections)))
        model._params = model_params

        sub = ODPSModelExpr(_source_data=model, _is_offline_model=False,
                            _model_params=expr._model_params.copy(), _predictor=expr._predictor)
        data_exprs = dict()
        for k, v in six.iteritems(expr._model_collections):
            data_exprs[k] = ModelDataCollectionExpr(_mlattr_model=sub, _data_item=k)
            data_exprs[k]._source_data = self._odps.get_table(data_exprs[k].table_name())
        sub._model_collections = data_exprs
        return sub

    def _cache(self, expr_dag, dag, expr, **kwargs):

        is_source_model = isinstance(expr, ODPSModelExpr) and expr_dag.root._source_data is not None

        # prevent the `partition` and `partitions` kwargs come from `persist`
        kwargs.pop('partition', None)
        kwargs.pop('partitions', None)

        if is_source_collection(expr_dag.root) or \
                is_constant_scalar(expr_dag.root) or \
                is_source_model:
            return

        execute_dag = ExprDAG(expr_dag.root, dag=expr_dag)

        if isinstance(expr, CollectionExpr):
            table_name = self._gen_table_name(expr)
            table = self._odps.get_table(table_name)
            root = expr_dag.root
            sub = CollectionExpr(_source_data=table, _schema=expr.schema)
            sub.add_deps(root)
            expr_dag.substitute(root, sub)

            kw = dict(kwargs)
            kw['lifecycle'] = options.temp_lifecycle

            execute_node = self._persist(table_name, execute_dag, dag, expr, **kw)

            def callback(result):
                if getattr(expr, 'is_extra_expr', False):
                    sub._source_data = result._source_data
                if isinstance(expr, DynamicMixin):
                    sub._schema = types.odps_schema_to_df_schema(table.table_schema)
                    refresh_dynamic(sub, expr_dag)

            execute_node.callback = callback
        elif isinstance(expr, ODPSModelExpr):
            model_name = self._gen_model_name(expr)
            sub = self._build_model(expr, model_name)
            root = expr_dag.root
            sub.add_deps(root)
            expr_dag.substitute(root, sub)

            kw = dict(kwargs)
            if 'lifecycle' in kw:
                del kw['lifecycle']

            execute_node = self._persist(model_name, execute_dag, dag, expr, **kw)
        else:
            assert isinstance(expr, Scalar)  # sequence is not cache-able

            class ValueHolder(object): pass
            sub = Scalar(_value_type=expr.dtype)
            sub._value = ValueHolder()

            execute_node = self._execute(execute_dag, dag, expr, **kwargs)

            def callback(res):
                sub._value = res
            execute_node.callback = callback

        return sub, execute_node

    def _write_persist_kw(self, name, expr, **kwargs):
        if isinstance(expr, CollectionExpr):
            persist_kw = kwargs.copy()
            persist_kw['_table'] = name
            project = persist_kw.pop('project', None)
            if self._odps.project != project:
                persist_kw['_project'] = project
            expr.persist_kw = persist_kw
        elif isinstance(expr, ODPSModelExpr):
            persist_kw = kwargs.copy()
            persist_kw['_model'] = name
            project = persist_kw.pop('project', None)
            if project is not None and self._odps.project != project:
                persist_kw['_project'] = project
            expr.persist_kw = persist_kw

    def _persist(self, name, expr_dag, dag, expr, **kwargs):
        self._write_persist_kw(name, expr, **kwargs)
        return super(OdpsAlgoEngine, self)._persist(name, expr_dag, dag, expr, **kwargs)

    @staticmethod
    def _is_output_model_only(src_expr):
        if isinstance(src_expr, MetricsResultExpr):
            return False
        output_exprs = src_expr.outputs()
        return not any(1 for out_expr in six.itervalues(output_exprs) if isinstance(out_expr, CollectionExpr))

    def _build_output_tables(self, expr):
        from .expr.exporters import get_output_table_name

        if not utils.str_to_bool(expr.algo_meta.get('buildTables', False)):
            return

        def create_output_table(table_name, table_schema):
            lifecycle = options.temp_lifecycle if is_temp_table(table_name) else options.lifecycle
            self._odps.create_table(table_name, table_schema, lifecycle=lifecycle)

        table_names, table_schemas = [], []
        for out_name, out_expr in six.iteritems(expr.outputs()):
            if getattr(out_expr, '_algo', None) is None:
                continue
            tn = get_output_table_name(expr, out_name)
            if tn:
                ts = getattr(out_expr, '_algo_schema', None) or out_expr._schema
                table_names.append(tn)
                table_schemas.append(df_schema_to_odps_schema(ts))

        executor = futures.ThreadPoolExecutor(10)
        list(executor.map(create_output_table, table_names, table_schemas))

    def _do_execute(self, expr_dag, src_expr, **kwargs):
        expr = expr_dag.root

        kwargs['_output_models_only'] = self._is_output_model_only(src_expr)

        kw = kwargs.copy()
        if isinstance(src_expr, ODPSModelExpr):
            ui = kw.pop('ui')
            progress_proportion = kw.pop('progress_proportion', 1)
            download_progress = progress_proportion
            ui_group = kw.pop('group', None)

            if hasattr(src_expr, '_source_data'):
                result_expr = src_expr
            else:
                if not context.is_cached(src_expr):
                    temp_name = self._gen_model_name(src_expr)
                    download_progress = 0.1 * progress_proportion
                    self._do_persist(expr_dag, src_expr, temp_name, ui=ui,
                                     progress_proportion=0.9 * progress_proportion, group=ui_group, **kw)

                result_expr = src_expr.get_cached(context.get_cached(src_expr))

            if result_expr._is_offline_model:
                from .expr.models.pmml import PmmlResult
                from .runners import XFlowNodeRunner

                model = result_expr._source_data
                if not options.ml.use_model_transfer:
                    pmml = model.get_model()
                    return PmmlResult(pmml)
                else:
                    volume_name = options.ml.model_volume
                    if not self._odps.exist_volume(volume_name):
                        self._odps.create_parted_volume(volume_name)

                    vol_part = hashlib.md5(utils.to_binary(model.name)).hexdigest()
                    tempobj.register_temp_volume_partition(self._odps, (volume_name, vol_part))
                    algo_params = {
                        'modelName': model.name,
                        'volumeName': volume_name,
                        'partition': vol_part,
                        'format': 'pmml'
                    }
                    runner = XFlowNodeRunner(self, 'modeltransfer', algo_params, {}, {},
                                             ui=ui, progress_proportion=download_progress, group=ui_group)
                    runner.execute()
                    pmml = self._odps.open_volume_reader(volume_name, vol_part, model.name + '.xml').read()
                    self._odps.delete_volume_partition(volume_name, vol_part)
                    return PmmlResult(utils.to_str(pmml))
            else:
                from .expr.models.base import TablesModelResult
                results = dict()
                frac = 1.0 / len(result_expr._model_collections)
                for key, item in six.iteritems(result_expr._model_collections):
                    result = item.execute(ui=ui, progress_proportion=frac * 0.1 * progress_proportion,
                                          group=ui_group)
                    results[key] = result
                return TablesModelResult(result_expr._model_params, results)
        elif isinstance(src_expr, MetricsResultExpr):
            if not src_expr.executed:
                expr.tables = dict((pt.name, self._gen_table_name(src_expr)) for pt in src_expr.output_ports)
                gen_params = expr.convert_params(src_expr)

                ui = kw.pop('ui')
                progress_proportion = kw.pop('progress_proportion', 1)
                ui_group = kw.pop('group', None)
                engine_kw = getattr(src_expr, '_engine_kw', {})

                engine_kw['lifecycle'] = options.temp_lifecycle
                if hasattr(src_expr, '_cases'):
                    kw['_cases'] = src_expr._cases

                self._run(src_expr._algo, gen_params, src_expr.algo_meta, engine_kw, ui,
                          progress_proportion=progress_proportion, group=ui_group, **kw)

                src_expr.executed = True

            if options.ml.dry_run:
                return None
            else:
                if hasattr(src_expr, '_result_callback'):
                    callback = src_expr._result_callback
                else:
                    callback = lambda v: v
                return callback(expr.calculator(self._odps))
        else:
            temp_name = self._gen_table_name(src_expr)
            persist_kw = kwargs.copy()
            persist_kw['_table'] = temp_name
            expr.persist_kw = persist_kw

            ui = kw.pop('ui')
            progress_proportion = kw.pop('progress_proportion', 1)
            ui_group = kw.pop('group', None)

            kw['lifecycle'] = options.temp_lifecycle
            df = self._do_persist(expr_dag, src_expr, temp_name, ui=ui,
                                  progress_proportion=0.9 * progress_proportion, group=ui_group, **kw)
            return df.execute(ui=ui, progress_proportion=0.1 * progress_proportion, group=ui_group)

    def _handle_expr_persist(self, out_expr):
        from ..df.backends.engine import ODPSSQLEngine

        class ODPSEngine(ODPSSQLEngine):
            def compile(self, expr, prettify=True, libraries=None):
                expr = self._convert_table(expr)
                expr_dag = expr.to_dag()
                self._analyze(expr_dag, expr)
                new_expr = self._rewrite(expr_dag)
                sql = self._compile(new_expr, prettify=prettify, libraries=libraries)
                if isinstance(sql, list):
                    return '\n'.join(sql)
                return sql

        if isinstance(out_expr, CollectionExpr):
            partition = out_expr.persist_kw.get('partition')
            partitions = out_expr.persist_kw.get('partitions')
            drop_table = out_expr.persist_kw.get('drop_table', False)
            create_table = out_expr.persist_kw.get('create_table', True)
            drop_partition = out_expr.persist_kw.get('drop_partition', False)
            create_partition = out_expr.persist_kw.get('create_partition', False)
            overwrite = out_expr.persist_kw.get('overwrite', True)
            cast = out_expr.persist_kw.get('cast', False)

            expr_table = out_expr.persist_kw['_table']
            expr_project = out_expr.persist_kw.get('_project')
            expr_table_path = expr_table if expr_project is None else expr_project + '.' + expr_table

            if partitions is None and partition is None:
                if drop_table:
                    self._odps.delete_table(expr_table, project=expr_project, if_exists=True)

                if self._odps.exist_table(expr_table):
                    temp_table_name = self._gen_table_name(out_expr)
                    out_expr.persist_kw['_table'] = temp_table_name
                    out_expr.persist_kw['_project'] = None

                    def callback():
                        t = self._odps.get_table(expr_table)
                        if t.table_schema.partitions:
                            raise CompileError('Cannot insert into partition table %s without specifying '
                                               '`partition` or `partitions`.')
                        expr = self._odps.get_table(temp_table_name).to_df()
                        expr = self._reorder(expr, t, cast=cast)

                        sql = ODPSEngine(self._odps).compile(expr, prettify=False)
                        action_str = 'OVERWRITE' if overwrite else 'INTO'
                        return 'INSERT {0} TABLE {1} \n{2}'.format(action_str, expr_table_path, sql)

                    return callback
                else:
                    return None
            elif partition is not None:
                temp_table_name = self._gen_table_name(out_expr)

                out_expr.persist_kw['_table'] = temp_table_name
                out_expr.persist_kw['_project'] = None

                def callback():
                    t = self._odps.get_table(temp_table_name)

                    for col in out_expr.schema.columns:
                        if col.name.lower() not in t.table_schema:
                            raise CompileError('Column(%s) does not exist in target table %s, '
                                               'writing cannot be performed.' % (col.name, t.name))

                    if drop_partition:
                        t.delete_partition(partition, if_exists=True)
                    if create_partition:
                        t.create_partition(partition, if_not_exists=True)

                    expr = t.to_df()
                    expr = self._reorder(expr, t, cast=cast)
                    sql = ODPSEngine(self._odps).compile(expr, prettify=False)
                    action_str = 'OVERWRITE' if overwrite else 'INTO'

                    return 'INSERT {0} TABLE {1} PARTITION({2}) {3}'.format(
                        action_str, expr_table_path, partition, sql,
                    )
                return callback
            else:
                temp_table_name = self._gen_table_name(out_expr)

                out_expr.persist_kw['_table'] = temp_table_name
                out_expr.persist_kw['_project'] = None

                if isinstance(partitions, tuple):
                    partitions = list(partitions)
                if not isinstance(partitions, list):
                    partitions = [partitions, ]

                def callback():
                    t = self._odps.get_table(temp_table_name)
                    schema = t.table_schema

                    columns = [c for c in schema.columns if c.name not in partitions]
                    ps = [Partition(name=pt, type=schema.get_type(pt)) for pt in partitions]
                    if drop_table:
                        self._odps.delete_table(expr_table, project=expr_project, if_exists=True)
                    if create_table:
                        lifecycle = options.temp_lifecycle if is_temp_table(expr_table) else options.lifecycle
                        self._odps.create_table(expr_table, TableSchema(columns=columns, partitions=ps),
                                                project=expr_project, lifecycle=lifecycle)

                    expr = t.to_df()
                    expr = self._reorder(expr, t, cast=cast, with_partitions=True)
                    sql = ODPSEngine(self._odps).compile(expr, prettify=False)
                    action_str = 'OVERWRITE' if overwrite else 'INTO'

                    return 'INSERT {0} TABLE {1} PARTITION({2}) {3}'.format(
                        action_str, expr_table_path, ', '.join(partitions), sql,
                    )
                return callback
        elif isinstance(out_expr, ODPSModelExpr):
            drop_model = out_expr.persist_kw.get('drop_model', False)
            expr_model = out_expr.persist_kw['_model']

            if drop_model:
                if out_expr._is_offline_model:
                    self._odps.delete_offline_model(expr_model, if_exists=True)
                else:
                    self._odps.delete_tables_model(expr_model, if_exists=True)

    def _do_persist(self, expr_dag, src_expr, name, partitions=None, partition=None, project=None,
                    drop_table=False, create_table=True, drop_partition=False, create_partition=False,
                    **kwargs):
        from .runners import SQLNodeRunner
        from .enums import PortType

        expr = expr_dag.root
        kwargs['_output_models_only'] = self._is_output_model_only(src_expr)

        output_exprs = src_expr.outputs()
        shared_kw = src_expr.shared_kw
        shared_kw['required_outputs'] = dict()
        if hasattr(src_expr, 'output_ports'):
            for out_port in src_expr.output_ports:
                if not out_port.required and out_port.name not in output_exprs:
                    continue

                if out_port.name in output_exprs:
                    out_expr = output_exprs[out_port.name]
                    if not getattr(out_expr, 'persist_kw', None):
                        expr_name = self._gen_table_name(out_expr) if isinstance(out_expr, CollectionExpr) \
                            else self._gen_model_name(expr)
                        self._write_persist_kw(expr_name, out_expr, **kwargs)
                else:
                    expr_name = self._gen_table_name(src_expr.node_name) if out_port.type == PortType.DATA \
                        else self._gen_model_name(src_expr.node_name)
                    shared_kw['required_outputs'][out_port.name] = expr_name
        src_expr.shared_kw = shared_kw

        kw = kwargs.copy()
        ui = kw.pop('ui')
        progress_proportion = kw.pop('progress_proportion', 1)
        ui_group = kw.pop('group', None)
        engine_kw = getattr(src_expr, '_engine_kw', None)

        if kw.get('lifecycle'):
            engine_kw['lifecycle'] = kw['lifecycle']
        elif options.lifecycle:
            engine_kw['lifecycle'] = options.lifecycle

        if hasattr(src_expr, '_cases'):
            kw['_cases'] = src_expr._cases

        if not options.ml.dry_run:
            self._build_output_tables(src_expr)

        sql_callbacks = []

        expr.wait_execution()

        if not src_expr.executed:
            for out_expr in six.itervalues(output_exprs):
                callback = self._handle_expr_persist(out_expr)
                if callback is not None:
                    sql_callbacks.append(callback)

        gen_params = expr.convert_params(src_expr)

        if not src_expr.executed:
            prog_ratio = 1
            sub_ratio = 0
            if sql_callbacks:
                prog_ratio = 0.8
                sub_ratio = (1 - prog_ratio) * progress_proportion / len(sql_callbacks)
            try:
                self._run(src_expr._algo, gen_params, src_expr.algo_meta, engine_kw, ui,
                          progress_proportion=prog_ratio * progress_proportion, group=ui_group, **kw)

                for cb in sql_callbacks:
                    sql = cb()
                    runner = SQLNodeRunner(self, 'SQL', dict(sql=sql), dict(), dict(), ui,
                                           progress_proportion=sub_ratio, group=ui_group)
                    runner.execute()
            finally:
                src_expr.executed = True

        if getattr(src_expr, 'is_extra_expr', False):
            t = src_expr._table_callback(self._odps, src_expr)
            context.cache(src_expr, t)

            if options.ml.dry_run:
                df = CollectionExpr(_source_data=t, _schema=src_expr._schema)
            else:
                df = DataFrame(t)
            df._ml_fields = src_expr._ml_fields
            return df

        ret = None
        for out_name, out_expr in six.iteritems(output_exprs):
            r = self._cache_expr_result(out_expr)
            if out_name == src_expr._output_name:
                ret = r

        return ret

    def _cache_expr_result(self, src_expr):
        if isinstance(src_expr, ODPSModelExpr):
            name = src_expr.persist_kw['_model']
            model_expr = self._build_model(src_expr, name)
            context.cache(src_expr, model_expr._source_data)

            if not model_expr._is_offline_model:
                params_str = utils.escape_odps_string(json.dumps(model_expr._source_data.params))
                for k, v in six.iteritems(model_expr._model_collections):
                    if not options.ml.dry_run:
                        self._odps.run_sql("alter table %s set comment '%s'" % (v._source_data.name, params_str))
                    context.cache(src_expr._model_collections[k], v._source_data)

            return model_expr
        else:
            name = src_expr.persist_kw['_table']
            project = src_expr.persist_kw.get('_project')
            t = self._odps.get_table(name, project=project)
            context.cache(src_expr, t)
            if options.ml.dry_run:
                df = CollectionExpr(_source_data=t, _schema=src_expr._schema)
            else:
                df = DataFrame(t)
            df._ml_fields = src_expr._ml_fields
            return df
