#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import functools
import json
import struct
import sys
import warnings
from collections import OrderedDict
from decimal import Decimal
from io import BytesIO, IOBase, StringIO

try:
    import numpy as np
except ImportError:
    np = None
try:
    import pandas as pd
except (ImportError, ValueError):
    pd = None
try:
    import pyarrow as pa
except (AttributeError, ImportError):
    pa = None
try:
    import pyarrow.compute as pac
except (AttributeError, ImportError):
    pac = None

from ... import compat, types, utils
from ...config import options
from ...errors import ChecksumError, DatetimeOverflowError
from ...lib.monotonic import monotonic
from ...models import Record
from ...readers import AbstractRecordReader
from ...types import PartitionSpec
from ..base import TunnelMetrics
from ..checksum import Checksum
from ..pb import wire_format
from ..pb.decoder import Decoder
from ..pb.errors import DecodeError
from ..wireconstants import ProtoWireConstants
from .types import odps_schema_to_arrow_schema

try:
    if not options.force_py:
        from .reader_c import BaseTunnelRecordReader, convert_legacy_decimal_bytes
    else:
        BaseTunnelRecordReader = convert_legacy_decimal_bytes = None
except ImportError as e:
    if options.force_c:
        raise e
    BaseTunnelRecordReader = convert_legacy_decimal_bytes = None

MICRO_SEC_PER_SEC = 1000000

