cupid/io/table/core.py (398 lines of code) (raw):

# Copyright 1999-2022 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import json import logging import time import warnings from types import GeneratorType from odps import options from odps.models import Table, TableSchema from odps.models.partition import Partition as TablePartition from cupid.rpc import CupidRpcController, CupidTaskServiceRpcChannel, SandboxRpcChannel from cupid.errors import CupidError from cupid.session import CupidSession try: from cupid.proto import cupid_task_service_pb2 as task_service_pb from cupid.proto import cupid_subprocess_service_pb2 as subprocess_pb except TypeError: warnings.warn('Cannot import protos from pycupid: ' 'consider upgrading your protobuf python package.', ImportWarning) raise ImportError logger = logging.getLogger(__name__) ATTEMPT_FILE_PREFIX = 'attempt_' class TableSplit(object): __slots__ = '_handle', '_split_index', '_split_file_start', '_split_file_end', \ '_schema_file_start', '_schema_file_end', '_meta_row_count', \ '_meta_raw_size' def __init__(self, **kwargs): if 'split_proto' in kwargs: split_pb = kwargs.pop('split_proto') self._split_index = split_pb.splitIndexId self._split_file_start = split_pb.splitFileStart self._split_file_end = split_pb.splitFileEnd self._schema_file_start = split_pb.schemaFileStart self._schema_file_end = split_pb.schemaFileEnd if kwargs.get('meta_proto'): meta_pb = kwargs.pop('meta_proto') try: self._meta_row_count = meta_pb.rowCount self._meta_raw_size = meta_pb.rawSize except AttributeError: pass for k in self.__slots__: if k in kwargs: setattr(self, k, kwargs[k]) elif k.lstrip('_') in kwargs: setattr(self, k, kwargs[k.lstrip('_')]) elif not hasattr(self, k): setattr(self, k, None) @property def handle(self): return self._handle @property def split_index(self): return self._split_index @property def split_file_start(self): return self._split_file_start @property def split_file_end(self): return self._split_file_end @property def schema_file_start(self): return self._schema_file_start @property def schema_file_end(self): return self._schema_file_end @property def meta_row_count(self): return self._meta_row_count @property def meta_raw_size(self): return self._meta_raw_size @property def split_proto(self): return subprocess_pb.InputSplit( splitIndexId=self._split_index, splitFileStart=self._split_file_start, splitFileEnd=self._split_file_end, schemaFileStart=self._schema_file_start, schemaFileEnd=self._schema_file_end, ) def _register_reader(self): channel = SandboxRpcChannel() stub = subprocess_pb.CupidSubProcessService_Stub(channel) req = subprocess_pb.RegisterTableReaderRequest(inputTableHandle=self._handle, inputSplit=self.split_proto) controller = CupidRpcController() resp = stub.RegisterTableReader(controller, req, None) if controller.Failed(): raise CupidError(controller.ErrorText()) logger.info("RegisterTableReader response: %s", resp) logger.info("RegisterTableReaderResponse protobuf field size = %d", len(resp.ListFields())) schema_json = json.loads(resp.schema) partition_schema_json = json.loads(resp.partitionSchema) \ if resp.HasField('partitionSchema') else None schema_names = [d['name'] for d in schema_json] schema_types = [d['type'] for d in schema_json] pt_schema_names = [d['name'] for d in partition_schema_json] pt_schema_types = [d['type'] for d in partition_schema_json] schema = TableSchema.from_lists(schema_names, schema_types, pt_schema_names, pt_schema_types) return resp.readIterator, schema def open_record_reader(self): from ...runtime import context context = context() read_iter, schema = self._register_reader() logger.debug('Obtained schema: %s', schema) return context.channel_client.create_record_reader(read_iter, schema) def open_arrow_file_reader(self): from ...runtime import context import pyarrow as pa context = context() read_iter, schema = self._register_reader() params = dict(type='ReadByLabel', label=read_iter, arrow=True, batch=True) return context.channel_client.create_file_reader('createTableInputStream', json.dumps(params).encode()) def open_arrow_reader(self): from ...runtime import context import pyarrow as pa context = context() read_iter, schema = self._register_reader() params = dict(type='ReadByLabel', label=read_iter, arrow=True, batch=True) stream = context.channel_client.create_file_reader('createTableInputStream', json.dumps(params).encode()) return pa.RecordBatchStreamReader(stream) class CupidTableDownloadSession(object): __slots__ = '_session', '_handle', '_splits', def __init__(self, **kwargs): for k in self.__slots__: if k in kwargs: setattr(self, k, kwargs[k]) elif k.lstrip('_') in kwargs: setattr(self, k, kwargs[k.lstrip('_')]) elif not hasattr(self, k): setattr(self, k, None) @property def splits(self): return self._splits def open_record_reader(self, split_id=0): return self._splits[split_id].open_record_reader() class BlockWriter(object): __slots__ = '_table_name', '_project_name', '_table_schema', '_partition_spec', '_block_id', '_handle' def __init__(self, **kwargs): for k in self.__slots__: if k in kwargs: setattr(self, k, kwargs[k]) elif k.lstrip('_') in kwargs: setattr(self, k, kwargs[k.lstrip('_')]) elif not hasattr(self, k): setattr(self, k, None) @property def table_name(self): return self._table_name @property def project_name(self): return self._project_name @property def block_id(self): return self._block_id @property def handle(self): return self._handle def new_record(self, values=None): from odps.models import Record return Record(schema=self._table_schema, values=values) def _register_writer(self, partition=None): if isinstance(partition, TablePartition): partition = str(partition.spec) controller = CupidRpcController() channel = SandboxRpcChannel() stub = subprocess_pb.CupidSubProcessService_Stub(channel) table_schema = self._table_schema schema_str = '|' + '|'.join(str(col.type) for col in table_schema.simple_columns) req = subprocess_pb.RegisterTableWriterRequest( outputTableHandle=self._handle, projectName=self._project_name, tableName=self._table_name, attemptFileName=ATTEMPT_FILE_PREFIX + self._block_id, partSpec=partition.replace("'", '') if partition else None, schema=schema_str, ) resp = stub.RegisterTableWriter(controller, req, None) write_label = resp.subprocessWriteTableLabel return write_label def _open_writer(self, partition=None, create_method=None): from ...runtime import context context = context() write_label = self._register_writer(partition) writer = getattr(context.channel_client, create_method)(write_label, self._table_schema) writer._block_id = self._block_id writer._partition_spec = partition return writer def open_arrow_writer(self, partition=None): from ...runtime import context context = context() write_label = self._register_writer(partition or self._partition_spec) return context.channel_client.create_arrow_writer(write_label) def open_record_writer(self, partition=None): return self._open_writer(partition=partition or self._partition_spec, create_method='create_record_writer') def commit(self): channel = SandboxRpcChannel() stub = subprocess_pb.CupidSubProcessService_Stub(channel) commit_actions = [subprocess_pb.CommitActionInfo( commitFileName=self._block_id, attemptFileName=ATTEMPT_FILE_PREFIX + self._block_id, partSpec=self._partition_spec, )] req = subprocess_pb.CommitTableFilesRequest( outputTableHandle=self._handle, projectName=self._project_name, tableName=self._table_name, commitActionInfos=commit_actions, ) controller = CupidRpcController() for _ in range(options.retry_times): stub.CommitTableFiles(controller, req, None) if controller.Failed(): time.sleep(0.1) controller = CupidRpcController() else: break if controller.Failed(): raise CupidError(controller.ErrorText()) class CupidTableUploadSession(object): __slots__ = '_session', '_table_name', '_project_name', '_handle', '_blocks' def __init__(self, **kwargs): self._blocks = dict() if 'blocks' in kwargs: blocks = kwargs.pop('blocks') if isinstance(blocks, dict): self._blocks.update(blocks) else: if not isinstance(blocks, (list, set, GeneratorType)): blocks = [blocks] for bl in blocks: if isinstance(bl, tuple): self._blocks[bl[0]] = bl[1] else: self._blocks[bl] = None for k in self.__slots__: if k in kwargs: setattr(self, k, kwargs[k]) elif k.lstrip('_') in kwargs: setattr(self, k, kwargs[k.lstrip('_')]) elif not hasattr(self, k): setattr(self, k, None) @property def handle(self): return self._handle def commit(self, overwrite=False): partitions = list(set(p for p in self._blocks.values() if p is not None)) if not partitions: partitions = [''] channel = CupidTaskServiceRpcChannel(self._session) stub = task_service_pb.CupidTaskService_Stub(channel) part_specs = [pt.replace("'", '') for pt in partitions] req = task_service_pb.CommitTableRequest( outputTableHandle=self._handle, projectName=self._project_name, tableName=self._table_name, isOverWrite=overwrite, lookupName=self._session.lookup_name, partSpecs=part_specs, ) controller = CupidRpcController() resp = None for _ in range(options.retry_times): resp = stub.CommitTable(controller, req, None) if controller.Failed(): time.sleep(0.1) controller = CupidRpcController() else: break if controller.Failed(): raise CupidError(controller.ErrorText()) logger.info( "[CupidTask] commitTable call, CurrentInstanceId: %s, " "request: %s, response: %s", self._session.lookup_name, req, resp, ) def create_download_session(session, table_or_parts, split_size=None, split_count=None, columns=None, with_split_meta=False): channel = CupidTaskServiceRpcChannel(session) stub = task_service_pb.CupidTaskService_Stub(channel) if not isinstance(table_or_parts, (list, tuple, set, GeneratorType)): table_or_parts = [table_or_parts] if split_size is None and split_count is None: split_count = 1 split_count = split_count or 0 split_size = (split_size or 1024 ** 2) // 1024 ** 2 table_pbs = [] for t in table_or_parts: if isinstance(t, Table): if not columns: columns = t.table_schema.names table_kw = dict( projectName=t.project.name, tableName=t.name, columns=','.join(columns), ) elif isinstance(t, TablePartition): if not columns: columns = t.table.table_schema.names table_kw = dict( projectName=t.table.project.name, tableName=t.table.name, columns=','.join(columns), partSpec=str(t.partition_spec).replace("'", '').strip(), ) else: raise NotImplementedError table_pbs.append(task_service_pb.TableInputInfo(**table_kw)) request = task_service_pb.SplitTablesRequest( lookupName=session.lookup_name, splitSize=split_size, splitCount=split_count, tableInputInfos=table_pbs, allowNoColumns=True, requireSplitMeta=with_split_meta, ) controller = CupidRpcController() resp = stub.SplitTables(controller, request, None) if controller.Failed(): raise CupidError(controller.ErrorText()) logger.info( "[CupidTask] splitTables call, CurrentInstanceId: %s, " "request: %s, response: %s" % ( session.lookup_name, str(request), str(resp), ) ) handle = resp.inputTableHandle channel = SandboxRpcChannel() stub = subprocess_pb.CupidSubProcessService_Stub(channel) if not with_split_meta: split_meta = itertools.repeat(None) else: req = subprocess_pb.GetSplitsMetaRequest( inputTableHandle=handle, ) controller = CupidRpcController() resp = stub.GetSplitsMeta(controller, req, None) logger.info( "[CupidTask] getSplitsMeta call, CurrentInstanceId: %s, " "request: %s, response: %s" % ( session.lookup_name, str(request), str(resp), ) ) if controller.Failed(): split_meta = itertools.repeat(None) logger.warning('Failed to get results of getSplitsMeta, ' 'may running on an old service') else: split_meta = resp.inputSplitsMeta req = subprocess_pb.GetSplitsRequest(inputTableHandle=handle) controller = CupidRpcController() resp = stub.GetSplits(controller, req, None) if controller.Failed(): raise CupidError(controller.ErrorText()) input_splits = [] for info, meta in zip(resp.inputSplits, split_meta): input_splits.append(TableSplit( split_proto=info, meta_proto=meta, handle=handle, columns=columns)) logger.info( "[SubProcess] getSplits call, CurrentInstanceId: %s, " "request: %s, response: %s" % ( session.lookup_name, str(req), str(resp), ) ) return CupidTableDownloadSession(session=session, handle=handle, splits=input_splits) def create_upload_session(session, table): controller = CupidRpcController() channel = CupidTaskServiceRpcChannel(session) stub = task_service_pb.CupidTaskService_Stub(channel) req = task_service_pb.WriteTableRequest(lookupName=session.lookup_name, tableName=table.name, projectName=table.project.name) resp = stub.WriteTable(controller, req, None) if controller.Failed(): raise CupidError(controller.ErrorText()) logger.info( "[CupidTask] writeTable call, CurrentInstanceId: %s, " "request: %s, response: %s", session.lookup_name, req, resp, ) return CupidTableUploadSession( session=session, table_name=table.name, project_name=table.project.name, handle=resp.outputTableHandle) def query_table_meta(session, table): controller = CupidRpcController() channel = CupidTaskServiceRpcChannel(session) stub = task_service_pb.CupidTaskService_Stub(channel) table_info = task_service_pb.TableInfo(projectName=table.project.name, tableName=table.name) req = task_service_pb.GetTableMetaRequest(lookupName=session.lookup_name, tableInfo=table_info, needContent=True, uploadFile='') resp = stub.GetTableMeta(controller, req, None) if controller.Failed(): raise CupidError(controller.ErrorText()) logger.info( "[CupidTask] getTableMeta call, CurrentInstanceId: %s, " "request: %s, response: %s", session.lookup_name, req, resp, ) return json.loads(resp.getTableMetaContent) CupidSession.create_download_session = create_download_session CupidSession.create_upload_session = create_upload_session CupidSession.query_table_meta = query_table_meta