# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import contextmanager
from typing import Dict, List, Optional, Union

import pyarrow as pa
from odps import ODPS
from odps.apis.storage_api import (
    StorageApiArrowClient,
    TableBatchScanResponse,
    TableBatchWriteResponse,
)
from odps.tunnel import TableDownloadSession, TableDownloadStatus, TableTunnel
from odps.types import OdpsSchema, PartitionSpec, timestamp_ntz
from odps.utils import call_with_retry

try:
    import pyarrow.compute as pac
except ImportError:
    pac = None

from ...config import options
from ...env import ODPS_STORAGE_API_ENDPOINT
from ...utils import is_empty, sync_pyodps_options
from .schema import odps_schema_to_arrow_schema

PartitionsType = Union[List[str], str, None]

_DEFAULT_ROW_BATCH_SIZE = 4096
_DOWNLOAD_ID_CACHE_SIZE = 100


class ODPSTableIO(ABC):
    def __new__(cls, odps: ODPS):
        if cls is ODPSTableIO:
            if options.use_common_table or ODPS_STORAGE_API_ENDPOINT in os.environ:
                return HaloTableIO(odps)
            else:
                return TunnelTableIO(odps)
        return super().__new__(cls)

    def __init__(self, odps: ODPS):
        self._odps = odps

    @classmethod
    def _get_reader_schema(
        cls,
        table_schema: OdpsSchema,
        columns: Optional[List[str]] = None,
        partition_columns: Union[None, bool, List[str]] = None,
    ) -> OdpsSchema:
        final_cols = []

        columns = (
            columns
            if not is_empty(columns)
            else [col.name for col in table_schema.simple_columns]
        )
        if partition_columns is True:
            partition_columns = [c.name for c in table_schema.partitions]
        else:
            partition_columns = partition_columns or []

        for col_name in columns + partition_columns:
            final_cols.append(table_schema[col_name])
        return OdpsSchema(final_cols)

    @abstractmethod
    def open_reader(
        self,
        full_table_name: str,
        partitions: PartitionsType = None,
        columns: Optional[List[str]] = None,
        partition_columns: Union[None, bool, List[str]] = None,
        start: Optional[int] = None,
        stop: Optional[int] = None,
        reverse_range: bool = False,
        row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
    ):
        raise NotImplementedError

    @abstractmethod
    def open_writer(
        self,
        full_table_name: str,
        partition: Optional[str] = None,
        overwrite: bool = True,
    ):
        raise NotImplementedError


class TunnelMultiPartitionReader:
    def __init__(
        self,
        odps_entry: ODPS,
        table_name: str,
        partitions: PartitionsType,
        columns: Optional[List[str]] = None,
        partition_columns: Optional[List[str]] = None,
        start: Optional[int] = None,
        count: Optional[int] = None,
        partition_to_download_ids: Dict[str, str] = None,
    ):
        self._odps_entry = odps_entry
        self._table = odps_entry.get_table(table_name)
        self._columns = columns

        odps_schema = ODPSTableIO._get_reader_schema(
            self._table.table_schema, columns, partition_columns
        )
        self._schema = odps_schema_to_arrow_schema(odps_schema)

        self._start = start or 0
        self._count = count
        self._row_left = count

        self._cur_reader = None
        self._reader_iter = None
        self._cur_partition_id = -1
        self._reader_start_pos = 0

        if partitions is None:
            if not self._table.table_schema.partitions:
                self._partitions = [None]
            else:
                self._partitions = [str(pt) for pt in self._table.partitions]
        elif isinstance(partitions, str):
            self._partitions = [partitions]
        else:
            self._partitions = partitions

        self._partition_cols = partition_columns
        self._partition_to_download_ids = partition_to_download_ids or dict()

    @property
    def count(self) -> Optional[int]:
        if len(self._partitions) > 1:
            return None
        return self._count

    def _open_next_reader(self):
        if self._cur_reader is not None:
            self._reader_start_pos += self._cur_reader.count

        if (
            self._row_left is not None and self._row_left <= 0
        ) or 1 + self._cur_partition_id >= len(self._partitions):
            self._cur_reader = None
            return

        while 1 + self._cur_partition_id < len(self._partitions):
            self._cur_partition_id += 1

            part_str = self._partitions[self._cur_partition_id]
            req_columns = self._schema.names
            with sync_pyodps_options():
                self._cur_reader = self._table.open_reader(
                    part_str,
                    columns=req_columns,
                    arrow=True,
                    download_id=self._partition_to_download_ids.get(part_str),
                    append_partitions=True,
                )
            if self._cur_reader.count + self._reader_start_pos > self._start:
                start = self._start - self._reader_start_pos
                if self._count is None:
                    count = None
                else:
                    count = min(self._count, self._cur_reader.count - start)

                with sync_pyodps_options():
                    self._reader_iter = self._cur_reader.read(start, count)
                break
            self._reader_start_pos += self._cur_reader.count
        else:
            self._cur_reader = None

    def read(self):
        with sync_pyodps_options():
            if self._cur_reader is None:
                self._open_next_reader()
                if self._cur_reader is None:
                    return None
            while self._cur_reader is not None:
                try:
                    batch = next(self._reader_iter)
                    if batch is not None:
                        if self._row_left is not None:
                            self._row_left -= batch.num_rows
                        return batch
                except StopIteration:
                    self._open_next_reader()
            return None

    def read_all(self) -> pa.Table:
        batches = []
        while True:
            batch = self.read()
            if batch is None:
                break
            batches.append(batch)
        if not batches:
            return self._schema.empty_table()
        return pa.Table.from_batches(batches)


