odps/tunnel/io/stream.py (508 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import threading
import zlib
from ... import compat, errors, options
from ...compat import BytesIO, Enum, Semaphore, six
from ...lib.monotonic import monotonic
from ..errors import TunnelError
try:
from urllib3.exceptions import ReadTimeoutError
except ImportError:
from requests import ReadTimeout as ReadTimeoutError
MICRO_SEC_PER_SEC = 1000000
# used for test case to force thread io
_FORCE_THREAD = False
if compat.LESS_PY32:
mv_to_bytes = lambda v: bytes(bytearray(v))
else:
mv_to_bytes = bytes
if compat.six.PY3:
def cast_memoryview(v):
if not isinstance(v, memoryview):
v = memoryview(v)
return v.cast("B")
else:
def cast_memoryview(v):
if not isinstance(v, memoryview):
v = memoryview(v)
return v
class RequestsIO(object):
CHUNK_SIZE = 256 * 1024
def __new__(cls, *args, **kwargs):
if cls is RequestsIO:
if (
not isinstance(threading.current_thread(), threading._MainThread)
or _FORCE_THREAD
):
return object.__new__(ThreadRequestsIO)
elif GreenletRequestsIO is not None:
return object.__new__(GreenletRequestsIO)
else:
return object.__new__(ThreadRequestsIO)
else:
return object.__new__(cls)
def __init__(self, post_call, chunk_size=None, record_io_time=False):
self._buf = BytesIO()
self._resp = None
self._async_err = None
self._chunk_size = chunk_size or self.CHUNK_SIZE
self._post_call = post_call
self._wait_obj = None
self._record_io_time = record_io_time
self._io_time_ms = 0
self._io_start_time = 0
self._io_end_time = 0
def _async_func(self):
try:
if self._record_io_time:
self._io_start_time = monotonic()
self._resp = self._post_call(self.data_generator())
except:
self._async_err = sys.exc_info()
self._wait_obj = None
def _reraise_errors(self):
if self._async_err is not None:
ex_type, ex_value, tb = self._async_err
six.reraise(ex_type, ex_value, tb)
@property
def io_time_ms(self):
return self._io_time_ms
def data_generator(self):
if self._record_io_time:
self._io_time_ms += int(
MICRO_SEC_PER_SEC * (monotonic() - self._io_start_time)
)
chunk_size = self._chunk_size
while True:
data = self.get()
if data is not None:
data = memoryview(data)
while data:
to_send = mv_to_bytes(data[:chunk_size])
data = data[chunk_size:]
if self._record_io_time:
ts = monotonic()
yield to_send
if self._record_io_time:
self._io_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts))
else:
break
if self._record_io_time:
self._io_end_time = monotonic()
def start(self):
pass
def get(self):
raise NotImplementedError
def put(self, data):
raise NotImplementedError
def write(self, data):
self._buf.write(data)
if self._buf.tell() >= self._chunk_size:
chunk = self._buf.getvalue()
self._buf = BytesIO()
self.put(chunk)
self._reraise_errors()
def flush(self):
if self._buf.tell():
chunk = self._buf.getvalue()
self._buf = BytesIO()
self.put(chunk)
self._reraise_errors()
def finish(self):
self.flush()
self.put(None)
wait_obj = self._wait_obj
if wait_obj and wait_obj.is_alive():
wait_obj.join()
if self._record_io_time:
self._io_time_ms += int(
MICRO_SEC_PER_SEC * (monotonic() - self._io_end_time)
)
self._reraise_errors()
return self._resp
class ThreadRequestsIO(RequestsIO):
def __init__(self, post_call, chunk_size=None, record_io_time=False):
super(ThreadRequestsIO, self).__init__(
post_call, chunk_size, record_io_time=record_io_time
)
self._last_data = None
self._sem_put = Semaphore(1)
self._sem_get = Semaphore(0)
self._wait_obj = threading.Thread(target=self._async_func)
self._wait_obj.daemon = True
self._acquire_timeout = options.connect_timeout
def _async_func(self):
try:
super(ThreadRequestsIO, self)._async_func()
finally:
# make sure subsequent put() call does not get stuck
self._sem_put.release()
def start(self):
self._wait_obj.start()
def get(self):
self._sem_get.acquire()
data = self._last_data
self._sem_put.release()
return data
def put(self, data):
self._reraise_errors()
assert self._wait_obj is not None and self._wait_obj.is_alive()
try:
rc = self._sem_put.acquire(timeout=self._acquire_timeout)
if not rc:
raise TimeoutError("Wait for data semaphore timed out")
self._reraise_errors()
self._last_data = data
except:
self._last_data = None
raise
finally:
self._sem_get.release()
try:
from greenlet import greenlet
class GreenletRequestsIO(RequestsIO):
def __init__(self, post_call, chunk_size=None, record_io_time=False):
super(GreenletRequestsIO, self).__init__(
post_call, chunk_size, record_io_time=record_io_time
)
self._cur_greenlet = greenlet.getcurrent()
self._writer_greenlet = greenlet(self._async_func)
self._last_data = None
self._writer_greenlet.switch()
def get(self):
self._cur_greenlet.switch()
return self._last_data
def put(self, data):
self._last_data = data
# handover control
self._writer_greenlet.switch()
except ImportError:
GreenletRequestsIO = None
class CompressOption(object):
class CompressAlgorithm(Enum):
ODPS_RAW = "RAW"
ODPS_ZLIB = "ZLIB"
ODPS_SNAPPY = "SNAPPY"
ODPS_ZSTD = "ZSTD"
ODPS_LZ4 = "LZ4"
ODPS_ARROW_LZ4 = "ARROW_LZ4"
def get_encoding(self, legacy=True):
cls = type(self)
if legacy:
if self == cls.ODPS_RAW:
return None
elif self == cls.ODPS_ZLIB:
return "deflate"
elif self == cls.ODPS_ZSTD:
return "zstd"
elif self == cls.ODPS_LZ4:
return "x-lz4-frame"
elif self == cls.ODPS_SNAPPY:
return "x-snappy-framed"
elif self == cls.ODPS_ARROW_LZ4:
return "x-odps-lz4-frame"
else:
raise TunnelError("invalid compression option")
else:
if self == cls.ODPS_RAW:
return None
elif self == cls.ODPS_ZSTD:
return "ZSTD"
elif self == cls.ODPS_LZ4 or self == cls.ODPS_ARROW_LZ4:
return "LZ4_FRAME"
else:
raise TunnelError("invalid compression option")
@classmethod
def from_encoding(cls, encoding):
encoding = encoding.lower() if encoding else None
if encoding is None or encoding == "identity":
return cls.ODPS_RAW
elif encoding == "deflate":
return cls.ODPS_ZLIB
elif encoding == "zstd":
return cls.ODPS_ZSTD
elif encoding == "x-lz4-frame":
return cls.ODPS_LZ4
elif encoding == "x-snappy-framed":
return cls.ODPS_SNAPPY
elif encoding == "x-odps-lz4-frame" or encoding == "lz4_frame":
return cls.ODPS_ARROW_LZ4
else:
raise TunnelError("invalid encoding name %s" % encoding)
def __init__(
self, compress_algo=CompressAlgorithm.ODPS_ZLIB, level=None, strategy=None
):
compress_algo = compress_algo or self.CompressAlgorithm.ODPS_ZLIB
if isinstance(compress_algo, CompressOption.CompressAlgorithm):
self.algorithm = compress_algo
else:
self.algorithm = CompressOption.CompressAlgorithm(compress_algo.upper())
self.level = level or 1
self.strategy = strategy or 0
_lz4_algorithms = (
CompressOption.CompressAlgorithm.ODPS_LZ4,
CompressOption.CompressAlgorithm.ODPS_ARROW_LZ4,
)
def get_compress_stream(buffer, compress_option=None):
algo = getattr(compress_option, "algorithm", None)
if algo is None or algo == CompressOption.CompressAlgorithm.ODPS_RAW:
return buffer
elif algo == CompressOption.CompressAlgorithm.ODPS_ZLIB:
return DeflateOutputStream(buffer, level=compress_option.level)
elif algo == CompressOption.CompressAlgorithm.ODPS_ZSTD:
return ZstdOutputStream(buffer, level=compress_option.level)
elif algo == CompressOption.CompressAlgorithm.ODPS_SNAPPY:
return SnappyOutputStream(buffer, level=compress_option.level)
elif algo in _lz4_algorithms:
return LZ4OutputStream(buffer, level=compress_option.level)
else:
raise errors.InvalidArgument("Invalid compression algorithm %s." % algo)
def get_decompress_stream(resp, compress_option=None, requests=True):
algo = getattr(compress_option, "algorithm", None)
if algo is None or algo == CompressOption.CompressAlgorithm.ODPS_RAW:
stream_cls = RequestsInputStream # create a file-like object from body
elif algo == CompressOption.CompressAlgorithm.ODPS_ZLIB:
stream_cls = DeflateRequestsInputStream
elif algo == CompressOption.CompressAlgorithm.ODPS_ZSTD:
stream_cls = ZstdRequestsInputStream
elif algo == CompressOption.CompressAlgorithm.ODPS_SNAPPY:
stream_cls = SnappyRequestsInputStream
elif algo in _lz4_algorithms:
stream_cls = LZ4RequestsInputStream
else:
raise errors.InvalidArgument("Invalid compression algorithm %s." % algo)
if not requests:
stream_cls = stream_cls.get_raw_input_stream_class()
return stream_cls(resp)
class CompressOutputStream(object):
def __init__(self, output, level=1):
self._compressor = self._get_compressor(level=level)
self._output = output
def _get_compressor(self, level=1):
raise NotImplementedError
def write(self, data):
if self._compressor:
compressed_data = self._compressor.compress(data)
if compressed_data:
self._output.write(compressed_data)
else:
pass # buffering
else:
self._output.write(data)
def flush(self):
if self._compressor:
remaining = self._compressor.flush()
if remaining:
self._output.write(remaining)
class DeflateOutputStream(CompressOutputStream):
def _get_compressor(self, level=1):
return zlib.compressobj(level)
class SnappyOutputStream(CompressOutputStream):
def _get_compressor(self, level=1):
try:
import snappy
except ImportError:
raise errors.DependencyNotInstalledError(
"python-snappy library is required for snappy support"
)
return snappy.StreamCompressor()
class ZstdOutputStream(CompressOutputStream):
def _get_compressor(self, level=1):
try:
import zstandard
except ImportError:
raise errors.DependencyNotInstalledError(
"zstandard library is required for zstd support"
)
return zstandard.ZstdCompressor().compressobj()
class LZ4OutputStream(CompressOutputStream):
def _get_compressor(self, level=1):
try:
import lz4.frame
except ImportError:
raise errors.DependencyNotInstalledError(
"lz4 library is required for lz4 support"
)
self._begun = False
return lz4.frame.LZ4FrameCompressor(compression_level=level)
def write(self, data):
if not self._begun:
self._output.write(self._compressor.begin())
self._begun = True
super(LZ4OutputStream, self).write(data)
class SimpleInputStream(object):
READ_BLOCK_SIZE = 1024 * 64
def __init__(self, input):
self._input = input
self._internal_buffer = memoryview(b"")
self._buffered_len = 0
self._buffered_pos = 0
self._pos = 0
self._closed = False
@staticmethod
def readable():
return True
def __len__(self):
return self._pos
def read(self, limit):
if self._closed:
raise IOError("closed")
if limit <= self._buffered_len - self._buffered_pos:
mv = self._internal_buffer[self._buffered_pos : self._buffered_pos + limit]
self._buffered_pos += len(mv)
self._pos += len(mv)
return mv_to_bytes(mv)
bufs = list()
size_left = limit
while size_left > 0:
content = self._internal_read(size_left)
if not content:
break
bufs.append(content)
size_left -= len(content)
ret = bytes().join(bufs)
self._pos += len(ret)
return ret
def peek(self):
if self._buffered_pos == self._buffered_len:
self._refill_buffer()
if self._buffered_pos == self._buffered_len:
# still nothing can be read
return None
return self._internal_buffer[self._buffered_pos]
def readinto(self, b):
if self._closed:
raise IOError("closed")
b = cast_memoryview(b)
limit = len(b)
if limit <= self._buffered_len - self._buffered_pos:
mv = self._internal_buffer[self._buffered_pos : self._buffered_pos + limit]
self._buffered_pos += len(mv)
b[:limit] = mv
self._pos += len(mv)
return len(mv)
pos = 0
while pos < limit:
rsize = self._internal_readinto(b, pos)
if not rsize:
break
pos += rsize
self._pos += pos
return pos
def _internal_read(self, limit):
if self._buffered_pos == self._buffered_len:
self._refill_buffer()
mv = self._internal_buffer[self._buffered_pos : self._buffered_pos + limit]
self._buffered_pos += len(mv)
return mv_to_bytes(mv)
def _internal_readinto(self, b, start):
if self._buffered_pos == self._buffered_len:
self._refill_buffer()
size = len(b) - start
mv = self._internal_buffer[self._buffered_pos : self._buffered_pos + size]
size = len(mv)
self._buffered_pos += size
b[start : start + size] = mv
return size
def _refill_buffer(self):
self._buffered_pos = 0
self._buffered_len = 0
buffer = []
while True:
content = self._buffer_next_chunk()
if content is None:
break
if content:
length = len(content)
self._buffered_len += length
buffer.append(content)
break
if len(buffer) == 1:
self._internal_buffer = memoryview(buffer[0])
else:
self._internal_buffer = memoryview(bytes().join(buffer))
def _read_block(self):
content = self._input.read(self.READ_BLOCK_SIZE)
return content if content else None
def _buffer_next_chunk(self):
return self._read_block()
@property
def closed(self):
return self._closed
def close(self):
self._closed = True
class DecompressInputStream(SimpleInputStream):
def __init__(self, input):
super(DecompressInputStream, self).__init__(input)
self._decompressor = self._get_decompressor()
def _get_decompressor(self):
raise NotImplementedError
def _buffer_next_chunk(self):
data = self._read_block()
if data is None:
return None
if data:
return self._decompressor.decompress(data)
else:
return self._decompressor.flush()
class DeflateInputStream(DecompressInputStream):
def _get_decompressor(self):
return zlib.decompressobj(zlib.MAX_WBITS)
class SnappyInputStream(DecompressInputStream):
def _get_decompressor(self):
try:
import snappy
except ImportError:
raise errors.DependencyNotInstalledError(
"python-snappy library is required for snappy support"
)
return snappy.StreamDecompressor()
class ZstdInputStream(DecompressInputStream):
def _get_decompressor(self):
try:
import zstandard
except ImportError:
raise errors.DependencyNotInstalledError(
"zstandard library is required for zstd support"
)
return zstandard.ZstdDecompressor().decompressobj()
class LZ4InputStream(DecompressInputStream):
def _get_decompressor(self):
try:
import lz4.frame
except ImportError:
raise errors.DependencyNotInstalledError(
"lz4 library is required for lz4 support"
)
return lz4.frame.LZ4FrameDecompressor()
class RawRequestsStreamMixin(object):
_decode_content = False
@classmethod
def get_raw_input_stream_class(cls):
for base in cls.__mro__:
if (
base is not cls
and base is not RawRequestsStreamMixin
and issubclass(cls, SimpleInputStream)
):
return base
return None
def _read_block(self):
try:
content = self._input.raw.read(
self.READ_BLOCK_SIZE, decode_content=self._decode_content
)
return content if content else None
except ReadTimeoutError:
if callable(options.tunnel_read_timeout_callback):
options.tunnel_read_timeout_callback(*sys.exc_info())
raise
class RequestsInputStream(RawRequestsStreamMixin, SimpleInputStream):
_decode_content = True
# Requests automatically decompress gzip data!
class DeflateRequestsInputStream(RawRequestsStreamMixin, SimpleInputStream):
_decode_content = True
@classmethod
def get_raw_input_stream_class(cls):
return DeflateInputStream
class SnappyRequestsInputStream(RawRequestsStreamMixin, SnappyInputStream):
pass
class ZstdRequestsInputStream(RawRequestsStreamMixin, ZstdInputStream):
pass
class LZ4RequestsInputStream(RawRequestsStreamMixin, LZ4InputStream):
pass