odps/tunnel/tabletunnel.py (1,167 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import sys
import time
import requests
from .. import errors, options, serializers, types, utils
from ..compat import Enum, six
from ..lib.monotonic import monotonic
from ..models import Projects, Record, TableSchema
from ..types import Column
from .base import TUNNEL_VERSION, BaseTunnel
from .errors import TunnelError, TunnelReadTimeout, TunnelWriteTimeout
from .io.reader import ArrowRecordReader, TunnelArrowReader, TunnelRecordReader
from .io.stream import CompressOption, get_decompress_stream
from .io.writer import (
ArrowWriter,
BufferedArrowWriter,
BufferedRecordWriter,
RecordWriter,
StreamRecordWriter,
Upsert,
)
try:
import numpy as np
except ImportError:
np = None
try:
import pyarrow as pa
except ImportError:
pa = None
logger = logging.getLogger(__name__)
TUNNEL_DATA_TRANSFORM_VERSION = "v1"
DEFAULT_UPSERT_COMMIT_TIMEOUT = 120
def _wrap_upload_call(request_id):
def wrapper(func):
@six.wraps(func)
def wrapped(*args, **kwargs):
try:
return func(*args, **kwargs)
except requests.ConnectionError as ex:
ex_str = str(ex)
if "timed out" in ex_str:
raise TunnelWriteTimeout(ex_str, request_id=request_id)
else:
raise
return wrapped
return wrapper
class BaseTableTunnelSession(serializers.JSONSerializableModel):
@staticmethod
def get_common_headers(content_length=None, chunked=False, tags=None):
header = {
"odps-tunnel-date-transform": TUNNEL_DATA_TRANSFORM_VERSION,
"odps-tunnel-sdk-support-schema-evolution": "true",
"x-odps-tunnel-version": TUNNEL_VERSION,
}
if content_length is not None:
header["Content-Length"] = content_length
if chunked:
header.update(
{
"Transfer-Encoding": "chunked",
"Content-Type": "application/octet-stream",
}
)
tags = tags or options.tunnel.tags
if tags:
if isinstance(tags, six.string_types):
tags = tags.split(",")
header["odps-tunnel-tags"] = ",".join(tags)
return header
@staticmethod
def normalize_partition_spec(partition_spec):
if isinstance(partition_spec, six.string_types):
partition_spec = types.PartitionSpec(partition_spec)
if isinstance(partition_spec, types.PartitionSpec):
partition_spec = str(partition_spec).replace("'", "")
return partition_spec
def get_common_params(self, **kwargs):
params = {k: str(v) for k, v in kwargs.items()}
if getattr(self, "_quota_name", None):
params["quotaName"] = self._quota_name
if self._partition_spec is not None and len(self._partition_spec) > 0:
params["partition"] = self._partition_spec
return params
def check_tunnel_response(self, resp):
if not self._client.is_ok(resp):
e = TunnelError.parse(resp)
raise e
@classmethod
def _get_default_compress_option(cls):
if not options.tunnel.compress.enabled:
return None
return CompressOption(
compress_algo=options.tunnel.compress.algo,
level=options.tunnel.compress.level,
strategy=options.tunnel.compress.strategy,
)
def new_record(self, values=None):
"""
Generate a record of the current upload session.
:param values: the values of this records
:type values: list
:return: record
:rtype: :class:`odps.models.Record`
:Example:
>>> session = TableTunnel(o).create_upload_session('test_table')
>>> record = session.new_record()
>>> record[0] = 'my_name'
>>> record[1] = 'my_id'
>>> record = session.new_record(['my_name', 'my_id'])
.. seealso:: :class:`odps.models.Record`
"""
return Record(
schema=self.schema,
values=values,
max_field_size=getattr(self, "max_field_size", None),
)
class TableDownloadSession(BaseTableTunnelSession):
"""
Tunnel session for downloading data from tables. Instances of this class
should be created by :meth:`TableTunnel.create_download_session`.
"""
__slots__ = (
"_client",
"_table",
"_partition_spec",
"_compress_option",
"_quota_name",
"_tags",
)
class Status(Enum):
Unknown = "UNKNOWN"
Normal = "NORMAL"
Closes = "CLOSES"
Expired = "EXPIRED"
Initiating = "INITIATING"
id = serializers.JSONNodeField("DownloadID")
status = serializers.JSONNodeField(
"Status", parse_callback=lambda s: TableDownloadSession.Status(s.upper())
)
count = serializers.JSONNodeField("RecordCount")
schema = serializers.JSONNodeReferenceField(TableSchema, "Schema")
quota_name = serializers.JSONNodeField("QuotaName")
def __init__(
self,
client,
table,
partition_spec,
download_id=None,
compress_option=None,
async_mode=True,
timeout=None,
quota_name=None,
tags=None,
**kw
):
super(TableDownloadSession, self).__init__()
self._client = client
self._table = table
self._partition_spec = self.normalize_partition_spec(partition_spec)
self._quota_name = quota_name
if "async_" in kw:
async_mode = kw.pop("async_")
if kw:
raise TypeError("Cannot accept arguments %s" % ", ".join(kw.keys()))
self._tags = tags or options.tunnel.tags
if isinstance(self._tags, six.string_types):
self._tags = self._tags.split(",")
if download_id is None:
self._init(async_mode=async_mode, timeout=timeout)
else:
self.id = download_id
self.reload()
self._compress_option = compress_option or self._get_default_compress_option()
logger.info("Tunnel session created: %r", self)
if options.tunnel_session_create_callback:
options.tunnel_session_create_callback(self)
def __repr__(self):
return "<TableDownloadSession id=%s project=%s table=%s partition_spec=%r>" % (
self.id,
self._table.project.name,
self._table.name,
self._partition_spec,
)
def _init(self, async_mode, timeout):
params = self.get_common_params(downloads="")
headers = self.get_common_headers(content_length=0, tags=self._tags)
if async_mode:
params["asyncmode"] = "true"
url = self._table.table_resource()
ts = monotonic()
try:
resp = self._client.post(
url, {}, params=params, headers=headers, timeout=timeout
)
except requests.exceptions.ReadTimeout:
if callable(options.tunnel_session_create_timeout_callback):
options.tunnel_session_create_timeout_callback(*sys.exc_info())
raise
self.check_tunnel_response(resp)
delay_time = 0.1
self.parse(resp, obj=self)
while self.status == self.Status.Initiating:
if timeout and monotonic() - ts > timeout:
try:
raise TunnelReadTimeout(
"Waiting for tunnel ready timed out. id=%s, table=%s"
% (self.id, self._table.name)
)
except TunnelReadTimeout:
if callable(options.tunnel_session_create_timeout_callback):
options.tunnel_session_create_timeout_callback(*sys.exc_info())
raise
time.sleep(delay_time)
delay_time = min(delay_time * 2, 5)
self.reload()
if self.schema is not None:
self.schema.build_snapshot()
def reload(self):
params = self.get_common_params(downloadid=self.id)
headers = self.get_common_headers(content_length=0, tags=self._tags)
url = self._table.table_resource()
resp = self._client.get(url, params=params, headers=headers)
self.check_tunnel_response(resp)
self.parse(resp, obj=self)
if self.schema is not None:
self.schema.build_snapshot()
def _build_input_stream(
self, start, count, compress=False, columns=None, arrow=False
):
compress_option = self._compress_option or CompressOption()
actions = ["data"]
params = self.get_common_params(downloadid=self.id)
headers = self.get_common_headers(content_length=0, tags=self._tags)
if compress:
encoding = compress_option.algorithm.get_encoding()
if encoding:
headers["Accept-Encoding"] = encoding
params["rowrange"] = "(%s,%s)" % (start, count)
if columns is not None and len(columns) > 0:
col_name = lambda col: col.name if isinstance(col, types.Column) else col
params["columns"] = ",".join(col_name(col) for col in columns)
if arrow:
actions.append("arrow")
url = self._table.table_resource()
resp = self._client.get(
url, stream=True, actions=actions, params=params, headers=headers
)
self.check_tunnel_response(resp)
content_encoding = resp.headers.get("Content-Encoding")
if content_encoding is not None:
compress_algo = CompressOption.CompressAlgorithm.from_encoding(
content_encoding
)
if compress_algo != compress_option.algorithm:
compress_option = self._compress_option = CompressOption(
compress_algo, -1, 0
)
compress = True
else:
compress = False
option = compress_option if compress else None
return get_decompress_stream(resp, option)
def _open_reader(
self,
start,
count,
compress=None,
columns=None,
arrow=False,
reader_cls=None,
**kw
):
pt_cols = (
set(types.PartitionSpec(self._partition_spec).keys())
if self._partition_spec
else set()
)
reader_cols = [c for c in columns if c not in pt_cols] if columns else columns
if compress is None:
compress = self._compress_option is not None
stream_kw = dict(compress=compress, columns=reader_cols, arrow=arrow)
def stream_creator(cursor):
return self._build_input_stream(start + cursor, count - cursor, **stream_kw)
return reader_cls(self.schema, stream_creator, columns=columns, **kw)
def open_record_reader(
self, start, count, compress=False, columns=None, append_partitions=True
):
"""
Open a reader to read data as records from the tunnel.
:param int start: start row index
:param int count: number of rows to read
:param bool compress: whether to compress data
:columns: list of column names to read
:append_partitions: whether to append partition values as columns
:return: a record reader
:rtype: :class:`TunnelRecordReader`
"""
return self._open_reader(
start,
count,
compress=compress,
columns=columns,
append_partitions=append_partitions,
partition_spec=self._partition_spec,
reader_cls=TunnelRecordReader,
)
def open_arrow_reader(
self, start, count, compress=False, columns=None, append_partitions=False
):
"""
Open a reader to read data as Arrow format from the tunnel.
:param int start: start row index
:param int count: number of rows to read
:param bool compress: whether to compress data
:columns: list of column names to read
:append_partitions: whether to append partition values as columns
:return: an Arrow reader
:rtype: :class:`TunnelArrowReader`
"""
return self._open_reader(
start,
count,
compress=compress,
columns=columns,
arrow=True,
append_partitions=append_partitions,
partition_spec=self._partition_spec,
reader_cls=TunnelArrowReader,
)
class TableUploadSession(BaseTableTunnelSession):
"""
Tunnel session for uploading data to tables. Instances of this class
should be created by :meth:`TableTunnel.create_upload_session`.
"""
__slots__ = (
"_client",
"_table",
"_partition_spec",
"_compress_option",
"_create_partition",
"_overwrite",
"_quota_name",
"_tags",
)
class Status(Enum):
Unknown = "UNKNOWN"
Normal = "NORMAL"
Closing = "CLOSING"
Closed = "CLOSED"
Canceled = "CANCELED"
Expired = "EXPIRED"
Critical = "CRITICAL"
id = serializers.JSONNodeField("UploadID")
status = serializers.JSONNodeField(
"Status", parse_callback=lambda s: TableUploadSession.Status(s.upper())
)
blocks = serializers.JSONNodesField("UploadedBlockList", "BlockID")
schema = serializers.JSONNodeReferenceField(TableSchema, "Schema")
max_field_size = serializers.JSONNodeField("MaxFieldSize")
quota_name = serializers.JSONNodeField("QuotaName")
def __init__(
self,
client,
table,
partition_spec,
upload_id=None,
compress_option=None,
create_partition=None,
overwrite=False,
quota_name=None,
tags=None,
):
super(TableUploadSession, self).__init__()
self._client = client
self._table = table
self._partition_spec = self.normalize_partition_spec(partition_spec)
self._create_partition = create_partition
self._quota_name = quota_name
self._overwrite = overwrite
self._tags = tags or options.tunnel.tags
if isinstance(self._tags, six.string_types):
self._tags = self._tags.split(",")
if upload_id is None:
self._init()
else:
self.id = upload_id
self.reload()
self._compress_option = compress_option or self._get_default_compress_option()
logger.info("Tunnel session created: %r", self)
if options.tunnel_session_create_callback:
options.tunnel_session_create_callback(self)
def __repr__(self):
repr_args = "id=%s project=%s table=%s partition_spec=%r" % (
self.id,
self._table.project.name,
self._table.name,
self._partition_spec,
)
if self._overwrite:
repr_args += " overwrite=True"
return "<TableUploadSession %s>" % repr_args
def _create_or_reload_session(self, reload=False):
headers = self.get_common_headers(content_length=0, tags=self._tags)
params = self.get_common_params(reload=reload)
if self._create_partition:
params["create_partition"] = "true"
if not reload and self._overwrite:
params["overwrite"] = "true"
if reload:
params["uploadid"] = self.id
else:
params["uploads"] = 1
def _call_tunnel(func, *args, **kw):
resp = func(*args, **kw)
self.check_tunnel_response(resp)
return resp
url = self._table.table_resource()
if reload:
resp = utils.call_with_retry(
_call_tunnel, self._client.get, url, params=params, headers=headers
)
else:
resp = utils.call_with_retry(
_call_tunnel, self._client.post, url, {}, params=params, headers=headers
)
self.parse(resp, obj=self)
if self.schema is not None:
self.schema.build_snapshot()
def _init(self):
self._create_or_reload_session(reload=False)
def reload(self):
self._create_or_reload_session(reload=True)
@classmethod
def _iter_data_in_batches(cls, data):
pos = 0
chunk_size = options.chunk_size
while pos < len(data):
yield data[pos : pos + chunk_size]
pos += chunk_size
def _open_writer(
self,
block_id=None,
compress=None,
buffer_size=None,
writer_cls=None,
initial_block_id=None,
block_id_gen=None,
):
compress_option = self._compress_option or CompressOption()
params = self.get_common_params(uploadid=self.id)
headers = self.get_common_headers(chunked=True, tags=self._tags)
if compress is None:
compress = self._compress_option is not None
if compress:
# special: rewrite LZ4 to ARROW_LZ4 for arrow tunnels
if (
writer_cls is not None
and issubclass(writer_cls, ArrowWriter)
and compress_option.algorithm
== CompressOption.CompressAlgorithm.ODPS_LZ4
):
compress_option.algorithm = (
CompressOption.CompressAlgorithm.ODPS_ARROW_LZ4
)
encoding = compress_option.algorithm.get_encoding()
if encoding:
headers["Content-Encoding"] = encoding
url = self._table.table_resource()
option = compress_option if compress else None
if block_id is None:
@_wrap_upload_call(self.id)
def upload_block(blockid, data):
params["blockid"] = blockid
def upload_func():
if isinstance(data, (bytes, bytearray)):
to_upload = self._iter_data_in_batches(data)
else:
to_upload = data
return self._client.put(
url, data=to_upload, params=params, headers=headers
)
return utils.call_with_retry(upload_func)
if writer_cls is ArrowWriter:
writer_cls = BufferedArrowWriter
params["arrow"] = ""
else:
writer_cls = BufferedRecordWriter
writer = writer_cls(
self.schema,
upload_block,
compress_option=option,
buffer_size=buffer_size,
block_id=initial_block_id,
block_id_gen=block_id_gen,
)
else:
params["blockid"] = block_id
@_wrap_upload_call(self.id)
def upload(data):
return self._client.put(url, data=data, params=params, headers=headers)
if writer_cls is ArrowWriter:
params["arrow"] = ""
writer = writer_cls(self.schema, upload, compress_option=option)
return writer
def open_record_writer(
self,
block_id=None,
compress=False,
buffer_size=None,
initial_block_id=None,
block_id_gen=None,
):
"""
Open a writer to write data in records to the tunnel.
:param int block_id: id of the block to write to. If not specified,
a :class:`BufferedRecordWriter` will be created.
:param int buffer_size: size of the buffer to use for buffered writers.
:param bool compress: whether to compress data
:return: a record writer
:rtype: :class:`RecordWriter` or :class:`BufferedRecordWriter`
"""
return self._open_writer(
block_id=block_id,
compress=compress,
buffer_size=buffer_size,
initial_block_id=initial_block_id,
block_id_gen=block_id_gen,
writer_cls=RecordWriter,
)
def open_arrow_writer(
self,
block_id=None,
compress=False,
buffer_size=None,
initial_block_id=None,
block_id_gen=None,
):
"""
Open a writer to write data in Arrow format to the tunnel.
:param int block_id: id of the block to write to. If not specified,
a :class:`BufferedArrowWriter` will be created.
:param int buffer_size: size of the buffer to use for buffered writers.
:param bool compress: whether to compress data
:return: an Arrow writer
:rtype: :class:`ArrowWriter` or :class:`BufferedArrowWriter`
"""
return self._open_writer(
block_id=block_id,
compress=compress,
buffer_size=buffer_size,
initial_block_id=initial_block_id,
block_id_gen=block_id_gen,
writer_cls=ArrowWriter,
)
def get_block_list(self):
self.reload()
return self.blocks
def commit(self, blocks):
"""
Commit written blocks to the tunnel. Can be called only once on a single session.
:param list blocks: list of block ids to commit
"""
if blocks is None:
raise ValueError("Invalid parameter: blocks.")
if isinstance(blocks, six.integer_types):
blocks = [blocks]
server_block_map = dict(
[(int(block_id), True) for block_id in self.get_block_list()]
)
client_block_map = dict([(int(block_id), True) for block_id in blocks])
if len(server_block_map) != len(client_block_map):
raise TunnelError(
"Blocks not match, server: %s, tunnelServerClient: %s. "
"Make sure all block writers closed or with-blocks exited."
% (len(server_block_map), len(client_block_map))
)
for block_id in blocks:
if block_id not in server_block_map:
raise TunnelError(
"Block not exists on server, block id is %s" % (block_id,)
)
self._complete_upload()
def _complete_upload(self):
headers = self.get_common_headers()
params = self.get_common_params(uploadid=self.id)
url = self._table.table_resource()
resp = utils.call_with_retry(
self._client.post,
url,
"",
params=params,
headers=headers,
exc_type=(
requests.Timeout,
requests.ConnectionError,
errors.InternalServerError,
),
)
self.parse(resp, obj=self)
class Slot(object):
def __init__(self, slot, server):
self._slot = slot
self._ip = None
self._port = None
self.set_server(server, True)
@property
def slot(self):
return self._slot
@property
def ip(self):
return self._ip
@property
def port(self):
return self._port
@property
def server(self):
return str(self._ip) + ":" + str(self._port)
def set_server(self, server, check_empty=False):
if len(server.split(":")) != 2:
raise TunnelError("Invalid slot format: {}".format(server))
ip, port = server.split(":")
if check_empty:
if (not ip) or (not port):
raise TunnelError("Empty server ip or port")
if ip:
self._ip = ip
if port:
self._port = int(port)
class TableStreamUploadSession(BaseTableTunnelSession):
"""
Tunnel session for uploading data in stream method to tables. Instances
of this class should be created by :meth:`TableTunnel.create_stream_upload_session`.
"""
__slots__ = (
"_client",
"_table",
"_partition_spec",
"_compress_option",
"_quota_name",
"_create_partition",
"_zorder_columns",
"_allow_schema_mismatch",
"_schema_version_reloader",
"_tags",
)
class Slots(object):
def __init__(self, slot_elements):
self._slots = []
self._cur_index = -1
for value in slot_elements:
if len(value) != 2:
raise TunnelError("Invalid slot routes")
self._slots.append(Slot(value[0], value[1]))
if len(self._slots) > 0:
self._cur_index = random.randint(0, len(self._slots))
self._iter = iter(self)
def __len__(self):
return len(self._slots)
def __next__(self):
return next(self._iter)
def __iter__(self):
while True:
if self._cur_index < 0:
yield None
else:
self._cur_index += 1
if self._cur_index >= len(self._slots):
self._cur_index = 0
yield self._slots[self._cur_index]
schema = serializers.JSONNodeReferenceField(TableSchema, "schema")
id = serializers.JSONNodeField("session_name")
status = serializers.JSONNodeField("status")
slots = serializers.JSONNodeField(
"slots", parse_callback=lambda val: TableStreamUploadSession.Slots(val)
)
quota_name = serializers.JSONNodeField("QuotaName")
schema_version = serializers.JSONNodeField("schema_version")
def __init__(
self,
client,
table,
partition_spec,
compress_option=None,
quota_name=None,
create_partition=False,
zorder_columns=None,
schema_version=None,
allow_schema_mismatch=True,
upload_id=None,
tags=None,
schema_version_reloader=None,
):
super(TableStreamUploadSession, self).__init__()
self._client = client
self._table = table
self._partition_spec = self.normalize_partition_spec(partition_spec)
self._quota_name = quota_name
self._create_partition = create_partition
self._zorder_columns = zorder_columns
self._allow_schema_mismatch = allow_schema_mismatch
self.schema_version = schema_version
self._schema_version_reloader = schema_version_reloader
self._tags = tags or options.tunnel.tags
if isinstance(self._tags, six.string_types):
self._tags = self._tags.split(",")
if upload_id is None:
if not allow_schema_mismatch and not schema_version:
self._init_with_latest_schema()
else:
self._init()
else:
self.id = upload_id
self.reload()
self._compress_option = compress_option or self._get_default_compress_option()
logger.info("Tunnel session created: %r", self)
if options.tunnel_session_create_callback:
options.tunnel_session_create_callback(self)
def __repr__(self):
return (
"<TableStreamUploadSession id=%s project=%s table=%s partition_spec=%s>"
% (
self.id,
self._table.project.name,
self._table.name,
self._partition_spec,
)
)
def _init(self):
params = self.get_common_params()
headers = self.get_common_headers(content_length=0, tags=self._tags)
if self._create_partition:
params["create_partition"] = "true"
if self.schema_version is not None:
params["schema_version"] = str(self.schema_version)
if self._zorder_columns:
cols = self._zorder_columns
if not isinstance(self._zorder_columns, six.string_types):
cols = ",".join(self._zorder_columns)
params["zorder_columns"] = cols
params["check_latest_schema"] = str(not self._allow_schema_mismatch).lower()
url = self._get_resource()
resp = self._client.post(url, {}, params=params, headers=headers)
self.check_tunnel_response(resp)
self.parse(resp, obj=self)
self._quota_name = self.quota_name
if self.schema is not None:
self.schema.build_snapshot()
def _init_with_latest_schema(self):
def init_with_table_version():
self.schema_version = self._schema_version_reloader()
self._init()
return utils.call_with_retry(
init_with_table_version, retry_times=None, exc_type=errors.NoSuchSchema
)
def _get_resource(self):
return self._table.table_resource() + "/streams"
def reload(self):
params = self.get_common_params(uploadid=self.id)
headers = self.get_common_headers(content_length=0, tags=self._tags)
url = self._get_resource()
resp = self._client.get(url, params=params, headers=headers)
self.check_tunnel_response(resp)
self.parse(resp, obj=self)
self._quota_name = self.quota_name
if self.schema is not None:
self.schema.build_snapshot()
def abort(self):
"""
Abort the upload session.
"""
params = self.get_common_params(uploadid=self.id)
slot = next(iter(self.slots))
headers = self.get_common_headers(content_length=0, tags=self._tags)
headers["odps-tunnel-routed-server"] = slot.server
url = self._get_resource()
resp = self._client.post(url, {}, params=params, headers=headers)
self.check_tunnel_response(resp)
def reload_slots(self, slot, server, slot_num):
if len(self.slots) != slot_num:
self.reload()
else:
slot.set_server(server)
def _open_writer(self, compress=False):
compress_option = self._compress_option or CompressOption()
slot = next(iter(self.slots))
headers = self.get_common_headers(chunked=True, tags=self._tags)
headers.update(
{
"odps-tunnel-slot-num": str(len(self.slots)),
"odps-tunnel-routed-server": slot.server,
}
)
if compress:
encoding = compress_option.algorithm.get_encoding()
if encoding:
headers["Content-Encoding"] = encoding
params = self.get_common_params(uploadid=self.id, slotid=slot.slot)
url = self._get_resource()
option = compress_option if compress else None
@_wrap_upload_call(self.id)
def upload_block(data):
return self._client.put(url, data=data, params=params, headers=headers)
writer = StreamRecordWriter(
self.schema, upload_block, session=self, slot=slot, compress_option=option
)
return writer
def open_record_writer(self, compress=False):
"""
Open a writer to write data in records to the tunnel.
:param bool compress: whether to compress data
:return: a record writer
:rtype: :class:`RecordWriter`
"""
return self._open_writer(compress=compress)
class TableUpsertSession(BaseTableTunnelSession):
"""
Tunnel session for inserting or updating data to upsert tables. Instances
of this class should be created by :meth:`TableTunnel.create_upsert_session`.
"""
__slots__ = (
"_client",
"_table",
"_partition_spec",
"_compress_option",
"_slot_num",
"_commit_timeout",
"_quota_name",
"_lifecycle",
"_tags",
)
UPSERT_EXTRA_COL_NUM = 5
UPSERT_VERSION_KEY = "__version"
UPSERT_APP_VERSION_KEY = "__app_version"
UPSERT_OPERATION_KEY = "__operation"
UPSERT_KEY_COLS_KEY = "__key_cols"
UPSERT_VALUE_COLS_KEY = "__value_cols"
class Status(Enum):
Normal = "NORMAL"
Committing = "COMMITTING"
Committed = "COMMITTED"
Expired = "EXPIRED"
Critical = "CRITICAL"
Aborted = "ABORTED"
class Slots(object):
def __init__(self, slot_elements):
self._slots = []
self._buckets = {}
for value in slot_elements:
slot = Slot(value["slot_id"], value["worker_addr"])
self._slots.append(slot)
self._buckets.update({idx: slot for idx in value["buckets"]})
for idx in self._buckets.keys():
if idx > len(self._buckets):
raise TunnelError("Invalid bucket value: " + str(idx))
@property
def buckets(self):
return self._buckets
def __len__(self):
return len(self._slots)
schema = serializers.JSONNodeReferenceField(TableSchema, "schema")
id = serializers.JSONNodeField("id")
status = serializers.JSONNodeField(
"status", parse_callback=lambda s: TableUpsertSession.Status(s.upper())
)
slots = serializers.JSONNodeField(
"slots", parse_callback=lambda val: TableUpsertSession.Slots(val)
)
quota_name = serializers.JSONNodeField("quota_name")
hash_keys = serializers.JSONNodeField("hash_key")
hasher = serializers.JSONNodeField("hasher")
def __init__(
self,
client,
table,
partition_spec,
compress_option=None,
slot_num=1,
commit_timeout=DEFAULT_UPSERT_COMMIT_TIMEOUT,
lifecycle=None,
quota_name=None,
upsert_id=None,
tags=None,
):
super(TableUpsertSession, self).__init__()
self._client = client
self._table = table
self._partition_spec = self.normalize_partition_spec(partition_spec)
self._lifecycle = lifecycle
self._quota_name = quota_name
self._slot_num = slot_num
self._commit_timeout = commit_timeout
self._tags = tags or options.tunnel.tags
if isinstance(self._tags, six.string_types):
self._tags = self._tags.split(",")
if upsert_id is None:
self._init()
else:
self.id = upsert_id
self.reload()
self._compress_option = compress_option or self._get_default_compress_option()
logger.info("Upsert session created: %r", self)
if options.tunnel_session_create_callback:
options.tunnel_session_create_callback(self)
def __repr__(self):
return "<TableUpsertSession id=%s project=%s table=%s partition_spec=%s>" % (
self.id,
self._table.project.name,
self._table.name,
self._partition_spec,
)
@property
def endpoint(self):
return self._client.endpoint
@property
def buckets(self):
return self.slots.buckets
def _get_resource(self):
return self._table.table_resource() + "/upserts"
def _patch_schema(self):
if self.schema is None:
return
patch_schema = types.OdpsSchema(
[
Column(self.UPSERT_VERSION_KEY, "bigint"),
Column(self.UPSERT_APP_VERSION_KEY, "bigint"),
Column(self.UPSERT_OPERATION_KEY, "tinyint"),
Column(self.UPSERT_KEY_COLS_KEY, "array<bigint>"),
Column(self.UPSERT_VALUE_COLS_KEY, "array<bigint>"),
],
)
self.schema = self.schema.extend(patch_schema)
self.schema.build_snapshot()
def _init_or_reload(self, reload=False):
params = self.get_common_params()
headers = self.get_common_headers(content_length=0, tags=self._tags)
if not reload:
params["slotnum"] = str(self._slot_num)
else:
params["upsertid"] = self.id
url = self._get_resource()
if not reload:
if self._lifecycle:
params["lifecycle"] = self._lifecycle
resp = self._client.post(url, {}, params=params, headers=headers)
else:
resp = self._client.get(url, params=params, headers=headers)
if self._client.is_ok(resp):
self.parse(resp, obj=self)
self._patch_schema()
else:
e = TunnelError.parse(resp)
raise e
def _init(self):
self._init_or_reload()
def new_record(self, values=None):
if values:
values = list(values) + [None] * 5
return super(TableUpsertSession, self).new_record(values)
def reload(self, init=False):
self._init_or_reload(reload=True)
def abort(self):
"""
Abort the current session.
"""
params = self.get_common_params(upsertid=self.id)
headers = self.get_common_headers(content_length=0, tags=self._tags)
headers["odps-tunnel-routed-server"] = self.slots.buckets[0].server
url = self._get_resource()
resp = self._client.delete(url, params=params, headers=headers)
self.check_tunnel_response(resp)
def open_upsert_stream(self, compress=False):
"""
Open an upsert stream to insert or update data in records to the tunnel.
:param bool compress: whether to compress data
:return: an upsert stream
:rtype: :class:`Upsert`
"""
params = self.get_common_params(upsertid=self.id)
headers = self.get_common_headers(tags=self._tags)
compress_option = self._compress_option or CompressOption()
if not compress:
compress_option = None
else:
encoding = compress_option.algorithm.get_encoding()
if encoding:
headers["Content-Encoding"] = encoding
url = self._get_resource()
@_wrap_upload_call(self.id)
def upload_block(bucket, slot, record_count, data):
req_params = params.copy()
req_params.update(
dict(
bucketid=bucket,
slotid=str(slot.slot),
record_count=str(record_count),
)
)
req_headers = headers.copy()
req_headers["odps-tunnel-routed-server"] = slot.server
req_headers["Content-Length"] = len(data)
return self._client.put(
url, data=data, params=req_params, headers=req_headers
)
return Upsert(self.schema, upload_block, self, compress_option)
def commit(self, async_=False):
"""
Commit the current session. Can be called only once on a single session.
"""
params = self.get_common_params(upsertid=self.id)
headers = self.get_common_headers(content_length=0, tags=self._tags)
headers["odps-tunnel-routed-server"] = self.slots.buckets[0].server
url = self._get_resource()
resp = self._client.post(url, params=params, headers=headers)
self.check_tunnel_response(resp)
self.reload()
if async_:
return
delay = 1
start = monotonic()
while self.status in (
TableUpsertSession.Status.Committing,
TableUpsertSession.Status.Normal,
):
try:
if monotonic() - start > self._commit_timeout:
raise TunnelError("Commit session timeout")
time.sleep(delay)
resp = self._client.post(url, params=params, headers=headers)
self.check_tunnel_response(resp)
self.reload()
delay = min(8, delay * 2)
except (errors.StreamSessionNotFound, errors.UpsertSessionNotFound):
self.status = TableUpsertSession.Status.Committed
if self.status != TableUpsertSession.Status.Committed:
raise TunnelError("commit session failed, status: " + self.status.value)
class TableTunnel(BaseTunnel):
"""
Table tunnel API Entry.
:param odps: ODPS Entry object
:param str project: project name
:param str endpoint: tunnel endpoint
:param str quota_name: name of tunnel quota
"""
def _get_tunnel_table(self, table, schema=None):
project_odps = None
try:
project_odps = self._project.odps
if isinstance(table, six.string_types):
table = project_odps.get_table(table, project=self._project.name)
except:
pass
project_name = self._project.name
if not isinstance(table, six.string_types):
project_name = table.project.name or project_name
schema = schema or getattr(table.get_schema(), "name", None)
table = table.name
parent = Projects(client=self.tunnel_rest)[project_name]
# tailor project for resource locating only
parent._set_tunnel_defaults(odps_entry=project_odps)
if schema is not None:
parent = parent.schemas[schema]
return parent.tables[table]
@staticmethod
def _build_compress_option(compress_algo=None, level=None, strategy=None):
if compress_algo is None:
return None
return CompressOption(
compress_algo=compress_algo, level=level, strategy=strategy
)
def create_download_session(
self,
table,
async_mode=True,
partition_spec=None,
download_id=None,
compress_option=None,
compress_algo=None,
compress_level=None,
compress_strategy=None,
schema=None,
timeout=None,
tags=None,
**kw
):
"""
Create a download session for table.
:param table: table object to read
:type table: str | :class:`odps.models.Table`
:param partition_spec: partition spec to read
:type partition_spec: str | :class:`odps.types.PartitionSpec`
:param str download_id: existing download id
:param compress_option: compress option
:type compress_option: :class:`odps.tunnel.CompressOption`
:param str compress_algo: compress algorithm
:param int compress_level: compress level
:param str schema: name of schema of the table
:param tags: tags of the upload session
:type tags: str | list
:return: :class:`TableDownloadSession`
"""
table = self._get_tunnel_table(table, schema)
compress_option = compress_option or self._build_compress_option(
compress_algo=compress_algo,
level=compress_level,
strategy=compress_strategy,
)
if "async_" in kw:
async_mode = kw.pop("async_")
if kw:
raise TypeError("Cannot accept arguments %s" % ", ".join(kw.keys()))
return TableDownloadSession(
self.tunnel_rest,
table,
partition_spec,
download_id=download_id,
compress_option=compress_option,
async_mode=async_mode,
timeout=timeout,
quota_name=self._quota_name,
tags=tags,
)
def create_upload_session(
self,
table,
partition_spec=None,
upload_id=None,
compress_option=None,
compress_algo=None,
compress_level=None,
compress_strategy=None,
schema=None,
overwrite=False,
create_partition=False,
tags=None,
):
"""
Create an upload session for table.
:param table: table object to read
:type table: str | :class:`odps.models.Table`
:param partition_spec: partition spec
:type partition_spec: str | :class:`odps.types.PartitionSpec`
:param str upload_id: existing upload id
:param compress_option: compress option
:type compress_option: :class:`odps.tunnel.CompressOption`
:param str compress_algo: compress algorithm
:param int compress_level: compress level
:param str schema: name of schema of the table
:param bool overwrite: whether to overwrite the table
:param bool create_partition: whether to create partitition if not exist
:param tags: tags of the upload session
:type tags: str | list
:return: :class:`TableUploadSession`
"""
table = self._get_tunnel_table(table, schema)
compress_option = compress_option or self._build_compress_option(
compress_algo=compress_algo,
level=compress_level,
strategy=compress_strategy,
)
return TableUploadSession(
self.tunnel_rest,
table,
partition_spec,
upload_id=upload_id,
compress_option=compress_option,
overwrite=overwrite,
quota_name=self._quota_name,
create_partition=create_partition,
tags=tags,
)
def create_stream_upload_session(
self,
table,
partition_spec=None,
compress_option=None,
compress_algo=None,
compress_level=None,
compress_strategy=None,
schema=None,
schema_version=None,
zorder_columns=None,
upload_id=None,
tags=None,
allow_schema_mismatch=True,
create_partition=False,
):
"""
Create a stream upload session for table.
:param table: table object to read
:type table: str | :class:`odps.models.Table`
:param partition_spec: partition spec
:type partition_spec: str | :class:`odps.types.PartitionSpec`
:param str upload_id: existing upload id
:param compress_option: compress option
:type compress_option: :class:`odps.tunnel.CompressOption`
:param str compress_algo: compress algorithm
:param int compress_level: compress level
:param str schema: name of schema of the table
:param str schema_version: schema version of the upload
:param tags: tags of the upload session
:type tags: str | list
:param bool allow_schema_mismatch: whether to allow table schema to be mismatched
:param bool create_partition: whether to create partition if not exist
:return: :class:`TableStreamUploadSession`
"""
table = self._get_tunnel_table(table, schema)
compress_option = compress_option or self._build_compress_option(
compress_algo=compress_algo,
level=compress_level,
strategy=compress_strategy,
)
version_need_reloaded = [False]
def schema_version_reloader():
src_table = self._project.tables[table.name]
if version_need_reloaded[0]:
src_table.reload_extend_info()
version_need_reloaded[0] = True
return src_table.schema_version
return TableStreamUploadSession(
self.tunnel_rest,
table,
partition_spec,
compress_option=compress_option,
quota_name=self._quota_name,
schema_version=schema_version,
upload_id=upload_id,
tags=tags,
allow_schema_mismatch=allow_schema_mismatch,
schema_version_reloader=schema_version_reloader,
create_partition=create_partition,
zorder_columns=zorder_columns,
)
def create_upsert_session(
self,
table,
partition_spec=None,
slot_num=1,
commit_timeout=120,
compress_option=None,
compress_algo=None,
compress_level=None,
compress_strategy=None,
schema=None,
upsert_id=None,
tags=None,
):
"""
Create an upsert session for table.
:param table: table object to read
:type table: str | :class:`odps.models.Table`
:param partition_spec: partition spec
:type partition_spec: str | :class:`odps.types.PartitionSpec`
:param str upsert_id: existing upsert id
:param commit_timeout: timeout for commit
:param compress_option: compress option
:type compress_option: :class:`odps.tunnel.CompressOption`
:param str compress_algo: compress algorithm
:param int compress_level: compress level
:param str schema: name of schema of the table
:param tags: tags of the upload session
:type tags: str | list
:return: :class:`TableUpsertSession`
"""
table = self._get_tunnel_table(table, schema)
compress_option = compress_option or self._build_compress_option(
compress_algo=compress_algo,
level=compress_level,
strategy=compress_strategy,
)
return TableUpsertSession(
self.tunnel_rest,
table,
partition_spec,
slot_num=slot_num,
upsert_id=upsert_id,
commit_timeout=commit_timeout,
compress_option=compress_option,
quota_name=self._quota_name,
tags=tags,
)
def open_preview_reader(
self,
table,
partition_spec=None,
columns=None,
limit=None,
compress_option=None,
compress_algo=None,
compress_level=None,
compress_strategy=None,
arrow=True,
timeout=None,
make_compat=True,
read_all=False,
tags=None,
):
"""
Open a preview reader for table to read initial rows.
:param table: table object to read
:type table: str | :class:`odps.models.Table`
:param partition_spec: partition spec to read
:type partition_spec: str | :class:`odps.types.PartitionSpec`
:param columns: columns to read
:param int limit: number of rows to read, 10000 by default
:param compress_option: compress option
:type compress_option: :class:`odps.tunnel.CompressOption`
:param str compress_algo: compress algorithm
:param int compress_level: compress level
:param str schema: name of schema of the table
:param bool arrow: if True, return an Arrow reader, otherwise return a record reader
:param tags: tags of the upload session
:type tags: str | list
"""
if pa is None:
raise ImportError("Need pyarrow to run open_preview_reader.")
tunnel_table = self._get_tunnel_table(table)
compress_option = compress_option or self._build_compress_option(
compress_algo=compress_algo,
level=compress_level,
strategy=compress_strategy,
)
params = {"limit": str(limit) if limit else "-1"}
partition_spec = BaseTableTunnelSession.normalize_partition_spec(partition_spec)
if columns:
col_set = set(columns)
ordered_col = [c.name for c in table.table_schema if c.name in col_set]
params["columns"] = ",".join(ordered_col)
if partition_spec is not None and len(partition_spec) > 0:
params["partition"] = partition_spec
headers = BaseTableTunnelSession.get_common_headers(content_length=0, tags=tags)
if compress_option:
encoding = compress_option.algorithm.get_encoding(legacy=False)
if encoding:
headers["Accept-Encoding"] = encoding
url = tunnel_table.table_resource(force_schema=True) + "/preview"
resp = self.tunnel_rest.get(
url, stream=True, params=params, headers=headers, timeout=timeout
)
if not self.tunnel_rest.is_ok(resp): # pragma: no cover
e = TunnelError.parse(resp)
raise e
input_stream = get_decompress_stream(resp)
if input_stream.peek() is None:
# stream is empty, replace with empty stream
input_stream = None
def stream_creator(pos):
# part retry not supported currently
assert pos == 0
return input_stream
reader = TunnelArrowReader(
table.table_schema, stream_creator, columns=columns, use_ipc_stream=True
)
if not arrow:
reader = ArrowRecordReader(
reader, make_compat=make_compat, read_all=read_all
)
return reader