cupid/io/table/core.py (398 lines of code) (raw):
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import json
import logging
import time
import warnings
from types import GeneratorType
from odps import options
from odps.models import Table, TableSchema
from odps.models.partition import Partition as TablePartition
from cupid.rpc import CupidRpcController, CupidTaskServiceRpcChannel, SandboxRpcChannel
from cupid.errors import CupidError
from cupid.session import CupidSession
try:
from cupid.proto import cupid_task_service_pb2 as task_service_pb
from cupid.proto import cupid_subprocess_service_pb2 as subprocess_pb
except TypeError:
warnings.warn('Cannot import protos from pycupid: '
'consider upgrading your protobuf python package.', ImportWarning)
raise ImportError
logger = logging.getLogger(__name__)
ATTEMPT_FILE_PREFIX = 'attempt_'
class TableSplit(object):
__slots__ = '_handle', '_split_index', '_split_file_start', '_split_file_end', \
'_schema_file_start', '_schema_file_end', '_meta_row_count', \
'_meta_raw_size'
def __init__(self, **kwargs):
if 'split_proto' in kwargs:
split_pb = kwargs.pop('split_proto')
self._split_index = split_pb.splitIndexId
self._split_file_start = split_pb.splitFileStart
self._split_file_end = split_pb.splitFileEnd
self._schema_file_start = split_pb.schemaFileStart
self._schema_file_end = split_pb.schemaFileEnd
if kwargs.get('meta_proto'):
meta_pb = kwargs.pop('meta_proto')
try:
self._meta_row_count = meta_pb.rowCount
self._meta_raw_size = meta_pb.rawSize
except AttributeError:
pass
for k in self.__slots__:
if k in kwargs:
setattr(self, k, kwargs[k])
elif k.lstrip('_') in kwargs:
setattr(self, k, kwargs[k.lstrip('_')])
elif not hasattr(self, k):
setattr(self, k, None)
@property
def handle(self):
return self._handle
@property
def split_index(self):
return self._split_index
@property
def split_file_start(self):
return self._split_file_start
@property
def split_file_end(self):
return self._split_file_end
@property
def schema_file_start(self):
return self._schema_file_start
@property
def schema_file_end(self):
return self._schema_file_end
@property
def meta_row_count(self):
return self._meta_row_count
@property
def meta_raw_size(self):
return self._meta_raw_size
@property
def split_proto(self):
return subprocess_pb.InputSplit(
splitIndexId=self._split_index,
splitFileStart=self._split_file_start,
splitFileEnd=self._split_file_end,
schemaFileStart=self._schema_file_start,
schemaFileEnd=self._schema_file_end,
)
def _register_reader(self):
channel = SandboxRpcChannel()
stub = subprocess_pb.CupidSubProcessService_Stub(channel)
req = subprocess_pb.RegisterTableReaderRequest(inputTableHandle=self._handle,
inputSplit=self.split_proto)
controller = CupidRpcController()
resp = stub.RegisterTableReader(controller, req, None)
if controller.Failed():
raise CupidError(controller.ErrorText())
logger.info("RegisterTableReader response: %s", resp)
logger.info("RegisterTableReaderResponse protobuf field size = %d", len(resp.ListFields()))
schema_json = json.loads(resp.schema)
partition_schema_json = json.loads(resp.partitionSchema) \
if resp.HasField('partitionSchema') else None
schema_names = [d['name'] for d in schema_json]
schema_types = [d['type'] for d in schema_json]
pt_schema_names = [d['name'] for d in partition_schema_json]
pt_schema_types = [d['type'] for d in partition_schema_json]
schema = TableSchema.from_lists(schema_names, schema_types, pt_schema_names, pt_schema_types)
return resp.readIterator, schema
def open_record_reader(self):
from ...runtime import context
context = context()
read_iter, schema = self._register_reader()
logger.debug('Obtained schema: %s', schema)
return context.channel_client.create_record_reader(read_iter, schema)
def open_arrow_file_reader(self):
from ...runtime import context
import pyarrow as pa
context = context()
read_iter, schema = self._register_reader()
params = dict(type='ReadByLabel', label=read_iter, arrow=True, batch=True)
return context.channel_client.create_file_reader('createTableInputStream', json.dumps(params).encode())
def open_arrow_reader(self):
from ...runtime import context
import pyarrow as pa
context = context()
read_iter, schema = self._register_reader()
params = dict(type='ReadByLabel', label=read_iter, arrow=True, batch=True)
stream = context.channel_client.create_file_reader('createTableInputStream', json.dumps(params).encode())
return pa.RecordBatchStreamReader(stream)
class CupidTableDownloadSession(object):
__slots__ = '_session', '_handle', '_splits',
def __init__(self, **kwargs):
for k in self.__slots__:
if k in kwargs:
setattr(self, k, kwargs[k])
elif k.lstrip('_') in kwargs:
setattr(self, k, kwargs[k.lstrip('_')])
elif not hasattr(self, k):
setattr(self, k, None)
@property
def splits(self):
return self._splits
def open_record_reader(self, split_id=0):
return self._splits[split_id].open_record_reader()
class BlockWriter(object):
__slots__ = '_table_name', '_project_name', '_table_schema', '_partition_spec', '_block_id', '_handle'
def __init__(self, **kwargs):
for k in self.__slots__:
if k in kwargs:
setattr(self, k, kwargs[k])
elif k.lstrip('_') in kwargs:
setattr(self, k, kwargs[k.lstrip('_')])
elif not hasattr(self, k):
setattr(self, k, None)
@property
def table_name(self):
return self._table_name
@property
def project_name(self):
return self._project_name
@property
def block_id(self):
return self._block_id
@property
def handle(self):
return self._handle
def new_record(self, values=None):
from odps.models import Record
return Record(schema=self._table_schema, values=values)
def _register_writer(self, partition=None):
if isinstance(partition, TablePartition):
partition = str(partition.spec)
controller = CupidRpcController()
channel = SandboxRpcChannel()
stub = subprocess_pb.CupidSubProcessService_Stub(channel)
table_schema = self._table_schema
schema_str = '|' + '|'.join(str(col.type) for col in table_schema.simple_columns)
req = subprocess_pb.RegisterTableWriterRequest(
outputTableHandle=self._handle,
projectName=self._project_name,
tableName=self._table_name,
attemptFileName=ATTEMPT_FILE_PREFIX + self._block_id,
partSpec=partition.replace("'", '') if partition else None,
schema=schema_str,
)
resp = stub.RegisterTableWriter(controller, req, None)
write_label = resp.subprocessWriteTableLabel
return write_label
def _open_writer(self, partition=None, create_method=None):
from ...runtime import context
context = context()
write_label = self._register_writer(partition)
writer = getattr(context.channel_client, create_method)(write_label, self._table_schema)
writer._block_id = self._block_id
writer._partition_spec = partition
return writer
def open_arrow_writer(self, partition=None):
from ...runtime import context
context = context()
write_label = self._register_writer(partition or self._partition_spec)
return context.channel_client.create_arrow_writer(write_label)
def open_record_writer(self, partition=None):
return self._open_writer(partition=partition or self._partition_spec, create_method='create_record_writer')
def commit(self):
channel = SandboxRpcChannel()
stub = subprocess_pb.CupidSubProcessService_Stub(channel)
commit_actions = [subprocess_pb.CommitActionInfo(
commitFileName=self._block_id,
attemptFileName=ATTEMPT_FILE_PREFIX + self._block_id,
partSpec=self._partition_spec,
)]
req = subprocess_pb.CommitTableFilesRequest(
outputTableHandle=self._handle,
projectName=self._project_name,
tableName=self._table_name,
commitActionInfos=commit_actions,
)
controller = CupidRpcController()
for _ in range(options.retry_times):
stub.CommitTableFiles(controller, req, None)
if controller.Failed():
time.sleep(0.1)
controller = CupidRpcController()
else:
break
if controller.Failed():
raise CupidError(controller.ErrorText())
class CupidTableUploadSession(object):
__slots__ = '_session', '_table_name', '_project_name', '_handle', '_blocks'
def __init__(self, **kwargs):
self._blocks = dict()
if 'blocks' in kwargs:
blocks = kwargs.pop('blocks')
if isinstance(blocks, dict):
self._blocks.update(blocks)
else:
if not isinstance(blocks, (list, set, GeneratorType)):
blocks = [blocks]
for bl in blocks:
if isinstance(bl, tuple):
self._blocks[bl[0]] = bl[1]
else:
self._blocks[bl] = None
for k in self.__slots__:
if k in kwargs:
setattr(self, k, kwargs[k])
elif k.lstrip('_') in kwargs:
setattr(self, k, kwargs[k.lstrip('_')])
elif not hasattr(self, k):
setattr(self, k, None)
@property
def handle(self):
return self._handle
def commit(self, overwrite=False):
partitions = list(set(p for p in self._blocks.values() if p is not None))
if not partitions:
partitions = ['']
channel = CupidTaskServiceRpcChannel(self._session)
stub = task_service_pb.CupidTaskService_Stub(channel)
part_specs = [pt.replace("'", '') for pt in partitions]
req = task_service_pb.CommitTableRequest(
outputTableHandle=self._handle,
projectName=self._project_name,
tableName=self._table_name,
isOverWrite=overwrite,
lookupName=self._session.lookup_name,
partSpecs=part_specs,
)
controller = CupidRpcController()
resp = None
for _ in range(options.retry_times):
resp = stub.CommitTable(controller, req, None)
if controller.Failed():
time.sleep(0.1)
controller = CupidRpcController()
else:
break
if controller.Failed():
raise CupidError(controller.ErrorText())
logger.info(
"[CupidTask] commitTable call, CurrentInstanceId: %s, "
"request: %s, response: %s", self._session.lookup_name, req, resp,
)
def create_download_session(session, table_or_parts, split_size=None, split_count=None,
columns=None, with_split_meta=False):
channel = CupidTaskServiceRpcChannel(session)
stub = task_service_pb.CupidTaskService_Stub(channel)
if not isinstance(table_or_parts, (list, tuple, set, GeneratorType)):
table_or_parts = [table_or_parts]
if split_size is None and split_count is None:
split_count = 1
split_count = split_count or 0
split_size = (split_size or 1024 ** 2) // 1024 ** 2
table_pbs = []
for t in table_or_parts:
if isinstance(t, Table):
if not columns:
columns = t.table_schema.names
table_kw = dict(
projectName=t.project.name,
tableName=t.name,
columns=','.join(columns),
)
elif isinstance(t, TablePartition):
if not columns:
columns = t.table.table_schema.names
table_kw = dict(
projectName=t.table.project.name,
tableName=t.table.name,
columns=','.join(columns),
partSpec=str(t.partition_spec).replace("'", '').strip(),
)
else:
raise NotImplementedError
table_pbs.append(task_service_pb.TableInputInfo(**table_kw))
request = task_service_pb.SplitTablesRequest(
lookupName=session.lookup_name,
splitSize=split_size,
splitCount=split_count,
tableInputInfos=table_pbs,
allowNoColumns=True,
requireSplitMeta=with_split_meta,
)
controller = CupidRpcController()
resp = stub.SplitTables(controller, request, None)
if controller.Failed():
raise CupidError(controller.ErrorText())
logger.info(
"[CupidTask] splitTables call, CurrentInstanceId: %s, "
"request: %s, response: %s" % (
session.lookup_name, str(request), str(resp),
)
)
handle = resp.inputTableHandle
channel = SandboxRpcChannel()
stub = subprocess_pb.CupidSubProcessService_Stub(channel)
if not with_split_meta:
split_meta = itertools.repeat(None)
else:
req = subprocess_pb.GetSplitsMetaRequest(
inputTableHandle=handle,
)
controller = CupidRpcController()
resp = stub.GetSplitsMeta(controller, req, None)
logger.info(
"[CupidTask] getSplitsMeta call, CurrentInstanceId: %s, "
"request: %s, response: %s" % (
session.lookup_name, str(request), str(resp),
)
)
if controller.Failed():
split_meta = itertools.repeat(None)
logger.warning('Failed to get results of getSplitsMeta, '
'may running on an old service')
else:
split_meta = resp.inputSplitsMeta
req = subprocess_pb.GetSplitsRequest(inputTableHandle=handle)
controller = CupidRpcController()
resp = stub.GetSplits(controller, req, None)
if controller.Failed():
raise CupidError(controller.ErrorText())
input_splits = []
for info, meta in zip(resp.inputSplits, split_meta):
input_splits.append(TableSplit(
split_proto=info, meta_proto=meta, handle=handle, columns=columns))
logger.info(
"[SubProcess] getSplits call, CurrentInstanceId: %s, "
"request: %s, response: %s" % (
session.lookup_name,
str(req), str(resp),
)
)
return CupidTableDownloadSession(session=session, handle=handle, splits=input_splits)
def create_upload_session(session, table):
controller = CupidRpcController()
channel = CupidTaskServiceRpcChannel(session)
stub = task_service_pb.CupidTaskService_Stub(channel)
req = task_service_pb.WriteTableRequest(lookupName=session.lookup_name, tableName=table.name,
projectName=table.project.name)
resp = stub.WriteTable(controller, req, None)
if controller.Failed():
raise CupidError(controller.ErrorText())
logger.info(
"[CupidTask] writeTable call, CurrentInstanceId: %s, "
"request: %s, response: %s", session.lookup_name, req, resp,
)
return CupidTableUploadSession(
session=session, table_name=table.name, project_name=table.project.name, handle=resp.outputTableHandle)
def query_table_meta(session, table):
controller = CupidRpcController()
channel = CupidTaskServiceRpcChannel(session)
stub = task_service_pb.CupidTaskService_Stub(channel)
table_info = task_service_pb.TableInfo(projectName=table.project.name, tableName=table.name)
req = task_service_pb.GetTableMetaRequest(lookupName=session.lookup_name, tableInfo=table_info,
needContent=True, uploadFile='')
resp = stub.GetTableMeta(controller, req, None)
if controller.Failed():
raise CupidError(controller.ErrorText())
logger.info(
"[CupidTask] getTableMeta call, CurrentInstanceId: %s, "
"request: %s, response: %s", session.lookup_name, req, resp,
)
return json.loads(resp.getTableMetaContent)
CupidSession.create_download_session = create_download_session
CupidSession.create_upload_session = create_upload_session
CupidSession.query_table_meta = query_table_meta