if BaseTunnelRecordReader is None:

    class BaseTunnelRecordReader(AbstractRecordReader):
        def __init__(
            self,
            schema,
            stream_creator,
            columns=None,
            partition_spec=None,
            append_partitions=False,
        ):
            self._schema = schema
            if columns is None:
                self._columns = (
                    self._schema.columns
                    if append_partitions
                    else self._schema.simple_columns
                )
            else:
                self._columns = [self._schema[c] for c in columns]

            self._enable_client_metrics = options.tunnel.enable_client_metrics
            self._server_metrics_string = None
            self._local_wall_time_ms = 0
            self._acc_network_time_ms = 0
            self._injected_error = None

            self._curr_cursor = 0
            self._stream_creator = stream_creator
            self._reader = None
            self._reopen_reader()

            if self._enable_client_metrics:
                ts = monotonic()

            self._read_limit = options.table_read_limit
            self._to_datetime = utils.MillisecondsConverter().from_milliseconds
            self._to_datetime_utc = utils.MillisecondsConverter(
                local_tz=False
            ).from_milliseconds
            self._to_date = utils.to_date
            self._partition_spec = (
                PartitionSpec(partition_spec) if partition_spec else None
            )
            self._append_partitions = append_partitions

            map_as_ordered_dict = options.map_as_ordered_dict
            if map_as_ordered_dict is None:
                map_as_ordered_dict = sys.version_info[:2] <= (3, 6)
            self._map_dict_hook = OrderedDict if map_as_ordered_dict else dict

            struct_as_ordered_dict = options.struct_as_ordered_dict
            if struct_as_ordered_dict is None:
                struct_as_ordered_dict = sys.version_info[:2] <= (3, 6)
            self._struct_dict_hook = OrderedDict if struct_as_ordered_dict else dict

            if self._enable_client_metrics:
                self._local_wall_time_ms += compat.long_type(
                    MICRO_SEC_PER_SEC * (monotonic() - ts)
                )

        def _mode(self):
            return "py"

        @property
        def count(self):
            return self._curr_cursor

        @property
        def _network_wall_time_ms(self):
            return self._reader.network_wall_time_ms + self._acc_network_time_ms

        def _inject_error(self, cursor, exc):
            self._injected_error = (cursor, exc)

        def _reopen_reader(self):
            if self._enable_client_metrics:
                ts = monotonic()

            stream = self._stream_creator(self._curr_cursor)
            if self._enable_client_metrics:
                self._acc_network_time_ms += compat.long_type(
                    MICRO_SEC_PER_SEC * (monotonic() - ts)
                )
                if self._reader is not None:
                    self._acc_network_time_ms += self._reader.network_wall_time_ms

            self._last_n_bytes = len(self._reader) if self._curr_cursor != 0 else 0
            self._reader = Decoder(
                stream, record_network_time=self._enable_client_metrics
            )
            self._crc = Checksum()
            self._crccrc = Checksum()
            self._attempt_row_count = 0

            if self._enable_client_metrics:
                self._local_wall_time_ms += compat.long_type(
                    MICRO_SEC_PER_SEC * (monotonic() - ts)
                )

        def _read_field(self, data_type):
            if data_type == types.float_:
                val = self._reader.read_float()
                self._crc.update_float(val)
            elif data_type == types.double:
                val = self._reader.read_double()
                self._crc.update_double(val)
            elif data_type == types.boolean:
                val = self._reader.read_bool()
                self._crc.update_bool(val)
            elif data_type in types.integer_types:
                val = self._reader.read_sint64()
                self._crc.update_long(val)
            elif data_type == types.string:
                val = self._reader.read_string()
                self._crc.update(val)
            elif data_type == types.binary:
                val = self._reader.read_string()
                self._crc.update(val)
            elif data_type == types.datetime:
                val = self._reader.read_sint64()
                self._crc.update_long(val)
                try:
                    val = self._to_datetime(val)
                except DatetimeOverflowError:
                    if not options.tunnel.overflow_date_as_none:
                        raise
                    val = None
            elif data_type == types.date:
                val = self._reader.read_sint64()
                self._crc.update_long(val)
                val = self._to_date(val)
            elif data_type == types.timestamp or data_type == types.timestamp_ntz:
                to_datetime = (
                    self._to_datetime_utc
                    if data_type == types.timestamp_ntz
                    else self._to_datetime
                )
                l_val = self._reader.read_sint64()
                self._crc.update_long(l_val)
                nano_secs = self._reader.read_sint32()
                self._crc.update_int(nano_secs)
                if pd is None:
                    raise ImportError(
                        "To use TIMESTAMP in pyodps, you need to install pandas."
                    )
                try:
                    val = pd.Timestamp(to_datetime(l_val * 1000)) + pd.Timedelta(
                        nanoseconds=nano_secs
                    )
                except DatetimeOverflowError:
                    if not options.tunnel.overflow_date_as_none:
                        raise
                    val = None
            elif data_type == types.interval_day_time:
                l_val = self._reader.read_sint64()
                self._crc.update_long(l_val)
                nano_secs = self._reader.read_sint32()
                self._crc.update_int(nano_secs)
                if pd is None:
                    raise ImportError(
                        "To use INTERVAL_DAY_TIME in pyodps, you need to install pandas."
                    )
                val = pd.Timedelta(seconds=l_val, nanoseconds=nano_secs)
            elif data_type == types.interval_year_month:
                l_val = self._reader.read_sint64()
                self._crc.update_long(l_val)
                return compat.Monthdelta(l_val)
            elif data_type == types.json:
                sval = self._reader.read_string()
                val = json.loads(sval)
                self._crc.update(sval)
            elif isinstance(data_type, (types.Char, types.Varchar)):
                val = self._reader.read_string()
                self._crc.update(val)
            elif isinstance(data_type, types.Decimal):
                val = self._reader.read_string()
                self._crc.update(val)
            elif isinstance(data_type, types.Array):
                val = self._read_array(data_type.value_type)
            elif isinstance(data_type, types.Map):
                keys = self._read_array(data_type.key_type)
                values = self._read_array(data_type.value_type)
                val = self._map_dict_hook(zip(keys, values))
            elif isinstance(data_type, types.Struct):
                val = self._read_struct(data_type)
            else:
                raise IOError("Unsupported type %s" % data_type)
            return val

        def _read_array(self, value_type):
            res = []

            size = self._reader.read_uint32()
            for _ in range(size):
                if self._reader.read_bool():
                    res.append(None)
                else:
                    res.append(self._read_field(value_type))

            return res

        def _read_struct(self, value_type):
            res_list = [None] * len(value_type.field_types)
            for idx, field_type in enumerate(value_type.field_types.values()):
                if not self._reader.read_bool():
                    res_list[idx] = self._read_field(field_type)
            if options.struct_as_dict:
                return self._struct_dict_hook(
                    zip(value_type.field_types.keys(), res_list)
                )
            else:
                return value_type.namedtuple_type(*res_list)

        def _read_single_record(self):
            if (
                self._injected_error is not None
                and self._injected_error[0] == self._curr_cursor
            ):
                self._injected_error = None
                raise self._injected_error[1]

            if self._read_limit is not None and self.count >= self._read_limit:
                warnings.warn(
                    "Number of lines read via tunnel already reaches the limitation.",
                    RuntimeWarning,
                )
                return None

            record = Record(self._columns, max_field_size=(1 << 63) - 1)

            while True:
                index, _ = self._reader.read_field_number_and_wire_type()

                if index == 0:
                    continue
                if index == ProtoWireConstants.TUNNEL_END_RECORD:
                    checksum = utils.long_to_int(self._crc.getvalue())
                    if int(self._reader.read_uint32()) != utils.int_to_uint(checksum):
                        raise ChecksumError("Checksum invalid")
                    self._crc.reset()
                    self._crccrc.update_int(checksum)
                    break

                if index == ProtoWireConstants.TUNNEL_META_COUNT:
                    if self._attempt_row_count != self._reader.read_sint64():
                        raise IOError("count does not match")

                    (
                        idx_of_checksum,
                        wire_type,
                    ) = self._reader.read_field_number_and_wire_type()
                    if ProtoWireConstants.TUNNEL_META_CHECKSUM != idx_of_checksum:
                        if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED:
                            raise IOError("Invalid stream data.")
                        self._crc.update_int(idx_of_checksum)

                        self._server_metrics_string = self._reader.read_string()
                        self._crc.update(self._server_metrics_string)

                        idx_of_checksum = (
                            self._reader.read_field_number_and_wire_type()[0]
                        )
                        if idx_of_checksum != ProtoWireConstants.TUNNEL_END_METRICS:
                            raise IOError("Invalid stream data.")
                        checksum = utils.long_to_int(self._crc.getvalue())
                        self._crc.reset()
                        if utils.int_to_uint(checksum) != int(
                            self._reader.read_uint32()
                        ):
                            raise ChecksumError("Checksum invalid.")

                        idx_of_checksum = (
                            self._reader.read_field_number_and_wire_type()[0]
                        )
                    if ProtoWireConstants.TUNNEL_META_CHECKSUM != idx_of_checksum:
                        raise IOError("Invalid stream data.")
                    if int(self._crccrc.getvalue()) != self._reader.read_uint32():
                        raise ChecksumError("Checksum invalid.")

                    return

                if index > len(self._columns):
                    raise IOError(
                        "Invalid protobuf tag. Perhaps the datastream "
                        "from server is crushed."
                    )

                self._crc.update_int(index)

                i = index - 1
                record[i] = self._read_field(self._columns[i].type)

            if self._append_partitions and self._partition_spec is not None:
                for k, v in self._partition_spec.items():
                    try:
                        record[k] = v
                    except KeyError:
                        # skip non-existing fields
                        pass

            self._curr_cursor += 1
            self._attempt_row_count += 1
            return record

        def read(self):
            if self._enable_client_metrics:
                ts = monotonic()
            result = utils.call_with_retry(
                self._read_single_record, reset_func=self._reopen_reader
            )
            if self._enable_client_metrics:
                self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts))
            return result

        def reads(self):
            return self.__iter__()

        @property
        def n_bytes(self):
            return self._last_n_bytes + len(self._reader)

        def get_total_bytes(self):
            return self.n_bytes


