#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import time
import uuid
import logging

import requests
from typing import List
from mars.core.context import get_context
from mars.oscar.errors import ActorNotExist
from mars.core import OutputType
from mars.dataframe.operands import DataFrameOperandMixin, DataFrameOperand
from mars.dataframe.utils import build_concatenated_rows_frame, parse_index
from mars.serialization.serializables import (
    StringField,
    SeriesField,
    BoolField,
    DictField,
    Int64Field,
    ListField,
    FieldTypes,
)

from ....config import options
from ....utils import to_str
from ..cupid_service import CupidServiceClient

logger = logging.getLogger(__name__)

_output_type_kw = dict(_output_types=[OutputType.dataframe])


class DataFrameWriteTable(DataFrameOperand, DataFrameOperandMixin):
    _op_type_ = 123460

    dtypes = SeriesField("dtypes")

    odps_params = DictField("odps_params", default=None)
    table_name = StringField("table_name", default=None)
    partition_spec = StringField("partition_spec", default=None)
    partition_columns = ListField("partition_columns", FieldTypes.string, default=None)
    overwrite = BoolField("overwrite", default=None)
    write_batch_size = Int64Field("write_batch_size", default=None)
    unknown_as_string = BoolField("unknown_as_string", default=None)
    tunnel_quota_name = StringField("tunnel_quota_name", default=None)

    def __init__(self, **kw):
        kw.update(_output_type_kw)
        super(DataFrameWriteTable, self).__init__(**kw)

    @property
    def retryable(self):
        return "CUPID_SERVICE_SOCKET" not in os.environ

    def __call__(self, x):
        shape = (0,) * len(x.shape)
        index_value = parse_index(x.index_value.to_pandas()[:0], x.key, "index")
        columns_value = parse_index(
            x.columns_value.to_pandas()[:0], x.key, "columns", store_data=True
        )
        return self.new_dataframe(
            [x],
            shape=shape,
            dtypes=x.dtypes[:0],
            index_value=index_value,
            columns_value=columns_value,
        )

    @classmethod
    def _tile_cupid(cls, op):
        from mars.dataframe.utils import build_concatenated_rows_frame

        cupid_client = CupidServiceClient()
        upload_handle = cupid_client.create_table_upload_session(
            op.odps_params, op.table_name
        )

        input_df = build_concatenated_rows_frame(op.inputs[0])
        out_df = op.outputs[0]

        out_chunks = []
        out_chunk_shape = (0,) * len(input_df.shape)
        blocks = {}
        for chunk in input_df.chunks:
            block_id = str(int(time.time())) + "_" + str(uuid.uuid4()).replace("-", "")
            chunk_op = DataFrameWriteTableSplit(
                dtypes=op.dtypes,
                table_name=op.table_name,
                odps_params=op.odps_params,
                unknown_as_string=op.unknown_as_string,
                partition_spec=op.partition_spec,
                cupid_handle=to_str(upload_handle),
                block_id=block_id,
                write_batch_size=op.write_batch_size,
            )
            out_chunk = chunk_op.new_chunk(
                [chunk],
                shape=out_chunk_shape,
                index=chunk.index,
                index_value=out_df.index_value,
                dtypes=chunk.dtypes,
            )
            out_chunks.append(out_chunk)
            blocks[block_id] = op.partition_spec

        # build commit tree
        combine_size = 8
        chunks = out_chunks
        while len(chunks) >= combine_size:
            new_chunks = []
            for i in range(0, len(chunks), combine_size):
                chks = chunks[i : i + combine_size]
                if len(chks) == 1:
                    chk = chks[0]
                else:
                    chk_op = DataFrameWriteTableCommit(
                        dtypes=op.dtypes, is_terminal=False
                    )
                    chk = chk_op.new_chunk(
                        chks,
                        shape=out_chunk_shape,
                        index_value=out_df.index_value,
                        dtypes=op.dtypes,
                    )
                new_chunks.append(chk)
            chunks = new_chunks

        assert len(chunks) < combine_size

        commit_table_op = DataFrameWriteTableCommit(
            dtypes=op.dtypes,
            table_name=op.table_name,
            blocks=blocks,
            cupid_handle=to_str(upload_handle),
            overwrite=op.overwrite,
            odps_params=op.odps_params,
            is_terminal=True,
        )
        commit_table_chunk = commit_table_op.new_chunk(
            chunks,
            shape=out_chunk_shape,
            dtypes=op.dtypes,
            index_value=out_df.index_value,
            index=(0,) * len(out_chunk_shape),
        )

        new_op = op.copy()
        return new_op.new_dataframes(
            op.inputs,
            shape=out_df.shape,
            index_value=out_df.index_value,
            dtypes=out_df.dtypes,
            columns_value=out_df.columns_value,
            chunks=[commit_table_chunk],
            nsplits=((0,),) * len(out_chunk_shape),
        )

    @classmethod
    def _tile_tunnel(cls, op):
        from odps import ODPS

        out_df = op.outputs[0]

        if op.overwrite:
            o = ODPS(
                op.odps_params["access_id"],
                op.odps_params["secret_access_key"],
                project=op.odps_params["project"],
                endpoint=op.odps_params["endpoint"],
            )
            data_target = o.get_table(op.table_name)
            if op.partition_spec:
                data_target = data_target.get_partition(op.partition_spec)
            data_target.truncate()

        in_df = build_concatenated_rows_frame(op.inputs[0])
        logger.info("Tile table %s[%s]", op.table_name, op.partition_spec)
        recorder_name = str(uuid.uuid4())
        out_chunks = []
        for chunk in in_df.chunks:
            chunk_op = DataFrameWriteTableSplit(
                dtypes=op.dtypes,
                table_name=op.table_name,
                odps_params=op.odps_params,
                partition_spec=op.partition_spec,
                commit_recorder_name=recorder_name,
                tunnel_quota_name=op.tunnel_quota_name,
            )
            index_value = parse_index(chunk.index_value.to_pandas()[:0], chunk)
            out_chunk = chunk_op.new_chunk(
                [chunk],
                shape=(0, 0),
                index_value=index_value,
                columns_value=out_df.columns_value,
                dtypes=out_df.dtypes,
                index=chunk.index,
            )
            out_chunks.append(out_chunk)
        ctx = get_context()
        ctx.create_remote_object(recorder_name, _TunnelCommitRecorder, len(out_chunks))
        new_op = op.copy()
        params = out_df.params.copy()
        params.update(
            dict(chunks=out_chunks, nsplits=((0,) * in_df.chunk_shape[0], (0,)))
        )
        return new_op.new_tileables([in_df], **params)

    @classmethod
    def tile(cls, op):
        if "CUPID_SERVICE_SOCKET" in os.environ:
            return cls._tile_cupid(op)
        else:
            return cls._tile_tunnel(op)