class TunnelTableIO(ODPSTableIO):
    _down_session_ids = OrderedDict()

    @classmethod
    def create_download_sessions(
        cls,
        odps_entry: ODPS,
        full_table_name: str,
        partitions: List[Optional[str]] = None,
    ) -> Dict[Optional[str], TableDownloadSession]:
        table = odps_entry.get_table(full_table_name)
        tunnel = TableTunnel(odps_entry, quota_name=options.tunnel_quota_name)
        parts = (
            [partitions]
            if partitions is None or isinstance(partitions, str)
            else partitions
        )
        part_to_session = dict()
        for part in parts:
            part_key = (full_table_name, part)
            down_session = None

            if part_key in cls._down_session_ids:
                down_id = cls._down_session_ids[part_key]
                down_session = tunnel.create_download_session(
                    table, async_mode=True, partition_spec=part, download_id=down_id
                )
                if down_session.status != TableDownloadStatus.Normal:
                    down_session = None

            if down_session is None:
                down_session = tunnel.create_download_session(
                    table, async_mode=True, partition_spec=part
                )

            while len(cls._down_session_ids) >= _DOWNLOAD_ID_CACHE_SIZE:
                cls._down_session_ids.popitem(False)
            cls._down_session_ids[part_key] = down_session.id
            part_to_session[part] = down_session
        return part_to_session

    @contextmanager
    def open_reader(
        self,
        full_table_name: str,
        partitions: PartitionsType = None,
        columns: Optional[List[str]] = None,
        partition_columns: Union[None, bool, List[str]] = None,
        start: Optional[int] = None,
        stop: Optional[int] = None,
        reverse_range: bool = False,
        row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
    ):
        with sync_pyodps_options():
            table = self._odps.get_table(full_table_name)

        if partition_columns is True:
            partition_columns = [c.name for c in table.table_schema.partitions]

        total_records = None
        part_to_down_id = None
        if (
            (start is not None and start < 0)
            or (stop is not None and stop < 0)
            or (reverse_range and start is None)
        ):
            with sync_pyodps_options():
                tunnel_sessions = self.create_download_sessions(
                    self._odps, full_table_name, partitions
                )
                part_to_down_id = {
                    pt: session.id for (pt, session) in tunnel_sessions.items()
                }
                total_records = sum(
                    session.count for session in tunnel_sessions.values()
                )

        count = None
        if start is not None or stop is not None:
            if reverse_range:
                start = start if start is not None else total_records - 1
                stop = stop if stop is not None else -1
            else:
                start = start if start is not None else 0
                stop = stop if stop is not None else None
            start = start if start >= 0 else total_records + start
            stop = stop if stop is None or stop >= 0 else total_records + stop
            if reverse_range:
                count = start - stop
                start = stop + 1
            else:
                count = stop - start if stop is not None and start is not None else None

        yield TunnelMultiPartitionReader(
            self._odps,
            full_table_name,
            partitions=partitions,
            columns=columns,
            partition_columns=partition_columns,
            start=start,
            count=count,
            partition_to_download_ids=part_to_down_id,
        )

    @contextmanager
    def open_writer(
        self,
        full_table_name: str,
        partition: Optional[str] = None,
        overwrite: bool = True,
    ):
        table = self._odps.get_table(full_table_name)
        with sync_pyodps_options():
            with table.open_writer(
                partition=partition,
                arrow=True,
                create_partition=partition is not None,
                overwrite=overwrite,
            ) as writer:
                yield writer


