#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

import logging
import types as tps
from decimal import Decimal

from .compiler import SQLAlchemyCompiler
from . import analyzer as ana
from . import rewriter as rwr
from .types import df_schema_to_sqlalchemy_columns
from ..core import ExecuteNode, Engine
from ..errors import CompileError
from ..utils import write_table, refresh_dynamic
from ..frame import ResultFrame
from ..context import context
from ... import DataFrame
from ...utils import is_source_collection, is_constant_scalar
from ...expr.expressions import *
from ...expr.core import ExprDAG
from ...expr.dynamic import DynamicMixin
from ...types import DynamicSchema, Unknown
from ...backends.odpssql.types import df_schema_to_odps_schema, df_type_to_odps_type
from ....types import PartitionSpec
from ....utils import gen_temp_table
from ....models import TableSchema, Partition

logger = logging.getLogger(__name__)


class SQLExecuteNode(ExecuteNode):
    def _sql(self):
        raise NotImplementedError

    def __repr__(self):
        buf = six.StringIO()

        sql = self._sql()
        buf.write('MPP SQL compiled: \n\n')
        buf.write(sql)

        return buf.getvalue()

    def _repr_html_(self):
        buf = six.StringIO()

        sql = self._sql()
        buf.write('<h4>MPP SQL compiled</h4>')
        buf.write('<code>%s</code>' % sql)

        return buf.getvalue()


_engine_to_connections = {}


