odps/apis/storage_api/storage_api.py (597 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Client for interacting with the MaxCompute Storage API.""" import collections import json import logging from enum import Enum from hashlib import md5 from io import BytesIO, IOBase from typing import List, Union try: import pyarrow as pa except ImportError: pa = None from requests import codes from ... import ODPS, options, serializers from ...models import Table from ...models.core import JSONRemoteModel from ...tunnel.io import RequestsIO from ...utils import to_binary STORAGE_VERSION = "1" URL_PREFIX = "/api/storage/v" + STORAGE_VERSION logger = logging.getLogger(__name__) class Status(Enum): INIT = "INIT" OK = "OK" WAIT = "WAIT" RUNNING = "RUNNING" class SessionStatus(Enum): INIT = "INIT" NORMAL = "NORMAL" CRITICAL = "CRITICAL" EXPIRED = "EXPIRED" COMMITTING = "COMMITTING" COMMITTED = "COMMITTED" class SplitOptions(JSONRemoteModel): class SplitMode(str, Enum): SIZE = "Size" PARALLELISM = "Parallelism" ROW_OFFSET = "RowOffset" BUCKET = "Bucket" split_mode = serializers.JSONNodeField( "SplitMode", parse_callback=lambda s: SplitOptions.SplitMode(s) ) split_number = serializers.JSONNodeField("SplitNumber") cross_partition = serializers.JSONNodeField("CrossPartition") def __init__(self, **kwargs): super(SplitOptions, self).__init__(**kwargs) self.split_mode = self.split_mode or SplitOptions.SplitMode.SIZE self.split_number = self.split_number or 256 * 1024 * 1024 self.cross_partition = ( self.cross_partition if self.cross_partition is not None else True ) @classmethod def get_default_options(self, mode): options = SplitOptions() options.cross_partition = True if mode == SplitOptions.SplitMode.SIZE: options.split_mode = SplitOptions.SplitMode.SIZE options.split_number = 256 * 1024 * 1024 elif mode == SplitOptions.SplitMode.PARALLELISM: options.split_mode = SplitOptions.SplitMode.PARALLELISM options.split_number = 32 elif mode == SplitOptions.SplitMode.ROW_OFFSET: options.split_mode = SplitOptions.SplitMode.ROW_OFFSET options.split_number = 0 elif mode == SplitOptions.SplitMode.BUCKET: options.split_mode = SplitOptions.SplitMode.BUCKET return options class ArrowOptions(JSONRemoteModel): class TimestampUnit(str, Enum): SECOND = "second" MILLI = "milli" MICRO = "micro" NANO = "nano" timestamp_unit = serializers.JSONNodeField( "TimestampUnit", parse_callback=lambda s: ArrowOptions.TimestampUnit(s) ) date_time_unit = serializers.JSONNodeField( "DatetimeUnit", parse_callback=lambda s: ArrowOptions.TimestampUnit(s) ) def __init__(self, **kwargs): super(ArrowOptions, self).__init__(**kwargs) self.timestamp_unit = self.timestamp_unit or ArrowOptions.TimestampUnit.NANO self.date_time_unit = self.date_time_unit or ArrowOptions.TimestampUnit.MILLI class Column(JSONRemoteModel): name = serializers.JSONNodeField("Name") type = serializers.JSONNodeField("Type") comment = serializers.JSONNodeField("Comment") nullable = serializers.JSONNodeField("Nullable") class DataSchema(JSONRemoteModel): data_columns = serializers.JSONNodesReferencesField(Column, "DataColumns") partition_columns = serializers.JSONNodesReferencesField(Column, "PartitionColumns") class DataFormat(JSONRemoteModel): type = serializers.JSONNodeField("Type") version = serializers.JSONNodeField("Version") class DynamicPartitionOptions(JSONRemoteModel): invalid_strategy = serializers.JSONNodeField("InvalidStrategy") invalid_limit = serializers.JSONNodeField("InvalidLimit") dynamic_partition_limit = serializers.JSONNodeField("DynamicPartitionLimit") def __init__(self, **kwargs): super(DynamicPartitionOptions, self).__init__(**kwargs) self.invalid_strategy = self.invalid_strategy or "Exception" self.invalid_limit = self.invalid_limit or 1 self.dynamic_partition_limit = self.dynamic_partition_limit or 512 class Order(JSONRemoteModel): name = serializers.JSONNodeField("Name") sort_direction = serializers.JSONNodeField("SortDirection") class RequiredDistribution(JSONRemoteModel): type = serializers.JSONNodeField("Type") cluster_keys = serializers.JSONNodeField("ClusterKeys") buckets_number = serializers.JSONNodeField("BucketsNumber") class Compression(Enum): UNCOMPRESSED = 0 ZSTD = 1 LZ4_FRAME = 2 def to_string(self): if self.value == 0: return None elif self.value == 1: return "zstd" elif self.value == 2: return "lz4" else: return "unknown" class TableBatchScanRequest(serializers.JSONSerializableModel): required_data_columns = serializers.JSONNodeField("RequiredDataColumns") required_partition_columns = serializers.JSONNodeField("RequiredPartitionColumns") required_partitions = serializers.JSONNodeField("RequiredPartitions") required_bucket_ids = serializers.JSONNodeField("RequiredBucketIds") split_options = serializers.JSONNodeReferenceField(SplitOptions, "SplitOptions") arrow_options = serializers.JSONNodeReferenceField(ArrowOptions, "ArrowOptions") filter_predicate = serializers.JSONNodeField("FilterPredicate") def __init__(self, **kwargs): super(TableBatchScanRequest, self).__init__(**kwargs) self.required_data_columns = self.required_data_columns or [] self.required_partition_columns = self.required_partition_columns or [] self.required_partitions = self.required_partitions or [] self.required_bucket_ids = self.required_bucket_ids or [] self.split_options = self.split_options or SplitOptions() self.arrow_options = self.arrow_options or ArrowOptions() self.filter_predicate = self.filter_predicate or "" class TableBatchScanResponse(serializers.JSONSerializableModel): __slots__ = ["status", "request_id"] session_id = serializers.JSONNodeField("SessionId") session_type = serializers.JSONNodeField("SessionType") session_status = serializers.JSONNodeField( "SessionStatus", parse_callback=lambda s: SessionStatus(s.upper()) ) expiration_time = serializers.JSONNodeField("ExpirationTime") split_count = serializers.JSONNodeField("SplitsCount") record_count = serializers.JSONNodeField("RecordCount") data_schema = serializers.JSONNodeReferenceField(DataSchema, "DataSchema") supported_data_format = serializers.JSONNodesReferencesField( DataFormat, "SupportedDataFormat" ) def __init__(self): super(TableBatchScanResponse, self).__init__() self.status = Status.INIT self.request_id = "" class SessionRequest(object): def __init__(self, session_id, refresh=False): self.session_id = session_id self.refresh = refresh class TableBatchWriteRequest(serializers.JSONSerializableModel): dynamic_partition_options = serializers.JSONNodeReferenceField( DynamicPartitionOptions, "DynamicPartitionOptions" ) arrow_options = serializers.JSONNodeReferenceField(ArrowOptions, "ArrowOptions") overwrite = serializers.JSONNodeField("Overwrite") partition_spec = serializers.JSONNodeField("PartitionSpec") support_write_cluster = serializers.JSONNodeField("SupportWriteCluster") def __init__(self, **kwargs): super(TableBatchWriteRequest, self).__init__(**kwargs) self.partition_spec = self.partition_spec or "" self.arrow_options = self.arrow_options or ArrowOptions() self.dynamic_partition_options = ( self.dynamic_partition_options or DynamicPartitionOptions() ) self.overwrite = self.overwrite if self.overwrite is not None else True self.support_write_cluster = self.support_write_cluster or False class TableBatchWriteResponse(serializers.JSONSerializableModel): __slots__ = ["status", "request_id"] session_status = serializers.JSONNodeField( "SessionStatus", parse_callback=lambda s: SessionStatus(s.upper()) ) expiration_time = serializers.JSONNodeField("ExpirationTime") session_id = serializers.JSONNodeField("SessionId") data_schema = serializers.JSONNodeReferenceField(DataSchema, "DataSchema") supported_data_format = serializers.JSONNodesReferencesField( DataFormat, "SupportedDataFormat" ) max_block_num = serializers.JSONNodeField("MaxBlockNumber") required_ordering = serializers.JSONNodesReferencesField(Order, "RequiredOrdering") required_distribution = serializers.JSONNodeReferenceField( RequiredDistribution, "RequiredDistribution" ) def __init__(self): super(TableBatchWriteResponse, self).__init__() self.status = Status.INIT self.request_id = "" class ReadRowsRequest(object): def __init__( self, session_id, split_index=0, row_index=0, row_count=0, max_batch_rows=4096, compression=Compression.LZ4_FRAME, data_format=DataFormat(), ): self.session_id = session_id self.split_index = split_index self.row_index = row_index self.row_count = row_count self.max_batch_rows = max_batch_rows self.compression = compression self.data_format = data_format class ReadRowsResponse(object): def __init__(self): self.status = Status.INIT self.request_id = "" class WriteRowsRequest(object): def __init__( self, session_id, block_number=0, attempt_number=0, bucket_id=0, compression=Compression.LZ4_FRAME, data_format=DataFormat(), ): self.session_id = session_id self.block_number = block_number self.attempt_number = attempt_number self.bucket_id = bucket_id self.compression = compression self.data_format = data_format class WriteRowsResponse(object): def __init__(self): self.status = Status.INIT self.request_id = "" self.commit_message = "" def update_request_id(response, resp): if "x-odps-request-id" in resp.headers: response.request_id = resp.headers["x-odps-request-id"] class StreamReader(IOBase): """Stream reader.""" def __init__(self, download): self._stopped = False raw_reader = download() self._raw_reader = raw_reader # need to confirm read size self._chunk_size = 65536 self._buffers = collections.deque() def readable(self): """Check whether this stream reader has been closed or not. Returns: Readable or not. """ return not self._stopped def _read_chunk(self): buf = self._raw_reader.raw.read(self._chunk_size) return buf def _fill_next_buffer(self): data = self._read_chunk() if len(data) == 0: return self._buffers.append(BytesIO(data)) def read(self, nbytes=None): """Read stream data from the server. Args: nbytes: The number of bytes to be read. All data will be read at once if set to None. Returns: Stream data. None means all the data has been read or there is error occurred. """ if self._stopped: return b"" total_size = 0 bufs = [] while nbytes is None or total_size < nbytes: if not self._buffers: self._fill_next_buffer() if not self._buffers: break to_read = nbytes - total_size if nbytes is not None else None buf = self._buffers[0].read(to_read) if not buf: self._buffers.popleft() else: bufs.append(buf) total_size += len(buf) return b"".join(bufs) def get_status(self): """Get the status of the stream reader. Returns: Status.OK or Status.RUNNING. """ if not self._stopped: return Status.RUNNING else: return Status.OK def get_request_id(self): """Get the request id. Returns: Request id. """ if not self._stopped: logger.error("The reader is not closed yet, please wait") return None if ( self._raw_reader is not None and "x-odps-request-id" in self._raw_reader.headers ): return self._raw_reader.headers["x-odps-request-id"] else: return None def close(self): """If there is no data can be read from server, it will be called to close the stream reader.""" self._stopped = True class StreamWriter(IOBase): """Stream writer.""" def __init__(self, upload): self._req_io = RequestsIO(upload, chunk_size=options.chunk_size) self._req_io.start() self._res = None self._stopped = False def writable(self): """Check whether this stream writer has been closed or not. Returns: Writable or not. """ return not self._stopped def write(self, data): """Write data to the server. Returns: Success or not. """ if self._stopped: return False self._req_io.write(data) return True def finish(self): """The stream writer is not expected to write data if finish has been called. Returns: Commit message returned from the server. User should bring this message to do the write session commit. Success or not. """ self._stopped = True self._res = self._req_io.finish() if self._res is not None and self._res.status_code == codes["ok"]: resp_json = self._res.json() return resp_json["CommitMessage"], True else: return None, False def get_status(self): """Get the status of this stream writer. Returns: Status.OK or Status.RUNNING. """ if not self._stopped: return Status.RUNNING else: return Status.OK def get_request_id(self): """Get the request id. Returns: Request id. """ if not self._stopped: logger.error("The writer is not closed yet, please close first") return None if self._res is not None and "x-odps-request-id" in self._res.headers: return self._res.headers["x-odps-request-id"] else: return None class ArrowReader(object): """Arrow batch reader.""" def __init__(self, stream_reader): if pa is None: raise ValueError("To use arrow reader you need to install pyarrow") self._reader = stream_reader self._arrow_stream = None def _read_next_batch(self): if self._arrow_stream is None: self._arrow_stream = pa.ipc.open_stream(self._reader) try: batch = self._arrow_stream.read_next_batch() return batch except StopIteration: return None def read(self): """Read arrow batch from the server. Returns: Arrow record batch. None means all the data has been read or there is error occurred. """ if not self._reader.readable(): logger.error("Reader has been closed") return None batch = self._read_next_batch() if batch is None: self._reader.close() return batch def get_status(self): """Get the status of the arrow batch reader. Returns: Status.OK or Status.RUNNING. """ return self._reader.get_status() def get_request_id(self): """Get the request id. Returns: Request id. """ return self._reader.get_request_id() class ArrowWriter(object): """Arrow batch writer.""" def __init__(self, stream_writer, compression): self._arrow_writer = None self._compression = compression self._sink = stream_writer def write(self, record_batch): """Write one arrow batch to the server. Args: record_batch: The arrow batch to be written. Returns: Success or not. """ if not self._sink.writable(): logger.error("Writer has been closed") return False if self._arrow_writer is None: self._arrow_writer = pa.ipc.new_stream( self._sink, record_batch.schema, options=pa.ipc.IpcWriteOptions( compression=self._compression.to_string() ), ) self._arrow_writer.write_batch(record_batch) if not self._sink.writable(): logger.error("Writer has been closed as exception occurred") return False return True def finish(self): """The arrow writer is not expected to write data if finish has been called. Returns: Commit message returned from the server. User should bring this message to do the write session commit. Success ot not. """ if self._arrow_writer: self._arrow_writer.close() return self._sink.finish() def get_status(self): """Get the status of the arrow batch writer. Returns: Status.OK or Status.RUNNING. """ return self._sink.get_status() def get_request_id(self): """Get the request id. Returns: Request id. """ return self._sink.get_request_id() class StorageApiClient(object): """Client to bundle configuration needed for API requests.""" def __init__( self, odps: ODPS, table: Table, rest_endpoint: str = None, quota_name: str = None, tags: Union[None, str, List[str]] = None, ): if not isinstance(odps, ODPS) or not isinstance(table, Table): raise ValueError("Please input odps configuration") self._odps = odps self._table = table self._quota_name = quota_name self._rest_endpoint = rest_endpoint self._tunnel_rest = None self._tags = tags or options.tunnel.tags if isinstance(self._tags, str): self._tags = self._tags.split(",") @property def table(self): return self._table @property def tunnel_rest(self): if self._tunnel_rest is not None: return self._tunnel_rest from ...tunnel.tabletunnel import TableTunnel tunnel = TableTunnel( self._odps, endpoint=self._rest_endpoint, quota_name=self._quota_name ) self._tunnel_rest = tunnel.tunnel_rest return self._tunnel_rest def _get_resource(self, *args) -> str: endpoint = self.tunnel_rest.endpoint + URL_PREFIX url = self._table.table_resource(endpoint=endpoint, force_schema=True) return "/".join([url] + list(args)) def _fill_common_headers(self, raw_headers=None): headers = raw_headers or {} if self._tags: headers["odps-tunnel-tags"] = ",".join(self._tags) return headers def create_read_session( self, request: TableBatchScanRequest ) -> TableBatchScanResponse: """Create a read session. Args: request: Table split parameters sent to the server. Returns: Read session response returned from the server. """ if not isinstance(request, TableBatchScanRequest): raise ValueError( "Use TableBatchScanRequest class to build request for create read session interface" ) json_str = request.serialize() url = self._get_resource("sessions") headers = self._fill_common_headers({"Content-Type": "application/json"}) if json_str != "": headers["Content-MD5"] = md5(to_binary(json_str)).hexdigest() params = {"session_type": "batch_read"} if self._quota_name: params["quotaName"] = self._quota_name res = self.tunnel_rest.post(url, data=json_str, params=params, headers=headers) response = TableBatchScanResponse() response.parse(res, obj=response) response.status = ( Status.OK if res.status_code == codes["created"] else Status.WAIT ) update_request_id(response, res) return response def get_read_session(self, request: SessionRequest) -> TableBatchScanResponse: """Get the read session. Args: request: Read session parameters sent to the server. Returns: Read session response returned from the server. """ if not isinstance(request, SessionRequest): raise ValueError( "Use SessionRequest class to build request for get read session interface" ) url = self._get_resource("sessions", request.session_id) headers = self._fill_common_headers() params = {"session_type": "batch_read"} if self._quota_name: params["quotaName"] = self._quota_name if request.refresh: params["session_refresh"] = "true" res = self.tunnel_rest.get(url, params=params, headers=headers) response = TableBatchScanResponse() response.parse(res, obj=response) response.status = Status.OK update_request_id(response, res) return response def read_rows_stream(self, request: ReadRowsRequest) -> StreamReader: """Read one split of the read session. Stream means the data read from server is serialized arrow record batch. Args: request: Batch split parameters sent to the server. Returns: Stream reader. """ if not isinstance(request, ReadRowsRequest): raise ValueError( "Use ReadRowsRequest class to build request for read rows interface" ) url = self._get_resource("data") headers = self._fill_common_headers( { "Connection": "Keep-Alive", "Accept-Encoding": request.compression.name if request.compression != Compression.UNCOMPRESSED else "", } ) params = { "session_id": request.session_id, "max_batch_rows": str(request.max_batch_rows), "split_index": str(request.split_index), "row_count": str(request.row_count), "row_index": str(request.row_index), } if self._quota_name: params["quotaName"] = self._quota_name if request.data_format.type is not None: params["data_format_type"] = request.data_format.type if request.data_format.version is not None: params["data_format_version"] = request.data_format.version def download(): return self.tunnel_rest.get( url, stream=True, params=params, headers=headers ) return StreamReader(download) def create_write_session( self, request: TableBatchWriteRequest ) -> TableBatchWriteResponse: """Create a write session. Args: request: Table write parameters sent to the server. Returns: Write session response returned from the server. """ if not isinstance(request, TableBatchWriteRequest): raise ValueError( "Use TableBatchWriteRequest class to build request for create write session interface" ) json_str = request.serialize() url = self._get_resource("sessions") headers = self._fill_common_headers({"Content-Type": "application/json"}) if json_str != "": headers["Content-MD5"] = md5(to_binary(json_str)).hexdigest() params = {"session_type": "batch_write"} if self._quota_name: params["quotaName"] = self._quota_name res = self.tunnel_rest.post(url, data=json_str, params=params, headers=headers) response = TableBatchWriteResponse() response.parse(res, obj=response) response.status = Status.OK update_request_id(response, res) return response def get_write_session(self, request: SessionRequest) -> TableBatchWriteResponse: """Get a write session. Args: request: Write session parameters sent to the server. Returns: Write session response returned from the server. """ if not isinstance(request, SessionRequest): raise ValueError( "Use SessionRequest class to build request for get write session interface" ) url = self._get_resource("sessions", request.session_id) headers = self._fill_common_headers() params = {"session_type": "batch_write"} if self._quota_name: params["quotaName"] = self._quota_name res = self.tunnel_rest.get(url, params=params, headers=headers) response = TableBatchWriteResponse() response.parse(res, obj=response) response.status = Status.OK update_request_id(response, res) return response def write_rows_stream(self, request: WriteRowsRequest) -> StreamWriter: """Write one block of data to the write session. Stream means the data written to server is serialized arrow record batch. Args: request: Batch write parameters sent to the server. Returns: Stream writer. """ if not isinstance(request, WriteRowsRequest): raise ValueError( "Use WriteRowsRequest class to build request for write rows interface" ) url = self._get_resource("sessions", request.session_id, "data") headers = self._fill_common_headers( {"Content-Type": "application/octet-stream", "Transfer-Encoding": "chunked"} ) params = { "attempt_number": str(request.attempt_number), "block_number": str(request.block_number), } if self._quota_name: params["quotaName"] = self._quota_name if request.data_format.type != None: params["data_format_type"] = str(request.data_format.type) if request.data_format.version != None: params["data_format_version"] = str(request.data_format.version) def upload(data): return self.tunnel_rest.post(url, data=data, params=params, headers=headers) return StreamWriter(upload) def commit_write_session( self, request: SessionRequest, commit_msg: list ) -> TableBatchWriteResponse: """Commit the write session after write the last stream data. Args: request: Commit write session parameters sent to the server. commit_msg: Commit messages collected from the write_rows_stream(). Returns: Write session response returned from the server. """ if not isinstance(request, SessionRequest): raise ValueError( "Use SessionRequest class to build request for commit write session interface" ) if not isinstance(commit_msg, list): raise ValueError("Use list for commit message") commit_message_dict = {"CommitMessages": commit_msg} json_str = json.dumps(commit_message_dict) url = self._get_resource("commit") headers = self._fill_common_headers({"Content-Type": "application/json"}) params = {"session_id": request.session_id} if self._quota_name: params["quotaName"] = self._quota_name res = self.tunnel_rest.post(url, data=json_str, params=params, headers=headers) response = TableBatchWriteResponse() response.parse(res, obj=response) response.status = ( Status.OK if res.status_code == codes["created"] else Status.WAIT ) update_request_id(response, res) return response class StorageApiArrowClient(StorageApiClient): """Arrow batch client to bundle configuration needed for API requests.""" def read_rows_arrow(self, request: ReadRowsRequest) -> ArrowReader: """Read one split of the read session. Args: request: Arrow batch split parameters sent to the server. Returns: Arrow batch reader. """ if not isinstance(request, ReadRowsRequest): raise ValueError( "Use ReadRowsRequest class to build request for read rows interface" ) return ArrowReader(self.read_rows_stream(request)) def write_rows_arrow(self, request: WriteRowsRequest) -> ArrowWriter: """Write one block of data to the write session. Args: request: Arrow batch write parameters sent to the server. Returns: Arrow batch writer. """ if not isinstance(request, WriteRowsRequest): raise ValueError( "Use WriteRowsRequest class to build request for write rows interface" ) return ArrowWriter(self.write_rows_stream(request), request.compression)