class TunnelRecordReader(BaseTunnelRecordReader, AbstractRecordReader):
    """
    Reader object to read data from ODPS in records. Should be created
    with :meth:`TableDownloadSession.open_record_reader`.

    :Example:

    .. code-block:: python

        from odps.tunnel import TableTunnel

            tunnel = TableTunnel(o)
            download_session = tunnel.create_download_session('my_table', partition_spec='pt=test')

            # create a TunnelRecordReader
            with download_session.open_record_reader(0, download_session.count) as reader:
                for record in reader:
                    print(record.values)
    """

    def __next__(self):
        record = self.read()
        if record is None:
            raise StopIteration
        return record

    next = __next__

    @property
    def schema(self):
        return self._schema

    @property
    def metrics(self):
        if self._server_metrics_string is None:
            return None
        return TunnelMetrics.from_server_json(
            type(self).__name__,
            self._server_metrics_string,
            self._local_wall_time_ms,
            self._network_wall_time_ms,
        )

    def close(self):
        if hasattr(self._schema, "close"):
            self._schema.close()

    def __enter__(self):
        return self

    def __exit__(self, *_):
        self.close()


try:
    TunnelRecordReader.read.__doc__ = """
    Read next record.

    :return: A record object
    :rtype: :class:`~odps.models.Record`
    """
