#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import tarfile
import zipfile

from .... import compat, options
from ....compat import six
from ....errors import ODPSError
from ....lib.importer import CompressImporter
from ....models import Table, TableSchema, Partition as TableSchemaPartition
from ....models.partition import Partition
from ... import DataFrame
from ...expr.core import ExprDictionary
from ...expr.expressions import CollectionExpr, Scalar
from ...expr.dynamic import DynamicMixin
from ...backends.odpssql.types import df_schema_to_odps_schema, df_type_to_odps_type
from ...utils import is_source_collection, is_constant_scalar
from ...types import DynamicSchema, Unknown
from ..context import context
from ..core import Engine, ExecuteNode, ExprDAG
from ..errors import CompileError
from ..frame import ResultFrame
from ..utils import refresh_dynamic, write_table
from . import analyzer as ana
from .compiler import PandasCompiler
from .types import pd_to_df_schema


class PandasExecuteNode(ExecuteNode):
    def __repr__(self):
        return 'Local execution by pandas backend'

    def _repr_html_(self):
        return '<p>Local execution by pandas backend</p>'


def with_thirdparty_libs(fun):
    def wrapped(self, *args, **kwargs):
        libraries = self._get_libraries(kwargs.get('libraries'))
        importer = self._build_library_importer(libraries)
        if importer is not None:
            sys.meta_path.append(importer)

        try:
            return fun(self, *args, **kwargs)
        finally:
            if importer is not None:
                sys.meta_path = [p for p in sys.meta_path if p is not importer]

    wrapped.__name__ = fun.__name__
    wrapped.__doc__ = fun.__doc__
    return wrapped


