#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

import copy
import logging
import sys
import threading
import time
import types as tps
import uuid  # don't remove
import warnings

from ....compat import six
from ....config import options
from ....errors import ODPSError, NoPermission, ConnectTimeout, ParseError
from ....utils import TEMP_TABLE_PREFIX, get_supported_python_tag
from ....models import Table, TableSchema
from ....models.partition import Partition
from ....tempobj import register_temp_table
from ....ui import reload_instance_status, fetch_instance_group
from ...core import DataFrame
from ...expr.core import ExprDAG
from ...expr.dynamic import DynamicMixin
from ...expr.expressions import CollectionExpr, Scalar, Summary, SequenceExpr
from ...utils import is_source_collection, is_constant_scalar
from ...types import DynamicSchema, Unknown
from ..utils import refresh_dynamic, process_persist_kwargs
from ..core import Engine, ExecuteNode
from ..errors import CompileError
from ..frame import ResultFrame
from ..context import context
from . import types
from . import analyzer as ana
from . import rewriter as rwr
from .context import ODPSContext, UDF_CLASS_NAME
from .compiler import OdpsSQLCompiler
from .codegen import gen_udf
from .tunnel import TunnelEngine
from .models import MemCacheReference

logger = logging.getLogger(__name__)


class SQLExecuteNode(ExecuteNode):
    def _sql(self):
        raise NotImplementedError

    def __repr__(self):
        buf = six.StringIO()

        sql = self._sql()

        if sql:
            if isinstance(sql, list):
                sql = '\n'.join(sql)
            buf.write('SQL compiled: \n\n')
            buf.write(sql)
        else:
            buf.write('Use tunnel to download data')

        return buf.getvalue()

    def _repr_html_(self):
        buf = six.StringIO()

        sql = self._sql()

        if sql:
            if isinstance(sql, list):
                sql = '\n'.join(sql)
            buf.write('<h4>SQL compiled</h4>')
            buf.write('<code>%s</code>' % sql)
        else:
            buf.write('<p>Use tunnel to download data</p>')

        return buf.getvalue()


