odps/apis/storage_api/storage_api.py (597 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Client for interacting with the MaxCompute Storage API."""
import collections
import json
import logging
from enum import Enum
from hashlib import md5
from io import BytesIO, IOBase
from typing import List, Union
try:
import pyarrow as pa
except ImportError:
pa = None
from requests import codes
from ... import ODPS, options, serializers
from ...models import Table
from ...models.core import JSONRemoteModel
from ...tunnel.io import RequestsIO
from ...utils import to_binary
STORAGE_VERSION = "1"
URL_PREFIX = "/api/storage/v" + STORAGE_VERSION
logger = logging.getLogger(__name__)
class Status(Enum):
INIT = "INIT"
OK = "OK"
WAIT = "WAIT"
RUNNING = "RUNNING"
class SessionStatus(Enum):
INIT = "INIT"
NORMAL = "NORMAL"
CRITICAL = "CRITICAL"
EXPIRED = "EXPIRED"
COMMITTING = "COMMITTING"
COMMITTED = "COMMITTED"
class SplitOptions(JSONRemoteModel):
class SplitMode(str, Enum):
SIZE = "Size"
PARALLELISM = "Parallelism"
ROW_OFFSET = "RowOffset"
BUCKET = "Bucket"
split_mode = serializers.JSONNodeField(
"SplitMode", parse_callback=lambda s: SplitOptions.SplitMode(s)
)
split_number = serializers.JSONNodeField("SplitNumber")
cross_partition = serializers.JSONNodeField("CrossPartition")
def __init__(self, **kwargs):
super(SplitOptions, self).__init__(**kwargs)
self.split_mode = self.split_mode or SplitOptions.SplitMode.SIZE
self.split_number = self.split_number or 256 * 1024 * 1024
self.cross_partition = (
self.cross_partition if self.cross_partition is not None else True
)
@classmethod
def get_default_options(self, mode):
options = SplitOptions()
options.cross_partition = True
if mode == SplitOptions.SplitMode.SIZE:
options.split_mode = SplitOptions.SplitMode.SIZE
options.split_number = 256 * 1024 * 1024
elif mode == SplitOptions.SplitMode.PARALLELISM:
options.split_mode = SplitOptions.SplitMode.PARALLELISM
options.split_number = 32
elif mode == SplitOptions.SplitMode.ROW_OFFSET:
options.split_mode = SplitOptions.SplitMode.ROW_OFFSET
options.split_number = 0
elif mode == SplitOptions.SplitMode.BUCKET:
options.split_mode = SplitOptions.SplitMode.BUCKET
return options
class ArrowOptions(JSONRemoteModel):
class TimestampUnit(str, Enum):
SECOND = "second"
MILLI = "milli"
MICRO = "micro"
NANO = "nano"
timestamp_unit = serializers.JSONNodeField(
"TimestampUnit", parse_callback=lambda s: ArrowOptions.TimestampUnit(s)
)
date_time_unit = serializers.JSONNodeField(
"DatetimeUnit", parse_callback=lambda s: ArrowOptions.TimestampUnit(s)
)
def __init__(self, **kwargs):
super(ArrowOptions, self).__init__(**kwargs)
self.timestamp_unit = self.timestamp_unit or ArrowOptions.TimestampUnit.NANO
self.date_time_unit = self.date_time_unit or ArrowOptions.TimestampUnit.MILLI
class Column(JSONRemoteModel):
name = serializers.JSONNodeField("Name")
type = serializers.JSONNodeField("Type")
comment = serializers.JSONNodeField("Comment")
nullable = serializers.JSONNodeField("Nullable")
class DataSchema(JSONRemoteModel):
data_columns = serializers.JSONNodesReferencesField(Column, "DataColumns")
partition_columns = serializers.JSONNodesReferencesField(Column, "PartitionColumns")
class DataFormat(JSONRemoteModel):
type = serializers.JSONNodeField("Type")
version = serializers.JSONNodeField("Version")
class DynamicPartitionOptions(JSONRemoteModel):
invalid_strategy = serializers.JSONNodeField("InvalidStrategy")
invalid_limit = serializers.JSONNodeField("InvalidLimit")
dynamic_partition_limit = serializers.JSONNodeField("DynamicPartitionLimit")
def __init__(self, **kwargs):
super(DynamicPartitionOptions, self).__init__(**kwargs)
self.invalid_strategy = self.invalid_strategy or "Exception"
self.invalid_limit = self.invalid_limit or 1
self.dynamic_partition_limit = self.dynamic_partition_limit or 512
class Order(JSONRemoteModel):
name = serializers.JSONNodeField("Name")
sort_direction = serializers.JSONNodeField("SortDirection")
class RequiredDistribution(JSONRemoteModel):
type = serializers.JSONNodeField("Type")
cluster_keys = serializers.JSONNodeField("ClusterKeys")
buckets_number = serializers.JSONNodeField("BucketsNumber")
class Compression(Enum):
UNCOMPRESSED = 0
ZSTD = 1
LZ4_FRAME = 2
def to_string(self):
if self.value == 0:
return None
elif self.value == 1:
return "zstd"
elif self.value == 2:
return "lz4"
else:
return "unknown"
class TableBatchScanRequest(serializers.JSONSerializableModel):
required_data_columns = serializers.JSONNodeField("RequiredDataColumns")
required_partition_columns = serializers.JSONNodeField("RequiredPartitionColumns")
required_partitions = serializers.JSONNodeField("RequiredPartitions")
required_bucket_ids = serializers.JSONNodeField("RequiredBucketIds")
split_options = serializers.JSONNodeReferenceField(SplitOptions, "SplitOptions")
arrow_options = serializers.JSONNodeReferenceField(ArrowOptions, "ArrowOptions")
filter_predicate = serializers.JSONNodeField("FilterPredicate")
def __init__(self, **kwargs):
super(TableBatchScanRequest, self).__init__(**kwargs)
self.required_data_columns = self.required_data_columns or []
self.required_partition_columns = self.required_partition_columns or []
self.required_partitions = self.required_partitions or []
self.required_bucket_ids = self.required_bucket_ids or []
self.split_options = self.split_options or SplitOptions()
self.arrow_options = self.arrow_options or ArrowOptions()
self.filter_predicate = self.filter_predicate or ""
class TableBatchScanResponse(serializers.JSONSerializableModel):
__slots__ = ["status", "request_id"]
session_id = serializers.JSONNodeField("SessionId")
session_type = serializers.JSONNodeField("SessionType")
session_status = serializers.JSONNodeField(
"SessionStatus", parse_callback=lambda s: SessionStatus(s.upper())
)
expiration_time = serializers.JSONNodeField("ExpirationTime")
split_count = serializers.JSONNodeField("SplitsCount")
record_count = serializers.JSONNodeField("RecordCount")
data_schema = serializers.JSONNodeReferenceField(DataSchema, "DataSchema")
supported_data_format = serializers.JSONNodesReferencesField(
DataFormat, "SupportedDataFormat"
)
def __init__(self):
super(TableBatchScanResponse, self).__init__()
self.status = Status.INIT
self.request_id = ""
class SessionRequest(object):
def __init__(self, session_id, refresh=False):
self.session_id = session_id
self.refresh = refresh
class TableBatchWriteRequest(serializers.JSONSerializableModel):
dynamic_partition_options = serializers.JSONNodeReferenceField(
DynamicPartitionOptions, "DynamicPartitionOptions"
)
arrow_options = serializers.JSONNodeReferenceField(ArrowOptions, "ArrowOptions")
overwrite = serializers.JSONNodeField("Overwrite")
partition_spec = serializers.JSONNodeField("PartitionSpec")
support_write_cluster = serializers.JSONNodeField("SupportWriteCluster")
def __init__(self, **kwargs):
super(TableBatchWriteRequest, self).__init__(**kwargs)
self.partition_spec = self.partition_spec or ""
self.arrow_options = self.arrow_options or ArrowOptions()
self.dynamic_partition_options = (
self.dynamic_partition_options or DynamicPartitionOptions()
)
self.overwrite = self.overwrite if self.overwrite is not None else True
self.support_write_cluster = self.support_write_cluster or False
class TableBatchWriteResponse(serializers.JSONSerializableModel):
__slots__ = ["status", "request_id"]
session_status = serializers.JSONNodeField(
"SessionStatus", parse_callback=lambda s: SessionStatus(s.upper())
)
expiration_time = serializers.JSONNodeField("ExpirationTime")
session_id = serializers.JSONNodeField("SessionId")
data_schema = serializers.JSONNodeReferenceField(DataSchema, "DataSchema")
supported_data_format = serializers.JSONNodesReferencesField(
DataFormat, "SupportedDataFormat"
)
max_block_num = serializers.JSONNodeField("MaxBlockNumber")
required_ordering = serializers.JSONNodesReferencesField(Order, "RequiredOrdering")
required_distribution = serializers.JSONNodeReferenceField(
RequiredDistribution, "RequiredDistribution"
)
def __init__(self):
super(TableBatchWriteResponse, self).__init__()
self.status = Status.INIT
self.request_id = ""
class ReadRowsRequest(object):
def __init__(
self,
session_id,
split_index=0,
row_index=0,
row_count=0,
max_batch_rows=4096,
compression=Compression.LZ4_FRAME,
data_format=DataFormat(),
):
self.session_id = session_id
self.split_index = split_index
self.row_index = row_index
self.row_count = row_count
self.max_batch_rows = max_batch_rows
self.compression = compression
self.data_format = data_format
class ReadRowsResponse(object):
def __init__(self):
self.status = Status.INIT
self.request_id = ""
class WriteRowsRequest(object):
def __init__(
self,
session_id,
block_number=0,
attempt_number=0,
bucket_id=0,
compression=Compression.LZ4_FRAME,
data_format=DataFormat(),
):
self.session_id = session_id
self.block_number = block_number
self.attempt_number = attempt_number
self.bucket_id = bucket_id
self.compression = compression
self.data_format = data_format
class WriteRowsResponse(object):
def __init__(self):
self.status = Status.INIT
self.request_id = ""
self.commit_message = ""
def update_request_id(response, resp):
if "x-odps-request-id" in resp.headers:
response.request_id = resp.headers["x-odps-request-id"]
class StreamReader(IOBase):
"""Stream reader."""
def __init__(self, download):
self._stopped = False
raw_reader = download()
self._raw_reader = raw_reader
# need to confirm read size
self._chunk_size = 65536
self._buffers = collections.deque()
def readable(self):
"""Check whether this stream reader has been closed or not.
Returns:
Readable or not.
"""
return not self._stopped
def _read_chunk(self):
buf = self._raw_reader.raw.read(self._chunk_size)
return buf
def _fill_next_buffer(self):
data = self._read_chunk()
if len(data) == 0:
return
self._buffers.append(BytesIO(data))
def read(self, nbytes=None):
"""Read stream data from the server.
Args:
nbytes: The number of bytes to be read. All data will be read at once if set to None.
Returns:
Stream data. None means all the data has been read or there is error occurred.
"""
if self._stopped:
return b""
total_size = 0
bufs = []
while nbytes is None or total_size < nbytes:
if not self._buffers:
self._fill_next_buffer()
if not self._buffers:
break
to_read = nbytes - total_size if nbytes is not None else None
buf = self._buffers[0].read(to_read)
if not buf:
self._buffers.popleft()
else:
bufs.append(buf)
total_size += len(buf)
return b"".join(bufs)
def get_status(self):
"""Get the status of the stream reader.
Returns:
Status.OK or Status.RUNNING.
"""
if not self._stopped:
return Status.RUNNING
else:
return Status.OK
def get_request_id(self):
"""Get the request id.
Returns:
Request id.
"""
if not self._stopped:
logger.error("The reader is not closed yet, please wait")
return None
if (
self._raw_reader is not None
and "x-odps-request-id" in self._raw_reader.headers
):
return self._raw_reader.headers["x-odps-request-id"]
else:
return None
def close(self):
"""If there is no data can be read from server, it will be called to close the stream reader."""
self._stopped = True
class StreamWriter(IOBase):
"""Stream writer."""
def __init__(self, upload):
self._req_io = RequestsIO(upload, chunk_size=options.chunk_size)
self._req_io.start()
self._res = None
self._stopped = False
def writable(self):
"""Check whether this stream writer has been closed or not.
Returns:
Writable or not.
"""
return not self._stopped
def write(self, data):
"""Write data to the server.
Returns:
Success or not.
"""
if self._stopped:
return False
self._req_io.write(data)
return True
def finish(self):
"""The stream writer is not expected to write data if finish has been called.
Returns:
Commit message returned from the server. User should bring this message to do the write session commit.
Success or not.
"""
self._stopped = True
self._res = self._req_io.finish()
if self._res is not None and self._res.status_code == codes["ok"]:
resp_json = self._res.json()
return resp_json["CommitMessage"], True
else:
return None, False
def get_status(self):
"""Get the status of this stream writer.
Returns:
Status.OK or Status.RUNNING.
"""
if not self._stopped:
return Status.RUNNING
else:
return Status.OK
def get_request_id(self):
"""Get the request id.
Returns:
Request id.
"""
if not self._stopped:
logger.error("The writer is not closed yet, please close first")
return None
if self._res is not None and "x-odps-request-id" in self._res.headers:
return self._res.headers["x-odps-request-id"]
else:
return None
class ArrowReader(object):
"""Arrow batch reader."""
def __init__(self, stream_reader):
if pa is None:
raise ValueError("To use arrow reader you need to install pyarrow")
self._reader = stream_reader
self._arrow_stream = None
def _read_next_batch(self):
if self._arrow_stream is None:
self._arrow_stream = pa.ipc.open_stream(self._reader)
try:
batch = self._arrow_stream.read_next_batch()
return batch
except StopIteration:
return None
def read(self):
"""Read arrow batch from the server.
Returns:
Arrow record batch. None means all the data has been read or there is error occurred.
"""
if not self._reader.readable():
logger.error("Reader has been closed")
return None
batch = self._read_next_batch()
if batch is None:
self._reader.close()
return batch
def get_status(self):
"""Get the status of the arrow batch reader.
Returns:
Status.OK or Status.RUNNING.
"""
return self._reader.get_status()
def get_request_id(self):
"""Get the request id.
Returns:
Request id.
"""
return self._reader.get_request_id()
class ArrowWriter(object):
"""Arrow batch writer."""
def __init__(self, stream_writer, compression):
self._arrow_writer = None
self._compression = compression
self._sink = stream_writer
def write(self, record_batch):
"""Write one arrow batch to the server.
Args:
record_batch: The arrow batch to be written.
Returns:
Success or not.
"""
if not self._sink.writable():
logger.error("Writer has been closed")
return False
if self._arrow_writer is None:
self._arrow_writer = pa.ipc.new_stream(
self._sink,
record_batch.schema,
options=pa.ipc.IpcWriteOptions(
compression=self._compression.to_string()
),
)
self._arrow_writer.write_batch(record_batch)
if not self._sink.writable():
logger.error("Writer has been closed as exception occurred")
return False
return True
def finish(self):
"""The arrow writer is not expected to write data if finish has been called.
Returns:
Commit message returned from the server. User should bring this message
to do the write session commit.
Success ot not.
"""
if self._arrow_writer:
self._arrow_writer.close()
return self._sink.finish()
def get_status(self):
"""Get the status of the arrow batch writer.
Returns:
Status.OK or Status.RUNNING.
"""
return self._sink.get_status()
def get_request_id(self):
"""Get the request id.
Returns:
Request id.
"""
return self._sink.get_request_id()
class StorageApiClient(object):
"""Client to bundle configuration needed for API requests."""
def __init__(
self,
odps: ODPS,
table: Table,
rest_endpoint: str = None,
quota_name: str = None,
tags: Union[None, str, List[str]] = None,
):
if not isinstance(odps, ODPS) or not isinstance(table, Table):
raise ValueError("Please input odps configuration")
self._odps = odps
self._table = table
self._quota_name = quota_name
self._rest_endpoint = rest_endpoint
self._tunnel_rest = None
self._tags = tags or options.tunnel.tags
if isinstance(self._tags, str):
self._tags = self._tags.split(",")
@property
def table(self):
return self._table
@property
def tunnel_rest(self):
if self._tunnel_rest is not None:
return self._tunnel_rest
from ...tunnel.tabletunnel import TableTunnel
tunnel = TableTunnel(
self._odps, endpoint=self._rest_endpoint, quota_name=self._quota_name
)
self._tunnel_rest = tunnel.tunnel_rest
return self._tunnel_rest
def _get_resource(self, *args) -> str:
endpoint = self.tunnel_rest.endpoint + URL_PREFIX
url = self._table.table_resource(endpoint=endpoint, force_schema=True)
return "/".join([url] + list(args))
def _fill_common_headers(self, raw_headers=None):
headers = raw_headers or {}
if self._tags:
headers["odps-tunnel-tags"] = ",".join(self._tags)
return headers
def create_read_session(
self, request: TableBatchScanRequest
) -> TableBatchScanResponse:
"""Create a read session.
Args:
request: Table split parameters sent to the server.
Returns:
Read session response returned from the server.
"""
if not isinstance(request, TableBatchScanRequest):
raise ValueError(
"Use TableBatchScanRequest class to build request for create read session interface"
)
json_str = request.serialize()
url = self._get_resource("sessions")
headers = self._fill_common_headers({"Content-Type": "application/json"})
if json_str != "":
headers["Content-MD5"] = md5(to_binary(json_str)).hexdigest()
params = {"session_type": "batch_read"}
if self._quota_name:
params["quotaName"] = self._quota_name
res = self.tunnel_rest.post(url, data=json_str, params=params, headers=headers)
response = TableBatchScanResponse()
response.parse(res, obj=response)
response.status = (
Status.OK if res.status_code == codes["created"] else Status.WAIT
)
update_request_id(response, res)
return response
def get_read_session(self, request: SessionRequest) -> TableBatchScanResponse:
"""Get the read session.
Args:
request: Read session parameters sent to the server.
Returns:
Read session response returned from the server.
"""
if not isinstance(request, SessionRequest):
raise ValueError(
"Use SessionRequest class to build request for get read session interface"
)
url = self._get_resource("sessions", request.session_id)
headers = self._fill_common_headers()
params = {"session_type": "batch_read"}
if self._quota_name:
params["quotaName"] = self._quota_name
if request.refresh:
params["session_refresh"] = "true"
res = self.tunnel_rest.get(url, params=params, headers=headers)
response = TableBatchScanResponse()
response.parse(res, obj=response)
response.status = Status.OK
update_request_id(response, res)
return response
def read_rows_stream(self, request: ReadRowsRequest) -> StreamReader:
"""Read one split of the read session. Stream means the data read from server is serialized arrow record batch.
Args:
request: Batch split parameters sent to the server.
Returns:
Stream reader.
"""
if not isinstance(request, ReadRowsRequest):
raise ValueError(
"Use ReadRowsRequest class to build request for read rows interface"
)
url = self._get_resource("data")
headers = self._fill_common_headers(
{
"Connection": "Keep-Alive",
"Accept-Encoding": request.compression.name
if request.compression != Compression.UNCOMPRESSED
else "",
}
)
params = {
"session_id": request.session_id,
"max_batch_rows": str(request.max_batch_rows),
"split_index": str(request.split_index),
"row_count": str(request.row_count),
"row_index": str(request.row_index),
}
if self._quota_name:
params["quotaName"] = self._quota_name
if request.data_format.type is not None:
params["data_format_type"] = request.data_format.type
if request.data_format.version is not None:
params["data_format_version"] = request.data_format.version
def download():
return self.tunnel_rest.get(
url, stream=True, params=params, headers=headers
)
return StreamReader(download)
def create_write_session(
self, request: TableBatchWriteRequest
) -> TableBatchWriteResponse:
"""Create a write session.
Args:
request: Table write parameters sent to the server.
Returns:
Write session response returned from the server.
"""
if not isinstance(request, TableBatchWriteRequest):
raise ValueError(
"Use TableBatchWriteRequest class to build request for create write session interface"
)
json_str = request.serialize()
url = self._get_resource("sessions")
headers = self._fill_common_headers({"Content-Type": "application/json"})
if json_str != "":
headers["Content-MD5"] = md5(to_binary(json_str)).hexdigest()
params = {"session_type": "batch_write"}
if self._quota_name:
params["quotaName"] = self._quota_name
res = self.tunnel_rest.post(url, data=json_str, params=params, headers=headers)
response = TableBatchWriteResponse()
response.parse(res, obj=response)
response.status = Status.OK
update_request_id(response, res)
return response
def get_write_session(self, request: SessionRequest) -> TableBatchWriteResponse:
"""Get a write session.
Args:
request: Write session parameters sent to the server.
Returns:
Write session response returned from the server.
"""
if not isinstance(request, SessionRequest):
raise ValueError(
"Use SessionRequest class to build request for get write session interface"
)
url = self._get_resource("sessions", request.session_id)
headers = self._fill_common_headers()
params = {"session_type": "batch_write"}
if self._quota_name:
params["quotaName"] = self._quota_name
res = self.tunnel_rest.get(url, params=params, headers=headers)
response = TableBatchWriteResponse()
response.parse(res, obj=response)
response.status = Status.OK
update_request_id(response, res)
return response
def write_rows_stream(self, request: WriteRowsRequest) -> StreamWriter:
"""Write one block of data to the write session. Stream means the data written to server is serialized arrow record batch.
Args:
request: Batch write parameters sent to the server.
Returns:
Stream writer.
"""
if not isinstance(request, WriteRowsRequest):
raise ValueError(
"Use WriteRowsRequest class to build request for write rows interface"
)
url = self._get_resource("sessions", request.session_id, "data")
headers = self._fill_common_headers(
{"Content-Type": "application/octet-stream", "Transfer-Encoding": "chunked"}
)
params = {
"attempt_number": str(request.attempt_number),
"block_number": str(request.block_number),
}
if self._quota_name:
params["quotaName"] = self._quota_name
if request.data_format.type != None:
params["data_format_type"] = str(request.data_format.type)
if request.data_format.version != None:
params["data_format_version"] = str(request.data_format.version)
def upload(data):
return self.tunnel_rest.post(url, data=data, params=params, headers=headers)
return StreamWriter(upload)
def commit_write_session(
self, request: SessionRequest, commit_msg: list
) -> TableBatchWriteResponse:
"""Commit the write session after write the last stream data.
Args:
request: Commit write session parameters sent to the server.
commit_msg: Commit messages collected from the write_rows_stream().
Returns:
Write session response returned from the server.
"""
if not isinstance(request, SessionRequest):
raise ValueError(
"Use SessionRequest class to build request for commit write session interface"
)
if not isinstance(commit_msg, list):
raise ValueError("Use list for commit message")
commit_message_dict = {"CommitMessages": commit_msg}
json_str = json.dumps(commit_message_dict)
url = self._get_resource("commit")
headers = self._fill_common_headers({"Content-Type": "application/json"})
params = {"session_id": request.session_id}
if self._quota_name:
params["quotaName"] = self._quota_name
res = self.tunnel_rest.post(url, data=json_str, params=params, headers=headers)
response = TableBatchWriteResponse()
response.parse(res, obj=response)
response.status = (
Status.OK if res.status_code == codes["created"] else Status.WAIT
)
update_request_id(response, res)
return response
class StorageApiArrowClient(StorageApiClient):
"""Arrow batch client to bundle configuration needed for API requests."""
def read_rows_arrow(self, request: ReadRowsRequest) -> ArrowReader:
"""Read one split of the read session.
Args:
request: Arrow batch split parameters sent to the server.
Returns:
Arrow batch reader.
"""
if not isinstance(request, ReadRowsRequest):
raise ValueError(
"Use ReadRowsRequest class to build request for read rows interface"
)
return ArrowReader(self.read_rows_stream(request))
def write_rows_arrow(self, request: WriteRowsRequest) -> ArrowWriter:
"""Write one block of data to the write session.
Args:
request: Arrow batch write parameters sent to the server.
Returns:
Arrow batch writer.
"""
if not isinstance(request, WriteRowsRequest):
raise ValueError(
"Use WriteRowsRequest class to build request for write rows interface"
)
return ArrowWriter(self.write_rows_stream(request), request.compression)