class PandasEngine(Engine):
    def __init__(self, odps=None):
        self._odps = odps
        self._file_objs = []

    def _new_execute_node(self, expr_dag):
        return PandasExecuteNode(expr_dag)

    def _run(self, expr_dag, pd_dag, ui=None, progress_proportion=1, **_):
        ui.status('Try to execute by local pandas...', clear_keys=True)

        results = ExprDictionary()
        while True:
            topos = pd_dag.topological_sort()

            no_sub = True
            for node in topos:
                expr, func = node
                if expr in results:
                    continue
                res = func(results)
                if isinstance(res, tuple):
                    src = expr
                    expr = res[0]
                    res = res[1]
                    results[src] = res
                    results[expr] = res

                    # break cuz the dag has changed
                    no_sub = False
                    break
                results[expr] = res
            if no_sub:
                break

        ui.inc(progress_proportion)

        try:
            return results[expr_dag.root]
        except KeyError as e:
            if len(results) == 1:
                return compat.lvalues(results)[0]
            raise e
        finally:
            for fo in self._file_objs:
                try:
                    fo.close()
                except:
                    pass
            self._file_objs = []

    def _new_analyzer(self, expr_dag, on_sub=None):
        return ana.Analyzer(expr_dag)

    def _compile(self, expr_dag):
        backend = PandasCompiler(expr_dag)
        return backend.compile(expr_dag.root)

    def _cache(self, expr_dag, dag, expr, **kwargs):
        import pandas as pd

        if is_source_collection(expr_dag.root) or \
                is_constant_scalar(expr_dag.root):
            return

        execute_dag = ExprDAG(expr_dag.root, dag=expr_dag)

        if isinstance(expr, CollectionExpr):
            root = expr_dag.root
            sub = CollectionExpr(_source_data=pd.DataFrame(), _schema=expr.schema)
            sub.add_deps(root)
            expr_dag.substitute(root, sub)

            execute_node = self._execute(execute_dag, dag, expr, ret_df=True, **kwargs)

            def callback(res):
                for col in res.columns:
                    sub._source_data[col] = res[col]
                if isinstance(expr, DynamicMixin):
                    sub._schema = pd_to_df_schema(res)
                    refresh_dynamic(sub, expr_dag)

            execute_node.callback = callback
        else:
            assert isinstance(expr, Scalar)  # sequence is not cache-able

            class ValueHolder(object): pass
            sub = Scalar(_value_type=expr.dtype)
            sub._value = ValueHolder()
            root = expr_dag.root
            sub.add_deps(root)
            expr_dag.substitute(root, sub)

            execute_node = self._execute(execute_dag, dag, expr, **kwargs)

            def callback(res):
                sub._value = res

            execute_node.callback = callback

        return sub, execute_node

    def _build_library_importer(self, libraries):
        if libraries is None:
            return None

        def _open_file(*args, **kwargs):
            handle = open(*args, **kwargs)
            self._file_objs.append(handle)
            return handle

        readers = []

        for lib in libraries:
            if isinstance(lib, six.string_types):
                lib = os.path.abspath(lib)
                file_dict = dict()
                if os.path.isfile(lib):
                    file_dict[lib.replace(os.path.sep, '/')] = _open_file(lib, 'rb')
                else:
                    for root, dirs, files in os.walk(lib):
                        for f in files:
                            fpath = os.path.join(root, f)
                            file_dict[fpath.replace(os.path.sep, '/')] = _open_file(fpath, 'rb')
                readers.append(file_dict)
            else:
                lib_name = lib.name
                if lib_name.endswith('.zip') or lib_name.endswith('.egg') or lib_name.endswith('.whl'):
                    readers.append(zipfile.ZipFile(lib.open(mode='rb')))
                elif lib_name.endswith('.tar') or lib_name.endswith('.tar.gz') or lib_name.endswith('.tar.bz2'):
                    if lib_name.endswith('.tar'):
                        mode = 'r'
                    else:
                        mode = 'r:gz' if lib_name.endswith('.tar.gz') else 'r:bz2'
                    readers.append(tarfile.open(fileobj=six.BytesIO(lib.open(mode='rb').read()), mode=mode))
                elif lib_name.endswith('.py'):
                    tarbinary = six.BytesIO()
                    tar = tarfile.open(fileobj=tarbinary, mode='w:gz')
                    fbin = lib.open(mode='rb').read()
                    info = tarfile.TarInfo(name='pyodps_lib/' + lib_name)
                    info.size = len(fbin)
                    tar.addfile(info, fileobj=six.BytesIO(fbin))
                    tar.close()
                    readers.append(tarfile.open(fileobj=six.BytesIO(tarbinary.getvalue()), mode='r:gz'))
                else:
                    raise ValueError(
                        'Unknown library type which should be one of zip(egg, wheel), tar, or tar.gz')

        return CompressImporter(*readers, extract_binary=True, supersede=options.df.supersede_libraries)

    @with_thirdparty_libs
    def _do_execute(self, expr_dag, expr, ui=None, progress_proportion=1,
                    head=None, tail=None, **kw):
        expr_dag = self._convert_table(expr_dag)
        self._rewrite(expr_dag)

        ret_df = kw.pop('ret_df', False)

        src_expr = expr

        pd_dag = self._compile(expr_dag)
        df = self._run(expr_dag, pd_dag, ui=ui, progress_proportion=progress_proportion,
                       **kw)

        if not isinstance(src_expr, Scalar):
            need_cache = False
            if not isinstance(src_expr, CollectionExpr) or getattr(src_expr, "_source_data", None) is None:
                # only cache when expression is not data source
                need_cache = True
                context.cache(src_expr, df)
            # reset schema
            if isinstance(src_expr, CollectionExpr) and \
                    (isinstance(src_expr._schema, DynamicSchema) or
                         any(isinstance(col.type, Unknown) for col in src_expr._schema.columns)):
                src_expr._schema = expr_dag.root.schema
            if head:
                df = df[:head]
            elif tail:
                df = df[-tail:]
            if ret_df:
                return df

            if need_cache:
                df = df.copy()
            return ResultFrame(df, schema=expr_dag.root.schema)
        else:
            res = df.values[0][0]
            context.cache(src_expr, res)
            return res

    @with_thirdparty_libs
    def _do_persist(self, expr_dag, expr, name, ui=None, project=None,
                    partitions=None, partition=None, odps=None, lifecycle=None,
                    progress_proportion=1, execute_percent=0.5, overwrite=True,
                    drop_table=False, create_table=True, drop_partition=False,
                    create_partition=False, cast=False, schema=None, **kwargs):
        expr_dag = self._convert_table(expr_dag)
        self._rewrite(expr_dag)

        if isinstance(name, Partition):
            partition = name.partition_spec
            name = name.table
        if isinstance(name, Table):
            table = name
            project = table.project.name
            if table.get_schema():
                schema = table.get_schema().name
            name = table.name

        src_expr = expr
        expr = expr_dag.root
        odps = odps or self._odps
        if odps is None:
            raise ODPSError('ODPS entrance should be provided')
        schema = schema or odps.schema

        df = self._do_execute(expr_dag, src_expr, ui=ui,
                              progress_proportion=progress_proportion * execute_percent, **kwargs)
        t_schema = TableSchema(columns=df.columns)

        if drop_table:
            odps.delete_table(name, project=project, schema=schema, if_exists=True)

        if partitions is not None:
            if drop_partition:
                raise ValueError('Cannot drop partitions when specify `partitions`')
            if create_partition:
                raise ValueError('Cannot create partitions when specify `partitions`')

            if isinstance(partitions, tuple):
                partitions = list(partitions)
            if not isinstance(partitions, list):
                partitions = [partitions, ]

            for p in partitions:
                if p not in t_schema:
                    raise ValueError(
                            'Partition field(%s) does not exist in DataFrame schema' % p)

            t_schema = df_schema_to_odps_schema(t_schema)
            columns = [c for c in t_schema.columns if c.name not in partitions]
            ps = [TableSchemaPartition(name=t, type=t_schema.get_type(t)) for t in partitions]
            t_schema = TableSchema(columns=columns, partitions=ps)

            if odps.exist_table(name, project=project, schema=schema) or not create_table:
                t = odps.get_table(name, project=project, schema=schema)
            else:
                t = odps.create_table(name, t_schema, project=project, schema=schema, lifecycle=lifecycle)
        elif partition is not None:
            if odps.exist_table(name, project=project, schema=schema) or not create_table:
                t = odps.get_table(name, project=project, schema=schema)
                partition = self._get_partition(partition, t)

                if drop_partition:
                    t.delete_partition(partition, if_exists=True)
                if create_partition:
                    t.create_partition(partition, if_not_exists=True)
            else:
                partition = self._get_partition(partition)
                project_obj = odps.get_project(project)
                column_names = [n for n in expr.schema.names if n not in partition]
                column_types = [
                    df_type_to_odps_type(expr.schema[n].type, project=project_obj)
                    for n in column_names
                ]
                partition_names = [n for n in partition.keys]
                partition_types = ['string'] * len(partition_names)
                t = odps.create_table(
                    name, TableSchema.from_lists(
                        column_names, column_types, partition_names, partition_types
                    ),
                    project=project, lifecycle=lifecycle, schema=schema
                )
                if create_partition is None or create_partition is True:
                    t.create_partition(partition)
        else:
            if drop_partition:
                raise ValueError('Cannot drop partition for non-partition table')
            if create_partition:
                raise ValueError('Cannot create partition for non-partition table')

            if odps.exist_table(name, project=project, schema=schema) or not create_table:
                t = odps.get_table(name, project=project, schema=schema)
                if t.table_schema.partitions:
                    raise CompileError('Cannot insert into partition table %s without specifying '
                                       '`partition` or `partitions`.')
            else:
                t = odps.create_table(
                    name,
                    df_schema_to_odps_schema(t_schema),
                    project=project,
                    lifecycle=lifecycle,
                    schema=schema,
                )

        write_table(df, t, ui=ui, cast=cast, overwrite=overwrite, partitions=partitions, partition=partition,
                    progress_proportion=progress_proportion*(1-execute_percent))

        if partition:
            if partition:
                filters = []
                df = DataFrame(t)
                for k in partition.keys:
                    filters.append(df[k] == partition[k])
                return df.filter(*filters)
        return DataFrame(t)