class ODPSSQLEngine(Engine):
    def __init__(self, odps):
        self._odps = odps
        self._ctx_local = threading.local()
        self._instances = []

    @property
    def _ctx(self):
        if not hasattr(self._ctx_local, '_ctx'):
            self._ctx_local._ctx = ODPSContext(self._odps)
        return self._ctx_local._ctx

    def stop(self):
        for inst_id in self._instances:
            try:
                self._odps.stop_instance(inst_id)
            except ODPSError:
                pass
        self._ctx.close()

    def _new_execute_node(self, expr_dag):
        node = SQLExecuteNode(expr_dag)

        def _sql(*_):
            return self._compile_sql(node.expr_dag)

        node._sql = tps.MethodType(_sql, node)
        return node

    @staticmethod
    def _get_task_percent(task_progress):
        if len(task_progress.stages) > 0:
            all_percent = sum((float(stage.terminated_workers) / stage.total_workers)
                              for stage in task_progress.stages if stage.total_workers > 0)
            return all_percent / len(task_progress.stages)
        else:
            return 0

    def _reload_ui(self, group, instance, ui):
        if group:
            reload_instance_status(self._odps, group, instance.id)
            ui.update_group()
            return fetch_instance_group(group).instances.get(instance.id)

    def _run(self, sql, ui, progress_proportion=1, hints=None, priority=None,
             running_cluster=None, group=None, libraries=None, schema=None, image=None):
        libraries = self._ctx.prepare_resources(self._get_libraries(libraries))
        self._ctx.create_udfs(libraries=libraries)

        hints = hints or dict()
        if self._ctx.get_udf_count() > 0 and sys.version_info[:2] >= (3, 6):
            hints['odps.sql.jobconf.odps2'] = True
            hints['odps.sql.python.version'] = get_supported_python_tag()
        image = image or options.df.image
        if image:
            hints['odps.session.image'] = image
        hints.update(self._ctx.get_udf_sql_hints())

        instance = self._odps.run_sql(sql, hints=hints, priority=priority, name='PyODPSDataFrameTask',
                                      running_cluster=running_cluster, default_schema=schema)

        self._instances.append(instance.id)
        logger.info(
            'Instance ID: %s\n  Log view: %s', instance.id, instance.get_logview_address()
        )
        ui.status('Executing', 'execution details')

        percent = 0
        last_log_progress = 0
        progress_time = start_time = time.time()
        while not instance.is_terminated(retry=True):
            inst_progress = self._reload_ui(group, instance, ui)

            if inst_progress:
                last_percent = percent
                if inst_progress is not None and len(inst_progress.tasks) > 0:
                    percent = sum(self._get_task_percent(task)
                                  for task in six.itervalues(inst_progress.tasks)) \
                              / len(inst_progress.tasks)
                else:
                    percent = 0
                percent = min(1, max(percent, last_percent))
                ui.inc((percent - last_percent) * progress_proportion)

            if logger.getEffectiveLevel() <= logging.INFO:
                try:
                    check_time = time.time()
                    task_progresses = inst_progress.tasks
                    total_progress = sum(
                        stage.finished_percentage
                        for progress in task_progresses.values()
                        for stage in progress.stages
                    )
                    if check_time - start_time >= options.progress_time_interval and (
                        total_progress - last_log_progress >= options.progress_percentage_gap
                        or check_time - progress_time >= options.progress_time_interval
                    ):
                        output_parts = [instance.id] + [
                            progress.get_stage_progress_formatted_string()
                            for progress in task_progresses.values()
                        ]
                        logger.info(" ".join(output_parts))
                        last_log_progress = total_progress
                        progress_time = check_time
                except:
                    # make sure progress display does not affect execution
                    pass

            time.sleep(1)

        instance.wait_for_success()

        self._reload_ui(group, instance, ui)
        if percent < 1:
            ui.inc((1 - percent) * progress_proportion)

        return instance

    def _handle_cases(self, *args, **kwargs):
        tunnel_engine = TunnelEngine(self._odps)
        return tunnel_engine.execute(*args, **kwargs)

    def _gen_table_name(self):
        table_name = '%s%s_%s' % (TEMP_TABLE_PREFIX, int(time.time()),
                                  str(uuid.uuid4()).replace('-', '_'))
        register_temp_table(self._odps, table_name, schema=self._ctx.default_schema)
        return table_name

    def _new_analyzer(self, expr_dag, on_sub=None):
        return ana.Analyzer(expr_dag, on_sub=on_sub)

    def _new_rewriter(self, expr_dag):
        return rwr.Rewriter(expr_dag)

    def _compile_sql(self, expr_dag, prettify=True):
        self._rewrite(expr_dag)

        return self._compile(expr_dag.root, prettify=prettify)

    def _compile(self, expr, prettify=False, libraries=None):
        backend = OdpsSQLCompiler(self._ctx, beautify=prettify)

        libraries = self._ctx.prepare_resources(self._get_libraries(libraries))
        self._ctx.register_udfs(*gen_udf(expr, UDF_CLASS_NAME, libraries=libraries))

        return backend.compile(expr)

    def _mem_cache(self, expr_dag, expr):
        engine = self
        root = expr

        def _sub():
            ref_name = self._ctx.get_mem_cache_ref_name(root)
            sub = CollectionExpr(_source_data=MemCacheReference(root._id, ref_name),
                                 _schema=expr_dag.root.schema, _id=expr_dag.root._id)
            expr_dag.substitute(expr_dag.root, sub)
            return sub

        if self._ctx.is_expr_mem_cached(expr):
            if is_source_collection(expr) and \
                    isinstance(expr._source_data, MemCacheReference):
                return
            _sub()
            return

        class MemCacheCompiler(OdpsSQLCompiler):
            def visit_source_collection(self, expr):
                if isinstance(expr._source_data, MemCacheReference):
                    alias = self._ctx.register_collection(expr)
                    from_clause = '{0} {1}'.format(expr._source_data.ref_name, alias)
                    self.add_from_clause(expr, from_clause)
                    self._ctx.add_expr_compiled(expr, from_clause)
                    engine._ctx.register_mem_cache_dep(root, expr)
                else:
                    super(MemCacheCompiler, self).visit_source_collection(expr)

        compiler = MemCacheCompiler(self._ctx, indent_size=0)
        sql = compiler.compile(expr_dag.root).replace('\n', '')
        self._ctx.register_mem_cache_sql(root, sql)
        sub = _sub()

        return sub, None

    def _cache(self, expr_dag, dag, expr, **kwargs):
        if isinstance(expr, CollectionExpr) and expr._mem_cache:
            return self._mem_cache(expr_dag, expr)

        # prevent the kwargs come from `persist`
        process_persist_kwargs(kwargs)

        if is_source_collection(expr_dag.root) or \
                is_constant_scalar(expr_dag.root):
            return

        execute_dag = ExprDAG(expr_dag.root, dag=expr_dag)

        if isinstance(expr, CollectionExpr):
            table_name = self._gen_table_name()
            table = self._odps.get_table(table_name, schema=self._ctx.default_schema)
            root = expr_dag.root
            sub = CollectionExpr(_source_data=table, _schema=expr.schema)
            sub.add_deps(root)
            expr_dag.substitute(root, sub)

            kw = dict(kwargs)
            kw['lifecycle'] = options.temp_lifecycle

            execute_node = self._persist(table_name, execute_dag, dag, expr, **kw)

            def callback(_):
                if isinstance(expr, DynamicMixin):
                    sub._schema = types.odps_schema_to_df_schema(table.table_schema)
                    refresh_dynamic(sub, expr_dag)

            execute_node.callback = callback
        else:
            assert isinstance(expr, Scalar)  # sequence is not cache-able

            class ValueHolder(object): pass
            sub = Scalar(_value_type=expr.dtype)
            sub._value = ValueHolder()
            root = expr_dag.root
            sub.add_deps(root)
            expr_dag.substitute(root, sub)

            execute_node = self._execute(execute_dag, dag, expr, **kwargs)

            def callback(res):
                sub._value = res
            execute_node.callback = callback

        return sub, execute_node

    @classmethod
    def _join_sql(cls, sql):
        if isinstance(sql, list):
            return '\n'.join(sql)
        return sql

    def _do_execute(self, expr_dag, expr, ui=None, progress_proportion=1,
                    lifecycle=None, head=None, tail=None, hints=None,
                    priority=None, running_cluster=None, schema=None, **kw):
        lifecycle = lifecycle or options.temp_lifecycle
        group = kw.get('group')
        libraries = kw.pop('libraries', None)
        image = kw.pop('image', None)
        use_tunnel = kw.get('use_tunnel', True)

        self._ctx.default_schema = schema or self._ctx.default_schema
        if self._odps.is_schema_namespace_enabled(hints):
            self._ctx.default_schema = self._ctx.default_schema or "default"

        expr_dag = self._convert_table(expr_dag)
        self._rewrite(expr_dag)

        src_expr = expr
        expr = expr_dag.root

        if isinstance(expr, Scalar) and expr.value is not None:
            ui.inc(progress_proportion)
            return expr.value

        no_permission = False
        if options.df.optimizes.tunnel:
            force_tunnel = kw.get('_force_tunnel', False)
            try:
                result = self._handle_cases(expr, ui, progress_proportion=progress_proportion,
                                            head=head, tail=tail)
            except KeyboardInterrupt:
                ui.status('Halt by interruption')
                sys.exit(1)
            except (NoPermission, ConnectTimeout) as ex:
                result = None
                no_permission = True
                if head:
                    expr = expr[:head]
                warnings.warn('Failed to download data by table tunnel, 10000 records will be limited.\n' +
                              'Cause: ' + str(ex))
            if force_tunnel or result is not None:
                return result

        try:
            sql = self._compile(expr, libraries=libraries)
            if types.get_local_use_odps2_types(self._odps.get_project()):
                hints = copy.copy(hints or {})
                hints["odps.sql.type.system.odps2"] = "true"

            cache_data = None
            if not no_permission and isinstance(expr, CollectionExpr) and not isinstance(expr, Summary):
                # When tunnel cannot handle, we will try to create a table
                tmp_table_name = '%s%s' % (TEMP_TABLE_PREFIX, str(uuid.uuid4()).replace('-', '_'))
                register_temp_table(self._odps, tmp_table_name, schema=self._ctx.default_schema)
                cache_data = self._odps.get_table(tmp_table_name, schema=self._ctx.default_schema)

                lifecycle_str = 'LIFECYCLE {0} '.format(lifecycle) if lifecycle is not None else ''
                format_sql = lambda s: 'CREATE TABLE {0} {1}AS \n{2}'.format(tmp_table_name, lifecycle_str, s)
                if isinstance(sql, list):
                    sql[-1] = format_sql(sql[-1])
                else:
                    sql = format_sql(sql)

            sql = self._join_sql(sql)

            logger.info('Sql compiled:\n' + sql)

            try:
                instance = self._run(sql, ui, progress_proportion=progress_proportion * 0.9, hints=hints,
                                     priority=priority, running_cluster=running_cluster, group=group,
                                     libraries=libraries, image=image, schema=self._ctx.default_schema)
            finally:
                self._ctx.close()  # clear udfs and resources generated

            res = self._fetch(expr, src_expr, instance, ui,
                              cache_data=cache_data, head=head, tail=tail,
                              use_tunnel=use_tunnel, group=group,
                              progress_proportion=progress_proportion * 0.1,
                              finish=kw.get('finish', True))
        finally:
            types.set_local_use_odps2_types(None)

        if kw.get('ret_instance', False) is True:
            return instance, res
        return res

    def _fetch(self, expr, src_expr, instance, ui, progress_proportion=1,
               cache_data=None, head=None, tail=None, use_tunnel=True,
               group=None, finish=True):
        if isinstance(expr, (CollectionExpr, Summary)):
            df_schema = expr._schema
            schema = types.df_schema_to_odps_schema(expr._schema, ignorecase=True)
        elif isinstance(expr, SequenceExpr):
            df_schema = TableSchema.from_lists([expr.name], [expr._data_type])
            schema = types.df_schema_to_odps_schema(df_schema, ignorecase=True)
        else:
            df_schema = None
            schema = None

        if cache_data is not None:
            if group and finish:
                ui.remove_keys(group)
            if use_tunnel:
                try:
                    if finish:
                        ui.status('Start to use tunnel to download results...')
                    with cache_data.open_reader(reopen=True) as reader:
                        if head:
                            reader = reader[:head]
                        elif tail:
                            start = max(reader.count - tail, 0)
                            reader = reader[start: ]
                        try:
                            return ResultFrame([r.values for r in reader], schema=df_schema)
                        finally:
                            context.cache(src_expr, cache_data)
                            # reset schema
                            if isinstance(src_expr, CollectionExpr) and \
                                    (isinstance(src_expr._schema, DynamicSchema) or
                                     any(isinstance(col.type, Unknown) for col in src_expr._schema.columns)):
                                src_expr._schema = df_schema
                            ui.inc(progress_proportion)
                except ODPSError as ex:
                    # some project has closed the tunnel download
                    # we just ignore the error
                    warnings.warn('Failed to download data by table tunnel, 10000 records will be limited.\n' +
                                  'Cause: ' + str(ex))
                    pass

            if tail:
                raise NotImplementedError

            try:
                if finish:
                    ui.status('Start to use head to download results...')
                return ResultFrame(cache_data.head(head or 10000), schema=df_schema)
            finally:
                context.cache(src_expr, cache_data)
                ui.inc(progress_proportion)

        with instance.open_reader(schema=schema, use_tunnel=False) as reader:
            ui.status('Start to read instance results...')
            if not isinstance(src_expr, Scalar):
                if head:
                    reader = reader[:head]
                elif tail:
                    start = max(reader.count - tail, 0)
                    reader = reader[start: ]
                try:
                    return ResultFrame([r.values for r in reader], schema=df_schema)
                finally:
                    context.cache(src_expr, cache_data)
                    ui.inc(progress_proportion)
            else:
                ui.inc(progress_proportion)
                odps_type = types.df_type_to_odps_type(src_expr._value_type, project=instance.project)
                res = types.odps_types.validate_value(reader[0][0], odps_type)
                context.cache(src_expr, res)
                return res

    def _do_persist(self, expr_dag, expr, name, partitions=None, partition=None, project=None, ui=None,
                    progress_proportion=1, lifecycle=None, hints=None, priority=None,
                    running_cluster=None, overwrite=True, drop_table=False, create_table=True,
                    drop_partition=False, create_partition=None, cast=False, schema=None, **kw):
        group = kw.get('group')
        libraries = kw.pop('libraries', None)
        image = kw.pop('image', None)

        if isinstance(name, Partition):
            partition = name.partition_spec
            name = name.table
        if isinstance(name, Table):
            table = name
            project = table.project.name
            if table.get_schema():
                schema = table.get_schema().name
            name = table.name

        lifecycle = options.temp_lifecycle if name.startswith(TEMP_TABLE_PREFIX) else lifecycle

        self._ctx.default_schema = schema or self._ctx.default_schema
        if self._odps.is_schema_namespace_enabled(hints):
            self._ctx.default_schema = self._ctx.default_schema or "default"

        expr_dag = self._convert_table(expr_dag)
        self._rewrite(expr_dag)

        src_expr = expr
        expr = expr_dag.root

        should_cache = False

        if drop_table:
            self._odps.delete_table(name, project=project, schema=schema, if_exists=True)

        if project is not None or schema is not None:
            project = project or self._ctx._odps.project
            schema = schema or self._ctx.default_schema
        if project is None:
            table_name = name
        elif schema is None:
            table_name = '%s.`%s`' % (project, name)
        else:
            table_name = '%s.%s.`%s`' % (project, schema, name)

        project_obj = self._odps.get_project(project)
        if partitions is None and partition is None:
            # the non-partitioned table
            if drop_partition:
                raise ValueError('Cannot drop partition for non-partition table')
            if create_partition:
                raise ValueError('Cannot create partition for non-partition table')

            if self._odps.exist_table(name, project=project, schema=schema) or not create_table:
                t = self._odps.get_table(name, project=project, schema=schema)
                if t.table_schema.partitions:
                    raise CompileError('Cannot insert into partition table %s without specifying '
                                       '`partition` or `partitions`.')
                expr = self._reorder(expr, t, cast=cast)
            else:
                # We don't use `CREATE TABLE ... AS` because it will report `table already exists`
                # when service retries.
                if isinstance(expr, CollectionExpr):
                    t_schema = types.df_schema_to_odps_schema(expr.schema, ignorecase=True)
                else:
                    col_name = expr.name
                    tp = types.df_type_to_odps_type(expr.dtype, project=project_obj)
                    t_schema = TableSchema.from_lists([col_name, ], [tp, ])
                self._odps.create_table(name, TableSchema(columns=t_schema.columns),
                                        project=project, schema=schema, lifecycle=lifecycle)

            sql = self._compile(expr, prettify=False, libraries=libraries)
            action_str = 'OVERWRITE' if overwrite else 'INTO'
            format_sql = lambda s: 'INSERT {0} TABLE {1} \n{2}'.format(action_str, table_name, s)
            if isinstance(sql, list):
                sql[-1] = format_sql(sql[-1])
            else:
                sql = format_sql(sql)

            should_cache = True
        elif partition is not None:
            if self._odps.exist_table(name, project=project, schema=schema) or not create_table:
                t = self._odps.get_table(name, project=project, schema=schema)
                partition = self._get_partition(partition, t)

                if drop_partition:
                    t.delete_partition(partition, if_exists=True)
                if create_partition:
                    t.create_partition(partition, if_not_exists=True)
            else:
                partition = self._get_partition(partition)
                column_names = [n for n in expr.schema.names if n not in partition]
                column_types = [
                    types.df_type_to_odps_type(expr.schema[n].type, project=project_obj)
                    for n in column_names
                ]
                partition_names = [n for n in partition.keys]
                partition_types = ['string'] * len(partition_names)
                t = self._odps.create_table(
                    name, TableSchema.from_lists(
                        column_names, column_types, partition_names, partition_types
                    ),
                    project=project, lifecycle=lifecycle, schema=schema)
                if create_partition is None or create_partition is True:
                    t.create_partition(partition)

            expr = self._reorder(expr, t, cast=cast)
            sql = self._compile(expr, prettify=False, libraries=libraries)

            action_str = 'OVERWRITE' if overwrite else 'INTO'
            format_sql = lambda s: 'INSERT {0} TABLE {1} PARTITION({2}) \n{3}'.format(
                action_str, table_name, partition, s
            )
            if isinstance(sql, list):
                sql[-1] = format_sql(sql[-1])
            else:
                sql = format_sql(sql)
        else:
            if isinstance(partitions, tuple):
                partitions = list(partitions)
            if not isinstance(partitions, list):
                partitions = [partitions, ]

            if isinstance(expr, CollectionExpr):
                t_schema = types.df_schema_to_odps_schema(expr.schema, ignorecase=True)
            else:
                col_name = expr.name
                tp = types.df_type_to_odps_type(expr.dtype, project=project_obj)
                t_schema = TableSchema.from_lists([col_name, ], [tp, ])

            for p in partitions:
                if p not in t_schema:
                    raise ValueError(
                        'Partition field(%s) does not exist in DataFrame schema' % p)

            columns = [c for c in t_schema.columns if c.name not in partitions]
            ps = [TableSchema.TablePartition(name=pt, type=t_schema.get_type(pt)) for pt in partitions]
            if self._odps.exist_table(name, project=project, schema=schema) or not create_table:
                t = self._odps.get_table(name, project=project, schema=schema)
            else:
                t = self._odps.create_table(name, TableSchema(columns=columns, partitions=ps),
                                            project=project, lifecycle=lifecycle, schema=schema)
            if drop_partition:
                raise ValueError('Cannot drop partitions when specify `partitions`')
            if create_partition:
                raise ValueError('Cannot create partitions when specify `partitions`')
            expr = expr[[c.name for c in expr.schema if c.name not in partitions] + partitions]

            expr = self._reorder(expr, t, cast=cast, with_partitions=True)
            sql = self._compile(expr, prettify=False, libraries=libraries)

            action_str = 'OVERWRITE' if overwrite else 'INTO'
            format_sql = lambda s: 'INSERT {0} TABLE {1} PARTITION({2}) \n{3}'.format(
                action_str, table_name, ', '.join(partitions), s
            )
            if isinstance(sql, list):
                sql[-1] = format_sql(sql[-1])
            else:
                sql = format_sql(sql)

        sql = self._join_sql(sql)

        logger.info('Sql compiled:\n' + sql)

        try:
            instance = self._run(sql, ui, progress_proportion=progress_proportion, hints=hints,
                                 priority=priority, running_cluster=running_cluster, group=group,
                                 libraries=libraries, image=image, schema=schema)
        except ParseError as ex:
            logger.error("Failed to run DF generated SQL: %s:\n%s", str(ex), sql)
            raise
        finally:
            self._ctx.close()  # clear udfs and resources generated

        t = self._odps.get_table(name, project=project, schema=schema)
        if should_cache and not is_source_collection(src_expr):
            # TODO: support cache partition
            context.cache(src_expr, t)
        if partition:
            filters = []
            df = DataFrame(t)
            for k in partition.keys:
                # actual type of partition and column type may mismatch
                filters.append(df[k] == Scalar(partition[k]).cast(df[k].dtype))
            res = df.filter(*filters)
        else:
            res = DataFrame(t)
        if kw.get('ret_instance', False) is True:
            return instance, res
        return res
