odps/mars_extension/oscar/dataframe/datastore.py (425 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2022 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import time
import uuid
import logging
import requests
from typing import List
from mars.core.context import get_context
from mars.oscar.errors import ActorNotExist
from mars.core import OutputType
from mars.dataframe.operands import DataFrameOperandMixin, DataFrameOperand
from mars.dataframe.utils import build_concatenated_rows_frame, parse_index
from mars.serialization.serializables import (
StringField,
SeriesField,
BoolField,
DictField,
Int64Field,
ListField,
FieldTypes,
)
from ....config import options
from ....utils import to_str
from ..cupid_service import CupidServiceClient
logger = logging.getLogger(__name__)
_output_type_kw = dict(_output_types=[OutputType.dataframe])
class DataFrameWriteTable(DataFrameOperand, DataFrameOperandMixin):
_op_type_ = 123460
dtypes = SeriesField("dtypes")
odps_params = DictField("odps_params", default=None)
table_name = StringField("table_name", default=None)
partition_spec = StringField("partition_spec", default=None)
partition_columns = ListField("partition_columns", FieldTypes.string, default=None)
overwrite = BoolField("overwrite", default=None)
write_batch_size = Int64Field("write_batch_size", default=None)
unknown_as_string = BoolField("unknown_as_string", default=None)
tunnel_quota_name = StringField("tunnel_quota_name", default=None)
def __init__(self, **kw):
kw.update(_output_type_kw)
super(DataFrameWriteTable, self).__init__(**kw)
@property
def retryable(self):
return "CUPID_SERVICE_SOCKET" not in os.environ
def __call__(self, x):
shape = (0,) * len(x.shape)
index_value = parse_index(x.index_value.to_pandas()[:0], x.key, "index")
columns_value = parse_index(
x.columns_value.to_pandas()[:0], x.key, "columns", store_data=True
)
return self.new_dataframe(
[x],
shape=shape,
dtypes=x.dtypes[:0],
index_value=index_value,
columns_value=columns_value,
)
@classmethod
def _tile_cupid(cls, op):
from mars.dataframe.utils import build_concatenated_rows_frame
cupid_client = CupidServiceClient()
upload_handle = cupid_client.create_table_upload_session(
op.odps_params, op.table_name
)
input_df = build_concatenated_rows_frame(op.inputs[0])
out_df = op.outputs[0]
out_chunks = []
out_chunk_shape = (0,) * len(input_df.shape)
blocks = {}
for chunk in input_df.chunks:
block_id = str(int(time.time())) + "_" + str(uuid.uuid4()).replace("-", "")
chunk_op = DataFrameWriteTableSplit(
dtypes=op.dtypes,
table_name=op.table_name,
odps_params=op.odps_params,
unknown_as_string=op.unknown_as_string,
partition_spec=op.partition_spec,
cupid_handle=to_str(upload_handle),
block_id=block_id,
write_batch_size=op.write_batch_size,
)
out_chunk = chunk_op.new_chunk(
[chunk],
shape=out_chunk_shape,
index=chunk.index,
index_value=out_df.index_value,
dtypes=chunk.dtypes,
)
out_chunks.append(out_chunk)
blocks[block_id] = op.partition_spec
# build commit tree
combine_size = 8
chunks = out_chunks
while len(chunks) >= combine_size:
new_chunks = []
for i in range(0, len(chunks), combine_size):
chks = chunks[i : i + combine_size]
if len(chks) == 1:
chk = chks[0]
else:
chk_op = DataFrameWriteTableCommit(
dtypes=op.dtypes, is_terminal=False
)
chk = chk_op.new_chunk(
chks,
shape=out_chunk_shape,
index_value=out_df.index_value,
dtypes=op.dtypes,
)
new_chunks.append(chk)
chunks = new_chunks
assert len(chunks) < combine_size
commit_table_op = DataFrameWriteTableCommit(
dtypes=op.dtypes,
table_name=op.table_name,
blocks=blocks,
cupid_handle=to_str(upload_handle),
overwrite=op.overwrite,
odps_params=op.odps_params,
is_terminal=True,
)
commit_table_chunk = commit_table_op.new_chunk(
chunks,
shape=out_chunk_shape,
dtypes=op.dtypes,
index_value=out_df.index_value,
index=(0,) * len(out_chunk_shape),
)
new_op = op.copy()
return new_op.new_dataframes(
op.inputs,
shape=out_df.shape,
index_value=out_df.index_value,
dtypes=out_df.dtypes,
columns_value=out_df.columns_value,
chunks=[commit_table_chunk],
nsplits=((0,),) * len(out_chunk_shape),
)
@classmethod
def _tile_tunnel(cls, op):
from odps import ODPS
out_df = op.outputs[0]
if op.overwrite:
o = ODPS(
op.odps_params["access_id"],
op.odps_params["secret_access_key"],
project=op.odps_params["project"],
endpoint=op.odps_params["endpoint"],
)
data_target = o.get_table(op.table_name)
if op.partition_spec:
data_target = data_target.get_partition(op.partition_spec)
data_target.truncate()
in_df = build_concatenated_rows_frame(op.inputs[0])
logger.info("Tile table %s[%s]", op.table_name, op.partition_spec)
recorder_name = str(uuid.uuid4())
out_chunks = []
for chunk in in_df.chunks:
chunk_op = DataFrameWriteTableSplit(
dtypes=op.dtypes,
table_name=op.table_name,
odps_params=op.odps_params,
partition_spec=op.partition_spec,
commit_recorder_name=recorder_name,
tunnel_quota_name=op.tunnel_quota_name,
)
index_value = parse_index(chunk.index_value.to_pandas()[:0], chunk)
out_chunk = chunk_op.new_chunk(
[chunk],
shape=(0, 0),
index_value=index_value,
columns_value=out_df.columns_value,
dtypes=out_df.dtypes,
index=chunk.index,
)
out_chunks.append(out_chunk)
ctx = get_context()
ctx.create_remote_object(recorder_name, _TunnelCommitRecorder, len(out_chunks))
new_op = op.copy()
params = out_df.params.copy()
params.update(
dict(chunks=out_chunks, nsplits=((0,) * in_df.chunk_shape[0], (0,)))
)
return new_op.new_tileables([in_df], **params)
@classmethod
def tile(cls, op):
if "CUPID_SERVICE_SOCKET" in os.environ:
return cls._tile_cupid(op)
else:
return cls._tile_tunnel(op)
class DataFrameWriteTableSplit(DataFrameOperand, DataFrameOperandMixin):
_op_type_ = 123461
dtypes = SeriesField("dtypes")
table_name = StringField("table_name")
partition_spec = StringField("partition_spec", default=None)
cupid_handle = StringField("cupid_handle", default=None)
block_id = StringField("block_id", default=None)
write_batch_size = Int64Field("write_batch_size", default=None)
unknown_as_string = BoolField("unknown_as_string", default=None)
commit_recorder_name = StringField("commit_recorder_name", default=None)
# for tunnel
odps_params = DictField("odps_params", default=None)
tunnel_quota_name = StringField("tunnel_quota_name", default=None)
def __init__(self, **kw):
kw.update(_output_type_kw)
super(DataFrameWriteTableSplit, self).__init__(**kw)
@property
def retryable(self):
return "CUPID_SERVICE_SOCKET" not in os.environ
@classmethod
def _execute_in_cupid(cls, ctx, op):
import os
import pandas as pd
from odps import ODPS
from odps.accounts import BearerTokenAccount
cupid_client = CupidServiceClient()
to_store_data = ctx[op.inputs[0].key]
bearer_token = cupid_client.get_bearer_token()
account = BearerTokenAccount(bearer_token)
project = os.environ.get("ODPS_PROJECT_NAME", None)
odps_params = op.odps_params.copy()
if project:
odps_params["project"] = project
endpoint = os.environ.get("ODPS_RUNTIME_ENDPOINT") or odps_params["endpoint"]
o = ODPS(
None,
None,
account=account,
project=odps_params["project"],
endpoint=endpoint,
)
table_obj = o.get_table(op.table_name)
odps_schema = table_obj.table_schema
project_name = table_obj.project.name
schema_name = table_obj.get_schema().name if table_obj.get_schema() is not None else None
table_name = table_obj.name
writer_config = dict(
_table_name=table_name,
_project_name=project_name,
_schema_name=schema_name,
_table_schema=odps_schema,
_partition_spec=op.partition_spec,
_block_id=op.block_id,
_handle=op.cupid_handle,
)
cupid_client.write_table_data(writer_config, to_store_data, op.write_batch_size)
ctx[op.outputs[0].key] = pd.DataFrame()
@classmethod
def _execute_arrow_tunnel(cls, ctx, op):
from odps import ODPS
from odps.tunnel import TableTunnel
import pyarrow as pa
import pandas as pd
project = os.environ.get("ODPS_PROJECT_NAME", None)
odps_params = op.odps_params.copy()
if project:
odps_params["project"] = project
endpoint = os.environ.get("ODPS_RUNTIME_ENDPOINT") or odps_params["endpoint"]
o = ODPS(
odps_params["access_id"],
odps_params["secret_access_key"],
project=odps_params["project"],
endpoint=endpoint,
)
t = o.get_table(op.table_name)
tunnel = TableTunnel(o, project=t.project, quota_name=op.tunnel_quota_name)
retry_times = options.retry_times
init_sleep_secs = 1
split_index = op.inputs[0].index
logger.info(
"Start creating upload session for table %s split index %s retry_times %s.",
op.table_name,
split_index,
retry_times,
)
retries = 0
while True:
try:
if op.partition_spec is not None:
upload_session = tunnel.create_upload_session(
t.name, partition_spec=op.partition_spec
)
else:
upload_session = tunnel.create_upload_session(t.name)
break
except:
if retries >= retry_times:
raise
retries += 1
sleep_secs = retries * init_sleep_secs
logger.exception(
"Create upload session failed, sleep %s seconds and retry it",
sleep_secs,
exc_info=1,
)
time.sleep(sleep_secs)
logger.info(
"Start writing table %s. split_index: %s tunnel_session: %s",
op.table_name,
split_index,
upload_session.id,
)
retries = 0
while True:
try:
writer = upload_session.open_arrow_writer(0)
arrow_rb = pa.RecordBatch.from_pandas(ctx[op.inputs[0].key])
writer.write(arrow_rb)
writer.close()
break
except:
if retries >= retry_times:
raise
retries += 1
sleep_secs = retries * init_sleep_secs
logger.exception(
"Write data failed, sleep %s seconds and retry it",
sleep_secs,
exc_info=1,
)
time.sleep(sleep_secs)
recorder_name = op.commit_recorder_name
try:
recorder = ctx.get_remote_object(recorder_name)
except ActorNotExist:
while True:
logger.info(
"Writing table %s has been finished, waiting to be canceled by speculative scheduler",
op.table_name,
)
time.sleep(3)
can_commit, can_destroy = recorder.try_commit(split_index)
if can_commit:
# FIXME If this commit failed or the process crashed, the whole write will still raise error.
# But this situation is very rare so we skip the error handling.
logger.info(
"Committing to table %s with upload session %s", op.table_name, upload_session.id
)
upload_session.commit([0])
logger.info(
"Finish writing table %s. split_index: %s tunnel_session: %s",
op.table_name,
split_index,
upload_session.id,
)
else:
logger.info(
"Skip writing table %s. split_index: %s", op.table_name, split_index
)
if can_destroy:
try:
ctx.destroy_remote_object(recorder_name)
logger.info("Delete remote object %s", recorder_name)
except ActorNotExist:
pass
logger.info(
"Committing to table %s with upload session %s", op.table_name, upload_session.id
)
upload_session.commit([0])
logger.info(
"Finish writing table %s. split_index: %s tunnel_session: %s",
op.table_name,
split_index,
upload_session.id,
)
ctx[op.outputs[0].key] = pd.DataFrame()
@classmethod
def execute(cls, ctx, op):
if op.cupid_handle is not None:
cls._execute_in_cupid(ctx, op)
else:
cls._execute_arrow_tunnel(ctx, op)
class _TunnelCommitRecorder:
_commit_status: List[bool]
def __init__(self, n_chunk: int):
self._n_chunk = n_chunk
self._commit_status = {}
def try_commit(self, index: tuple):
if index in self._commit_status:
return False, len(self._commit_status) == self._n_chunk
else:
self._commit_status[index] = True
return True, len(self._commit_status) == self._n_chunk
class DataFrameWriteTableCommit(DataFrameOperand, DataFrameOperandMixin):
_op_type_ = 123462
dtypes = SeriesField("dtypes")
odps_params = DictField("odps_params")
table_name = StringField("table_name")
overwrite = BoolField("overwrite", default=False)
blocks = DictField("blocks", default=None)
cupid_handle = StringField("cupid_handle", default=None)
is_terminal = BoolField("is_terminal", default=None)
def __init__(self, **kw):
kw.update(_output_type_kw)
super(DataFrameWriteTableCommit, self).__init__(**kw)
@classmethod
def execute(cls, ctx, op):
import pandas as pd
from ..cupid_service import CupidServiceClient
if op.is_terminal:
odps_params = op.odps_params.copy()
project = os.environ.get("ODPS_PROJECT_NAME", None)
if project:
odps_params["project"] = project
client = CupidServiceClient()
client.commit_table_upload_session(
odps_params, op.table_name, op.cupid_handle, op.blocks, op.overwrite
)
ctx[op.outputs[0].key] = pd.DataFrame()
def write_odps_table(
df,
table,
partition=None,
overwrite=False,
unknown_as_string=None,
odps_params=None,
write_batch_size=None,
tunnel_quota_name=None,
):
op = DataFrameWriteTable(
dtypes=df.dtypes,
odps_params=odps_params,
table_name=table.full_table_name,
unknown_as_string=unknown_as_string,
partition_spec=partition,
overwrite=overwrite,
write_batch_size=write_batch_size,
tunnel_quota_name=tunnel_quota_name,
)
return op(df)