class DataFrameWriteTableSplit(DataFrameOperand, DataFrameOperandMixin):
    _op_type_ = 123461

    dtypes = SeriesField("dtypes")
    table_name = StringField("table_name")
    partition_spec = StringField("partition_spec", default=None)
    cupid_handle = StringField("cupid_handle", default=None)
    block_id = StringField("block_id", default=None)
    write_batch_size = Int64Field("write_batch_size", default=None)
    unknown_as_string = BoolField("unknown_as_string", default=None)
    commit_recorder_name = StringField("commit_recorder_name", default=None)

    # for tunnel
    odps_params = DictField("odps_params", default=None)
    tunnel_quota_name = StringField("tunnel_quota_name", default=None)

    def __init__(self, **kw):
        kw.update(_output_type_kw)
        super(DataFrameWriteTableSplit, self).__init__(**kw)

    @property
    def retryable(self):
        return "CUPID_SERVICE_SOCKET" not in os.environ

    @classmethod
    def _execute_in_cupid(cls, ctx, op):
        import os

        import pandas as pd
        from odps import ODPS
        from odps.accounts import BearerTokenAccount

        cupid_client = CupidServiceClient()
        to_store_data = ctx[op.inputs[0].key]

        bearer_token = cupid_client.get_bearer_token()
        account = BearerTokenAccount(bearer_token)
        project = os.environ.get("ODPS_PROJECT_NAME", None)
        odps_params = op.odps_params.copy()
        if project:
            odps_params["project"] = project
        endpoint = os.environ.get("ODPS_RUNTIME_ENDPOINT") or odps_params["endpoint"]
        o = ODPS(
            None,
            None,
            account=account,
            project=odps_params["project"],
            endpoint=endpoint,
        )
        table_obj = o.get_table(op.table_name)
        odps_schema = table_obj.table_schema
        project_name = table_obj.project.name
        schema_name = table_obj.get_schema().name if table_obj.get_schema() is not None else None
        table_name = table_obj.name

        writer_config = dict(
            _table_name=table_name,
            _project_name=project_name,
            _schema_name=schema_name,
            _table_schema=odps_schema,
            _partition_spec=op.partition_spec,
            _block_id=op.block_id,
            _handle=op.cupid_handle,
        )
        cupid_client.write_table_data(writer_config, to_store_data, op.write_batch_size)
        ctx[op.outputs[0].key] = pd.DataFrame()

    @classmethod
    def _execute_arrow_tunnel(cls, ctx, op):
        from odps import ODPS
        from odps.tunnel import TableTunnel
        import pyarrow as pa
        import pandas as pd

        project = os.environ.get("ODPS_PROJECT_NAME", None)
        odps_params = op.odps_params.copy()
        if project:
            odps_params["project"] = project
        endpoint = os.environ.get("ODPS_RUNTIME_ENDPOINT") or odps_params["endpoint"]
        o = ODPS(
            odps_params["access_id"],
            odps_params["secret_access_key"],
            project=odps_params["project"],
            endpoint=endpoint,
        )

        t = o.get_table(op.table_name)
        tunnel = TableTunnel(o, project=t.project, quota_name=op.tunnel_quota_name)
        retry_times = options.retry_times
        init_sleep_secs = 1
        split_index = op.inputs[0].index
        logger.info(
            "Start creating upload session for table %s split index %s retry_times %s.",
            op.table_name,
            split_index,
            retry_times,
        )
        retries = 0
        while True:
            try:
                if op.partition_spec is not None:
                    upload_session = tunnel.create_upload_session(
                        t.name, partition_spec=op.partition_spec
                    )
                else:
                    upload_session = tunnel.create_upload_session(t.name)
                break
            except:
                if retries >= retry_times:
                    raise
                retries += 1
                sleep_secs = retries * init_sleep_secs
                logger.exception(
                    "Create upload session failed, sleep %s seconds and retry it",
                    sleep_secs,
                    exc_info=1,
                )
                time.sleep(sleep_secs)
        logger.info(
            "Start writing table %s. split_index: %s tunnel_session: %s",
            op.table_name,
            split_index,
            upload_session.id,
        )
        retries = 0
        while True:
            try:
                writer = upload_session.open_arrow_writer(0)
                arrow_rb = pa.RecordBatch.from_pandas(ctx[op.inputs[0].key])
                writer.write(arrow_rb)
                writer.close()
                break
            except:
                if retries >= retry_times:
                    raise
                retries += 1
                sleep_secs = retries * init_sleep_secs
                logger.exception(
                    "Write data failed, sleep %s seconds and retry it",
                    sleep_secs,
                    exc_info=1,
                )
                time.sleep(sleep_secs)
        recorder_name = op.commit_recorder_name
        try:
            recorder = ctx.get_remote_object(recorder_name)
        except ActorNotExist:
            while True:
                logger.info(
                    "Writing table %s has been finished, waiting to be canceled by speculative scheduler",
                    op.table_name,
                )
                time.sleep(3)
        can_commit, can_destroy = recorder.try_commit(split_index)
        if can_commit:
            # FIXME If this commit failed or the process crashed, the whole write will still raise error.
            # But this situation is very rare so we skip the error handling.
            logger.info(
                "Committing to table %s with upload session %s", op.table_name, upload_session.id
            )
            upload_session.commit([0])
            logger.info(
                "Finish writing table %s. split_index: %s tunnel_session: %s",
                op.table_name,
                split_index,
                upload_session.id,
            )
        else:
            logger.info(
                "Skip writing table %s. split_index: %s", op.table_name, split_index
            )
        if can_destroy:
            try:
                ctx.destroy_remote_object(recorder_name)
                logger.info("Delete remote object %s", recorder_name)
            except ActorNotExist:
                pass
        logger.info(
            "Committing to table %s with upload session %s", op.table_name, upload_session.id
        )
        upload_session.commit([0])
        logger.info(
            "Finish writing table %s. split_index: %s tunnel_session: %s",
            op.table_name,
            split_index,
            upload_session.id,
        )
        ctx[op.outputs[0].key] = pd.DataFrame()

    @classmethod
    def execute(cls, ctx, op):
        if op.cupid_handle is not None:
            cls._execute_in_cupid(ctx, op)
        else:
            cls._execute_arrow_tunnel(ctx, op)