class HaloTableArrowReader:
    def __init__(
        self,
        client: StorageApiArrowClient,
        scan_info: TableBatchScanResponse,
        odps_schema: OdpsSchema,
        start: Optional[int] = None,
        count: Optional[int] = None,
        row_batch_size: Optional[int] = None,
    ):
        self._client = client
        self._scan_info = scan_info

        self._cur_split_id = -1
        self._cur_reader = None

        self._odps_schema = odps_schema
        self._arrow_schema = odps_schema_to_arrow_schema(odps_schema)

        self._start = start
        self._count = count
        self._cursor = 0
        self._row_batch_size = row_batch_size

    @property
    def count(self) -> int:
        return self._count

    def _open_next_reader(self):
        from odps.apis.storage_api import ReadRowsRequest

        if 0 <= self._scan_info.split_count <= self._cur_split_id + 1:
            # scan by split
            self._cur_reader = None
            return
        elif self._count is not None and self._cursor >= self._count:
            # scan by range
            self._cur_reader = None
            return

        read_rows_kw = {}
        if self._start is not None:
            read_rows_kw["row_index"] = self._start + self._cursor
            read_rows_kw["row_count"] = min(
                self._row_batch_size, self._count - self._cursor
            )
            self._cursor = min(self._count, self._cursor + self._row_batch_size)

        req = ReadRowsRequest(
            session_id=self._scan_info.session_id,
            split_index=self._cur_split_id + 1,
            **read_rows_kw,
        )
        self._cur_reader = call_with_retry(self._client.read_rows_arrow, req)
        self._cur_split_id += 1

    def _convert_timezone(self, batch: pa.RecordBatch) -> pa.RecordBatch:
        timezone = options.local_timezone
        if not any(isinstance(tp, pa.TimestampType) for tp in batch.schema.types):
            return batch

        cols = []
        for idx in range(batch.num_columns):
            col = batch.column(idx)
            name = batch.schema.names[idx]
            if not isinstance(col.type, pa.TimestampType):
                cols.append(col)
                continue
            if self._odps_schema[name].type == timestamp_ntz:
                col = col.cast(pa.timestamp(col.type.unit))
                cols.append(col)
                continue

            if hasattr(pac, "local_timestamp"):
                col = col.cast(pa.timestamp(col.type.unit, timezone))
            else:
                pd_col = col.to_pandas().dt.tz_convert(timezone)
                col = pa.Array.from_pandas(pd_col).cast(
                    pa.timestamp(col.type.unit, timezone)
                )
            cols.append(col)

        return pa.RecordBatch.from_arrays(cols, names=batch.schema.names)

    def read(self):
        if self._cur_reader is None:
            self._open_next_reader()
            if self._cur_reader is None:
                return None
        while self._cur_reader is not None:
            batch = self._cur_reader.read()
            if batch is not None:
                return self._convert_timezone(batch)
            self._open_next_reader()
        return None

    def read_all(self) -> pa.Table:
        batches = []
        while True:
            batch = self.read()
            if batch is None:
                break
            batches.append(batch)
        if not batches:
            return self._arrow_schema.empty_table()
        return pa.Table.from_batches(batches)


class HaloTableArrowWriter:
    def __init__(
        self,
        client: StorageApiArrowClient,
        write_info: TableBatchWriteResponse,
        odps_schema: OdpsSchema,
    ):
        self._client = client
        self._write_info = write_info
        self._odps_schema = odps_schema
        self._arrow_schema = odps_schema_to_arrow_schema(odps_schema)

        self._writer = None

    def open(self):
        from odps.apis.storage_api import WriteRowsRequest

        self._writer = call_with_retry(
            self._client.write_rows_arrow,
            WriteRowsRequest(self._write_info.session_id),
        )

    @classmethod
    def _localize_timezone(cls, col, tz=None):
        from odps.lib import tzlocal

        if tz is None:
            if options.local_timezone is None:
                tz = str(tzlocal.get_localzone())
            else:
                tz = str(options.local_timezone)

        if col.type.tz is not None:
            return col
        if hasattr(pac, "assume_timezone"):
            col = pac.assume_timezone(col, tz)
            return col
        else:
            col = col.to_pandas()
            return pa.Array.from_pandas(col.dt.tz_localize(tz))

    def _convert_schema(self, batch: pa.RecordBatch):
        if batch.schema == self._arrow_schema and not any(
            isinstance(tp, pa.TimestampType) for tp in self._arrow_schema.types
        ):
            return batch
        cols = []
        for idx in range(batch.num_columns):
            col = batch.column(idx)
            name = batch.schema.names[idx]

            if isinstance(col.type, pa.TimestampType):
                if self._odps_schema[name].type == timestamp_ntz:
                    col = self._localize_timezone(col, "UTC")
                else:
                    col = self._localize_timezone(col)

            if col.type != self._arrow_schema.types[idx]:
                col = col.cast(self._arrow_schema.types[idx])
            cols.append(col)
        return pa.RecordBatch.from_arrays(cols, names=batch.schema.names)

    def write(self, batch):
        if isinstance(batch, pa.Table):
            for b in batch.to_batches():
                self._writer.write(self._convert_schema(b))
        else:
            self._writer.write(self._convert_schema(batch))

    def close(self):
        commit_msg, is_success = self._writer.finish()
        if not is_success:
            raise IOError(commit_msg)
        return commit_msg