except:
    pass


class ArrowStreamReader(IOBase):
    def __init__(self, raw_reader, arrow_schema):
        self._reader = raw_reader

        self._crc = Checksum()
        self._crccrc = Checksum()
        self._pos = 0
        self._chunk_size = None

        self._buffers = collections.deque()
        self._buffers.append(BytesIO(arrow_schema.serialize().to_pybytes()))

    def __len__(self):
        return len(self._reader)

    def readable(self):
        return True

    @staticmethod
    def _read_unint32(b):
        return struct.unpack("!I", b)

    def _read_chunk_size(self):
        try:
            i = self._read_unint32(self._reader.read(4))
            self._pos += 4
            return i[0]  # unpack() result is a 1-element tuple.
        except struct.error as e:
            raise DecodeError(e)

    def _read_chunk(self):
        read_size = self._chunk_size + 4
        b = self._reader.read(read_size)
        if 0 < len(b) < 4:
            raise ChecksumError("Checksum invalid")
        self._pos += len(b)
        self._crc.update(b[:-4])
        self._crccrc.update(b[:-4])
        return b

    def _fill_next_buffer(self):
        if self._chunk_size is None:
            self._chunk_size = self._read_chunk_size()

        b = self._read_chunk()
        data = b[:-4]
        crc_data = b[-4:]
        if len(b) == 0:
            return

        if len(b) < self._chunk_size + 4:
            # is last chunk
            read_checksum = self._read_unint32(crc_data)[0]
            checksum = int(self._crccrc.getvalue())
            if checksum != read_checksum:
                raise ChecksumError("Checksum invalid")
            self._pos += len(data) + 4
            self._buffers.append(BytesIO(data))
            self._crccrc.reset()
        else:
            checksum = int(self._crc.getvalue())
            read_checksum = self._read_unint32(crc_data)[0]
            if checksum != read_checksum:
                raise ChecksumError("Checksum invalid")
            self._crc.reset()
            self._buffers.append(BytesIO(data))

    def read(self, nbytes=None):
        tot_size = 0
        bufs = []
        while nbytes is None or tot_size < nbytes:
            if not self._buffers:
                self._fill_next_buffer()
                if not self._buffers:
                    break

            to_read = nbytes - tot_size if nbytes is not None else None
            buf = self._buffers[0].read(to_read)
            if not buf:
                self._buffers.popleft()
            else:
                bufs.append(buf)
                tot_size += len(buf)
        if len(bufs) == 1:
            return bufs[0]
        return b"".join(bufs)

    def close(self):
        if hasattr(self._reader, "close"):
            self._reader.close()