class _TunnelCommitRecorder:
    _commit_status: List[bool]

    def __init__(self, n_chunk: int):
        self._n_chunk = n_chunk
        self._commit_status = {}

    def try_commit(self, index: tuple):
        if index in self._commit_status:
            return False, len(self._commit_status) == self._n_chunk
        else:
            self._commit_status[index] = True
            return True, len(self._commit_status) == self._n_chunk


class DataFrameWriteTableCommit(DataFrameOperand, DataFrameOperandMixin):
    _op_type_ = 123462

    dtypes = SeriesField("dtypes")

    odps_params = DictField("odps_params")
    table_name = StringField("table_name")
    overwrite = BoolField("overwrite", default=False)
    blocks = DictField("blocks", default=None)
    cupid_handle = StringField("cupid_handle", default=None)
    is_terminal = BoolField("is_terminal", default=None)

    def __init__(self, **kw):
        kw.update(_output_type_kw)
        super(DataFrameWriteTableCommit, self).__init__(**kw)

    @classmethod
    def execute(cls, ctx, op):
        import pandas as pd
        from ..cupid_service import CupidServiceClient

        if op.is_terminal:
            odps_params = op.odps_params.copy()
            project = os.environ.get("ODPS_PROJECT_NAME", None)
            if project:
                odps_params["project"] = project

            client = CupidServiceClient()
            client.commit_table_upload_session(
                odps_params, op.table_name, op.cupid_handle, op.blocks, op.overwrite
            )

        ctx[op.outputs[0].key] = pd.DataFrame()


def write_odps_table(
    df,
    table,
    partition=None,
    overwrite=False,
    unknown_as_string=None,
    odps_params=None,
    write_batch_size=None,
    tunnel_quota_name=None,
):
    op = DataFrameWriteTable(
        dtypes=df.dtypes,
        odps_params=odps_params,
        table_name=table.full_table_name,
        unknown_as_string=unknown_as_string,
        partition_spec=partition,
        overwrite=overwrite,
        write_batch_size=write_batch_size,
        tunnel_quota_name=tunnel_quota_name,
    )
    return op(df)
