odps/tunnel/io/writer.py (898 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import struct
try:
import pyarrow as pa
except (AttributeError, ImportError):
pa = None
try:
import pyarrow.compute as pac
except (AttributeError, ImportError):
pac = None
try:
import numpy as np
except ImportError:
np = None
try:
import pandas as pd
except (ImportError, ValueError):
pd = None
from ... import compat, options, types, utils
from ...compat import Decimal, Enum, futures, six
from ...lib.monotonic import monotonic
from ..base import TunnelMetrics
from ..checksum import Checksum
from ..errors import TunnelError
from ..pb.encoder import Encoder
from ..pb.wire_format import (
WIRETYPE_FIXED32,
WIRETYPE_FIXED64,
WIRETYPE_LENGTH_DELIMITED,
WIRETYPE_VARINT,
)
from ..wireconstants import ProtoWireConstants
from .stream import RequestsIO, get_compress_stream
from .types import odps_schema_to_arrow_schema
try:
if not options.force_py:
from ..hasher_c import RecordHasher
from .writer_c import BaseRecordWriter
else:
from ..hasher import RecordHasher
BaseRecordWriter = None
except ImportError as e:
if options.force_c:
raise e
from ..hasher import RecordHasher
BaseRecordWriter = None
MICRO_SEC_PER_SEC = 1000000
varint_tag_types = types.integer_types + (
types.boolean,
types.datetime,
types.date,
types.interval_year_month,
)
length_delim_tag_types = (
types.string,
types.binary,
types.timestamp,
types.timestamp_ntz,
types.interval_day_time,
types.json,
)
if BaseRecordWriter is None:
class ProtobufWriter(object):
"""
ProtobufWriter is a stream-interface wrapper around encoder_c.Encoder(c)
and encoder.Encoder(py)
"""
DEFAULT_BUFFER_SIZE = 4096
def __init__(self, output, buffer_size=None):
self._encoder = Encoder()
self._output = output
self._buffer_size = buffer_size or self.DEFAULT_BUFFER_SIZE
self._n_total = 0
def _re_init(self, output):
self._encoder = Encoder()
self._output = output
self._n_total = 0
def _mode(self):
return "py"
def flush(self):
if len(self._encoder) > 0:
data = self._encoder.tostring()
self._output.write(data)
self._n_total += len(self._encoder)
self._encoder = Encoder()
def close(self):
self.flush_all()
def flush_all(self):
self.flush()
self._output.flush()
def _refresh_buffer(self):
"""Control the buffer size of _encoder. Flush if necessary"""
if len(self._encoder) > self._buffer_size:
self.flush()
@property
def n_bytes(self):
return self._n_total + len(self._encoder)
def __len__(self):
return self.n_bytes
def _write_tag(self, field_num, wire_type):
self._encoder.append_tag(field_num, wire_type)
self._refresh_buffer()
def _write_raw_long(self, val):
self._encoder.append_sint64(val)
self._refresh_buffer()
def _write_raw_int(self, val):
self._encoder.append_sint32(val)
self._refresh_buffer()
def _write_raw_uint(self, val):
self._encoder.append_uint32(val)
self._refresh_buffer()
def _write_raw_bool(self, val):
self._encoder.append_bool(val)
self._refresh_buffer()
def _write_raw_float(self, val):
self._encoder.append_float(val)
self._refresh_buffer()
def _write_raw_double(self, val):
self._encoder.append_double(val)
self._refresh_buffer()
def _write_raw_string(self, val):
self._encoder.append_string(val)
self._refresh_buffer()
class BaseRecordWriter(ProtobufWriter):
def __init__(self, schema, out, encoding="utf-8"):
self._encoding = encoding
self._schema = schema
self._columns = self._schema.columns
self._crc = Checksum()
self._crccrc = Checksum()
self._curr_cursor = 0
self._to_milliseconds = utils.MillisecondsConverter().to_milliseconds
self._to_milliseconds_utc = utils.MillisecondsConverter(
local_tz=False
).to_milliseconds
self._to_days = utils.to_days
self._enable_client_metrics = options.tunnel.enable_client_metrics
self._local_wall_time_ms = 0
super(BaseRecordWriter, self).__init__(out)
def write(self, record):
n_record_fields = len(record)
n_columns = len(self._columns)
if self._enable_client_metrics:
ts = monotonic()
if n_record_fields > n_columns:
raise IOError("record fields count is more than schema.")
for i in range(min(n_record_fields, n_columns)):
if self._schema.is_partition(self._columns[i]):
continue
val = record[i]
if val is None:
continue
pb_index = i + 1
self._crc.update_int(pb_index)
data_type = self._columns[i].type
if data_type in varint_tag_types:
self._write_tag(pb_index, WIRETYPE_VARINT)
elif data_type == types.float_:
self._write_tag(pb_index, WIRETYPE_FIXED32)
elif data_type == types.double:
self._write_tag(pb_index, WIRETYPE_FIXED64)
elif data_type in length_delim_tag_types:
self._write_tag(pb_index, WIRETYPE_LENGTH_DELIMITED)
elif isinstance(
data_type,
(
types.Char,
types.Varchar,
types.Decimal,
types.Array,
types.Map,
types.Struct,
),
):
self._write_tag(pb_index, WIRETYPE_LENGTH_DELIMITED)
else:
raise IOError("Invalid data type: %s" % data_type)
self._write_field(val, data_type)
checksum = utils.long_to_int(self._crc.getvalue())
self._write_tag(ProtoWireConstants.TUNNEL_END_RECORD, WIRETYPE_VARINT)
self._write_raw_uint(utils.long_to_uint(checksum))
self._crc.reset()
self._crccrc.update_int(checksum)
self._curr_cursor += 1
if self._enable_client_metrics:
self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts))
def _write_bool(self, data):
self._crc.update_bool(data)
self._write_raw_bool(data)
def _write_long(self, data):
self._crc.update_long(data)
self._write_raw_long(data)
def _write_float(self, data):
self._crc.update_float(data)
self._write_raw_float(data)
def _write_double(self, data):
self._crc.update_double(data)
self._write_raw_double(data)
def _write_string(self, data):
if isinstance(data, six.text_type):
data = data.encode(self._encoding)
self._crc.update(data)
self._write_raw_string(data)
def _write_timestamp(self, data, ntz=False):
to_mills = self._to_milliseconds_utc if ntz else self._to_milliseconds
t_val = int(to_mills(data.to_pydatetime(warn=False)) / 1000)
nano_val = data.microsecond * 1000 + data.nanosecond
self._crc.update_long(t_val)
self._write_raw_long(t_val)
self._crc.update_int(nano_val)
self._write_raw_int(nano_val)
def _write_interval_day_time(self, data):
t_val = data.days * 3600 * 24 + data.seconds
nano_val = data.microseconds * 1000 + data.nanoseconds
self._crc.update_long(t_val)
self._write_raw_long(t_val)
self._crc.update_int(nano_val)
self._write_raw_int(nano_val)
def _write_array(self, data, data_type):
for value in data:
if value is None:
self._write_raw_bool(True)
else:
self._write_raw_bool(False)
self._write_field(value, data_type)
def _write_struct(self, data, data_type):
if isinstance(data, dict):
vals = [None] * len(data)
for idx, key in enumerate(data_type.field_types.keys()):
vals[idx] = data[key]
data = tuple(vals)
for value, typ in zip(data, data_type.field_types.values()):
if value is None:
self._write_raw_bool(True)
else:
self._write_raw_bool(False)
self._write_field(value, typ)
def _write_field(self, val, data_type):
if data_type == types.boolean:
self._write_bool(val)
elif data_type == types.datetime:
val = self._to_milliseconds(val)
self._write_long(val)
elif data_type == types.date:
val = self._to_days(val)
self._write_long(val)
elif data_type == types.float_:
self._write_float(val)
elif data_type == types.double:
self._write_double(val)
elif data_type in types.integer_types:
self._write_long(val)
elif data_type == types.string:
self._write_string(val)
elif data_type == types.binary:
self._write_string(val)
elif data_type == types.timestamp or data_type == types.timestamp_ntz:
self._write_timestamp(val, ntz=data_type == types.timestamp_ntz)
elif data_type == types.interval_day_time:
self._write_interval_day_time(val)
elif data_type == types.interval_year_month:
self._write_long(val.total_months())
elif isinstance(data_type, (types.Char, types.Varchar)):
self._write_string(val)
elif isinstance(data_type, types.Decimal):
self._write_string(str(val))
elif isinstance(data_type, types.Json):
self._write_string(json.dumps(val))
elif isinstance(data_type, types.Array):
self._write_raw_uint(len(val))
self._write_array(val, data_type.value_type)
elif isinstance(data_type, types.Map):
self._write_raw_uint(len(val))
self._write_array(compat.lkeys(val), data_type.key_type)
self._write_raw_uint(len(val))
self._write_array(compat.lvalues(val), data_type.value_type)
elif isinstance(data_type, types.Struct):
self._write_struct(val, data_type)
else:
raise IOError("Invalid data type: %s" % data_type)
@property
def count(self):
return self._curr_cursor
def _write_finish_tags(self):
self._write_tag(ProtoWireConstants.TUNNEL_META_COUNT, WIRETYPE_VARINT)
self._write_raw_long(self.count)
self._write_tag(ProtoWireConstants.TUNNEL_META_CHECKSUM, WIRETYPE_VARINT)
self._write_raw_uint(utils.long_to_uint(self._crccrc.getvalue()))
def close(self):
self._write_finish_tags()
super(BaseRecordWriter, self).close()
self._curr_cursor = 0
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# if an error occurs inside the with block, we do not commit
if exc_val is not None:
return
self.close()
class RecordWriter(BaseRecordWriter):
"""
Writer object to write data to ODPS with records. Should be created
with :meth:`TableUploadSession.open_record_writer` with ``block_id`` specified.
:Example:
Here we show an example of writing data to ODPS with two records created in different ways.
.. code-block:: python
from odps.tunnel import TableTunnel
tunnel = TableTunnel(o)
upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test')
# creates a RecordWriter instance for block 0
with upload_session.open_record_writer(0) as writer:
record = upload_session.new_record()
record[0] = 'test1'
record[1] = 'id1'
writer.write(record)
record = upload_session.new_record(['test2', 'id2'])
writer.write(record)
# commit block 0
upload_session.commit([0])
:Note:
``RecordWriter`` holds long HTTP connection which might be closed at
server end when the duration is over 3 minutes. Please avoid opening
``RecordWriter`` for a long period. Details can be found :ref:`here <tunnel>`.
"""
def __init__(
self, schema, request_callback, compress_option=None, encoding="utf-8"
):
self._enable_client_metrics = options.tunnel.enable_client_metrics
self._server_metrics_string = None
if self._enable_client_metrics:
ts = monotonic()
self._req_io = RequestsIO(
request_callback,
chunk_size=options.chunk_size,
record_io_time=self._enable_client_metrics,
)
out = get_compress_stream(self._req_io, compress_option)
super(RecordWriter, self).__init__(schema, out, encoding=encoding)
self._req_io.start()
if self._enable_client_metrics:
self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts))
@property
def metrics(self):
if self._server_metrics_string is None:
return None
return TunnelMetrics.from_server_json(
type(self).__name__,
self._server_metrics_string,
self._local_wall_time_ms,
self._req_io.io_time_ms,
)
def write(self, record):
"""
Write a record to the tunnel.
:param record: record to write
:type record: :class:`odps.models.Record`
"""
if self._req_io._async_err:
ex_type, ex_value, tb = self._req_io._async_err
six.reraise(ex_type, ex_value, tb)
super(RecordWriter, self).write(record)
def close(self):
"""
Close the writer and flush all data to server.
"""
if self._enable_client_metrics:
ts = monotonic()
super(RecordWriter, self).close()
resp = self._req_io.finish()
if self._enable_client_metrics:
self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts))
self._server_metrics_string = resp.headers.get("odps-tunnel-metrics")
def get_total_bytes(self):
return self.n_bytes
class BufferedRecordWriter(BaseRecordWriter):
"""
Writer object to write data to ODPS with records. Should be created
with :meth:`TableUploadSession.open_record_writer` without ``block_id``.
Results should be submitted with :meth:`TableUploadSession.commit` with
returned value from :meth:`get_blocks_written`.
:Example:
Here we show an example of writing data to ODPS with two records created in different ways.
.. code-block:: python
from odps.tunnel import TableTunnel
tunnel = TableTunnel(o)
upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test')
# creates a BufferedRecordWriter instance
with upload_session.open_record_writer() as writer:
record = upload_session.new_record()
record[0] = 'test1'
record[1] = 'id1'
writer.write(record)
record = upload_session.new_record(['test2', 'id2'])
writer.write(record)
# commit blocks
upload_session.commit(writer.get_blocks_written())
"""
def __init__(
self,
schema,
request_callback,
compress_option=None,
encoding="utf-8",
buffer_size=None,
block_id=None,
block_id_gen=None,
):
self._request_callback = request_callback
self._block_id = block_id or 0
self._blocks_written = []
self._buffer = compat.BytesIO()
self._n_bytes_written = 0
self._compress_option = compress_option
self._block_id_gen = block_id_gen
self._enable_client_metrics = options.tunnel.enable_client_metrics
self._server_metrics_string = None
self._network_wall_time_ms = 0
if not self._enable_client_metrics:
self._accumulated_metrics = None
else:
self._accumulated_metrics = TunnelMetrics(type(self).__name__)
ts = monotonic()
out = get_compress_stream(self._buffer, compress_option)
super(BufferedRecordWriter, self).__init__(schema, out, encoding=encoding)
# make sure block buffer size is applied correctly here
self._buffer_size = buffer_size or options.tunnel.block_buffer_size
if self._enable_client_metrics:
self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts))
@property
def cur_block_id(self):
return self._block_id
def _get_next_block_id(self):
if callable(self._block_id_gen):
return self._block_id_gen()
return self._block_id + 1
def write(self, record):
"""
Write a record to the tunnel.
:param record: record to write
:type record: :class:`odps.models.Record`
"""
super(BufferedRecordWriter, self).write(record)
if 0 < self._buffer_size < self._n_raw_bytes:
self._flush()
def close(self):
"""
Close the writer and flush all data to server.
"""
if self._n_raw_bytes > 0:
self._flush()
self.flush_all()
self._buffer.close()
def _collect_metrics(self):
if self._enable_client_metrics:
if self._server_metrics_string is not None:
self._accumulated_metrics += TunnelMetrics.from_server_json(
type(self).__name__,
self._server_metrics_string,
self._local_wall_time_ms,
self._network_wall_time_ms,
)
self._server_metrics_string = None
self._local_wall_time_ms = 0
self._network_wall_time_ms = 0
def _reset_writer(self, write_response):
self._collect_metrics()
if self._enable_client_metrics:
ts = monotonic()
self._buffer = compat.BytesIO()
out = get_compress_stream(self._buffer, self._compress_option)
self._re_init(out)
self._curr_cursor = 0
self._crccrc.reset()
self._crc.reset()
if self._enable_client_metrics:
self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts))
def _send_buffer(self):
if self._enable_client_metrics:
ts = monotonic()
resp = self._request_callback(self._block_id, self._buffer.getvalue())
if self._enable_client_metrics:
self._network_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts))
return resp
def _flush(self):
if self._enable_client_metrics:
ts = monotonic()
self._write_finish_tags()
self._n_bytes_written += self._n_raw_bytes
self.flush_all()
resp = self._send_buffer()
self._server_metrics_string = resp.headers.get("odps-tunnel-metrics")
self._blocks_written.append(self._block_id)
self._block_id = self._get_next_block_id()
if self._enable_client_metrics:
self._local_wall_time_ms += int(MICRO_SEC_PER_SEC * (monotonic() - ts))
self._reset_writer(resp)
@property
def metrics(self):
return self._accumulated_metrics
@property
def _n_raw_bytes(self):
return super(BufferedRecordWriter, self).n_bytes
@property
def n_bytes(self):
return self._n_bytes_written + self._n_raw_bytes
def get_total_bytes(self):
return self.n_bytes
def get_blocks_written(self):
"""
Get block ids created during writing. Should be provided as the argument to
:meth:`TableUploadSession.commit`.
"""
return self._blocks_written
# make sure original typo class also referable
BufferredRecordWriter = BufferedRecordWriter
class StreamRecordWriter(BufferedRecordWriter):
def __init__(
self,
schema,
request_callback,
session,
slot,
compress_option=None,
encoding="utf-8",
buffer_size=None,
):
self.session = session
self.slot = slot
self._record_count = 0
super(StreamRecordWriter, self).__init__(
schema,
request_callback,
compress_option=compress_option,
encoding=encoding,
buffer_size=buffer_size,
)
def write(self, record):
super(StreamRecordWriter, self).write(record)
self._record_count += 1
def _reset_writer(self, write_response):
self._record_count = 0
slot_server = write_response.headers["odps-tunnel-routed-server"]
slot_num = int(write_response.headers["odps-tunnel-slot-num"])
self.session.reload_slots(self.slot, slot_server, slot_num)
super(StreamRecordWriter, self)._reset_writer(write_response)
def _send_buffer(self):
def gen(): # synchronize chunk upload
data = self._buffer.getvalue()
chunk_size = options.chunk_size
while data:
to_send = data[:chunk_size]
data = data[chunk_size:]
yield to_send
return self._request_callback(gen())
class BaseArrowWriter(object):
def __init__(self, schema, out=None, chunk_size=None):
if pa is None:
raise ValueError("To use arrow writer you need to install pyarrow")
self._schema = schema
self._arrow_schema = odps_schema_to_arrow_schema(schema)
self._chunk_size = chunk_size or options.chunk_size
self._crc = Checksum()
self._crccrc = Checksum()
self._cur_chunk_size = 0
self._output = out
self._chunk_size_written = False
self._pd_mappers = self._build_pd_mappers()
def _re_init(self, output):
self._output = output
self._chunk_size_written = False
self._cur_chunk_size = 0
def _write_chunk_size(self):
self._write_uint32(self._chunk_size)
def _write_uint32(self, val):
data = struct.pack("!I", utils.long_to_uint(val))
self._output.write(data)
def _write_chunk(self, buf):
if not self._chunk_size_written:
self._write_chunk_size()
self._chunk_size_written = True
self._output.write(buf)
self._crc.update(buf)
self._crccrc.update(buf)
self._cur_chunk_size += len(buf)
if self._cur_chunk_size >= self._chunk_size:
checksum = self._crc.getvalue()
self._write_uint32(checksum)
self._crc.reset()
self._cur_chunk_size = 0
@classmethod
def _localize_timezone(cls, col, tz=None):
from ...lib import tzlocal
if tz is None:
if options.local_timezone is True or options.local_timezone is None:
tz = str(tzlocal.get_localzone())
elif options.local_timezone is False:
tz = "UTC"
else:
tz = str(options.local_timezone)
if col.type.tz is not None:
return col
if hasattr(pac, "assume_timezone") and isinstance(tz, str):
# pyarrow.compute.assume_timezone only accepts
# string-represented zones
col = pac.assume_timezone(col, timezone=tz)
return col
else:
pd_col = col.to_pandas().dt.tz_localize(tz)
return pa.Array.from_pandas(pd_col)
@classmethod
def _str_to_decimal_array(cls, col, dec_type):
dec_col = col.to_pandas().map(Decimal)
return pa.Array.from_pandas(dec_col, type=dec_type)
def _build_pd_mappers(self):
pa_dec_types = (pa.Decimal128Type,)
if hasattr(pa, "Decimal256Type"):
pa_dec_types += (pa.Decimal256Type,)
def _need_cast(arrow_type):
if isinstance(arrow_type, (pa.MapType, pa.StructType) + pa_dec_types):
return True
elif isinstance(arrow_type, pa.ListType):
return _need_cast(arrow_type.value_type)
else:
return False
def _build_mapper(cur_type):
if isinstance(cur_type, pa.MapType):
key_mapper = _build_mapper(cur_type.key_type)
value_mapper = _build_mapper(cur_type.item_type)
def mapper(data):
if isinstance(data, dict):
return [
(key_mapper(k), value_mapper(v)) for k, v in data.items()
]
else:
return data
elif isinstance(cur_type, pa.ListType):
item_mapper = _build_mapper(cur_type.value_type)
def mapper(data):
if data is None:
return data
return [item_mapper(element) for element in data]
elif isinstance(cur_type, pa.StructType):
val_mappers = dict()
for fid in range(cur_type.num_fields):
field = cur_type[fid]
val_mappers[field.name.lower()] = _build_mapper(field.type)
def mapper(data):
if isinstance(data, (list, tuple)):
field_names = getattr(data, "_fields", None) or [
cur_type[fid].name for fid in range(cur_type.num_fields)
]
data = dict(zip(data, field_names))
if isinstance(data, dict):
fields = dict()
for key, val in data.items():
fields[key] = val_mappers[key.lower()](val)
data = fields
return data
elif isinstance(cur_type, pa_dec_types):
def mapper(data):
if data is None:
return None
return Decimal(data)
else:
mapper = lambda x: x
return mapper
mappers = dict()
for name, typ in zip(self._arrow_schema.names, self._arrow_schema.types):
if _need_cast(typ):
mappers[name.lower()] = _build_mapper(typ)
return mappers
def _convert_df_types(self, df):
dest_df = df.copy()
lower_to_df_name = {utils.to_lower_str(s): s for s in df.columns}
new_fields = []
for name, typ in zip(self._arrow_schema.names, self._arrow_schema.types):
df_name = lower_to_df_name[name.lower()]
new_fields.append(pa.field(df_name, typ))
if df_name not in df.columns:
dest_df[df_name] = None
continue
if name.lower() not in self._pd_mappers:
continue
dest_df[df_name] = df[df_name].map(self._pd_mappers[name.lower()])
df_arrow_schema = pa.schema(new_fields)
return pa.Table.from_pandas(dest_df, df_arrow_schema)
def write(self, data):
"""
Write an Arrow RecordBatch, an Arrow Table or a pandas DataFrame.
"""
if isinstance(data, pd.DataFrame):
arrow_data = self._convert_df_types(data)
elif isinstance(data, (pa.Table, pa.RecordBatch)):
arrow_data = data
else:
raise TypeError("Cannot support writing data type %s", type(data))
arrow_decimal_types = (pa.Decimal128Type,)
if hasattr(pa, "Decimal256Type"):
arrow_decimal_types += (pa.Decimal256Type,)
assert isinstance(arrow_data, (pa.RecordBatch, pa.Table))
if arrow_data.schema != self._arrow_schema or any(
isinstance(tp, pa.TimestampType) for tp in arrow_data.schema.types
):
lower_names = [n.lower() for n in arrow_data.schema.names]
type_dict = dict(zip(lower_names, arrow_data.schema.types))
column_dict = dict(zip(lower_names, arrow_data.columns))
arrays = []
for name, tp in zip(self._arrow_schema.names, self._arrow_schema.types):
lower_name = name.lower()
if lower_name not in column_dict:
raise ValueError(
"Input record batch does not contain column %s" % name
)
if isinstance(tp, pa.TimestampType):
if self._schema[lower_name].type == types.timestamp_ntz:
col = self._localize_timezone(column_dict[lower_name], "UTC")
else:
col = self._localize_timezone(column_dict[lower_name])
column_dict[lower_name] = col.cast(
pa.timestamp(tp.unit, col.type.tz)
)
elif (
isinstance(tp, arrow_decimal_types)
and isinstance(column_dict[lower_name], (pa.Array, pa.ChunkedArray))
and column_dict[lower_name].type in (pa.binary(), pa.string())
):
column_dict[lower_name] = self._str_to_decimal_array(
column_dict[lower_name], tp
)
if tp == type_dict[lower_name]:
arrays.append(column_dict[lower_name])
else:
try:
arrays.append(column_dict[lower_name].cast(tp, safe=False))
except (pa.ArrowInvalid, pa.ArrowNotImplementedError):
raise ValueError(
"Failed to cast column %s to type %s" % (name, tp)
)
pa_type = type(arrow_data)
arrow_data = pa_type.from_arrays(arrays, names=self._arrow_schema.names)
if isinstance(arrow_data, pa.RecordBatch):
batches = [arrow_data]
else: # pa.Table
batches = arrow_data.to_batches()
for batch in batches:
data = batch.serialize().to_pybytes()
written_bytes = 0
while written_bytes < len(data):
length = min(
self._chunk_size - self._cur_chunk_size, len(data) - written_bytes
)
chunk_data = data[written_bytes : written_bytes + length]
self._write_chunk(chunk_data)
written_bytes += length
def _write_finish_tags(self):
checksum = self._crccrc.getvalue()
self._write_uint32(checksum)
self._crccrc.reset()
def flush(self):
self._output.flush()
def _finish(self):
self._write_finish_tags()
self._output.flush()
def close(self):
"""
Closes the writer and flush all data to server.
"""
self._finish()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# if an error occurs inside the with block, we do not commit
if exc_val is not None:
return
self.close()
class ArrowWriter(BaseArrowWriter):
"""
Writer object to write data to ODPS using Arrow format. Should be created
with :meth:`TableUploadSession.open_arrow_writer` with ``block_id`` specified.
:Example:
Here we show an example of writing a pandas DataFrame to ODPS.
.. code-block:: python
import pandas as pd
from odps.tunnel import TableTunnel
tunnel = TableTunnel(o)
upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test')
# creates an ArrowWriter instance for block 0
with upload_session.open_arrow_writer(0) as writer:
df = pd.DataFrame({'col1': ['test1', 'test2'], 'col2': ['id1', 'id2']})
writer.write(df)
# commit block 0
upload_session.commit([0])
:Note:
``ArrowWriter`` holds long HTTP connection which might be closed at
server end when the duration is over 3 minutes. Please avoid opening
``ArrowWriter`` for a long period. Details can be found :ref:`here <tunnel>`.
"""
def __init__(self, schema, request_callback, compress_option=None, chunk_size=None):
self._req_io = RequestsIO(request_callback, chunk_size=chunk_size)
out = get_compress_stream(self._req_io, compress_option)
super(ArrowWriter, self).__init__(schema, out, chunk_size)
self._req_io.start()
def _finish(self):
super(ArrowWriter, self)._finish()
self._req_io.finish()
class BufferedArrowWriter(BaseArrowWriter):
"""
Writer object to write data to ODPS using Arrow format. Should be created
with :meth:`TableUploadSession.open_arrow_writer` without ``block_id``.
Results should be submitted with :meth:`TableUploadSession.commit` with
returned value from :meth:`get_blocks_written`.
:Example:
Here we show an example of writing a pandas DataFrame to ODPS.
.. code-block:: python
import pandas as pd
from odps.tunnel import TableTunnel
tunnel = TableTunnel(o)
upload_session = tunnel.create_upload_session('my_table', partition_spec='pt=test')
# creates a BufferedArrowWriter instance
with upload_session.open_arrow_writer() as writer:
df = pd.DataFrame({'col1': ['test1', 'test2'], 'col2': ['id1', 'id2']})
writer.write(df)
# commit blocks
upload_session.commit(writer.get_blocks_written())
"""
def __init__(
self,
schema,
request_callback,
compress_option=None,
buffer_size=None,
chunk_size=None,
block_id=None,
block_id_gen=None,
):
self._buffer_size = buffer_size or options.tunnel.block_buffer_size
self._request_callback = request_callback
self._block_id = block_id or 0
self._blocks_written = []
self._buffer = compat.BytesIO()
self._compress_option = compress_option
self._n_bytes_written = 0
self._block_id_gen = block_id_gen
out = get_compress_stream(self._buffer, compress_option)
super(BufferedArrowWriter, self).__init__(schema, out, chunk_size=chunk_size)
@property
def cur_block_id(self):
return self._block_id
def _get_next_block_id(self):
if callable(self._block_id_gen):
return self._block_id_gen()
return self._block_id + 1
def write(self, data):
super(BufferedArrowWriter, self).write(data)
if 0 < self._buffer_size < self._n_raw_bytes:
self._flush()
def close(self):
if self._n_raw_bytes > 0:
self._flush()
self._finish()
self._buffer.close()
def _reset_writer(self):
self._buffer = compat.BytesIO()
out = get_compress_stream(self._buffer, self._compress_option)
self._re_init(out)
self._crccrc.reset()
self._crc.reset()
def _send_buffer(self):
return self._request_callback(self._block_id, self._buffer.getvalue())
def _flush(self):
self._write_finish_tags()
self._n_bytes_written += self._n_raw_bytes
self._send_buffer()
self._blocks_written.append(self._block_id)
self._block_id = self._get_next_block_id()
self._reset_writer()
@property
def _n_raw_bytes(self):
return self._buffer.tell()
@property
def n_bytes(self):
return self._n_bytes_written + self._n_raw_bytes
def get_total_bytes(self):
return self.n_bytes
def get_blocks_written(self):
"""
Get block ids created during writing. Should be provided as the argument to
:meth:`TableUploadSession.commit`.
"""
return self._blocks_written
class Upsert(object):
"""
Object to insert or update data into an ODPS upsert table with records. Should be
created with :meth:`TableUpsertSession.open_upsert_stream`.
:Example:
Here we show an example of inserting, updating and deleting data to an upsert table.
.. code-block:: python
from odps.tunnel import TableTunnel
tunnel = TableTunnel(o)
upsert_session = tunnel.create_upsert_session('my_table', partition_spec='pt=test')
# creates a BufferedRecordWriter instance
stream = upsert_session.open_upsert_stream(compress=True)
rec = upsert_session.new_record(["0", "v1"])
stream.upsert(rec)
rec = upsert_session.new_record(["0", "v2"])
stream.upsert(rec)
rec = upsert_session.new_record(["1", "v1"])
stream.upsert(rec)
rec = upsert_session.new_record(["2", "v1"])
stream.upsert(rec)
stream.delete(rec)
stream.flush()
stream.close()
upsert_session.commit()
"""
DEFAULT_MAX_BUFFER_SIZE = 64 * 1024**2
DEFAULT_SLOT_BUFFER_SIZE = 1024**2
class Operation(Enum):
UPSERT = "UPSERT"
DELETE = "DELETE"
class Status(Enum):
NORMAL = "NORMAL"
ERROR = "ERROR"
CLOSED = "CLOSED"
def __init__(
self,
schema,
request_callback,
session,
compress_option=None,
encoding="utf-8",
max_buffer_size=None,
slot_buffer_size=None,
):
self._schema = schema
self._request_callback = request_callback
self._session = session
self._compress_option = compress_option
self._max_buffer_size = max_buffer_size or self.DEFAULT_MAX_BUFFER_SIZE
self._slot_buffer_size = slot_buffer_size or self.DEFAULT_SLOT_BUFFER_SIZE
self._total_n_bytes = 0
self._status = Upsert.Status.NORMAL
self._schema = session.schema
self._encoding = encoding
self._hash_keys = self._session.hash_keys
self._hasher = RecordHasher(schema, self._session.hasher, self._hash_keys)
self._buckets = self._session.buckets.copy()
self._bucket_buffers = {}
self._bucket_writers = {}
for slot in session.buckets.keys():
self._build_bucket_writer(slot)
@property
def status(self):
return self._status
@property
def n_bytes(self):
return self._total_n_bytes
def upsert(self, record):
"""
Insert or update a record.
:param record: record to write
:type record: :class:`odps.models.Record`
"""
return self._write(record, Upsert.Operation.UPSERT)
def delete(self, record):
"""
Delete a record.
:param record: record to write
:type record: :class:`odps.models.Record`
"""
return self._write(record, Upsert.Operation.DELETE)
def flush(self, flush_all=True):
"""
Flush all data in buffer to server.
"""
if len(self._session.buckets) != len(self._bucket_writers):
raise TunnelError("session slot map is changed")
else:
self._buckets = self._session.buckets.copy()
bucket_written = dict()
bucket_to_count = dict()
def write_bucket(bucket_id):
slot = self._buckets[bucket_id]
sio = self._bucket_buffers[bucket_id]
rec_count = bucket_to_count[bucket_id]
self._request_callback(bucket_id, slot, rec_count, sio.getvalue())
self._build_bucket_writer(bucket_id)
bucket_written[bucket_id] = True
retry = 0
while True:
futs = []
pool = futures.ThreadPoolExecutor(len(self._bucket_writers))
try:
self._check_status()
for bucket, writer in self._bucket_writers.items():
if writer.n_bytes == 0 or bucket_written.get(bucket):
continue
if not flush_all and writer.n_bytes <= self._slot_buffer_size:
continue
bucket_to_count[bucket] = writer.count
writer.close()
futs.append(pool.submit(write_bucket, bucket))
for fut in futs:
fut.result()
break
except KeyboardInterrupt:
raise TunnelError("flush interrupted")
except:
retry += 1
if retry == 3:
raise
finally:
pool.shutdown()
def close(self):
"""
Close the stream and write all data to server.
"""
if self.status == Upsert.Status.NORMAL:
self.flush()
self._status = Upsert.Status.CLOSED
def _build_bucket_writer(self, slot):
self._bucket_buffers[slot] = compat.BytesIO()
self._bucket_writers[slot] = BaseRecordWriter(
self._schema,
get_compress_stream(self._bucket_buffers[slot], self._compress_option),
encoding=self._encoding,
)
def _check_status(self):
if self._status == Upsert.Status.CLOSED:
raise TunnelError("Stream is closed!")
elif self._status == Upsert.Status.ERROR:
raise TunnelError("Stream has error!")
def _write(self, record, op, valid_columns=None):
self._check_status()
bucket = self._hasher.hash(record) % len(self._bucket_writers)
if bucket not in self._bucket_writers:
raise TunnelError(
"Tunnel internal error! Do not have bucket for hash key " + bucket
)
record[self._session.UPSERT_OPERATION_KEY] = ord(
b"U" if op == Upsert.Operation.UPSERT else b"D"
)
if valid_columns is None:
record[self._session.UPSERT_VALUE_COLS_KEY] = []
else:
valid_cols_set = set(valid_columns)
col_idxes = [
idx for idx, col in self._schema.columns if col in valid_cols_set
]
record[self._session.UPSERT_VALUE_COLS_KEY] = col_idxes
writer = self._bucket_writers[bucket]
prev_written_size = writer.n_bytes
writer.write(record)
written_size = writer.n_bytes
self._total_n_bytes += written_size - prev_written_size
if writer.n_bytes > self._slot_buffer_size:
self.flush(False)
elif self._total_n_bytes > self._max_buffer_size:
self.flush(True)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# if an error occurs inside the with block, we do not commit
if exc_val is not None:
return
self.close()