class SQLAlchemyEngine(Engine):
    def __init__(self, odps=None):
        self._odps = odps

    def _new_execute_node(self, expr_dag):
        node = SQLExecuteNode(expr_dag)

        def _sql(*_):
            return self._compile_sql(expr_dag)

        def verify(*_):
            try:
                _sql(*_)
                return True
            except NotImplementedError:
                return False

        node._sql = tps.MethodType(_sql, node)
        node.verify = tps.MethodType(verify, node)
        return node

    def _new_analyzer(self, expr_dag, on_sub=None):
        return ana.Analyzer(expr_dag, on_sub=on_sub)

    def _new_rewriter(self, expr_dag):
        return rwr.Rewriter(expr_dag)

    def _compile_sql(self, expr_dag):
        self._rewrite(expr_dag)

        sa = self._compile(expr_dag.root, expr_dag)
        return self._sa_to_sql(sa)

    def _sa_to_sql(self, sa):
        try:
            return sa.compile(compile_kwargs={"literal_binds": True})
        except NotImplementedError:
            return sa.compile()

    def _status_ui(self, ui):
        ui.status('Try to execute by sqlalchemy...', clear_keys=True)

    @classmethod
    def _get_or_create_conn(cls, engine):
        if engine in _engine_to_connections:
            return _engine_to_connections[engine]
        conn = engine.connect()
        _engine_to_connections[engine] = conn
        return conn

    def _get_cast_func(self, tp):
        df_types_to_builtin_types = {
            types.Integer: int,
            types.Float: float
        }
        for df_type, builtin_type in six.iteritems(df_types_to_builtin_types):
            if isinstance(tp, df_type):
                return lambda value: builtin_type(value)
        return lambda value: value

    def _get_sa_table(self, table_name, engine, schema):
        import sqlalchemy

        metadata = sqlalchemy.MetaData(bind=engine)
        columns = df_schema_to_sqlalchemy_columns(schema, engine=engine)
        table = sqlalchemy.Table(table_name, metadata, *columns, extend_existing=True)

        return table

    def _get_result(self, table_name, engine, schema, head, tail):
        import sqlalchemy

        table = self._get_sa_table(table_name, engine, schema)
        sa = sqlalchemy.select([table.alias('t1')])
        if head:
            sa = sa.limit(head)
        elif tail:
            count = sa.alias('t2').count().execute().scalar()
            skip = max(count - tail, 0)
            if skip:
                sa = sa.offset(skip)

        conn = self._get_or_create_conn(engine)
        return conn.execute(sa)

    @classmethod
    def _create_table(cls, table_name, sa, expr):
        from .ext import SACreateTempTableAs

        return SACreateTempTableAs(table_name, sa)

    def _convert_result(self, result, schema):
        res = [list(r) for r in result]
        if len(res) > 0:
            record = res[0]
            for i, r in enumerate(record):
                if schema[i].type != types.decimal and isinstance(r, Decimal):
                    cast = self._get_cast_func(schema[i].type)
                    for r in res:
                        r[i] = cast(r[i])

        return res

    def _run(self, sa, ui, expr_dag, src_expr, progress_proportion=1,
             head=None, tail=None, fetch=True, tmp_table_name=None,
             execution_options=None):
        self._status_ui(ui)

        schema = expr_dag.root.schema
        execution_options = dict() if execution_options is None else execution_options
        try:
            if isinstance(src_expr, (CollectionExpr, SequenceExpr)):
                if tmp_table_name is None:
                    tmp_table_name = gen_temp_table()
                to_execute = self._create_table(tmp_table_name, sa, expr_dag.root)
                logger.info('Sql compiled:\n' + self._sa_to_sql(to_execute))
                conn = self._get_or_create_conn(sa.bind)
                conn.execution_options(**execution_options).execute(to_execute)

                if fetch:
                    result = self._get_result(tmp_table_name, sa.bind, schema, head, tail)
                    res = self._convert_result(result, schema)
                    return tmp_table_name, res
                else:
                    return tmp_table_name, None
            else:
                logger.info('Sql compiled:\n' + self._sa_to_sql(sa))

                conn = self._get_or_create_conn(sa.bind)
                res = conn.execution_options(**execution_options).execute(sa).scalar()
                if src_expr.dtype != types.decimal and isinstance(res, Decimal):
                    return self._get_cast_func(src_expr.dtype)(res)
                return res
        finally:
            ui.inc(progress_proportion)

    def _cache(self, expr_dag, dag, expr, **kwargs):
        if is_source_collection(expr_dag.root) or \
                is_constant_scalar(expr_dag.root):
            return

        execute_dag = ExprDAG(expr_dag.root, dag=expr_dag)

        if isinstance(expr, CollectionExpr):
            table_name = gen_temp_table()
            table = self._get_table(table_name, expr_dag)
            root = expr_dag.root
            sub = CollectionExpr(_source_data=table, _schema=expr.schema)
            sub.add_deps(root)
            expr_dag.substitute(root, sub)

            kw = dict(kwargs)

            execute_node = self._execute(execute_dag, dag, expr,
                                         execute_kw={'fetch': False, 'temp_table_name': table_name},
                                         **kw)

            def callback(res):
                if isinstance(expr, DynamicMixin):
                    sub._schema = res.schema
                    refresh_dynamic(sub, expr_dag)

            execute_node.callback = callback
        else:
            assert isinstance(expr, Scalar)  # sequence is not cache-able

            class ValueHolder(object):
                pass

            sub = Scalar(_value_type=expr.dtype)
            sub._value = ValueHolder()
            root = expr_dag.root
            sub.add_deps(root)
            expr_dag.substitute(root, sub)

            execute_node = self._execute(execute_dag, dag, expr, **kwargs)

            def callback(res):
                sub._value = res

            execute_node.callback = callback

        return sub, execute_node

    def _compile(self, expr, dag):
        compiler = SQLAlchemyCompiler(dag)
        return compiler.compile(expr)

    def _get_table(self, table_name, expr_dag, bind=None):
        if bind is None:
            bind = next(e for e in expr_dag.traverse()
                        if is_source_collection(e) and e._source_data.bind)._source_data.bind

        return self._get_sa_table(table_name, bind, expr_dag.root.schema)

    def _do_execute(self, expr_dag, expr, ui=None, progress_proportion=1,
                    head=None, tail=None, **kwargs):
        expr_dag = self._convert_table(expr_dag)
        self._rewrite(expr_dag)

        src_expr = expr
        expr = expr_dag.root

        if isinstance(expr, Scalar) and expr.value is not None:
            ui.inc(progress_proportion)
            return expr.value

        sqlalchemy_expr = self._compile(expr, expr_dag)

        fetch = kwargs.pop('fetch', True)
        temp_table_name = kwargs.pop('temp_table_name', None)
        execution_options = kwargs.pop('execution_options',
                                       options.df.sqlalchemy.execution_options)
        result = self._run(sqlalchemy_expr, ui, expr_dag, src_expr,
                           progress_proportion=progress_proportion,
                           head=head, tail=tail, fetch=fetch,
                           tmp_table_name=temp_table_name,
                           execution_options=execution_options)

        if not isinstance(src_expr, Scalar):
            # reset schema
            if isinstance(src_expr, CollectionExpr) and \
                    (isinstance(src_expr._schema, DynamicSchema) or
                         any(isinstance(col.type, Unknown) for col in src_expr._schema.columns)):
                src_expr._schema = expr_dag.root.schema
            table_name, result = result
            table = self._get_table(table_name, expr_dag, sqlalchemy_expr.bind)
            context.cache(src_expr, table)
            if fetch:
                return ResultFrame(result, schema=expr_dag.root.schema)
            else:
                return table
        else:
            context.cache(src_expr, result)
            return result

    def _do_persist(self, expr_dag, expr, name, partitions=None, partition=None,
                    odps=None, project=None, ui=None,
                    progress_proportion=1, execute_percent=0.5, lifecycle=None,
                    overwrite=True, drop_table=False, create_table=True,
                    drop_partition=False, create_partition=False, cast=False, **kwargs):
        expr_dag = self._convert_table(expr_dag)
        self._rewrite(expr_dag)

        src_expr = expr
        expr = expr_dag.root
        odps = odps or self._odps

        try:
            import pandas
        except ImportError:
            raise DependencyNotInstalledError('persist requires for pandas')

        df = self._do_execute(expr_dag, src_expr, ui=ui,
                              progress_proportion=progress_proportion * execute_percent, **kwargs)
        schema = TableSchema(columns=df.columns)

        if partitions is not None:
            if drop_partition:
                raise ValueError('Cannot drop partitions when specify `partitions`')
            if create_partition:
                raise ValueError('Cannot create partitions when specify `partitions`')

            if isinstance(partitions, tuple):
                partitions = list(partitions)
            if not isinstance(partitions, list):
                partitions = [partitions, ]

            for p in partitions:
                if p not in schema:
                    raise ValueError(
                            'Partition field(%s) does not exist in DataFrame schema' % p)

            schema = df_schema_to_odps_schema(schema)
            columns = [c for c in schema.columns if c.name not in partitions]
            ps = [Partition(name=t, type=schema.get_type(t)) for t in partitions]
            schema = TableSchema(columns=columns, partitions=ps)

            if odps.exist_table(name, project=project) or not create_table:
                t = odps.get_table(name, project=project)
            else:
                t = odps.create_table(name, schema, project=project)
        elif partition is not None:
            if odps.exist_table(name, project=project) or not create_table:
                t = odps.get_table(name, project=project)
                partition = self._get_partition(partition, t)

                if drop_partition:
                    t.delete_partition(partition, if_exists=True)
                if create_partition:
                    t.create_partition(partition, if_not_exists=True)
            else:
                partition = self._get_partition(partition)
                project_obj = odps.get_project(project)
                column_names = [n for n in expr.schema.names if n not in partition]
                column_types = [
                    df_type_to_odps_type(expr.schema[n].type, project=project_obj)
                    for n in column_names
                ]
                partition_names = [n for n in partition.keys]
                partition_types = ['string'] * len(partition_names)
                t = odps.create_table(
                    name, TableSchema.from_lists(
                        column_names, column_types, partition_names, partition_types
                    ),
                    project=project,
                )
                if create_partition is None or create_partition is True:
                    t.create_partition(partition)
        else:
            if drop_partition:
                raise ValueError('Cannot drop partition for non-partition table')
            if create_partition:
                raise ValueError('Cannot create partition for non-partition table')

            if odps.exist_table(name, project=project) or not create_table:
                t = odps.get_table(name, project=project)
                if t.table_schema.partitions:
                    raise CompileError('Cannot insert into partition table %s without specifying '
                                       '`partition` or `partitions`.')
            else:
                t = odps.create_table(name, df_schema_to_odps_schema(schema), project=project)

        write_table(df, t, ui=ui, cast=cast, overwrite=overwrite, partitions=partitions, partition=partition,
                    progress_proportion=progress_proportion * (1 - execute_percent))

        if partition:
            filters = []
            df = DataFrame(t)
            for k in partition.keys:
                filters.append(df[k] == partition[k])
            return df.filter(*filters)
        return DataFrame(t)