class HaloTableIO(ODPSTableIO):
    _storage_api_endpoint = os.getenv(ODPS_STORAGE_API_ENDPOINT)

    @staticmethod
    def _convert_partitions(partitions: PartitionsType) -> Optional[List[str]]:
        if partitions is None:
            return []
        elif isinstance(partitions, (str, PartitionSpec)):
            partitions = [partitions]
        return [
            "/".join(f"{k}={v}" for k, v in PartitionSpec(pt).items())
            for pt in partitions
        ]

    @contextmanager
    def open_reader(
        self,
        full_table_name: str,
        partitions: PartitionsType = None,
        columns: Optional[List[str]] = None,
        partition_columns: Union[None, bool, List[str]] = None,
        start: Optional[int] = None,
        stop: Optional[int] = None,
        reverse_range: bool = False,
        row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
    ):
        from odps.apis.storage_api import (
            SessionRequest,
            SessionStatus,
            SplitOptions,
            TableBatchScanRequest,
        )

        table = self._odps.get_table(full_table_name)
        client = StorageApiArrowClient(
            self._odps,
            table,
            rest_endpoint=self._storage_api_endpoint,
            quota_name=options.tunnel_quota_name,
        )

        split_option = SplitOptions.SplitMode.SIZE
        if start is not None or stop is not None:
            split_option = SplitOptions.SplitMode.ROW_OFFSET

        scan_kw = {
            "required_partitions": self._convert_partitions(partitions),
            "split_options": SplitOptions.get_default_options(split_option),
        }
        columns = columns or [c.name for c in table.table_schema.simple_columns]
        scan_kw["required_data_columns"] = columns
        if partition_columns is True:
            scan_kw["required_partition_columns"] = [
                c.name for c in table.table_schema.partitions
            ]
        else:
            scan_kw["required_partition_columns"] = partition_columns

        # todo add more options for partition column handling
        req = TableBatchScanRequest(**scan_kw)
        resp = call_with_retry(client.create_read_session, req)

        session_id = resp.session_id
        status = resp.session_status
        while status == SessionStatus.INIT:
            resp = call_with_retry(client.get_read_session, SessionRequest(session_id))
            status = resp.session_status
            time.sleep(1.0)

        assert status == SessionStatus.NORMAL

        count = None
        if start is not None or stop is not None:
            if reverse_range:
                start = start if start is not None else resp.record_count - 1
                stop = stop if stop is not None else -1
            else:
                start = start if start is not None else 0
                stop = stop if stop is not None else resp.record_count
            start = start if start >= 0 else resp.record_count + start
            stop = stop if stop >= 0 else resp.record_count + stop
            if reverse_range:
                count = start - stop
                start = stop + 1
            else:
                count = stop - start

        reader_schema = self._get_reader_schema(
            table.table_schema, columns, partition_columns
        )
        yield HaloTableArrowReader(
            client,
            resp,
            odps_schema=reader_schema,
            start=start,
            count=count,
            row_batch_size=row_batch_size,
        )

    @contextmanager
    def open_writer(
        self,
        full_table_name: str,
        partition: Optional[str] = None,
        overwrite: bool = True,
    ):
        from odps.apis.storage_api import (
            SessionRequest,
            SessionStatus,
            TableBatchWriteRequest,
        )

        table = self._odps.get_table(full_table_name)
        client = StorageApiArrowClient(
            self._odps,
            table,
            rest_endpoint=self._storage_api_endpoint,
            quota_name=options.tunnel_quota_name,
        )

        part_strs = self._convert_partitions(partition)
        part_str = part_strs[0] if part_strs else None
        req = TableBatchWriteRequest(partition_spec=part_str, overwrite=overwrite)
        resp = call_with_retry(client.create_write_session, req)

        session_id = resp.session_id
        writer = HaloTableArrowWriter(client, resp, table.table_schema)
        writer.open()

        yield writer

        commit_msg = writer.close()
        resp = call_with_retry(
            client.commit_write_session,
            SessionRequest(session_id=session_id),
            [commit_msg],
        )
        while resp.session_status == SessionStatus.COMMITTING:
            resp = call_with_retry(
                client.get_write_session, SessionRequest(session_id=session_id)
            )
        assert resp.session_status == SessionStatus.COMMITTED