class TunnelArrowReader(object):
    """
    Reader object to read data from ODPS in Arrow format. Should be created
    with :meth:`TableDownloadSession.open_arrow_reader`.

    :Example:

    .. code-block:: python

        from odps.tunnel import TableTunnel

            tunnel = TableTunnel(o)
            download_session = tunnel.create_download_session('my_table', partition_spec='pt=test')

            # create a TunnelArrowReader
            with download_session.open_arrow_reader(0, download_session.count) as reader:
                for batch in reader:
                    print(batch.to_pandas())
    """

    def __init__(
        self,
        schema,
        stream_creator,
        columns=None,
        partition_spec=None,
        append_partitions=False,
        use_ipc_stream=False,
    ):
        if pa is None:
            raise ValueError("To use arrow reader you need to install pyarrow")

        self._raw_schema = schema

        raw_arrow_schema = odps_schema_to_arrow_schema(schema)
        if columns is None:
            self._schema = schema
            self._arrow_schema = self._raw_arrow_schema = raw_arrow_schema
        else:
            self._schema = types.OdpsSchema([schema[c] for c in columns])
            self._raw_arrow_schema = pa.schema(
                [s for s in raw_arrow_schema if s.name in columns]
            )
            self._arrow_schema = odps_schema_to_arrow_schema(self._schema)
        self._columns = columns

        self._append_partitions = append_partitions
        self._partition_spec = partition_spec

        self._pos = 0
        self._stream_creator = stream_creator
        self._use_ipc_stream = use_ipc_stream
        self._reopen_reader()

        self._to_datetime = utils.MillisecondsConverter().from_milliseconds
        self._read_limit = options.table_read_limit

        self.closed = False

        self._pd_column_converters = dict()
        # True if need to convert numeric columns with None as float types
        self._coerce_numpy_columns = set()
        for col in schema.simple_columns:
            arrow_type = self._arrow_schema.field(
                self._arrow_schema.get_field_index(col.name)
            ).type
            if isinstance(col.type, (types.Map, types.Array, types.Struct)):
                self._pd_column_converters[col.name] = ArrowRecordFieldConverter(
                    col.type, convert_ts=False
                )
            elif options.tunnel.pd_cast_mode == "numpy" and (
                pa.types.is_integer(arrow_type) or pa.types.is_floating(arrow_type)
            ):
                self._coerce_numpy_columns.add(col.name)

        self._injected_error = None

    def _reopen_reader(self):
        self._last_n_bytes = len(self._reader) if self._pos != 0 else 0
        input_stream = self._stream_creator(self._pos)
        self._arrow_stream = None
        if self._use_ipc_stream:
            self._reader = input_stream
        else:
            self._reader = ArrowStreamReader(input_stream, self._raw_arrow_schema)

    def _inject_error(self, cursor, exc):
        self._injected_error = (cursor, exc)

    @property
    def schema(self):
        return self._schema

    def _read_next_raw_batch(self):
        if self._injected_error is not None and self._injected_error[0] <= self._pos:
            self._injected_error = None
            raise self._injected_error[1]

        if self._arrow_stream is None:
            self._arrow_stream = pa.ipc.open_stream(self._reader)

        if self._read_limit is not None and self._pos >= self._read_limit:
            warnings.warn(
                "Number of lines read via tunnel already reaches the limitation.",
                RuntimeWarning,
            )
            return None

        try:
            batch = self._arrow_stream.read_next_batch()
            self._pos += batch.num_rows
        except pa.ArrowTypeError as ex:
            if str(ex) != "Did not pass numpy.dtype object":
                raise
            else:
                raise pa.ArrowTypeError(
                    "Error caused by version mismatch. Try install numpy<1.20 or "
                    "upgrade your pyarrow version. Original message: " + str(ex)
                )
        except StopIteration:
            return None
        return batch

    def _convert_timezone(self, batch):
        from ...lib import tzlocal

        if not any(isinstance(tp, pa.TimestampType) for tp in batch.schema.types):
            return batch

        timezone = raw_timezone = options.local_timezone
        if timezone is True or timezone is None:
            timezone = str(tzlocal.get_localzone())

        cols = []
        for idx in range(batch.num_columns):
            col = batch.column(idx)
            name = batch.schema.names[idx]
            if not isinstance(col.type, pa.TimestampType):
                cols.append(col)
                continue

            if timezone is False or self._schema[name].type == types.timestamp_ntz:
                col = col.cast(pa.timestamp(col.type.unit))
            else:
                col = col.cast(pa.timestamp(col.type.unit, timezone))
                if raw_timezone is None or raw_timezone is True:
                    if hasattr(pac, "local_timestamp"):
                        col = pac.local_timestamp(col)
                    else:
                        col = pa.Array.from_pandas(
                            col.to_pandas().dt.tz_localize(None)
                        ).cast(pa.timestamp(col.type.unit))
            cols.append(col)

        return pa.RecordBatch.from_arrays(cols, names=batch.schema.names)

    def _append_partition_cols(self, batch):
        col_set = set(self._columns or [c.name for c in self._schema.columns])
        pt_obj = (
            types.PartitionSpec(self._partition_spec) if self._partition_spec else None
        )

        sel_col_set = set(self._columns or [])
        if pt_obj and any(c in sel_col_set for c in pt_obj.keys()):
            # append partitions selected in columns argument
            self._append_partitions = True
        if not pt_obj or not self._append_partitions:
            return batch

        batch_cols = list(batch.columns)
        batch_col_names = list(batch.schema.names)
        for key, val in pt_obj.items():
            if key not in col_set:
                continue
            val = types.validate_value(val, self._schema[key].type)
            batch_cols.append(pa.array(np.repeat([val], batch.num_rows)))
            batch_col_names.append(key)
        return pa.RecordBatch.from_arrays(batch_cols, names=batch_col_names)

    def read_next_batch(self):
        """
        Read next Arrow RecordBatch from tunnel.

        :return: Arrow RecordBatch
        """
        if self._reader is None:
            return None

        batch = utils.call_with_retry(
            self._read_next_raw_batch, reset_func=self._reopen_reader
        )
        if batch is None:
            return None

        batch = self._append_partition_cols(batch)
        if self._columns and self._columns != batch.schema.names:
            col_to_array = dict()
            col_name_set = set(self._columns)
            for name, arr in zip(batch.schema.names, batch.columns):
                if name not in col_name_set:
                    continue
                col_to_array[name] = arr
            arrays = [col_to_array[name] for name in self._columns]
            batch = pa.RecordBatch.from_arrays(arrays, names=self._columns)
        batch = self._convert_timezone(batch)
        return batch

    def read(self):
        """
        Read all data from tunnel and forms an Arrow Table.

        :return: Arrow Table
        """
        batches = []
        while True:
            batch = self.read_next_batch()
            if batch is None:
                break
            batches.append(batch)

        if not batches:
            return self._arrow_schema.empty_table()
        return pa.Table.from_batches(batches)

    def __iter__(self):
        return self

    def __next__(self):
        """
        Read next Arrow RecordBatch from tunnel.
        """
        batch = self.read_next_batch()
        if batch is None:
            raise StopIteration
        return batch

    @property
    def count(self):
        return self._pos

    @property
    def n_bytes(self):
        return self._last_n_bytes + len(self._reader)

    def get_total_bytes(self):
        return self.n_bytes

    def close(self):
        if hasattr(self._reader, "close"):
            self._reader.close()

    def _convert_batch_to_pandas(self, batch):
        series_list = []
        type_mapper = pd.ArrowDtype if options.tunnel.pd_cast_mode == "arrow" else None
        if not self._pd_column_converters and not self._coerce_numpy_columns:
            return batch.to_pandas(types_mapper=type_mapper)
        for col_name, arrow_column in zip(batch.schema.names, batch.columns):
            if col_name not in self._pd_column_converters:
                series = arrow_column.to_pandas(types_mapper=type_mapper)
                if col_name in self._coerce_numpy_columns and series.dtype == np.dtype(
                    "O"
                ):
                    if pa.types.is_floating(arrow_column.type):
                        col_dtype = arrow_column.type.to_pandas_dtype()
                    else:
                        col_dtype = np.dtype(float)
                    series = series.astype(col_dtype)
                series_list.append(series)
            else:
                try:
                    series = arrow_column.to_pandas(types_mapper=type_mapper)
                except pa.ArrowNotImplementedError:
                    dtype = type_mapper(arrow_column.type) if type_mapper else None
                    series = pd.Series(
                        arrow_column.to_pylist(), name=col_name, dtype=dtype
                    )
                series_list.append(series.map(self._pd_column_converters[col_name]))
        return pd.concat(series_list, axis=1)

    def to_pandas(self):
        """
        Read all data from tunnel and convert to a Pandas DataFrame.
        """
        import pandas as pd

        batches = []
        while True:
            batch = self.read_next_batch()
            if batch is None:
                break
            batches.append(self._convert_batch_to_pandas(batch))

        if not batches:
            return self._arrow_schema.empty_table().to_pandas()
        return pd.concat(batches, axis=0, ignore_index=True)

    def __enter__(self):
        return self

    def __exit__(self, *_):
        self.close()


