odps/tunnel/io/reader_c.pyx (575 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. cimport cython import decimal import json import sys import warnings from collections import OrderedDict from cpython.datetime cimport import_datetime from libc.stdint cimport * from libc.string cimport * from ...lib.monotonic import monotonic from ...src.types_c cimport BaseRecord from ...src.utils_c cimport CMillisecondsConverter, to_date from ..checksum_c cimport Checksum from ..pb.decoder_c cimport CDecoder from ... import compat, options, types, utils from ...errors import ChecksumError, DatetimeOverflowError from ...models import Record from ...readers import AbstractRecordReader # noqa from ...types import PartitionSpec from ..pb import wire_format from ..wireconstants import ProtoWireConstants DEF MAX_READ_SIZE_LIMIT = (1 << 63) - 1 DEF MICRO_SEC_PER_SEC = 1_000_000L cdef: uint32_t WIRE_TUNNEL_META_COUNT = ProtoWireConstants.TUNNEL_META_COUNT uint32_t WIRE_TUNNEL_META_CHECKSUM = ProtoWireConstants.TUNNEL_META_CHECKSUM uint32_t WIRE_TUNNEL_END_RECORD = ProtoWireConstants.TUNNEL_END_RECORD uint32_t WIRE_TUNNEL_END_METRICS = ProtoWireConstants.TUNNEL_END_METRICS uint32_t WIRETYPE_LENGTH_DELIMITED = wire_format.WIRETYPE_LENGTH_DELIMITED cdef: int64_t BOOL_TYPE_ID = types.boolean._type_id int64_t DATETIME_TYPE_ID = types.datetime._type_id int64_t DATE_TYPE_ID = types.date._type_id int64_t STRING_TYPE_ID = types.string._type_id int64_t FLOAT_TYPE_ID = types.float_._type_id int64_t DOUBLE_TYPE_ID = types.double._type_id int64_t BIGINT_TYPE_ID = types.bigint._type_id int64_t BINARY_TYPE_ID = types.binary._type_id int64_t TIMESTAMP_TYPE_ID = types.timestamp._type_id int64_t INTERVAL_DAY_TIME_TYPE_ID = types.interval_day_time._type_id int64_t INTERVAL_YEAR_MONTH_TYPE_ID = types.interval_year_month._type_id int64_t DECIMAL_TYPE_ID = types.Decimal._type_id int64_t JSON_TYPE_ID = types.Json._type_id int64_t TIMESTAMP_NTZ_TYPE_ID = types.timestamp_ntz._type_id int64_t ARRAY_TYPE_ID = types.Array._type_id int64_t MAP_TYPE_ID = types.Map._type_id int64_t STRUCT_TYPE_ID = types.Struct._type_id cdef: object pd_timestamp = None object pd_timedelta = None import_datetime() cdef class BaseTunnelRecordReader: def __init__( self, object schema, object stream_creator, object columns=None, object partition_spec=None, bint append_partitions=True, ): cdef double ts self._enable_client_metrics = options.tunnel.enable_client_metrics self._server_metrics_string = None self._c_local_wall_time_ms = 0 self._c_acc_network_time_ms = 0 if self._enable_client_metrics: ts = monotonic() self._schema = schema if columns is None: self._columns = ( self._schema.columns if append_partitions else self._schema.simple_columns ) else: self._columns = [self._schema[c] for c in columns] self._reader_schema = types.OdpsSchema(columns=self._columns) self._schema_snapshot = self._reader_schema.build_snapshot() self._n_columns = len(self._columns) self._partition_vals = [] self._append_partitions = append_partitions partition_spec = PartitionSpec(partition_spec) if partition_spec is not None else None self._field_readers = [None] * self._schema_snapshot._col_count for idx, col_type in enumerate(self._schema_snapshot._col_types): self._field_readers[idx] = _build_field_reader(self, col_type) for i in range(self._n_columns): if partition_spec is not None and self._columns[i].name in partition_spec: self._partition_vals.append((i, partition_spec[self._columns[i].name])) if self._enable_client_metrics: self._c_local_wall_time_ms += <long>( MICRO_SEC_PER_SEC * (<double>monotonic() - ts) ) self._curr_cursor = 0 self._stream_creator = stream_creator self._reader = None self._reopen_reader() if self._enable_client_metrics: ts = monotonic() self._read_limit = -1 if options.table_read_limit is None else options.table_read_limit if self._enable_client_metrics: self._c_local_wall_time_ms += <long>( MICRO_SEC_PER_SEC * (<double>monotonic() - ts) ) self._n_injected_error_cursor = -1 self._injected_error_exc = None @cython.cdivision(True) def _reopen_reader(self): cdef object stream cdef double ts if self._enable_client_metrics: ts = monotonic() stream = self._stream_creator(self._curr_cursor) if self._enable_client_metrics: self._c_acc_network_time_ms += <long>( MICRO_SEC_PER_SEC * (<double>monotonic() - ts) ) if self._reader is not None: self._c_acc_network_time_ms += ( self._reader._network_wall_time_ns // 1000 ) self._reader = CDecoder(stream, record_network_time=self._enable_client_metrics) self._last_n_bytes = self._reader.position() if self._curr_cursor != 0 else 0 self._crc = Checksum() self._crccrc = Checksum() self._attempt_row_count = 0 if self._enable_client_metrics: self._c_local_wall_time_ms += <long>( MICRO_SEC_PER_SEC * (<double>monotonic() - ts) ) def _inject_error(self, cursor, exc): self._n_injected_error_cursor = cursor self._injected_error_exc = exc def _mode(self): return "c" @property def count(self): return self._curr_cursor cdef int _set_record_list_value(self, list record, int i, object value) except? -1: record[i] = self._schema_snapshot.validate_value(i, value, MAX_READ_SIZE_LIMIT) return 0 cdef _read(self): cdef: int index int checksum int idx_of_checksum int i int data_type_id int32_t wire_type object data_type BaseRecord record list rec_list if self._n_injected_error_cursor == self._curr_cursor: self._n_injected_error_cursor = -1 raise self._injected_error_exc if self._curr_cursor >= self._read_limit > 0: warnings.warn( "Number of lines read via tunnel already reaches the limitation.", RuntimeWarning, ) return None record = Record(schema=self._reader_schema, max_field_size=MAX_READ_SIZE_LIMIT) rec_list = record._c_values while True: index = self._reader.read_field_number(NULL) if index == 0: continue if index == WIRE_TUNNEL_END_RECORD: checksum = <int32_t>self._crc.c_getvalue() if self._reader.read_uint32() != <uint32_t>checksum: raise ChecksumError("Checksum invalid") self._crc.c_reset() self._crccrc.c_update_int(checksum) break if index == WIRE_TUNNEL_META_COUNT: if self._attempt_row_count != self._reader.read_sint64(): raise IOError("count does not match") idx_of_checksum = self._reader.read_field_number(&wire_type) if WIRE_TUNNEL_META_CHECKSUM != idx_of_checksum: if wire_type != WIRETYPE_LENGTH_DELIMITED: raise IOError("Invalid stream data.") self._crc.c_update_int(idx_of_checksum) self._server_metrics_string = self._reader.read_string() self._crc.c_update( self._server_metrics_string, len(self._server_metrics_string) ) idx_of_checksum = self._reader.read_field_number(NULL) if idx_of_checksum != WIRE_TUNNEL_END_METRICS: raise IOError("Invalid stream data.") checksum = <int32_t>self._crc.c_getvalue() if <uint32_t>checksum != self._reader.read_uint32(): raise ChecksumError("Checksum invalid.") self._crc.reset() idx_of_checksum = self._reader.read_field_number(NULL) if WIRE_TUNNEL_META_CHECKSUM != idx_of_checksum: raise IOError("Invalid stream data.") if self._crccrc.c_getvalue() != self._reader.read_uint32(): raise ChecksumError("Checksum invalid.") return if index > self._n_columns: raise IOError( "Invalid protobuf tag. Perhaps the datastream " "from server is crushed." ) self._crc.c_update_int(index) i = index - 1 (<AbstractFieldReader>self._field_readers[i]).read(rec_list, i) if self._append_partitions: for idx, val in self._partition_vals: rec_list[idx] = val self._attempt_row_count += 1 self._curr_cursor += 1 return record cpdef read(self): cdef: int retry_num = 0 double ts object result if self._enable_client_metrics: ts = monotonic() while True: try: result = self._read() if self._enable_client_metrics: self._c_local_wall_time_ms += <long>( MICRO_SEC_PER_SEC * (<double>monotonic() - ts) ) return result except: retry_num += 1 if retry_num > options.retry_times: raise self._reopen_reader() def reads(self): return self.__iter__() @property def n_bytes(self): return self._last_n_bytes + self._reader.position() def get_total_bytes(self): return self.n_bytes @property def _local_wall_time_ms(self): return self._c_local_wall_time_ms @property @cython.cdivision(True) def _network_wall_time_ms(self): return self._reader._network_wall_time_ns // 1000 + self._c_acc_network_time_ms cdef _build_field_reader(BaseTunnelRecordReader record_reader, object data_type): cdef int data_type_id = data_type._type_id import_datetime() if data_type_id == FLOAT_TYPE_ID: return FloatFieldReader(record_reader) elif data_type_id == BIGINT_TYPE_ID: return BigintFieldReader(record_reader) elif data_type_id == DOUBLE_TYPE_ID: return DoubleFieldReader(record_reader) elif data_type_id == STRING_TYPE_ID: return StringFieldReader(record_reader) elif data_type_id == BOOL_TYPE_ID: return BoolFieldReader(record_reader) elif data_type_id == DATETIME_TYPE_ID: return DatetimeFieldReader(record_reader) elif data_type_id == BINARY_TYPE_ID: return StringFieldReader(record_reader) elif data_type_id == TIMESTAMP_TYPE_ID: return TimestampFieldReader(record_reader) elif data_type_id == TIMESTAMP_NTZ_TYPE_ID: return TimestampNTZFieldReader(record_reader) elif data_type_id == DATE_TYPE_ID: return DateFieldReader(record_reader) elif data_type_id == INTERVAL_DAY_TIME_TYPE_ID: return IntervalDayTimeFieldReader(record_reader) elif data_type_id == INTERVAL_YEAR_MONTH_TYPE_ID: return IntervalYearMonthFieldReader(record_reader) elif data_type_id == JSON_TYPE_ID: return JsonFieldReader(record_reader) elif data_type_id == DECIMAL_TYPE_ID: return DecimalFieldReader(record_reader) elif data_type_id == ARRAY_TYPE_ID: return ArrayFieldReader(record_reader, data_type) elif data_type_id == MAP_TYPE_ID: return MapFieldReader(record_reader, data_type) elif data_type_id == STRUCT_TYPE_ID: return StructFieldReader(record_reader, data_type) elif isinstance(data_type, (types.Char, types.Varchar)): return StringFieldReader(record_reader) else: raise IOError("Unsupported type %s" % data_type) cdef class AbstractFieldReader: need_validate = None cdef BaseTunnelRecordReader _record_reader cdef bint _need_validate def __init__(self, BaseTunnelRecordReader record_reader): self._record_reader = record_reader self._need_validate = self.need_validate cdef object _read_raw(self): raise NotImplementedError cdef inline int read(self, list dest, int idx) except? -1: if not self._need_validate: dest[idx] = self._read_raw() else: self._record_reader._set_record_list_value( dest, idx, self._read_raw() ) return 0 cdef class BigintFieldReader(AbstractFieldReader): need_validate = False cdef object _read_raw(self): cdef int64_t val = self._record_reader._reader.read_sint64() self._record_reader._crc.c_update_long(val) return val cdef class FloatFieldReader(AbstractFieldReader): need_validate = False cdef object _read_raw(self): cdef float val = self._record_reader._reader.read_float() self._record_reader._crc.c_update_float(val) return val cdef class DoubleFieldReader(AbstractFieldReader): need_validate = False cdef object _read_raw(self): cdef double val = self._record_reader._reader.read_double() self._record_reader._crc.c_update_double(val) return val cdef class BoolFieldReader(AbstractFieldReader): need_validate = False cdef object _read_raw(self): cdef bint val = self._record_reader._reader.read_bool() self._record_reader._crc.c_update_bool(val) return val cdef class StringFieldReader(AbstractFieldReader): need_validate = True cdef object _read_raw(self): cdef bytes val = self._record_reader._reader.read_string() self._record_reader._crc.c_update(val, len(val)) return val cdef class DecimalFieldReader(AbstractFieldReader): need_validate = True cdef object _read_raw(self): cdef bytes val = self._record_reader._reader.read_string() self._record_reader._crc.c_update(val, len(val)) return val cdef class JsonFieldReader(AbstractFieldReader): need_validate = False cdef object _read_raw(self): cdef bytes val = self._record_reader._reader.read_string() self._record_reader._crc.c_update(val, len(val)) return json.loads(val) cdef class DatetimeFieldReader(AbstractFieldReader): need_validate = False cdef CMillisecondsConverter _mills_converter cdef bint _overflow_date_as_none def __init__(self, BaseTunnelRecordReader record_reader): super(DatetimeFieldReader, self).__init__(record_reader) self._mills_converter = CMillisecondsConverter() self._overflow_date_as_none = options.tunnel.overflow_date_as_none cdef object _read_raw(self): cdef int64_t val = self._record_reader._reader.read_sint64() self._record_reader._crc.c_update_long(val) try: return self._mills_converter.from_milliseconds(val) except DatetimeOverflowError: if not self._overflow_date_as_none: raise return None cdef class DateFieldReader(AbstractFieldReader): need_validate = False cdef object _read_raw(self): cdef int64_t val = self._record_reader._reader.read_sint64() self._record_reader._crc.c_update_long(val) return to_date(val) cdef class BaseTimestampFieldReader(AbstractFieldReader): need_validate = False _ntz = None cdef CMillisecondsConverter _mills_converter cdef bint _overflow_date_as_none def __init__(self, BaseTunnelRecordReader record_reader): super(BaseTimestampFieldReader, self).__init__(record_reader) self._overflow_date_as_none = options.tunnel.overflow_date_as_none if self._ntz: self._mills_converter = CMillisecondsConverter(local_tz=False) else: self._mills_converter = CMillisecondsConverter() cdef object _read_raw(self): cdef: int64_t val int32_t nano_secs global pd_timestamp, pd_timedelta if pd_timestamp is None: import pandas as pd pd_timestamp = pd.Timestamp pd_timedelta = pd.Timedelta val = self._record_reader._reader.read_sint64() self._record_reader._crc.c_update_long(val) nano_secs = self._record_reader._reader.read_sint32() self._record_reader._crc.c_update_int(nano_secs) try: return ( pd_timestamp(self._mills_converter.from_milliseconds(val * 1000)) + pd_timedelta(nanoseconds=nano_secs) ) except DatetimeOverflowError: if not self._overflow_date_as_none: raise return None cdef class TimestampFieldReader(BaseTimestampFieldReader): _ntz = False cdef class TimestampNTZFieldReader(BaseTimestampFieldReader): _ntz = True cdef class IntervalDayTimeFieldReader(AbstractFieldReader): need_validate = False cdef _read_raw(self): cdef: int64_t val int32_t nano_secs global pd_timestamp, pd_timedelta if pd_timedelta is None: import pandas as pd pd_timestamp = pd.Timestamp pd_timedelta = pd.Timedelta val = self._record_reader._reader.read_sint64() self._record_reader._crc.c_update_long(val) nano_secs = self._record_reader._reader.read_sint32() self._record_reader._crc.c_update_int(nano_secs) return pd_timedelta(seconds=val, nanoseconds=nano_secs) cdef class IntervalYearMonthFieldReader(AbstractFieldReader): need_validate = False cdef _read_raw(self): cdef int64_t val = self._record_reader._reader.read_sint64() self._record_reader._crc.c_update_long(val) return compat.Monthdelta(val) cdef class ArrayFieldReader(AbstractFieldReader): need_validate = True cdef AbstractFieldReader _element_reader def __init__( self, BaseTunnelRecordReader record_reader, object data_type ): super(ArrayFieldReader, self).__init__(record_reader) self._element_reader = _build_field_reader(record_reader, data_type.value_type) cdef _read_raw(self): cdef: uint32_t idx, size object val size = self._record_reader._reader.read_uint32() cdef list res = [None] * size for idx in range(size): if not self._record_reader._reader.read_bool(): res[idx] = self._element_reader._read_raw() return res cdef class MapFieldReader(AbstractFieldReader): need_validate = True cdef AbstractFieldReader _keys_reader, _values_reader cdef bint _use_ordered_dict def __init__( self, BaseTunnelRecordReader record_reader, object data_type ): super(MapFieldReader, self).__init__(record_reader) self._keys_reader = ArrayFieldReader( record_reader, types.Array(data_type.key_type) ) self._values_reader = ArrayFieldReader( record_reader, types.Array(data_type.value_type) ) self._use_ordered_dict = data_type._use_ordered_dict cdef _read_raw(self): cdef list keys, values keys = self._keys_reader._read_raw() values = self._values_reader._read_raw() if self._use_ordered_dict: return OrderedDict(zip(keys, values)) else: return dict(zip(keys, values)) cdef class StructFieldReader(AbstractFieldReader): need_validate = True cdef: bint _struct_as_dict bint _use_ordered_dict list _field_readers list _field_keys list _field_types object _nt_type def __init__( self, BaseTunnelRecordReader record_reader, object data_type ): cdef int idx, field_count super(StructFieldReader, self).__init__(record_reader) self._struct_as_dict = data_type._struct_as_dict self._use_ordered_dict = data_type._use_ordered_dict self._nt_type = data_type.namedtuple_type field_count = len(data_type.field_types) self._field_keys = [None] * field_count self._field_types = [None] * field_count self._field_readers = [None] * field_count for idx, (field_key, field_type) in enumerate(data_type.field_types.items()): self._field_keys[idx] = field_key self._field_types[idx] = field_type self._field_readers[idx] = _build_field_reader(record_reader, field_type) cdef _read_raw(self): cdef: list res_list = [None] * len(self._field_types) int idx dict_hook = OrderedDict if self._use_ordered_dict else dict for idx, field_type in enumerate(self._field_types): if not self._record_reader._reader.read_bool(): res_list[idx] = (<AbstractFieldReader>self._field_readers[idx])._read_raw() if self._struct_as_dict: return dict_hook(zip(self._field_keys, res_list)) else: return self._nt_type(*res_list) cdef int DECIMAL_FRAC_CNT = 2 cdef int DECIMAL_INTG_CNT = 4 cdef int DECIMAL_PREC_CNT = DECIMAL_INTG_CNT + DECIMAL_FRAC_CNT cdef int DECIMAL_DIG_NUMS = 9 cdef int DECIMAL_FRAC_DIGS = DECIMAL_DIG_NUMS * DECIMAL_FRAC_CNT cdef int DECIMAL_INTG_DIGS = DECIMAL_DIG_NUMS * DECIMAL_INTG_CNT cdef int DECIMAL_PREC_DIGS = DECIMAL_DIG_NUMS * DECIMAL_PREC_CNT @cython.cdivision(True) cdef inline int32_t decimal_print_dig( char* buf, const int32_t* val, int count, bint tail = False ) nogil: cdef char* src = buf - count * DECIMAL_DIG_NUMS if tail else buf + 1 cdef char* ret = src cdef char* ptr cdef int32_t i, data, r for i in range(count): ptr = buf data = val[i] while data != 0: r = data // 10 ptr[0] = data - r * 10 + ord("0") if ptr[0] != ord("0") and (not tail or ret[0] == ord("0")): ret = ptr data = r ptr -= 1 buf -= DECIMAL_DIG_NUMS return src - ret if src >= ret else ret - src cpdef convert_legacy_decimal_bytes(bytes value, int32_t frac = 0): """ Legacy decimal memory layout: int8_t mNull; int8_t mSign; int8_t mIntg; int8_t mFrac; only 0, 1, 2 int32_t mData[6]; int8_t mPadding[4]; //For Memory Align """ if value is None: return None cdef const char *src_ptr = <const char *>value cdef bint is_null = src_ptr[0] cdef bint sign = src_ptr[1] cdef int mintg = src_ptr[2] cdef int mfrac = src_ptr[3] cdef const char *data = src_ptr + 4 cdef int32_t dec_cnt cdef char buf[9 * (2 + 4) + 4] cdef char *buf_ptr = buf memset(buf_ptr, ord("0"), sizeof(buf)) if is_null: # pragma: no cover return None if mintg + mfrac == 0: # IsZero buf[1] = ord(".") dec_cnt = 20 # "0.000000000000000000" if frac > 0: dec_cnt = dec_cnt if frac + 2 > dec_cnt else frac + 2 else: dec_cnt = 1 return decimal.Decimal(buf[0:dec_cnt].decode()) cdef int32_t icnt = decimal_print_dig( buf_ptr + DECIMAL_INTG_DIGS, <const int32_t *>data + DECIMAL_FRAC_CNT, mintg ) cdef char *start = buf_ptr + DECIMAL_INTG_DIGS + 1 - icnt if icnt > 0 else buf_ptr + DECIMAL_INTG_DIGS if sign: start -= 1 start[0] = ord("-") cdef int32_t fcnt = decimal_print_dig( buf_ptr + DECIMAL_PREC_DIGS + 1, <const int32_t *>data, DECIMAL_FRAC_CNT, True ) if frac <= DECIMAL_FRAC_DIGS: frac = frac if frac > 0 else 0 else: frac = DECIMAL_FRAC_DIGS fcnt = max(fcnt, frac) buf[DECIMAL_INTG_DIGS + 1] = ord(".") dec_cnt = buf_ptr + DECIMAL_INTG_DIGS + 1 - start + (fcnt + 1 if fcnt > 0 else 0) return decimal.Decimal(start[0:dec_cnt].decode())