core/maxframe/io/odpsio/tableio.py (566 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
from contextlib import contextmanager
from typing import Dict, List, Optional, Union
import pyarrow as pa
from odps import ODPS
from odps.apis.storage_api import (
StorageApiArrowClient,
TableBatchScanResponse,
TableBatchWriteResponse,
)
from odps.tunnel import TableDownloadSession, TableDownloadStatus, TableTunnel
from odps.types import OdpsSchema, PartitionSpec, timestamp_ntz
from odps.utils import call_with_retry
try:
import pyarrow.compute as pac
except ImportError:
pac = None
from ...config import options
from ...env import ODPS_STORAGE_API_ENDPOINT
from ...utils import is_empty, sync_pyodps_options
from .schema import odps_schema_to_arrow_schema
PartitionsType = Union[List[str], str, None]
_DEFAULT_ROW_BATCH_SIZE = 4096
_DOWNLOAD_ID_CACHE_SIZE = 100
class ODPSTableIO(ABC):
def __new__(cls, odps: ODPS):
if cls is ODPSTableIO:
if options.use_common_table or ODPS_STORAGE_API_ENDPOINT in os.environ:
return HaloTableIO(odps)
else:
return TunnelTableIO(odps)
return super().__new__(cls)
def __init__(self, odps: ODPS):
self._odps = odps
@classmethod
def _get_reader_schema(
cls,
table_schema: OdpsSchema,
columns: Optional[List[str]] = None,
partition_columns: Union[None, bool, List[str]] = None,
) -> OdpsSchema:
final_cols = []
columns = (
columns
if not is_empty(columns)
else [col.name for col in table_schema.simple_columns]
)
if partition_columns is True:
partition_columns = [c.name for c in table_schema.partitions]
else:
partition_columns = partition_columns or []
for col_name in columns + partition_columns:
final_cols.append(table_schema[col_name])
return OdpsSchema(final_cols)
@abstractmethod
def open_reader(
self,
full_table_name: str,
partitions: PartitionsType = None,
columns: Optional[List[str]] = None,
partition_columns: Union[None, bool, List[str]] = None,
start: Optional[int] = None,
stop: Optional[int] = None,
reverse_range: bool = False,
row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
):
raise NotImplementedError
@abstractmethod
def open_writer(
self,
full_table_name: str,
partition: Optional[str] = None,
overwrite: bool = True,
):
raise NotImplementedError
class TunnelMultiPartitionReader:
def __init__(
self,
odps_entry: ODPS,
table_name: str,
partitions: PartitionsType,
columns: Optional[List[str]] = None,
partition_columns: Optional[List[str]] = None,
start: Optional[int] = None,
count: Optional[int] = None,
partition_to_download_ids: Dict[str, str] = None,
):
self._odps_entry = odps_entry
self._table = odps_entry.get_table(table_name)
self._columns = columns
odps_schema = ODPSTableIO._get_reader_schema(
self._table.table_schema, columns, partition_columns
)
self._schema = odps_schema_to_arrow_schema(odps_schema)
self._start = start or 0
self._count = count
self._row_left = count
self._cur_reader = None
self._reader_iter = None
self._cur_partition_id = -1
self._reader_start_pos = 0
if partitions is None:
if not self._table.table_schema.partitions:
self._partitions = [None]
else:
self._partitions = [str(pt) for pt in self._table.partitions]
elif isinstance(partitions, str):
self._partitions = [partitions]
else:
self._partitions = partitions
self._partition_cols = partition_columns
self._partition_to_download_ids = partition_to_download_ids or dict()
@property
def count(self) -> Optional[int]:
if len(self._partitions) > 1:
return None
return self._count
def _open_next_reader(self):
if self._cur_reader is not None:
self._reader_start_pos += self._cur_reader.count
if (
self._row_left is not None and self._row_left <= 0
) or 1 + self._cur_partition_id >= len(self._partitions):
self._cur_reader = None
return
while 1 + self._cur_partition_id < len(self._partitions):
self._cur_partition_id += 1
part_str = self._partitions[self._cur_partition_id]
req_columns = self._schema.names
with sync_pyodps_options():
self._cur_reader = self._table.open_reader(
part_str,
columns=req_columns,
arrow=True,
download_id=self._partition_to_download_ids.get(part_str),
append_partitions=True,
)
if self._cur_reader.count + self._reader_start_pos > self._start:
start = self._start - self._reader_start_pos
if self._count is None:
count = None
else:
count = min(self._count, self._cur_reader.count - start)
with sync_pyodps_options():
self._reader_iter = self._cur_reader.read(start, count)
break
self._reader_start_pos += self._cur_reader.count
else:
self._cur_reader = None
def read(self):
with sync_pyodps_options():
if self._cur_reader is None:
self._open_next_reader()
if self._cur_reader is None:
return None
while self._cur_reader is not None:
try:
batch = next(self._reader_iter)
if batch is not None:
if self._row_left is not None:
self._row_left -= batch.num_rows
return batch
except StopIteration:
self._open_next_reader()
return None
def read_all(self) -> pa.Table:
batches = []
while True:
batch = self.read()
if batch is None:
break
batches.append(batch)
if not batches:
return self._schema.empty_table()
return pa.Table.from_batches(batches)
class TunnelTableIO(ODPSTableIO):
_down_session_ids = OrderedDict()
@classmethod
def create_download_sessions(
cls,
odps_entry: ODPS,
full_table_name: str,
partitions: List[Optional[str]] = None,
) -> Dict[Optional[str], TableDownloadSession]:
table = odps_entry.get_table(full_table_name)
tunnel = TableTunnel(odps_entry, quota_name=options.tunnel_quota_name)
parts = (
[partitions]
if partitions is None or isinstance(partitions, str)
else partitions
)
part_to_session = dict()
for part in parts:
part_key = (full_table_name, part)
down_session = None
if part_key in cls._down_session_ids:
down_id = cls._down_session_ids[part_key]
down_session = tunnel.create_download_session(
table, async_mode=True, partition_spec=part, download_id=down_id
)
if down_session.status != TableDownloadStatus.Normal:
down_session = None
if down_session is None:
down_session = tunnel.create_download_session(
table, async_mode=True, partition_spec=part
)
while len(cls._down_session_ids) >= _DOWNLOAD_ID_CACHE_SIZE:
cls._down_session_ids.popitem(False)
cls._down_session_ids[part_key] = down_session.id
part_to_session[part] = down_session
return part_to_session
@contextmanager
def open_reader(
self,
full_table_name: str,
partitions: PartitionsType = None,
columns: Optional[List[str]] = None,
partition_columns: Union[None, bool, List[str]] = None,
start: Optional[int] = None,
stop: Optional[int] = None,
reverse_range: bool = False,
row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
):
with sync_pyodps_options():
table = self._odps.get_table(full_table_name)
if partition_columns is True:
partition_columns = [c.name for c in table.table_schema.partitions]
total_records = None
part_to_down_id = None
if (
(start is not None and start < 0)
or (stop is not None and stop < 0)
or (reverse_range and start is None)
):
with sync_pyodps_options():
tunnel_sessions = self.create_download_sessions(
self._odps, full_table_name, partitions
)
part_to_down_id = {
pt: session.id for (pt, session) in tunnel_sessions.items()
}
total_records = sum(
session.count for session in tunnel_sessions.values()
)
count = None
if start is not None or stop is not None:
if reverse_range:
start = start if start is not None else total_records - 1
stop = stop if stop is not None else -1
else:
start = start if start is not None else 0
stop = stop if stop is not None else None
start = start if start >= 0 else total_records + start
stop = stop if stop is None or stop >= 0 else total_records + stop
if reverse_range:
count = start - stop
start = stop + 1
else:
count = stop - start if stop is not None and start is not None else None
yield TunnelMultiPartitionReader(
self._odps,
full_table_name,
partitions=partitions,
columns=columns,
partition_columns=partition_columns,
start=start,
count=count,
partition_to_download_ids=part_to_down_id,
)
@contextmanager
def open_writer(
self,
full_table_name: str,
partition: Optional[str] = None,
overwrite: bool = True,
):
table = self._odps.get_table(full_table_name)
with sync_pyodps_options():
with table.open_writer(
partition=partition,
arrow=True,
create_partition=partition is not None,
overwrite=overwrite,
) as writer:
yield writer
class HaloTableArrowReader:
def __init__(
self,
client: StorageApiArrowClient,
scan_info: TableBatchScanResponse,
odps_schema: OdpsSchema,
start: Optional[int] = None,
count: Optional[int] = None,
row_batch_size: Optional[int] = None,
):
self._client = client
self._scan_info = scan_info
self._cur_split_id = -1
self._cur_reader = None
self._odps_schema = odps_schema
self._arrow_schema = odps_schema_to_arrow_schema(odps_schema)
self._start = start
self._count = count
self._cursor = 0
self._row_batch_size = row_batch_size
@property
def count(self) -> int:
return self._count
def _open_next_reader(self):
from odps.apis.storage_api import ReadRowsRequest
if 0 <= self._scan_info.split_count <= self._cur_split_id + 1:
# scan by split
self._cur_reader = None
return
elif self._count is not None and self._cursor >= self._count:
# scan by range
self._cur_reader = None
return
read_rows_kw = {}
if self._start is not None:
read_rows_kw["row_index"] = self._start + self._cursor
read_rows_kw["row_count"] = min(
self._row_batch_size, self._count - self._cursor
)
self._cursor = min(self._count, self._cursor + self._row_batch_size)
req = ReadRowsRequest(
session_id=self._scan_info.session_id,
split_index=self._cur_split_id + 1,
**read_rows_kw,
)
self._cur_reader = call_with_retry(self._client.read_rows_arrow, req)
self._cur_split_id += 1
def _convert_timezone(self, batch: pa.RecordBatch) -> pa.RecordBatch:
timezone = options.local_timezone
if not any(isinstance(tp, pa.TimestampType) for tp in batch.schema.types):
return batch
cols = []
for idx in range(batch.num_columns):
col = batch.column(idx)
name = batch.schema.names[idx]
if not isinstance(col.type, pa.TimestampType):
cols.append(col)
continue
if self._odps_schema[name].type == timestamp_ntz:
col = col.cast(pa.timestamp(col.type.unit))
cols.append(col)
continue
if hasattr(pac, "local_timestamp"):
col = col.cast(pa.timestamp(col.type.unit, timezone))
else:
pd_col = col.to_pandas().dt.tz_convert(timezone)
col = pa.Array.from_pandas(pd_col).cast(
pa.timestamp(col.type.unit, timezone)
)
cols.append(col)
return pa.RecordBatch.from_arrays(cols, names=batch.schema.names)
def read(self):
if self._cur_reader is None:
self._open_next_reader()
if self._cur_reader is None:
return None
while self._cur_reader is not None:
batch = self._cur_reader.read()
if batch is not None:
return self._convert_timezone(batch)
self._open_next_reader()
return None
def read_all(self) -> pa.Table:
batches = []
while True:
batch = self.read()
if batch is None:
break
batches.append(batch)
if not batches:
return self._arrow_schema.empty_table()
return pa.Table.from_batches(batches)
class HaloTableArrowWriter:
def __init__(
self,
client: StorageApiArrowClient,
write_info: TableBatchWriteResponse,
odps_schema: OdpsSchema,
):
self._client = client
self._write_info = write_info
self._odps_schema = odps_schema
self._arrow_schema = odps_schema_to_arrow_schema(odps_schema)
self._writer = None
def open(self):
from odps.apis.storage_api import WriteRowsRequest
self._writer = call_with_retry(
self._client.write_rows_arrow,
WriteRowsRequest(self._write_info.session_id),
)
@classmethod
def _localize_timezone(cls, col, tz=None):
from odps.lib import tzlocal
if tz is None:
if options.local_timezone is None:
tz = str(tzlocal.get_localzone())
else:
tz = str(options.local_timezone)
if col.type.tz is not None:
return col
if hasattr(pac, "assume_timezone"):
col = pac.assume_timezone(col, tz)
return col
else:
col = col.to_pandas()
return pa.Array.from_pandas(col.dt.tz_localize(tz))
def _convert_schema(self, batch: pa.RecordBatch):
if batch.schema == self._arrow_schema and not any(
isinstance(tp, pa.TimestampType) for tp in self._arrow_schema.types
):
return batch
cols = []
for idx in range(batch.num_columns):
col = batch.column(idx)
name = batch.schema.names[idx]
if isinstance(col.type, pa.TimestampType):
if self._odps_schema[name].type == timestamp_ntz:
col = self._localize_timezone(col, "UTC")
else:
col = self._localize_timezone(col)
if col.type != self._arrow_schema.types[idx]:
col = col.cast(self._arrow_schema.types[idx])
cols.append(col)
return pa.RecordBatch.from_arrays(cols, names=batch.schema.names)
def write(self, batch):
if isinstance(batch, pa.Table):
for b in batch.to_batches():
self._writer.write(self._convert_schema(b))
else:
self._writer.write(self._convert_schema(batch))
def close(self):
commit_msg, is_success = self._writer.finish()
if not is_success:
raise IOError(commit_msg)
return commit_msg
class HaloTableIO(ODPSTableIO):
_storage_api_endpoint = os.getenv(ODPS_STORAGE_API_ENDPOINT)
@staticmethod
def _convert_partitions(partitions: PartitionsType) -> Optional[List[str]]:
if partitions is None:
return []
elif isinstance(partitions, (str, PartitionSpec)):
partitions = [partitions]
return [
"/".join(f"{k}={v}" for k, v in PartitionSpec(pt).items())
for pt in partitions
]
@contextmanager
def open_reader(
self,
full_table_name: str,
partitions: PartitionsType = None,
columns: Optional[List[str]] = None,
partition_columns: Union[None, bool, List[str]] = None,
start: Optional[int] = None,
stop: Optional[int] = None,
reverse_range: bool = False,
row_batch_size: int = _DEFAULT_ROW_BATCH_SIZE,
):
from odps.apis.storage_api import (
SessionRequest,
SessionStatus,
SplitOptions,
TableBatchScanRequest,
)
table = self._odps.get_table(full_table_name)
client = StorageApiArrowClient(
self._odps,
table,
rest_endpoint=self._storage_api_endpoint,
quota_name=options.tunnel_quota_name,
)
split_option = SplitOptions.SplitMode.SIZE
if start is not None or stop is not None:
split_option = SplitOptions.SplitMode.ROW_OFFSET
scan_kw = {
"required_partitions": self._convert_partitions(partitions),
"split_options": SplitOptions.get_default_options(split_option),
}
columns = columns or [c.name for c in table.table_schema.simple_columns]
scan_kw["required_data_columns"] = columns
if partition_columns is True:
scan_kw["required_partition_columns"] = [
c.name for c in table.table_schema.partitions
]
else:
scan_kw["required_partition_columns"] = partition_columns
# todo add more options for partition column handling
req = TableBatchScanRequest(**scan_kw)
resp = call_with_retry(client.create_read_session, req)
session_id = resp.session_id
status = resp.session_status
while status == SessionStatus.INIT:
resp = call_with_retry(client.get_read_session, SessionRequest(session_id))
status = resp.session_status
time.sleep(1.0)
assert status == SessionStatus.NORMAL
count = None
if start is not None or stop is not None:
if reverse_range:
start = start if start is not None else resp.record_count - 1
stop = stop if stop is not None else -1
else:
start = start if start is not None else 0
stop = stop if stop is not None else resp.record_count
start = start if start >= 0 else resp.record_count + start
stop = stop if stop >= 0 else resp.record_count + stop
if reverse_range:
count = start - stop
start = stop + 1
else:
count = stop - start
reader_schema = self._get_reader_schema(
table.table_schema, columns, partition_columns
)
yield HaloTableArrowReader(
client,
resp,
odps_schema=reader_schema,
start=start,
count=count,
row_batch_size=row_batch_size,
)
@contextmanager
def open_writer(
self,
full_table_name: str,
partition: Optional[str] = None,
overwrite: bool = True,
):
from odps.apis.storage_api import (
SessionRequest,
SessionStatus,
TableBatchWriteRequest,
)
table = self._odps.get_table(full_table_name)
client = StorageApiArrowClient(
self._odps,
table,
rest_endpoint=self._storage_api_endpoint,
quota_name=options.tunnel_quota_name,
)
part_strs = self._convert_partitions(partition)
part_str = part_strs[0] if part_strs else None
req = TableBatchWriteRequest(partition_spec=part_str, overwrite=overwrite)
resp = call_with_retry(client.create_write_session, req)
session_id = resp.session_id
writer = HaloTableArrowWriter(client, resp, table.table_schema)
writer.open()
yield writer
commit_msg = writer.close()
resp = call_with_retry(
client.commit_write_session,
SessionRequest(session_id=session_id),
[commit_msg],
)
while resp.session_status == SessionStatus.COMMITTING:
resp = call_with_retry(
client.get_write_session, SessionRequest(session_id=session_id)
)
assert resp.session_status == SessionStatus.COMMITTED