_reflective = lambda x: x


if convert_legacy_decimal_bytes is None:

    def convert_legacy_decimal_bytes(value):
        """
        Legacy decimal memory layout:
            int8_t  mNull;
            int8_t  mSign;
            int8_t  mIntg;
            int8_t  mFrac; only 0, 1, 2
            int32_t mData[6];
            int8_t mPadding[4]; //For Memory Align
        """
        if value is None:
            return None

        is_null, sign, intg, frac = struct.unpack("<4b", value[:4])
        if is_null:  # pragma: no cover
            return None
        if intg + frac == 0:
            return Decimal("0")

        sio = BytesIO() if compat.PY27 else StringIO()
        if sign > 0:
            sio.write("-")
        intg_nums = struct.unpack("<%dI" % intg, value[12 : 12 + intg * 4])
        intg_val = "".join("%09d" % d for d in reversed(intg_nums)).lstrip("0")
        sio.write(intg_val or "0")
        if frac > 0:
            sio.write(".")
            frac_nums = struct.unpack("<%dI" % frac, value[12 - frac * 4 : 12])
            sio.write("".join("%09d" % d for d in reversed(frac_nums)))
        return Decimal(sio.getvalue())


class ArrowRecordFieldConverter(object):
    _sensitive_types = (
        types.Datetime,
        types.Timestamp,
        types.TimestampNTZ,
        types.Array,
        types.Map,
        types.Struct,
    )

    def __init__(self, odps_type, arrow_type=None, convert_ts=True):
        self._mills_converter = utils.MillisecondsConverter()
        self._mills_converter_utc = utils.MillisecondsConverter(local_tz=False)
        self._convert_ts = convert_ts

        self._converter = self._build_converter(odps_type, arrow_type)

    def _convert_datetime(self, value):
        if value is None:
            return None
        mills = self._mills_converter.to_milliseconds(value)
        return self._mills_converter.from_milliseconds(mills)

    def _convert_ts_timestamp(self, value, ntz=False):
        if value is None:
            return None

        if not ntz:
            converter = self._mills_converter
        else:  # TimestampNtz
            converter = self._mills_converter_utc

        microsec = value.microsecond
        nanosec = value.nanosecond
        secs = int(converter.to_milliseconds(value.to_pydatetime()) / 1000)
        value = pd.Timestamp(converter.from_milliseconds(secs * 1000))
        return value.replace(microsecond=microsec, nanosecond=nanosec)

    def _convert_struct_timestamp(self, value, ntz=False):
        if value is None:
            return None

        if not ntz:
            converter = self._mills_converter
        else:  # TimestampNtz
            converter = self._mills_converter_utc

        ts = pd.Timestamp(converter.from_milliseconds(value["sec"] * 1000))
        nanos = value["nano"]
        return ts.replace(microsecond=nanos // 1000, nanosecond=nanos % 1000)

    @staticmethod
    def _convert_struct_timedelta(value):
        if value is None:
            return None

        nanos = value["nano"]
        return pd.Timedelta(
            seconds=value["sec"], microseconds=nanos // 1000, nanoseconds=nanos % 1000
        )

    @staticmethod
    def _convert_struct(value, field_converters, tuple_type, use_ordered_dict=False):
        if value is None:
            return None

        result_iter = ((k, field_converters[k](v)) for k, v in value.items())
        if tuple_type is not None:
            return tuple_type(**dict(result_iter))
        elif not use_ordered_dict:
            return dict(result_iter)
        else:
            d = dict(result_iter)
            return OrderedDict([(k, d[k]) for k in field_converters if k in d])

    def _build_converter(self, odps_type, arrow_type=None):
        import pyarrow as pa

        arrow_decimal_types = (pa.Decimal128Type,)
        if hasattr(pa, "Decimal256Type"):
            arrow_decimal_types += (pa.Decimal256Type,)

        if self._convert_ts and isinstance(odps_type, types.Datetime):
            return self._convert_datetime
        elif isinstance(odps_type, types.Timestamp):
            if isinstance(arrow_type, pa.StructType):
                return self._convert_struct_timestamp
            elif self._convert_ts:
                return self._convert_ts_timestamp
            else:
                return _reflective
        elif isinstance(odps_type, types.TimestampNTZ):
            if isinstance(arrow_type, pa.StructType):
                return functools.partial(self._convert_struct_timestamp, ntz=True)
            elif self._convert_ts:
                return functools.partial(self._convert_ts_timestamp, ntz=True)
            else:
                return _reflective
        elif (
            isinstance(odps_type, types.Decimal)
            and isinstance(arrow_type, pa.FixedSizeBinaryType)
            and not isinstance(arrow_type, arrow_decimal_types)
        ):
            return convert_legacy_decimal_bytes
        elif isinstance(odps_type, types.IntervalDayTime) and isinstance(
            arrow_type, pa.StructType
        ):
            return self._convert_struct_timedelta
        elif isinstance(odps_type, types.Array):
            arrow_value_type = getattr(arrow_type, "value_type", None)
            sub_converter = self._build_converter(
                odps_type.value_type, arrow_value_type
            )
            if sub_converter is _reflective:
                return _reflective
            return (
                lambda value: [sub_converter(x) for x in value]
                if value is not None
                else None
            )
        elif isinstance(odps_type, types.Map):
            arrow_key_type = getattr(arrow_type, "key_type", None)
            arrow_value_type = getattr(arrow_type, "item_type", None)

            key_converter = self._build_converter(odps_type.key_type, arrow_key_type)
            value_converter = self._build_converter(
                odps_type.value_type, arrow_value_type
            )
            dict_hook = OrderedDict if odps_type._use_ordered_dict else dict
            if key_converter is _reflective and value_converter is _reflective:
                return dict_hook

            return (
                lambda value: dict_hook(
                    [(key_converter(k), value_converter(v)) for k, v in value]
                )
                if value is not None
                else None
            )
        elif isinstance(odps_type, types.Struct):
            field_converters = OrderedDict()
            for field_name, field_type in odps_type.field_types.items():
                arrow_field_type = None
                if arrow_type is not None:
                    arrow_field_type = arrow_type[field_name].type
                field_converters[field_name] = self._build_converter(
                    field_type, arrow_field_type
                )

            if options.struct_as_dict:
                tuple_type = None
            else:
                tuple_type = odps_type.namedtuple_type
            use_ordered_dict = odps_type._use_ordered_dict
            return functools.partial(
                self._convert_struct,
                field_converters=field_converters,
                tuple_type=tuple_type,
                use_ordered_dict=use_ordered_dict,
            )
        else:
            return _reflective

    def __call__(self, value):
        if value is None:
            return None
        return self._converter(value)


class ArrowRecordReader(AbstractRecordReader):
    _complex_types_to_convert = (types.Array, types.Map, types.Struct)

    def __init__(self, arrow_reader, make_compat=True, read_all=False):
        self._arrow_reader = arrow_reader
        self._batch_pos = 0
        self._total_pos = 0
        self._cur_batch = None
        self._make_compat = make_compat

        self._field_converters = None

        if read_all:
            self._cur_batch = arrow_reader.read()

    def _convert_record(self, arrow_values):
        py_values = [x.as_py() for x in arrow_values]
        if not self._make_compat:
            return py_values
        else:
            return [
                converter(value)
                for value, converter in zip(py_values, self._field_converters)
            ]

    def read(self):
        if self._cur_batch is None or self._batch_pos >= self._cur_batch.num_rows:
            self._cur_batch = self._arrow_reader.read_next_batch()
            self._batch_pos = 0

            if self._cur_batch is None or self._cur_batch.num_rows == 0:
                return None

        if self._field_converters is None:
            table_schema = self._arrow_reader.schema
            self._field_converters = [
                ArrowRecordFieldConverter(table_schema[col_name].type, arrow_type)
                for col_name, arrow_type in zip(
                    self._cur_batch.schema.names, self._cur_batch.schema.types
                )
            ]
        tp = tuple(col[self._batch_pos] for col in self._cur_batch.columns)
        self._batch_pos += 1
        self._total_pos += 1

        return Record(schema=self.schema, values=self._convert_record(tp))

    def to_pandas(self, start=None, count=None, **kw):
        step = kw.get("step") or 1
        return self._arrow_reader.to_pandas().iloc[start : start + count : step]

    def close(self):
        self._arrow_reader.close()

    def __next__(self):
        rec = self.read()
        if rec is None:
            raise StopIteration
        return rec

    next = __next__

    @property
    def count(self):
        return self._total_pos

    @property
    def schema(self):
        return self._arrow_reader.schema

    def __enter__(self):
        return self

    def __exit__(self, *_):
        self.close()
