odps/tunnel/tabletunnel.py (1,167 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import random import sys import time import requests from .. import errors, options, serializers, types, utils from ..compat import Enum, six from ..lib.monotonic import monotonic from ..models import Projects, Record, TableSchema from ..types import Column from .base import TUNNEL_VERSION, BaseTunnel from .errors import TunnelError, TunnelReadTimeout, TunnelWriteTimeout from .io.reader import ArrowRecordReader, TunnelArrowReader, TunnelRecordReader from .io.stream import CompressOption, get_decompress_stream from .io.writer import ( ArrowWriter, BufferedArrowWriter, BufferedRecordWriter, RecordWriter, StreamRecordWriter, Upsert, ) try: import numpy as np except ImportError: np = None try: import pyarrow as pa except ImportError: pa = None logger = logging.getLogger(__name__) TUNNEL_DATA_TRANSFORM_VERSION = "v1" DEFAULT_UPSERT_COMMIT_TIMEOUT = 120 def _wrap_upload_call(request_id): def wrapper(func): @six.wraps(func) def wrapped(*args, **kwargs): try: return func(*args, **kwargs) except requests.ConnectionError as ex: ex_str = str(ex) if "timed out" in ex_str: raise TunnelWriteTimeout(ex_str, request_id=request_id) else: raise return wrapped return wrapper class BaseTableTunnelSession(serializers.JSONSerializableModel): @staticmethod def get_common_headers(content_length=None, chunked=False, tags=None): header = { "odps-tunnel-date-transform": TUNNEL_DATA_TRANSFORM_VERSION, "odps-tunnel-sdk-support-schema-evolution": "true", "x-odps-tunnel-version": TUNNEL_VERSION, } if content_length is not None: header["Content-Length"] = content_length if chunked: header.update( { "Transfer-Encoding": "chunked", "Content-Type": "application/octet-stream", } ) tags = tags or options.tunnel.tags if tags: if isinstance(tags, six.string_types): tags = tags.split(",") header["odps-tunnel-tags"] = ",".join(tags) return header @staticmethod def normalize_partition_spec(partition_spec): if isinstance(partition_spec, six.string_types): partition_spec = types.PartitionSpec(partition_spec) if isinstance(partition_spec, types.PartitionSpec): partition_spec = str(partition_spec).replace("'", "") return partition_spec def get_common_params(self, **kwargs): params = {k: str(v) for k, v in kwargs.items()} if getattr(self, "_quota_name", None): params["quotaName"] = self._quota_name if self._partition_spec is not None and len(self._partition_spec) > 0: params["partition"] = self._partition_spec return params def check_tunnel_response(self, resp): if not self._client.is_ok(resp): e = TunnelError.parse(resp) raise e @classmethod def _get_default_compress_option(cls): if not options.tunnel.compress.enabled: return None return CompressOption( compress_algo=options.tunnel.compress.algo, level=options.tunnel.compress.level, strategy=options.tunnel.compress.strategy, ) def new_record(self, values=None): """ Generate a record of the current upload session. :param values: the values of this records :type values: list :return: record :rtype: :class:`odps.models.Record` :Example: >>> session = TableTunnel(o).create_upload_session('test_table') >>> record = session.new_record() >>> record[0] = 'my_name' >>> record[1] = 'my_id' >>> record = session.new_record(['my_name', 'my_id']) .. seealso:: :class:`odps.models.Record` """ return Record( schema=self.schema, values=values, max_field_size=getattr(self, "max_field_size", None), ) class TableDownloadSession(BaseTableTunnelSession): """ Tunnel session for downloading data from tables. Instances of this class should be created by :meth:`TableTunnel.create_download_session`. """ __slots__ = ( "_client", "_table", "_partition_spec", "_compress_option", "_quota_name", "_tags", ) class Status(Enum): Unknown = "UNKNOWN" Normal = "NORMAL" Closes = "CLOSES" Expired = "EXPIRED" Initiating = "INITIATING" id = serializers.JSONNodeField("DownloadID") status = serializers.JSONNodeField( "Status", parse_callback=lambda s: TableDownloadSession.Status(s.upper()) ) count = serializers.JSONNodeField("RecordCount") schema = serializers.JSONNodeReferenceField(TableSchema, "Schema") quota_name = serializers.JSONNodeField("QuotaName") def __init__( self, client, table, partition_spec, download_id=None, compress_option=None, async_mode=True, timeout=None, quota_name=None, tags=None, **kw ): super(TableDownloadSession, self).__init__() self._client = client self._table = table self._partition_spec = self.normalize_partition_spec(partition_spec) self._quota_name = quota_name if "async_" in kw: async_mode = kw.pop("async_") if kw: raise TypeError("Cannot accept arguments %s" % ", ".join(kw.keys())) self._tags = tags or options.tunnel.tags if isinstance(self._tags, six.string_types): self._tags = self._tags.split(",") if download_id is None: self._init(async_mode=async_mode, timeout=timeout) else: self.id = download_id self.reload() self._compress_option = compress_option or self._get_default_compress_option() logger.info("Tunnel session created: %r", self) if options.tunnel_session_create_callback: options.tunnel_session_create_callback(self) def __repr__(self): return "<TableDownloadSession id=%s project=%s table=%s partition_spec=%r>" % ( self.id, self._table.project.name, self._table.name, self._partition_spec, ) def _init(self, async_mode, timeout): params = self.get_common_params(downloads="") headers = self.get_common_headers(content_length=0, tags=self._tags) if async_mode: params["asyncmode"] = "true" url = self._table.table_resource() ts = monotonic() try: resp = self._client.post( url, {}, params=params, headers=headers, timeout=timeout ) except requests.exceptions.ReadTimeout: if callable(options.tunnel_session_create_timeout_callback): options.tunnel_session_create_timeout_callback(*sys.exc_info()) raise self.check_tunnel_response(resp) delay_time = 0.1 self.parse(resp, obj=self) while self.status == self.Status.Initiating: if timeout and monotonic() - ts > timeout: try: raise TunnelReadTimeout( "Waiting for tunnel ready timed out. id=%s, table=%s" % (self.id, self._table.name) ) except TunnelReadTimeout: if callable(options.tunnel_session_create_timeout_callback): options.tunnel_session_create_timeout_callback(*sys.exc_info()) raise time.sleep(delay_time) delay_time = min(delay_time * 2, 5) self.reload() if self.schema is not None: self.schema.build_snapshot() def reload(self): params = self.get_common_params(downloadid=self.id) headers = self.get_common_headers(content_length=0, tags=self._tags) url = self._table.table_resource() resp = self._client.get(url, params=params, headers=headers) self.check_tunnel_response(resp) self.parse(resp, obj=self) if self.schema is not None: self.schema.build_snapshot() def _build_input_stream( self, start, count, compress=False, columns=None, arrow=False ): compress_option = self._compress_option or CompressOption() actions = ["data"] params = self.get_common_params(downloadid=self.id) headers = self.get_common_headers(content_length=0, tags=self._tags) if compress: encoding = compress_option.algorithm.get_encoding() if encoding: headers["Accept-Encoding"] = encoding params["rowrange"] = "(%s,%s)" % (start, count) if columns is not None and len(columns) > 0: col_name = lambda col: col.name if isinstance(col, types.Column) else col params["columns"] = ",".join(col_name(col) for col in columns) if arrow: actions.append("arrow") url = self._table.table_resource() resp = self._client.get( url, stream=True, actions=actions, params=params, headers=headers ) self.check_tunnel_response(resp) content_encoding = resp.headers.get("Content-Encoding") if content_encoding is not None: compress_algo = CompressOption.CompressAlgorithm.from_encoding( content_encoding ) if compress_algo != compress_option.algorithm: compress_option = self._compress_option = CompressOption( compress_algo, -1, 0 ) compress = True else: compress = False option = compress_option if compress else None return get_decompress_stream(resp, option) def _open_reader( self, start, count, compress=None, columns=None, arrow=False, reader_cls=None, **kw ): pt_cols = ( set(types.PartitionSpec(self._partition_spec).keys()) if self._partition_spec else set() ) reader_cols = [c for c in columns if c not in pt_cols] if columns else columns if compress is None: compress = self._compress_option is not None stream_kw = dict(compress=compress, columns=reader_cols, arrow=arrow) def stream_creator(cursor): return self._build_input_stream(start + cursor, count - cursor, **stream_kw) return reader_cls(self.schema, stream_creator, columns=columns, **kw) def open_record_reader( self, start, count, compress=False, columns=None, append_partitions=True ): """ Open a reader to read data as records from the tunnel. :param int start: start row index :param int count: number of rows to read :param bool compress: whether to compress data :columns: list of column names to read :append_partitions: whether to append partition values as columns :return: a record reader :rtype: :class:`TunnelRecordReader` """ return self._open_reader( start, count, compress=compress, columns=columns, append_partitions=append_partitions, partition_spec=self._partition_spec, reader_cls=TunnelRecordReader, ) def open_arrow_reader( self, start, count, compress=False, columns=None, append_partitions=False ): """ Open a reader to read data as Arrow format from the tunnel. :param int start: start row index :param int count: number of rows to read :param bool compress: whether to compress data :columns: list of column names to read :append_partitions: whether to append partition values as columns :return: an Arrow reader :rtype: :class:`TunnelArrowReader` """ return self._open_reader( start, count, compress=compress, columns=columns, arrow=True, append_partitions=append_partitions, partition_spec=self._partition_spec, reader_cls=TunnelArrowReader, ) class TableUploadSession(BaseTableTunnelSession): """ Tunnel session for uploading data to tables. Instances of this class should be created by :meth:`TableTunnel.create_upload_session`. """ __slots__ = ( "_client", "_table", "_partition_spec", "_compress_option", "_create_partition", "_overwrite", "_quota_name", "_tags", ) class Status(Enum): Unknown = "UNKNOWN" Normal = "NORMAL" Closing = "CLOSING" Closed = "CLOSED" Canceled = "CANCELED" Expired = "EXPIRED" Critical = "CRITICAL" id = serializers.JSONNodeField("UploadID") status = serializers.JSONNodeField( "Status", parse_callback=lambda s: TableUploadSession.Status(s.upper()) ) blocks = serializers.JSONNodesField("UploadedBlockList", "BlockID") schema = serializers.JSONNodeReferenceField(TableSchema, "Schema") max_field_size = serializers.JSONNodeField("MaxFieldSize") quota_name = serializers.JSONNodeField("QuotaName") def __init__( self, client, table, partition_spec, upload_id=None, compress_option=None, create_partition=None, overwrite=False, quota_name=None, tags=None, ): super(TableUploadSession, self).__init__() self._client = client self._table = table self._partition_spec = self.normalize_partition_spec(partition_spec) self._create_partition = create_partition self._quota_name = quota_name self._overwrite = overwrite self._tags = tags or options.tunnel.tags if isinstance(self._tags, six.string_types): self._tags = self._tags.split(",") if upload_id is None: self._init() else: self.id = upload_id self.reload() self._compress_option = compress_option or self._get_default_compress_option() logger.info("Tunnel session created: %r", self) if options.tunnel_session_create_callback: options.tunnel_session_create_callback(self) def __repr__(self): repr_args = "id=%s project=%s table=%s partition_spec=%r" % ( self.id, self._table.project.name, self._table.name, self._partition_spec, ) if self._overwrite: repr_args += " overwrite=True" return "<TableUploadSession %s>" % repr_args def _create_or_reload_session(self, reload=False): headers = self.get_common_headers(content_length=0, tags=self._tags) params = self.get_common_params(reload=reload) if self._create_partition: params["create_partition"] = "true" if not reload and self._overwrite: params["overwrite"] = "true" if reload: params["uploadid"] = self.id else: params["uploads"] = 1 def _call_tunnel(func, *args, **kw): resp = func(*args, **kw) self.check_tunnel_response(resp) return resp url = self._table.table_resource() if reload: resp = utils.call_with_retry( _call_tunnel, self._client.get, url, params=params, headers=headers ) else: resp = utils.call_with_retry( _call_tunnel, self._client.post, url, {}, params=params, headers=headers ) self.parse(resp, obj=self) if self.schema is not None: self.schema.build_snapshot() def _init(self): self._create_or_reload_session(reload=False) def reload(self): self._create_or_reload_session(reload=True) @classmethod def _iter_data_in_batches(cls, data): pos = 0 chunk_size = options.chunk_size while pos < len(data): yield data[pos : pos + chunk_size] pos += chunk_size def _open_writer( self, block_id=None, compress=None, buffer_size=None, writer_cls=None, initial_block_id=None, block_id_gen=None, ): compress_option = self._compress_option or CompressOption() params = self.get_common_params(uploadid=self.id) headers = self.get_common_headers(chunked=True, tags=self._tags) if compress is None: compress = self._compress_option is not None if compress: # special: rewrite LZ4 to ARROW_LZ4 for arrow tunnels if ( writer_cls is not None and issubclass(writer_cls, ArrowWriter) and compress_option.algorithm == CompressOption.CompressAlgorithm.ODPS_LZ4 ): compress_option.algorithm = ( CompressOption.CompressAlgorithm.ODPS_ARROW_LZ4 ) encoding = compress_option.algorithm.get_encoding() if encoding: headers["Content-Encoding"] = encoding url = self._table.table_resource() option = compress_option if compress else None if block_id is None: @_wrap_upload_call(self.id) def upload_block(blockid, data): params["blockid"] = blockid def upload_func(): if isinstance(data, (bytes, bytearray)): to_upload = self._iter_data_in_batches(data) else: to_upload = data return self._client.put( url, data=to_upload, params=params, headers=headers ) return utils.call_with_retry(upload_func) if writer_cls is ArrowWriter: writer_cls = BufferedArrowWriter params["arrow"] = "" else: writer_cls = BufferedRecordWriter writer = writer_cls( self.schema, upload_block, compress_option=option, buffer_size=buffer_size, block_id=initial_block_id, block_id_gen=block_id_gen, ) else: params["blockid"] = block_id @_wrap_upload_call(self.id) def upload(data): return self._client.put(url, data=data, params=params, headers=headers) if writer_cls is ArrowWriter: params["arrow"] = "" writer = writer_cls(self.schema, upload, compress_option=option) return writer def open_record_writer( self, block_id=None, compress=False, buffer_size=None, initial_block_id=None, block_id_gen=None, ): """ Open a writer to write data in records to the tunnel. :param int block_id: id of the block to write to. If not specified, a :class:`BufferedRecordWriter` will be created. :param int buffer_size: size of the buffer to use for buffered writers. :param bool compress: whether to compress data :return: a record writer :rtype: :class:`RecordWriter` or :class:`BufferedRecordWriter` """ return self._open_writer( block_id=block_id, compress=compress, buffer_size=buffer_size, initial_block_id=initial_block_id, block_id_gen=block_id_gen, writer_cls=RecordWriter, ) def open_arrow_writer( self, block_id=None, compress=False, buffer_size=None, initial_block_id=None, block_id_gen=None, ): """ Open a writer to write data in Arrow format to the tunnel. :param int block_id: id of the block to write to. If not specified, a :class:`BufferedArrowWriter` will be created. :param int buffer_size: size of the buffer to use for buffered writers. :param bool compress: whether to compress data :return: an Arrow writer :rtype: :class:`ArrowWriter` or :class:`BufferedArrowWriter` """ return self._open_writer( block_id=block_id, compress=compress, buffer_size=buffer_size, initial_block_id=initial_block_id, block_id_gen=block_id_gen, writer_cls=ArrowWriter, ) def get_block_list(self): self.reload() return self.blocks def commit(self, blocks): """ Commit written blocks to the tunnel. Can be called only once on a single session. :param list blocks: list of block ids to commit """ if blocks is None: raise ValueError("Invalid parameter: blocks.") if isinstance(blocks, six.integer_types): blocks = [blocks] server_block_map = dict( [(int(block_id), True) for block_id in self.get_block_list()] ) client_block_map = dict([(int(block_id), True) for block_id in blocks]) if len(server_block_map) != len(client_block_map): raise TunnelError( "Blocks not match, server: %s, tunnelServerClient: %s. " "Make sure all block writers closed or with-blocks exited." % (len(server_block_map), len(client_block_map)) ) for block_id in blocks: if block_id not in server_block_map: raise TunnelError( "Block not exists on server, block id is %s" % (block_id,) ) self._complete_upload() def _complete_upload(self): headers = self.get_common_headers() params = self.get_common_params(uploadid=self.id) url = self._table.table_resource() resp = utils.call_with_retry( self._client.post, url, "", params=params, headers=headers, exc_type=( requests.Timeout, requests.ConnectionError, errors.InternalServerError, ), ) self.parse(resp, obj=self) class Slot(object): def __init__(self, slot, server): self._slot = slot self._ip = None self._port = None self.set_server(server, True) @property def slot(self): return self._slot @property def ip(self): return self._ip @property def port(self): return self._port @property def server(self): return str(self._ip) + ":" + str(self._port) def set_server(self, server, check_empty=False): if len(server.split(":")) != 2: raise TunnelError("Invalid slot format: {}".format(server)) ip, port = server.split(":") if check_empty: if (not ip) or (not port): raise TunnelError("Empty server ip or port") if ip: self._ip = ip if port: self._port = int(port) class TableStreamUploadSession(BaseTableTunnelSession): """ Tunnel session for uploading data in stream method to tables. Instances of this class should be created by :meth:`TableTunnel.create_stream_upload_session`. """ __slots__ = ( "_client", "_table", "_partition_spec", "_compress_option", "_quota_name", "_create_partition", "_zorder_columns", "_allow_schema_mismatch", "_schema_version_reloader", "_tags", ) class Slots(object): def __init__(self, slot_elements): self._slots = [] self._cur_index = -1 for value in slot_elements: if len(value) != 2: raise TunnelError("Invalid slot routes") self._slots.append(Slot(value[0], value[1])) if len(self._slots) > 0: self._cur_index = random.randint(0, len(self._slots)) self._iter = iter(self) def __len__(self): return len(self._slots) def __next__(self): return next(self._iter) def __iter__(self): while True: if self._cur_index < 0: yield None else: self._cur_index += 1 if self._cur_index >= len(self._slots): self._cur_index = 0 yield self._slots[self._cur_index] schema = serializers.JSONNodeReferenceField(TableSchema, "schema") id = serializers.JSONNodeField("session_name") status = serializers.JSONNodeField("status") slots = serializers.JSONNodeField( "slots", parse_callback=lambda val: TableStreamUploadSession.Slots(val) ) quota_name = serializers.JSONNodeField("QuotaName") schema_version = serializers.JSONNodeField("schema_version") def __init__( self, client, table, partition_spec, compress_option=None, quota_name=None, create_partition=False, zorder_columns=None, schema_version=None, allow_schema_mismatch=True, upload_id=None, tags=None, schema_version_reloader=None, ): super(TableStreamUploadSession, self).__init__() self._client = client self._table = table self._partition_spec = self.normalize_partition_spec(partition_spec) self._quota_name = quota_name self._create_partition = create_partition self._zorder_columns = zorder_columns self._allow_schema_mismatch = allow_schema_mismatch self.schema_version = schema_version self._schema_version_reloader = schema_version_reloader self._tags = tags or options.tunnel.tags if isinstance(self._tags, six.string_types): self._tags = self._tags.split(",") if upload_id is None: if not allow_schema_mismatch and not schema_version: self._init_with_latest_schema() else: self._init() else: self.id = upload_id self.reload() self._compress_option = compress_option or self._get_default_compress_option() logger.info("Tunnel session created: %r", self) if options.tunnel_session_create_callback: options.tunnel_session_create_callback(self) def __repr__(self): return ( "<TableStreamUploadSession id=%s project=%s table=%s partition_spec=%s>" % ( self.id, self._table.project.name, self._table.name, self._partition_spec, ) ) def _init(self): params = self.get_common_params() headers = self.get_common_headers(content_length=0, tags=self._tags) if self._create_partition: params["create_partition"] = "true" if self.schema_version is not None: params["schema_version"] = str(self.schema_version) if self._zorder_columns: cols = self._zorder_columns if not isinstance(self._zorder_columns, six.string_types): cols = ",".join(self._zorder_columns) params["zorder_columns"] = cols params["check_latest_schema"] = str(not self._allow_schema_mismatch).lower() url = self._get_resource() resp = self._client.post(url, {}, params=params, headers=headers) self.check_tunnel_response(resp) self.parse(resp, obj=self) self._quota_name = self.quota_name if self.schema is not None: self.schema.build_snapshot() def _init_with_latest_schema(self): def init_with_table_version(): self.schema_version = self._schema_version_reloader() self._init() return utils.call_with_retry( init_with_table_version, retry_times=None, exc_type=errors.NoSuchSchema ) def _get_resource(self): return self._table.table_resource() + "/streams" def reload(self): params = self.get_common_params(uploadid=self.id) headers = self.get_common_headers(content_length=0, tags=self._tags) url = self._get_resource() resp = self._client.get(url, params=params, headers=headers) self.check_tunnel_response(resp) self.parse(resp, obj=self) self._quota_name = self.quota_name if self.schema is not None: self.schema.build_snapshot() def abort(self): """ Abort the upload session. """ params = self.get_common_params(uploadid=self.id) slot = next(iter(self.slots)) headers = self.get_common_headers(content_length=0, tags=self._tags) headers["odps-tunnel-routed-server"] = slot.server url = self._get_resource() resp = self._client.post(url, {}, params=params, headers=headers) self.check_tunnel_response(resp) def reload_slots(self, slot, server, slot_num): if len(self.slots) != slot_num: self.reload() else: slot.set_server(server) def _open_writer(self, compress=False): compress_option = self._compress_option or CompressOption() slot = next(iter(self.slots)) headers = self.get_common_headers(chunked=True, tags=self._tags) headers.update( { "odps-tunnel-slot-num": str(len(self.slots)), "odps-tunnel-routed-server": slot.server, } ) if compress: encoding = compress_option.algorithm.get_encoding() if encoding: headers["Content-Encoding"] = encoding params = self.get_common_params(uploadid=self.id, slotid=slot.slot) url = self._get_resource() option = compress_option if compress else None @_wrap_upload_call(self.id) def upload_block(data): return self._client.put(url, data=data, params=params, headers=headers) writer = StreamRecordWriter( self.schema, upload_block, session=self, slot=slot, compress_option=option ) return writer def open_record_writer(self, compress=False): """ Open a writer to write data in records to the tunnel. :param bool compress: whether to compress data :return: a record writer :rtype: :class:`RecordWriter` """ return self._open_writer(compress=compress) class TableUpsertSession(BaseTableTunnelSession): """ Tunnel session for inserting or updating data to upsert tables. Instances of this class should be created by :meth:`TableTunnel.create_upsert_session`. """ __slots__ = ( "_client", "_table", "_partition_spec", "_compress_option", "_slot_num", "_commit_timeout", "_quota_name", "_lifecycle", "_tags", ) UPSERT_EXTRA_COL_NUM = 5 UPSERT_VERSION_KEY = "__version" UPSERT_APP_VERSION_KEY = "__app_version" UPSERT_OPERATION_KEY = "__operation" UPSERT_KEY_COLS_KEY = "__key_cols" UPSERT_VALUE_COLS_KEY = "__value_cols" class Status(Enum): Normal = "NORMAL" Committing = "COMMITTING" Committed = "COMMITTED" Expired = "EXPIRED" Critical = "CRITICAL" Aborted = "ABORTED" class Slots(object): def __init__(self, slot_elements): self._slots = [] self._buckets = {} for value in slot_elements: slot = Slot(value["slot_id"], value["worker_addr"]) self._slots.append(slot) self._buckets.update({idx: slot for idx in value["buckets"]}) for idx in self._buckets.keys(): if idx > len(self._buckets): raise TunnelError("Invalid bucket value: " + str(idx)) @property def buckets(self): return self._buckets def __len__(self): return len(self._slots) schema = serializers.JSONNodeReferenceField(TableSchema, "schema") id = serializers.JSONNodeField("id") status = serializers.JSONNodeField( "status", parse_callback=lambda s: TableUpsertSession.Status(s.upper()) ) slots = serializers.JSONNodeField( "slots", parse_callback=lambda val: TableUpsertSession.Slots(val) ) quota_name = serializers.JSONNodeField("quota_name") hash_keys = serializers.JSONNodeField("hash_key") hasher = serializers.JSONNodeField("hasher") def __init__( self, client, table, partition_spec, compress_option=None, slot_num=1, commit_timeout=DEFAULT_UPSERT_COMMIT_TIMEOUT, lifecycle=None, quota_name=None, upsert_id=None, tags=None, ): super(TableUpsertSession, self).__init__() self._client = client self._table = table self._partition_spec = self.normalize_partition_spec(partition_spec) self._lifecycle = lifecycle self._quota_name = quota_name self._slot_num = slot_num self._commit_timeout = commit_timeout self._tags = tags or options.tunnel.tags if isinstance(self._tags, six.string_types): self._tags = self._tags.split(",") if upsert_id is None: self._init() else: self.id = upsert_id self.reload() self._compress_option = compress_option or self._get_default_compress_option() logger.info("Upsert session created: %r", self) if options.tunnel_session_create_callback: options.tunnel_session_create_callback(self) def __repr__(self): return "<TableUpsertSession id=%s project=%s table=%s partition_spec=%s>" % ( self.id, self._table.project.name, self._table.name, self._partition_spec, ) @property def endpoint(self): return self._client.endpoint @property def buckets(self): return self.slots.buckets def _get_resource(self): return self._table.table_resource() + "/upserts" def _patch_schema(self): if self.schema is None: return patch_schema = types.OdpsSchema( [ Column(self.UPSERT_VERSION_KEY, "bigint"), Column(self.UPSERT_APP_VERSION_KEY, "bigint"), Column(self.UPSERT_OPERATION_KEY, "tinyint"), Column(self.UPSERT_KEY_COLS_KEY, "array<bigint>"), Column(self.UPSERT_VALUE_COLS_KEY, "array<bigint>"), ], ) self.schema = self.schema.extend(patch_schema) self.schema.build_snapshot() def _init_or_reload(self, reload=False): params = self.get_common_params() headers = self.get_common_headers(content_length=0, tags=self._tags) if not reload: params["slotnum"] = str(self._slot_num) else: params["upsertid"] = self.id url = self._get_resource() if not reload: if self._lifecycle: params["lifecycle"] = self._lifecycle resp = self._client.post(url, {}, params=params, headers=headers) else: resp = self._client.get(url, params=params, headers=headers) if self._client.is_ok(resp): self.parse(resp, obj=self) self._patch_schema() else: e = TunnelError.parse(resp) raise e def _init(self): self._init_or_reload() def new_record(self, values=None): if values: values = list(values) + [None] * 5 return super(TableUpsertSession, self).new_record(values) def reload(self, init=False): self._init_or_reload(reload=True) def abort(self): """ Abort the current session. """ params = self.get_common_params(upsertid=self.id) headers = self.get_common_headers(content_length=0, tags=self._tags) headers["odps-tunnel-routed-server"] = self.slots.buckets[0].server url = self._get_resource() resp = self._client.delete(url, params=params, headers=headers) self.check_tunnel_response(resp) def open_upsert_stream(self, compress=False): """ Open an upsert stream to insert or update data in records to the tunnel. :param bool compress: whether to compress data :return: an upsert stream :rtype: :class:`Upsert` """ params = self.get_common_params(upsertid=self.id) headers = self.get_common_headers(tags=self._tags) compress_option = self._compress_option or CompressOption() if not compress: compress_option = None else: encoding = compress_option.algorithm.get_encoding() if encoding: headers["Content-Encoding"] = encoding url = self._get_resource() @_wrap_upload_call(self.id) def upload_block(bucket, slot, record_count, data): req_params = params.copy() req_params.update( dict( bucketid=bucket, slotid=str(slot.slot), record_count=str(record_count), ) ) req_headers = headers.copy() req_headers["odps-tunnel-routed-server"] = slot.server req_headers["Content-Length"] = len(data) return self._client.put( url, data=data, params=req_params, headers=req_headers ) return Upsert(self.schema, upload_block, self, compress_option) def commit(self, async_=False): """ Commit the current session. Can be called only once on a single session. """ params = self.get_common_params(upsertid=self.id) headers = self.get_common_headers(content_length=0, tags=self._tags) headers["odps-tunnel-routed-server"] = self.slots.buckets[0].server url = self._get_resource() resp = self._client.post(url, params=params, headers=headers) self.check_tunnel_response(resp) self.reload() if async_: return delay = 1 start = monotonic() while self.status in ( TableUpsertSession.Status.Committing, TableUpsertSession.Status.Normal, ): try: if monotonic() - start > self._commit_timeout: raise TunnelError("Commit session timeout") time.sleep(delay) resp = self._client.post(url, params=params, headers=headers) self.check_tunnel_response(resp) self.reload() delay = min(8, delay * 2) except (errors.StreamSessionNotFound, errors.UpsertSessionNotFound): self.status = TableUpsertSession.Status.Committed if self.status != TableUpsertSession.Status.Committed: raise TunnelError("commit session failed, status: " + self.status.value) class TableTunnel(BaseTunnel): """ Table tunnel API Entry. :param odps: ODPS Entry object :param str project: project name :param str endpoint: tunnel endpoint :param str quota_name: name of tunnel quota """ def _get_tunnel_table(self, table, schema=None): project_odps = None try: project_odps = self._project.odps if isinstance(table, six.string_types): table = project_odps.get_table(table, project=self._project.name) except: pass project_name = self._project.name if not isinstance(table, six.string_types): project_name = table.project.name or project_name schema = schema or getattr(table.get_schema(), "name", None) table = table.name parent = Projects(client=self.tunnel_rest)[project_name] # tailor project for resource locating only parent._set_tunnel_defaults(odps_entry=project_odps) if schema is not None: parent = parent.schemas[schema] return parent.tables[table] @staticmethod def _build_compress_option(compress_algo=None, level=None, strategy=None): if compress_algo is None: return None return CompressOption( compress_algo=compress_algo, level=level, strategy=strategy ) def create_download_session( self, table, async_mode=True, partition_spec=None, download_id=None, compress_option=None, compress_algo=None, compress_level=None, compress_strategy=None, schema=None, timeout=None, tags=None, **kw ): """ Create a download session for table. :param table: table object to read :type table: str | :class:`odps.models.Table` :param partition_spec: partition spec to read :type partition_spec: str | :class:`odps.types.PartitionSpec` :param str download_id: existing download id :param compress_option: compress option :type compress_option: :class:`odps.tunnel.CompressOption` :param str compress_algo: compress algorithm :param int compress_level: compress level :param str schema: name of schema of the table :param tags: tags of the upload session :type tags: str | list :return: :class:`TableDownloadSession` """ table = self._get_tunnel_table(table, schema) compress_option = compress_option or self._build_compress_option( compress_algo=compress_algo, level=compress_level, strategy=compress_strategy, ) if "async_" in kw: async_mode = kw.pop("async_") if kw: raise TypeError("Cannot accept arguments %s" % ", ".join(kw.keys())) return TableDownloadSession( self.tunnel_rest, table, partition_spec, download_id=download_id, compress_option=compress_option, async_mode=async_mode, timeout=timeout, quota_name=self._quota_name, tags=tags, ) def create_upload_session( self, table, partition_spec=None, upload_id=None, compress_option=None, compress_algo=None, compress_level=None, compress_strategy=None, schema=None, overwrite=False, create_partition=False, tags=None, ): """ Create an upload session for table. :param table: table object to read :type table: str | :class:`odps.models.Table` :param partition_spec: partition spec :type partition_spec: str | :class:`odps.types.PartitionSpec` :param str upload_id: existing upload id :param compress_option: compress option :type compress_option: :class:`odps.tunnel.CompressOption` :param str compress_algo: compress algorithm :param int compress_level: compress level :param str schema: name of schema of the table :param bool overwrite: whether to overwrite the table :param bool create_partition: whether to create partitition if not exist :param tags: tags of the upload session :type tags: str | list :return: :class:`TableUploadSession` """ table = self._get_tunnel_table(table, schema) compress_option = compress_option or self._build_compress_option( compress_algo=compress_algo, level=compress_level, strategy=compress_strategy, ) return TableUploadSession( self.tunnel_rest, table, partition_spec, upload_id=upload_id, compress_option=compress_option, overwrite=overwrite, quota_name=self._quota_name, create_partition=create_partition, tags=tags, ) def create_stream_upload_session( self, table, partition_spec=None, compress_option=None, compress_algo=None, compress_level=None, compress_strategy=None, schema=None, schema_version=None, zorder_columns=None, upload_id=None, tags=None, allow_schema_mismatch=True, create_partition=False, ): """ Create a stream upload session for table. :param table: table object to read :type table: str | :class:`odps.models.Table` :param partition_spec: partition spec :type partition_spec: str | :class:`odps.types.PartitionSpec` :param str upload_id: existing upload id :param compress_option: compress option :type compress_option: :class:`odps.tunnel.CompressOption` :param str compress_algo: compress algorithm :param int compress_level: compress level :param str schema: name of schema of the table :param str schema_version: schema version of the upload :param tags: tags of the upload session :type tags: str | list :param bool allow_schema_mismatch: whether to allow table schema to be mismatched :param bool create_partition: whether to create partition if not exist :return: :class:`TableStreamUploadSession` """ table = self._get_tunnel_table(table, schema) compress_option = compress_option or self._build_compress_option( compress_algo=compress_algo, level=compress_level, strategy=compress_strategy, ) version_need_reloaded = [False] def schema_version_reloader(): src_table = self._project.tables[table.name] if version_need_reloaded[0]: src_table.reload_extend_info() version_need_reloaded[0] = True return src_table.schema_version return TableStreamUploadSession( self.tunnel_rest, table, partition_spec, compress_option=compress_option, quota_name=self._quota_name, schema_version=schema_version, upload_id=upload_id, tags=tags, allow_schema_mismatch=allow_schema_mismatch, schema_version_reloader=schema_version_reloader, create_partition=create_partition, zorder_columns=zorder_columns, ) def create_upsert_session( self, table, partition_spec=None, slot_num=1, commit_timeout=120, compress_option=None, compress_algo=None, compress_level=None, compress_strategy=None, schema=None, upsert_id=None, tags=None, ): """ Create an upsert session for table. :param table: table object to read :type table: str | :class:`odps.models.Table` :param partition_spec: partition spec :type partition_spec: str | :class:`odps.types.PartitionSpec` :param str upsert_id: existing upsert id :param commit_timeout: timeout for commit :param compress_option: compress option :type compress_option: :class:`odps.tunnel.CompressOption` :param str compress_algo: compress algorithm :param int compress_level: compress level :param str schema: name of schema of the table :param tags: tags of the upload session :type tags: str | list :return: :class:`TableUpsertSession` """ table = self._get_tunnel_table(table, schema) compress_option = compress_option or self._build_compress_option( compress_algo=compress_algo, level=compress_level, strategy=compress_strategy, ) return TableUpsertSession( self.tunnel_rest, table, partition_spec, slot_num=slot_num, upsert_id=upsert_id, commit_timeout=commit_timeout, compress_option=compress_option, quota_name=self._quota_name, tags=tags, ) def open_preview_reader( self, table, partition_spec=None, columns=None, limit=None, compress_option=None, compress_algo=None, compress_level=None, compress_strategy=None, arrow=True, timeout=None, make_compat=True, read_all=False, tags=None, ): """ Open a preview reader for table to read initial rows. :param table: table object to read :type table: str | :class:`odps.models.Table` :param partition_spec: partition spec to read :type partition_spec: str | :class:`odps.types.PartitionSpec` :param columns: columns to read :param int limit: number of rows to read, 10000 by default :param compress_option: compress option :type compress_option: :class:`odps.tunnel.CompressOption` :param str compress_algo: compress algorithm :param int compress_level: compress level :param str schema: name of schema of the table :param bool arrow: if True, return an Arrow reader, otherwise return a record reader :param tags: tags of the upload session :type tags: str | list """ if pa is None: raise ImportError("Need pyarrow to run open_preview_reader.") tunnel_table = self._get_tunnel_table(table) compress_option = compress_option or self._build_compress_option( compress_algo=compress_algo, level=compress_level, strategy=compress_strategy, ) params = {"limit": str(limit) if limit else "-1"} partition_spec = BaseTableTunnelSession.normalize_partition_spec(partition_spec) if columns: col_set = set(columns) ordered_col = [c.name for c in table.table_schema if c.name in col_set] params["columns"] = ",".join(ordered_col) if partition_spec is not None and len(partition_spec) > 0: params["partition"] = partition_spec headers = BaseTableTunnelSession.get_common_headers(content_length=0, tags=tags) if compress_option: encoding = compress_option.algorithm.get_encoding(legacy=False) if encoding: headers["Accept-Encoding"] = encoding url = tunnel_table.table_resource(force_schema=True) + "/preview" resp = self.tunnel_rest.get( url, stream=True, params=params, headers=headers, timeout=timeout ) if not self.tunnel_rest.is_ok(resp): # pragma: no cover e = TunnelError.parse(resp) raise e input_stream = get_decompress_stream(resp) if input_stream.peek() is None: # stream is empty, replace with empty stream input_stream = None def stream_creator(pos): # part retry not supported currently assert pos == 0 return input_stream reader = TunnelArrowReader( table.table_schema, stream_creator, columns=columns, use_ipc_stream=True ) if not arrow: reader = ArrowRecordReader( reader, make_compat=make_compat, read_all=read_all ) return reader