odps/tunnel/io/reader.py (888 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import functools
import json
import struct
import sys
import warnings
from collections import OrderedDict
from decimal import Decimal
from io import BytesIO, IOBase, StringIO
try:
import numpy as np
except ImportError:
np = None
try:
import pandas as pd
except (ImportError, ValueError):
pd = None
try:
import pyarrow as pa
except (AttributeError, ImportError):
pa = None
try:
import pyarrow.compute as pac
except (AttributeError, ImportError):
pac = None
from ... import compat, types, utils
from ...config import options
from ...errors import ChecksumError, DatetimeOverflowError
from ...lib.monotonic import monotonic
from ...models import Record
from ...readers import AbstractRecordReader
from ...types import PartitionSpec
from ..base import TunnelMetrics
from ..checksum import Checksum
from ..pb import wire_format
from ..pb.decoder import Decoder
from ..pb.errors import DecodeError
from ..wireconstants import ProtoWireConstants
from .types import odps_schema_to_arrow_schema
try:
if not options.force_py:
from .reader_c import BaseTunnelRecordReader, convert_legacy_decimal_bytes
else:
BaseTunnelRecordReader = convert_legacy_decimal_bytes = None
except ImportError as e:
if options.force_c:
raise e
BaseTunnelRecordReader = convert_legacy_decimal_bytes = None
MICRO_SEC_PER_SEC = 1000000
if BaseTunnelRecordReader is None:
class BaseTunnelRecordReader(AbstractRecordReader):
def __init__(
self,
schema,
stream_creator,
columns=None,
partition_spec=None,
append_partitions=False,
):
self._schema = schema
if columns is None:
self._columns = (
self._schema.columns
if append_partitions
else self._schema.simple_columns
)
else:
self._columns = [self._schema[c] for c in columns]
self._enable_client_metrics = options.tunnel.enable_client_metrics
self._server_metrics_string = None
self._local_wall_time_ms = 0
self._acc_network_time_ms = 0
self._injected_error = None
self._curr_cursor = 0
self._stream_creator = stream_creator
self._reader = None
self._reopen_reader()
if self._enable_client_metrics:
ts = monotonic()
self._read_limit = options.table_read_limit
self._to_datetime = utils.MillisecondsConverter().from_milliseconds
self._to_datetime_utc = utils.MillisecondsConverter(
local_tz=False
).from_milliseconds
self._to_date = utils.to_date
self._partition_spec = (
PartitionSpec(partition_spec) if partition_spec else None
)
self._append_partitions = append_partitions
map_as_ordered_dict = options.map_as_ordered_dict
if map_as_ordered_dict is None:
map_as_ordered_dict = sys.version_info[:2] <= (3, 6)
self._map_dict_hook = OrderedDict if map_as_ordered_dict else dict
struct_as_ordered_dict = options.struct_as_ordered_dict
if struct_as_ordered_dict is None:
struct_as_ordered_dict = sys.version_info[:2] <= (3, 6)
self._struct_dict_hook = OrderedDict if struct_as_ordered_dict else dict
if self._enable_client_metrics:
self._local_wall_time_ms += compat.long_type(
MICRO_SEC_PER_SEC * (monotonic() - ts)
)
def _mode(self):
return "py"
@property
def count(self):
return self._curr_cursor
@property
def _network_wall_time_ms(self):
return self._reader.network_wall_time_ms + self._acc_network_time_ms
def _inject_error(self, cursor, exc):
self._injected_error = (cursor, exc)
def _reopen_reader(self):
if self._enable_client_metrics:
ts = monotonic()
stream = self._stream_creator(self._curr_cursor)
if self._enable_client_metrics:
self._acc_network_time_ms += compat.long_type(
MICRO_SEC_PER_SEC * (monotonic() - ts)
)
if self._reader is not None:
self._acc_network_time_ms += self._reader.network_wall_time_ms
self._last_n_bytes = len(self._reader) if self._curr_cursor != 0 else 0
self._reader = Decoder(
stream, record_network_time=self._enable_client_metrics
)
self._crc = Checksum()
self._crccrc = Checksum()
self._attempt_row_count = 0
if self._enable_client_metrics:
self._local_wall_time_ms += compat.long_type(
MICRO_SEC_PER_SEC * (monotonic() - ts)
)
def _read_field(self, data_type):
if data_type == types.float_:
val = self._reader.read_float()
self._crc.update_float(val)
elif data_type == types.double:
val = self._reader.read_double()
self._crc.update_double(val)
elif data_type == types.boolean:
val = self._reader.read_bool()
self._crc.update_bool(val)
elif data_type in types.integer_types:
val = self._reader.read_sint64()
self._crc.update_long(val)
elif data_type == types.string:
val = self._reader.read_string()
self._crc.update(val)
elif data_type == types.binary:
val = self._reader.read_string()
self._crc.update(val)
elif data_type == types.datetime:
val = self._reader.read_sint64()
self._crc.update_long(val)
try:
val = self._to_datetime(val)
except DatetimeOverflowError:
if not options.tunnel.overflow_date_as_none:
raise
val = None
elif data_type == types.date:
val = self._reader.read_sint64()
self._crc.update_long(val)
val = self._to_date(val)
elif data_type == types.timestamp or data_type == types.timestamp_ntz:
to_datetime = (
self._to_datetime_utc
if data_type == types.timestamp_ntz
else self._to_datetime
)
l_val = self._reader.read_sint64()
self._crc.update_long(l_val)
nano_secs = self._reader.read_sint32()
self._crc.update_int(nano_secs)
if pd is None:
raise ImportError(
"To use TIMESTAMP in pyodps, you need to install pandas."
)
try:
val = pd.Timestamp(to_datetime(l_val * 1000)) + pd.Timedelta(
nanoseconds=nano_secs
)
except DatetimeOverflowError:
if not options.tunnel.overflow_date_as_none:
raise
val = None
elif data_type == types.interval_day_time:
l_val = self._reader.read_sint64()
self._crc.update_long(l_val)
nano_secs = self._reader.read_sint32()
self._crc.update_int(nano_secs)
if pd is None:
raise ImportError(
"To use INTERVAL_DAY_TIME in pyodps, you need to install pandas."
)
val = pd.Timedelta(seconds=l_val, nanoseconds=nano_secs)
elif data_type == types.interval_year_month:
l_val = self._reader.read_sint64()
self._crc.update_long(l_val)
return compat.Monthdelta(l_val)
elif data_type == types.json:
sval = self._reader.read_string()
val = json.loads(sval)
self._crc.update(sval)
elif isinstance(data_type, (types.Char, types.Varchar)):
val = self._reader.read_string()
self._crc.update(val)
elif isinstance(data_type, types.Decimal):
val = self._reader.read_string()
self._crc.update(val)
elif isinstance(data_type, types.Array):
val = self._read_array(data_type.value_type)
elif isinstance(data_type, types.Map):
keys = self._read_array(data_type.key_type)
values = self._read_array(data_type.value_type)
val = self._map_dict_hook(zip(keys, values))
elif isinstance(data_type, types.Struct):
val = self._read_struct(data_type)
else:
raise IOError("Unsupported type %s" % data_type)
return val
def _read_array(self, value_type):
res = []
size = self._reader.read_uint32()
for _ in range(size):
if self._reader.read_bool():
res.append(None)
else:
res.append(self._read_field(value_type))
return res
def _read_struct(self, value_type):
res_list = [None] * len(value_type.field_types)
for idx, field_type in enumerate(value_type.field_types.values()):
if not self._reader.read_bool():
res_list[idx] = self._read_field(field_type)
if options.struct_as_dict:
return self._struct_dict_hook(
zip(value_type.field_types.keys(), res_list)
)
else:
return value_type.namedtuple_type(*res_list)
def _read_single_record(self):
if (
self._injected_error is not None
and self._injected_error[0] == self._curr_cursor
):
self._injected_error = None
raise self._injected_error[1]
if self._read_limit is not None and self.count >= self._read_limit:
warnings.warn(
"Number of lines read via tunnel already reaches the limitation.",
RuntimeWarning,
)
return None
record = Record(self._columns, max_field_size=(1 << 63) - 1)
while True:
index, _ = self._reader.read_field_number_and_wire_type()
if index == 0:
continue
if index == ProtoWireConstants.TUNNEL_END_RECORD:
checksum = utils.long_to_int(self._crc.getvalue())
if int(self._reader.read_uint32()) != utils.int_to_uint(checksum):
raise ChecksumError("Checksum invalid")
self._crc.reset()
self._crccrc.update_int(checksum)
break
if index == ProtoWireConstants.TUNNEL_META_COUNT:
if self._attempt_row_count != self._reader.read_sint64():
raise IOError("count does not match")
(
idx_of_checksum,
wire_type,
) = self._reader.read_field_number_and_wire_type()
if ProtoWireConstants.TUNNEL_META_CHECKSUM != idx_of_checksum:
if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED:
raise IOError("Invalid stream data.")
self._crc.update_int(idx_of_checksum)
self._server_metrics_string = self._reader.read_string()
self._crc.update(self._server_metrics_string)
idx_of_checksum = (
self._reader.read_field_number_and_wire_type()[0]
)
if idx_of_checksum != ProtoWireConstants.TUNNEL_END_METRICS:
raise IOError("Invalid stream data.")
checksum = utils.long_to_int(self._crc.getvalue())
self._crc.reset()
if utils.int_to_uint(checksum) != int(
self._reader.read_uint32()
):
raise ChecksumError("Checksum invalid.")
idx_of_checksum = (
self._reader.read_field_number_and_wire_type()[0]
)
if ProtoWireConstants.TUNNEL_META_CHECKSUM != idx_of_checksum:
raise IOError("Invalid stream data.")
if int(self._crccrc.getvalue()) != self._reader.read_uint32():
raise ChecksumError("Checksum invalid.")
return
if index > len(self._columns):
raise IOError(
"Invalid protobuf tag. Perhaps the datastream "
"from server is crushed."
)
self._crc.update_int(index)
i = index - 1
record[i] = self._read_field(self._columns[i].type)
if self._append_partitions and self._partition_spec is not None:
for k, v in self._partition_spec.items():
try:
record[k] = v
except KeyError:
# skip non-existing fields
pass
self._curr_cursor += 1
self._attempt_row_count += 1
return record
def read(self):
if self._enable_client_metrics:
ts = monotonic()
result = utils.call_with_retry(
self._read_single_record, reset_func=self._reopen_reader
)
if self._enable_client_metrics:
self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts))
return result
def reads(self):
return self.__iter__()
@property
def n_bytes(self):
return self._last_n_bytes + len(self._reader)
def get_total_bytes(self):
return self.n_bytes
class TunnelRecordReader(BaseTunnelRecordReader, AbstractRecordReader):
"""
Reader object to read data from ODPS in records. Should be created
with :meth:`TableDownloadSession.open_record_reader`.
:Example:
.. code-block:: python
from odps.tunnel import TableTunnel
tunnel = TableTunnel(o)
download_session = tunnel.create_download_session('my_table', partition_spec='pt=test')
# create a TunnelRecordReader
with download_session.open_record_reader(0, download_session.count) as reader:
for record in reader:
print(record.values)
"""
def __next__(self):
record = self.read()
if record is None:
raise StopIteration
return record
next = __next__
@property
def schema(self):
return self._schema
@property
def metrics(self):
if self._server_metrics_string is None:
return None
return TunnelMetrics.from_server_json(
type(self).__name__,
self._server_metrics_string,
self._local_wall_time_ms,
self._network_wall_time_ms,
)
def close(self):
if hasattr(self._schema, "close"):
self._schema.close()
def __enter__(self):
return self
def __exit__(self, *_):
self.close()
try:
TunnelRecordReader.read.__doc__ = """
Read next record.
:return: A record object
:rtype: :class:`~odps.models.Record`
"""
except:
pass
class ArrowStreamReader(IOBase):
def __init__(self, raw_reader, arrow_schema):
self._reader = raw_reader
self._crc = Checksum()
self._crccrc = Checksum()
self._pos = 0
self._chunk_size = None
self._buffers = collections.deque()
self._buffers.append(BytesIO(arrow_schema.serialize().to_pybytes()))
def __len__(self):
return len(self._reader)
def readable(self):
return True
@staticmethod
def _read_unint32(b):
return struct.unpack("!I", b)
def _read_chunk_size(self):
try:
i = self._read_unint32(self._reader.read(4))
self._pos += 4
return i[0] # unpack() result is a 1-element tuple.
except struct.error as e:
raise DecodeError(e)
def _read_chunk(self):
read_size = self._chunk_size + 4
b = self._reader.read(read_size)
if 0 < len(b) < 4:
raise ChecksumError("Checksum invalid")
self._pos += len(b)
self._crc.update(b[:-4])
self._crccrc.update(b[:-4])
return b
def _fill_next_buffer(self):
if self._chunk_size is None:
self._chunk_size = self._read_chunk_size()
b = self._read_chunk()
data = b[:-4]
crc_data = b[-4:]
if len(b) == 0:
return
if len(b) < self._chunk_size + 4:
# is last chunk
read_checksum = self._read_unint32(crc_data)[0]
checksum = int(self._crccrc.getvalue())
if checksum != read_checksum:
raise ChecksumError("Checksum invalid")
self._pos += len(data) + 4
self._buffers.append(BytesIO(data))
self._crccrc.reset()
else:
checksum = int(self._crc.getvalue())
read_checksum = self._read_unint32(crc_data)[0]
if checksum != read_checksum:
raise ChecksumError("Checksum invalid")
self._crc.reset()
self._buffers.append(BytesIO(data))
def read(self, nbytes=None):
tot_size = 0
bufs = []
while nbytes is None or tot_size < nbytes:
if not self._buffers:
self._fill_next_buffer()
if not self._buffers:
break
to_read = nbytes - tot_size if nbytes is not None else None
buf = self._buffers[0].read(to_read)
if not buf:
self._buffers.popleft()
else:
bufs.append(buf)
tot_size += len(buf)
if len(bufs) == 1:
return bufs[0]
return b"".join(bufs)
def close(self):
if hasattr(self._reader, "close"):
self._reader.close()
class TunnelArrowReader(object):
"""
Reader object to read data from ODPS in Arrow format. Should be created
with :meth:`TableDownloadSession.open_arrow_reader`.
:Example:
.. code-block:: python
from odps.tunnel import TableTunnel
tunnel = TableTunnel(o)
download_session = tunnel.create_download_session('my_table', partition_spec='pt=test')
# create a TunnelArrowReader
with download_session.open_arrow_reader(0, download_session.count) as reader:
for batch in reader:
print(batch.to_pandas())
"""
def __init__(
self,
schema,
stream_creator,
columns=None,
partition_spec=None,
append_partitions=False,
use_ipc_stream=False,
):
if pa is None:
raise ValueError("To use arrow reader you need to install pyarrow")
self._raw_schema = schema
raw_arrow_schema = odps_schema_to_arrow_schema(schema)
if columns is None:
self._schema = schema
self._arrow_schema = self._raw_arrow_schema = raw_arrow_schema
else:
self._schema = types.OdpsSchema([schema[c] for c in columns])
self._raw_arrow_schema = pa.schema(
[s for s in raw_arrow_schema if s.name in columns]
)
self._arrow_schema = odps_schema_to_arrow_schema(self._schema)
self._columns = columns
self._append_partitions = append_partitions
self._partition_spec = partition_spec
self._pos = 0
self._stream_creator = stream_creator
self._use_ipc_stream = use_ipc_stream
self._reopen_reader()
self._to_datetime = utils.MillisecondsConverter().from_milliseconds
self._read_limit = options.table_read_limit
self.closed = False
self._pd_column_converters = dict()
# True if need to convert numeric columns with None as float types
self._coerce_numpy_columns = set()
for col in schema.simple_columns:
arrow_type = self._arrow_schema.field(
self._arrow_schema.get_field_index(col.name)
).type
if isinstance(col.type, (types.Map, types.Array, types.Struct)):
self._pd_column_converters[col.name] = ArrowRecordFieldConverter(
col.type, convert_ts=False
)
elif options.tunnel.pd_cast_mode == "numpy" and (
pa.types.is_integer(arrow_type) or pa.types.is_floating(arrow_type)
):
self._coerce_numpy_columns.add(col.name)
self._injected_error = None
def _reopen_reader(self):
self._last_n_bytes = len(self._reader) if self._pos != 0 else 0
input_stream = self._stream_creator(self._pos)
self._arrow_stream = None
if self._use_ipc_stream:
self._reader = input_stream
else:
self._reader = ArrowStreamReader(input_stream, self._raw_arrow_schema)
def _inject_error(self, cursor, exc):
self._injected_error = (cursor, exc)
@property
def schema(self):
return self._schema
def _read_next_raw_batch(self):
if self._injected_error is not None and self._injected_error[0] <= self._pos:
self._injected_error = None
raise self._injected_error[1]
if self._arrow_stream is None:
self._arrow_stream = pa.ipc.open_stream(self._reader)
if self._read_limit is not None and self._pos >= self._read_limit:
warnings.warn(
"Number of lines read via tunnel already reaches the limitation.",
RuntimeWarning,
)
return None
try:
batch = self._arrow_stream.read_next_batch()
self._pos += batch.num_rows
except pa.ArrowTypeError as ex:
if str(ex) != "Did not pass numpy.dtype object":
raise
else:
raise pa.ArrowTypeError(
"Error caused by version mismatch. Try install numpy<1.20 or "
"upgrade your pyarrow version. Original message: " + str(ex)
)
except StopIteration:
return None
return batch
def _convert_timezone(self, batch):
from ...lib import tzlocal
if not any(isinstance(tp, pa.TimestampType) for tp in batch.schema.types):
return batch
timezone = raw_timezone = options.local_timezone
if timezone is True or timezone is None:
timezone = str(tzlocal.get_localzone())
cols = []
for idx in range(batch.num_columns):
col = batch.column(idx)
name = batch.schema.names[idx]
if not isinstance(col.type, pa.TimestampType):
cols.append(col)
continue
if timezone is False or self._schema[name].type == types.timestamp_ntz:
col = col.cast(pa.timestamp(col.type.unit))
else:
col = col.cast(pa.timestamp(col.type.unit, timezone))
if raw_timezone is None or raw_timezone is True:
if hasattr(pac, "local_timestamp"):
col = pac.local_timestamp(col)
else:
col = pa.Array.from_pandas(
col.to_pandas().dt.tz_localize(None)
).cast(pa.timestamp(col.type.unit))
cols.append(col)
return pa.RecordBatch.from_arrays(cols, names=batch.schema.names)
def _append_partition_cols(self, batch):
col_set = set(self._columns or [c.name for c in self._schema.columns])
pt_obj = (
types.PartitionSpec(self._partition_spec) if self._partition_spec else None
)
sel_col_set = set(self._columns or [])
if pt_obj and any(c in sel_col_set for c in pt_obj.keys()):
# append partitions selected in columns argument
self._append_partitions = True
if not pt_obj or not self._append_partitions:
return batch
batch_cols = list(batch.columns)
batch_col_names = list(batch.schema.names)
for key, val in pt_obj.items():
if key not in col_set:
continue
val = types.validate_value(val, self._schema[key].type)
batch_cols.append(pa.array(np.repeat([val], batch.num_rows)))
batch_col_names.append(key)
return pa.RecordBatch.from_arrays(batch_cols, names=batch_col_names)
def read_next_batch(self):
"""
Read next Arrow RecordBatch from tunnel.
:return: Arrow RecordBatch
"""
if self._reader is None:
return None
batch = utils.call_with_retry(
self._read_next_raw_batch, reset_func=self._reopen_reader
)
if batch is None:
return None
batch = self._append_partition_cols(batch)
if self._columns and self._columns != batch.schema.names:
col_to_array = dict()
col_name_set = set(self._columns)
for name, arr in zip(batch.schema.names, batch.columns):
if name not in col_name_set:
continue
col_to_array[name] = arr
arrays = [col_to_array[name] for name in self._columns]
batch = pa.RecordBatch.from_arrays(arrays, names=self._columns)
batch = self._convert_timezone(batch)
return batch
def read(self):
"""
Read all data from tunnel and forms an Arrow Table.
:return: Arrow Table
"""
batches = []
while True:
batch = self.read_next_batch()
if batch is None:
break
batches.append(batch)
if not batches:
return self._arrow_schema.empty_table()
return pa.Table.from_batches(batches)
def __iter__(self):
return self
def __next__(self):
"""
Read next Arrow RecordBatch from tunnel.
"""
batch = self.read_next_batch()
if batch is None:
raise StopIteration
return batch
@property
def count(self):
return self._pos
@property
def n_bytes(self):
return self._last_n_bytes + len(self._reader)
def get_total_bytes(self):
return self.n_bytes
def close(self):
if hasattr(self._reader, "close"):
self._reader.close()
def _convert_batch_to_pandas(self, batch):
series_list = []
type_mapper = pd.ArrowDtype if options.tunnel.pd_cast_mode == "arrow" else None
if not self._pd_column_converters and not self._coerce_numpy_columns:
return batch.to_pandas(types_mapper=type_mapper)
for col_name, arrow_column in zip(batch.schema.names, batch.columns):
if col_name not in self._pd_column_converters:
series = arrow_column.to_pandas(types_mapper=type_mapper)
if col_name in self._coerce_numpy_columns and series.dtype == np.dtype(
"O"
):
if pa.types.is_floating(arrow_column.type):
col_dtype = arrow_column.type.to_pandas_dtype()
else:
col_dtype = np.dtype(float)
series = series.astype(col_dtype)
series_list.append(series)
else:
try:
series = arrow_column.to_pandas(types_mapper=type_mapper)
except pa.ArrowNotImplementedError:
dtype = type_mapper(arrow_column.type) if type_mapper else None
series = pd.Series(
arrow_column.to_pylist(), name=col_name, dtype=dtype
)
series_list.append(series.map(self._pd_column_converters[col_name]))
return pd.concat(series_list, axis=1)
def to_pandas(self):
"""
Read all data from tunnel and convert to a Pandas DataFrame.
"""
import pandas as pd
batches = []
while True:
batch = self.read_next_batch()
if batch is None:
break
batches.append(self._convert_batch_to_pandas(batch))
if not batches:
return self._arrow_schema.empty_table().to_pandas()
return pd.concat(batches, axis=0, ignore_index=True)
def __enter__(self):
return self
def __exit__(self, *_):
self.close()
_reflective = lambda x: x
if convert_legacy_decimal_bytes is None:
def convert_legacy_decimal_bytes(value):
"""
Legacy decimal memory layout:
int8_t mNull;
int8_t mSign;
int8_t mIntg;
int8_t mFrac; only 0, 1, 2
int32_t mData[6];
int8_t mPadding[4]; //For Memory Align
"""
if value is None:
return None
is_null, sign, intg, frac = struct.unpack("<4b", value[:4])
if is_null: # pragma: no cover
return None
if intg + frac == 0:
return Decimal("0")
sio = BytesIO() if compat.PY27 else StringIO()
if sign > 0:
sio.write("-")
intg_nums = struct.unpack("<%dI" % intg, value[12 : 12 + intg * 4])
intg_val = "".join("%09d" % d for d in reversed(intg_nums)).lstrip("0")
sio.write(intg_val or "0")
if frac > 0:
sio.write(".")
frac_nums = struct.unpack("<%dI" % frac, value[12 - frac * 4 : 12])
sio.write("".join("%09d" % d for d in reversed(frac_nums)))
return Decimal(sio.getvalue())
class ArrowRecordFieldConverter(object):
_sensitive_types = (
types.Datetime,
types.Timestamp,
types.TimestampNTZ,
types.Array,
types.Map,
types.Struct,
)
def __init__(self, odps_type, arrow_type=None, convert_ts=True):
self._mills_converter = utils.MillisecondsConverter()
self._mills_converter_utc = utils.MillisecondsConverter(local_tz=False)
self._convert_ts = convert_ts
self._converter = self._build_converter(odps_type, arrow_type)
def _convert_datetime(self, value):
if value is None:
return None
mills = self._mills_converter.to_milliseconds(value)
return self._mills_converter.from_milliseconds(mills)
def _convert_ts_timestamp(self, value, ntz=False):
if value is None:
return None
if not ntz:
converter = self._mills_converter
else: # TimestampNtz
converter = self._mills_converter_utc
microsec = value.microsecond
nanosec = value.nanosecond
secs = int(converter.to_milliseconds(value.to_pydatetime()) / 1000)
value = pd.Timestamp(converter.from_milliseconds(secs * 1000))
return value.replace(microsecond=microsec, nanosecond=nanosec)
def _convert_struct_timestamp(self, value, ntz=False):
if value is None:
return None
if not ntz:
converter = self._mills_converter
else: # TimestampNtz
converter = self._mills_converter_utc
ts = pd.Timestamp(converter.from_milliseconds(value["sec"] * 1000))
nanos = value["nano"]
return ts.replace(microsecond=nanos // 1000, nanosecond=nanos % 1000)
@staticmethod
def _convert_struct_timedelta(value):
if value is None:
return None
nanos = value["nano"]
return pd.Timedelta(
seconds=value["sec"], microseconds=nanos // 1000, nanoseconds=nanos % 1000
)
@staticmethod
def _convert_struct(value, field_converters, tuple_type, use_ordered_dict=False):
if value is None:
return None
result_iter = ((k, field_converters[k](v)) for k, v in value.items())
if tuple_type is not None:
return tuple_type(**dict(result_iter))
elif not use_ordered_dict:
return dict(result_iter)
else:
d = dict(result_iter)
return OrderedDict([(k, d[k]) for k in field_converters if k in d])
def _build_converter(self, odps_type, arrow_type=None):
import pyarrow as pa
arrow_decimal_types = (pa.Decimal128Type,)
if hasattr(pa, "Decimal256Type"):
arrow_decimal_types += (pa.Decimal256Type,)
if self._convert_ts and isinstance(odps_type, types.Datetime):
return self._convert_datetime
elif isinstance(odps_type, types.Timestamp):
if isinstance(arrow_type, pa.StructType):
return self._convert_struct_timestamp
elif self._convert_ts:
return self._convert_ts_timestamp
else:
return _reflective
elif isinstance(odps_type, types.TimestampNTZ):
if isinstance(arrow_type, pa.StructType):
return functools.partial(self._convert_struct_timestamp, ntz=True)
elif self._convert_ts:
return functools.partial(self._convert_ts_timestamp, ntz=True)
else:
return _reflective
elif (
isinstance(odps_type, types.Decimal)
and isinstance(arrow_type, pa.FixedSizeBinaryType)
and not isinstance(arrow_type, arrow_decimal_types)
):
return convert_legacy_decimal_bytes
elif isinstance(odps_type, types.IntervalDayTime) and isinstance(
arrow_type, pa.StructType
):
return self._convert_struct_timedelta
elif isinstance(odps_type, types.Array):
arrow_value_type = getattr(arrow_type, "value_type", None)
sub_converter = self._build_converter(
odps_type.value_type, arrow_value_type
)
if sub_converter is _reflective:
return _reflective
return (
lambda value: [sub_converter(x) for x in value]
if value is not None
else None
)
elif isinstance(odps_type, types.Map):
arrow_key_type = getattr(arrow_type, "key_type", None)
arrow_value_type = getattr(arrow_type, "item_type", None)
key_converter = self._build_converter(odps_type.key_type, arrow_key_type)
value_converter = self._build_converter(
odps_type.value_type, arrow_value_type
)
dict_hook = OrderedDict if odps_type._use_ordered_dict else dict
if key_converter is _reflective and value_converter is _reflective:
return dict_hook
return (
lambda value: dict_hook(
[(key_converter(k), value_converter(v)) for k, v in value]
)
if value is not None
else None
)
elif isinstance(odps_type, types.Struct):
field_converters = OrderedDict()
for field_name, field_type in odps_type.field_types.items():
arrow_field_type = None
if arrow_type is not None:
arrow_field_type = arrow_type[field_name].type
field_converters[field_name] = self._build_converter(
field_type, arrow_field_type
)
if options.struct_as_dict:
tuple_type = None
else:
tuple_type = odps_type.namedtuple_type
use_ordered_dict = odps_type._use_ordered_dict
return functools.partial(
self._convert_struct,
field_converters=field_converters,
tuple_type=tuple_type,
use_ordered_dict=use_ordered_dict,
)
else:
return _reflective
def __call__(self, value):
if value is None:
return None
return self._converter(value)
class ArrowRecordReader(AbstractRecordReader):
_complex_types_to_convert = (types.Array, types.Map, types.Struct)
def __init__(self, arrow_reader, make_compat=True, read_all=False):
self._arrow_reader = arrow_reader
self._batch_pos = 0
self._total_pos = 0
self._cur_batch = None
self._make_compat = make_compat
self._field_converters = None
if read_all:
self._cur_batch = arrow_reader.read()
def _convert_record(self, arrow_values):
py_values = [x.as_py() for x in arrow_values]
if not self._make_compat:
return py_values
else:
return [
converter(value)
for value, converter in zip(py_values, self._field_converters)
]
def read(self):
if self._cur_batch is None or self._batch_pos >= self._cur_batch.num_rows:
self._cur_batch = self._arrow_reader.read_next_batch()
self._batch_pos = 0
if self._cur_batch is None or self._cur_batch.num_rows == 0:
return None
if self._field_converters is None:
table_schema = self._arrow_reader.schema
self._field_converters = [
ArrowRecordFieldConverter(table_schema[col_name].type, arrow_type)
for col_name, arrow_type in zip(
self._cur_batch.schema.names, self._cur_batch.schema.types
)
]
tp = tuple(col[self._batch_pos] for col in self._cur_batch.columns)
self._batch_pos += 1
self._total_pos += 1
return Record(schema=self.schema, values=self._convert_record(tp))
def to_pandas(self, start=None, count=None, **kw):
step = kw.get("step") or 1
return self._arrow_reader.to_pandas().iloc[start : start + count : step]
def close(self):
self._arrow_reader.close()
def __next__(self):
rec = self.read()
if rec is None:
raise StopIteration
return rec
next = __next__
@property
def count(self):
return self._total_pos
@property
def schema(self):
return self._arrow_reader.schema
def __enter__(self):
return self
def __exit__(self, *_):
self.close()