odps/models/tableio.py (1,337 lines of code) (raw):

# -*- coding: utf-8 -*- # Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import functools import itertools import logging import multiprocessing import os import socket import struct import sys import threading import uuid import warnings from collections import OrderedDict, defaultdict from types import GeneratorType, MethodType try: import pyarrow as pa except (AttributeError, ImportError): pa = None try: import pandas as pd except ImportError: pd = None from .. import errors from .. import types as odps_types from .. import utils from ..compat import Iterable, six from ..config import options from ..dag import DAG from ..lib import cloudpickle from ..lib.tblib import pickling_support from .readers import TunnelArrowReader, TunnelRecordReader logger = logging.getLogger(__name__) pickling_support.install() _GET_NEXT_BLOCK_CMD = 0x01 _PUT_WRITTEN_BLOCKS_CMD = 0x02 _SERVER_ERROR_CMD = 0xFD _STOP_SERVER_CMD = 0xFE if sys.version_info[0] == 2: _ord_if_possible = ord def _load_classmethod(cls, func_name): return getattr(cls, func_name) class InstanceMethodWrapper(object): """Trick for classmethods under Python 2.7 to be pickleable""" def __init__(self, func): assert isinstance(func, MethodType) assert isinstance(func.im_self, type) self._func = func def __call__(self, *args, **kw): return self._func() def __reduce__(self): return _load_classmethod, (self._func.im_self, self._func.im_func.__name__) _wrap_classmethod = InstanceMethodWrapper else: def _ord_if_possible(x): return x def _wrap_classmethod(x): return x class SpawnedTableReaderMixin(object): @property def schema(self): return self._parent.table_schema @staticmethod def _read_table_split( conn, download_id, start, count, idx, rest_client=None, project=None, table_name=None, partition_spec=None, tunnel_endpoint=None, quota_name=None, columns=None, arrow=False, schema_name=None, append_partitions=None, ): # read part data from ..tunnel import TableTunnel try: tunnel = TableTunnel( client=rest_client, project=project, endpoint=tunnel_endpoint, quota_name=quota_name, ) session = utils.call_with_retry( tunnel.create_download_session, table_name, schema=schema_name, download_id=download_id, partition_spec=partition_spec, ) def _data_to_pandas(): if not arrow: with session.open_record_reader( start, count, columns=columns, append_partitions=append_partitions, ) as reader: return reader.to_pandas() else: with session.open_arrow_reader( start, count, columns=columns, append_partitions=append_partitions, ) as reader: return reader.to_pandas() data = utils.call_with_retry(_data_to_pandas) conn.send((idx, data, True)) except: try: conn.send((idx, sys.exc_info(), False)) except: logger.exception("Failed to write in process %d", idx) raise def _get_process_split_reader(self, columns=None, append_partitions=None): rest_client = self._parent._client table_name = self._parent.name schema_name = self._parent.get_schema() project = self._parent.project.name tunnel_endpoint = self._download_session._client.endpoint quota_name = self._download_session._quota_name partition_spec = self._partition_spec return functools.partial( self._read_table_split, rest_client=rest_client, project=project, table_name=table_name, partition_spec=partition_spec, tunnel_endpoint=tunnel_endpoint, quota_name=quota_name, arrow=isinstance(self, TunnelArrowReader), columns=columns or self._column_names, schema_name=schema_name, append_partitions=append_partitions, ) class TableRecordReader(SpawnedTableReaderMixin, TunnelRecordReader): def __init__( self, table, download_session, partition_spec=None, columns=None, append_partitions=True, ): super(TableRecordReader, self).__init__( table, download_session, columns=columns, append_partitions=append_partitions, ) self._partition_spec = partition_spec class TableArrowReader(SpawnedTableReaderMixin, TunnelArrowReader): def __init__( self, table, download_session, partition_spec=None, columns=None, append_partitions=False, ): super(TableArrowReader, self).__init__( table, download_session, columns=columns, append_partitions=append_partitions, ) self._partition_spec = partition_spec class MPBlockServer(object): def __init__(self, writer): self._writer = writer self._sock = None self._serve_thread_obj = None self._authkey = multiprocessing.current_process().authkey @property def address(self): return self._sock.getsockname() if self._sock else None @property def authkey(self): return self._authkey def _serve_thread(self): while True: data, addr = self._sock.recvfrom(4096) try: pos = len(self._authkey) assert data[:pos] == self._authkey, "Authentication key mismatched!" cmd_code = _ord_if_possible(data[pos]) pos += 1 if cmd_code == _GET_NEXT_BLOCK_CMD: block_id = self._writer._gen_next_block_id() data = struct.pack("<B", _GET_NEXT_BLOCK_CMD) + struct.pack( "<I", block_id ) self._sock.sendto(data, addr) elif cmd_code == _PUT_WRITTEN_BLOCKS_CMD: blocks_queue = self._writer._used_block_id_queue (count,) = struct.unpack("<H", data[pos : pos + 2]) pos += 2 assert 4 * count < len(data), "Data too short for block count!" block_ids = struct.unpack( "<%dI" % count, data[pos : pos + 4 * count] ) blocks_queue.put(block_ids) self._sock.sendto(struct.pack("<B", _PUT_WRITTEN_BLOCKS_CMD), addr) elif cmd_code == _STOP_SERVER_CMD: assert ( addr[0] == self._sock.getsockname()[0] ), "Cannot stop server from other hosts!" break else: # pragma: no cover raise AssertionError("Unrecognized command %x", cmd_code) except BaseException: pk_exc_info = cloudpickle.dumps(sys.exc_info()) data = ( struct.pack("<B", _SERVER_ERROR_CMD) + struct.pack("<I", len(pk_exc_info)) + pk_exc_info ) self._sock.sendto(data, addr) logger.exception("Serve thread error.") self._sock.close() self._sock = None def start(self): self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self._sock.bind(("127.0.0.1", 0)) self._serve_thread_obj = threading.Thread(target=self._serve_thread) self._serve_thread_obj.daemon = True self._serve_thread_obj.start() def stop(self): stop_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) stop_data = self._authkey + struct.pack("<B", _STOP_SERVER_CMD) stop_sock.sendto(stop_data, self._sock.getsockname()) stop_sock.close() self._serve_thread_obj.join() class MPBlockClient(object): _MAX_BLOCK_COUNT = 256 def __init__(self, address, authkey): self._addr = address self._authkey = authkey self._sock = None def __del__(self): self.close() def close(self): if self._sock is not None: self._sock.close() self._sock = None def _get_socket(self): if self._sock is None: self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) return self._sock @staticmethod def _reraise_remote_error(recv_data): if recv_data[0] != _SERVER_ERROR_CMD: return (pk_len,) = struct.unpack("<I", recv_data[1:5]) exc_info = cloudpickle.loads(recv_data[5 : 5 + pk_len]) six.reraise(*exc_info) def get_next_block_id(self): sock = self._get_socket() data = self._authkey + struct.pack("<B", _GET_NEXT_BLOCK_CMD) sock.sendto(data, self._addr) recv_data, server_addr = sock.recvfrom(1024) self._reraise_remote_error(recv_data) assert _ord_if_possible(recv_data[0]) == _GET_NEXT_BLOCK_CMD assert self._addr == server_addr (count,) = struct.unpack("<I", recv_data[1:5]) return count def put_written_blocks(self, block_ids): sock = self._get_socket() for pos in range(0, len(block_ids), self._MAX_BLOCK_COUNT): sub_block_ids = block_ids[pos : pos + self._MAX_BLOCK_COUNT] data = ( self._authkey + struct.pack("<B", _PUT_WRITTEN_BLOCKS_CMD) + struct.pack("<H", len(sub_block_ids)) + struct.pack("<%dI" % len(sub_block_ids), *sub_block_ids) ) sock.sendto(data, self._addr) recv_data, server_addr = sock.recvfrom(1024) self._reraise_remote_error(recv_data) assert _ord_if_possible(recv_data[0]) == _PUT_WRITTEN_BLOCKS_CMD assert self._addr == server_addr class AbstractTableWriter(object): def __init__( self, table, upload_session, blocks=None, commit=True, on_close=None, **kwargs ): self._table = table self._upload_session = upload_session self._commit = commit self._closed = False self._on_close = on_close self._use_buffered_writer = None if blocks is not None: self._use_buffered_writer = False # block writer options self._blocks = blocks or upload_session.blocks or [0] self._blocks_writes = [False] * len(self._blocks) self._blocks_writers = [None] * len(self._blocks) for block in upload_session.blocks or (): self._blocks_writes[self._blocks.index(block)] = True # buffered writer options self._thread_to_buffered_writers = dict() # objects for cross-process sharings self._mp_server = None self._main_pid = kwargs.get("main_pid") or os.getpid() self._mp_fixed = kwargs.get("mp_fixed") if kwargs.get("mp_client"): self._mp_client = kwargs["mp_client"] self._mp_context = self._block_id_counter = self._used_block_id_queue = None else: self._mp_client = self._mp_authkey = None self._mp_context = kwargs.get("mp_context") or multiprocessing self._block_id_counter = kwargs.get( "block_id_counter" ) or self._mp_context.Value("i", 1 + max(upload_session.blocks or [0])) self._used_block_id_queue = ( kwargs.get("used_block_id_queue") or self._mp_context.Queue() ) @classmethod def _restore_subprocess_writer( cls, mp_server_address, mp_server_auth, upload_id, main_pid=None, blocks=None, rest_client=None, project=None, table_name=None, partition_spec=None, tunnel_endpoint=None, quota_name=None, schema=None, ): from ..core import ODPS from ..tunnel import TableTunnel odps_entry = ODPS( account=rest_client.account, app_account=rest_client.app_account, endpoint=rest_client.endpoint, overwrite_global=False, ) tunnel = TableTunnel( client=rest_client, project=project, endpoint=tunnel_endpoint, quota_name=quota_name, ) table = odps_entry.get_table(table_name, schema=schema, project=project) session = utils.call_with_retry( tunnel.create_upload_session, table_name, schema=schema, upload_id=upload_id, partition_spec=partition_spec, ) mp_client = MPBlockClient(mp_server_address, mp_server_auth) writer = cls( table, session, commit=False, blocks=blocks, main_pid=main_pid, mp_fixed=True, mp_client=mp_client, ) return writer def _start_mp_server(self): if self._mp_server is not None: return self._mp_server = MPBlockServer(self) self._mp_server.start() # replace mp queue with ordinary queue self._used_block_id_queue = six.moves.queue.Queue() def __reduce__(self): rest_client = self._table._client table_name = self._table.name schema_name = self._table.get_schema() project = self._table.project.name tunnel_endpoint = self._upload_session._client.endpoint quota_name = self._upload_session._quota_name partition_spec = self._upload_session._partition_spec blocks = None if self._use_buffered_writer is not False else self._blocks self._start_mp_server() return _wrap_classmethod(self._restore_subprocess_writer), ( self._mp_server.address, bytes(self._mp_server.authkey), self.upload_id, self._main_pid, blocks, rest_client, project, table_name, partition_spec, tunnel_endpoint, quota_name, schema_name, ) @property def upload_id(self): return self._upload_session.id @property def schema(self): return self._table.table_schema @property def status(self): return self._upload_session.status def _open_writer(self, block_id, compress): raise NotImplementedError def _write_contents(self, writer, *args): raise NotImplementedError def _gen_next_block_id(self): if self._mp_client is not None: return self._mp_client.get_next_block_id() with self._block_id_counter.get_lock(): block_id = self._block_id_counter.value self._block_id_counter.value += 1 return block_id def _fix_mp_attributes(self): if not self._mp_fixed: self._mp_fixed = True if os.getpid() == self._main_pid: return self._commit = False self._on_close = None def write(self, *args, **kwargs): if self._closed: raise IOError("Cannot write to a closed writer.") self._fix_mp_attributes() compress = kwargs.get("compress", False) block_id = kwargs.get("block_id") if block_id is None: if type(args[0]) in six.integer_types: block_id = args[0] args = args[1:] else: block_id = 0 if options.tunnel.use_block_writer_by_default else None use_buffered_writer = block_id is None if self._use_buffered_writer is None: self._use_buffered_writer = use_buffered_writer elif self._use_buffered_writer is not use_buffered_writer: raise ValueError( "Cannot mix block writing mode with non-block writing mode within a single writer" ) if use_buffered_writer: idx = None writer = self._thread_to_buffered_writers.get( threading.current_thread().ident ) else: idx = self._blocks.index(block_id) writer = self._blocks_writers[idx] if writer is None: writer = self._open_writer(block_id, compress) self._write_contents(writer, *args) if not use_buffered_writer: self._blocks_writes[idx] = True def close(self): if self._closed: return written_blocks = [] if self._use_buffered_writer: for writer in self._thread_to_buffered_writers.values(): writer.close() written_blocks.extend(writer.get_blocks_written()) else: for writer in self._blocks_writers: if writer is not None: writer.close() written_blocks = [ block for block, block_write in zip(self._blocks, self._blocks_writes) if block_write ] if written_blocks: if self._mp_client is not None: self._mp_client.put_written_blocks(written_blocks) else: self._used_block_id_queue.put(written_blocks) if self._commit: collected_blocks = [] # as queue.empty() not reliable, we need to fill local blocks manually collected_blocks.extend(written_blocks) while not self._used_block_id_queue.empty(): collected_blocks.extend(self._used_block_id_queue.get()) collected_blocks.extend(self._upload_session.blocks or []) collected_blocks = sorted(set(collected_blocks)) self._upload_session.commit(collected_blocks) if callable(self._on_close): self._on_close() if self._mp_client is not None: self._mp_client.close() self._mp_client = None if self._mp_server is not None: self._mp_server.stop() self._mp_server = None self._closed = True def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): # if an error occurs inside the with block, we do not commit if exc_val is not None: return self.close() class ToRecordsMixin(object): def _new_record(self, arg): raise NotImplementedError def _to_records(self, *args): def convert_records(arg, sample_rec): if odps_types.is_record(sample_rec): return arg elif isinstance(sample_rec, (list, tuple)): return (self._new_record(vals) for vals in arg) else: return [self._new_record(arg)] if len(args) == 0: return if len(args) > 1: args = [args] arg = args[0] if odps_types.is_record(arg): return [arg] elif isinstance(arg, (list, tuple)): return convert_records(arg, arg[0]) elif isinstance(arg, GeneratorType): try: # peek the first element and then put back next_arg = six.next(arg) chained = itertools.chain((next_arg,), arg) return convert_records(chained, next_arg) except StopIteration: return () else: raise ValueError("Unsupported record type.") class TableRecordWriter(ToRecordsMixin, AbstractTableWriter): def _open_writer(self, block_id, compress): if self._use_buffered_writer: writer = self._upload_session.open_record_writer( compress=compress, initial_block_id=self._gen_next_block_id(), block_id_gen=self._gen_next_block_id, ) thread_ident = threading.current_thread().ident self._thread_to_buffered_writers[thread_ident] = writer else: writer = self._upload_session.open_record_writer( block_id, compress=compress ) self._blocks_writers[block_id] = writer return writer def _new_record(self, arg): return self._upload_session.new_record(arg) def _write_contents(self, writer, *args): for record in self._to_records(*args): writer.write(record) class TableArrowWriter(AbstractTableWriter): def _open_writer(self, block_id, compress): if self._use_buffered_writer: writer = self._upload_session.open_arrow_writer( compress=compress, initial_block_id=self._gen_next_block_id(), block_id_gen=self._gen_next_block_id, ) thread_ident = threading.current_thread().ident self._thread_to_buffered_writers[thread_ident] = writer else: writer = self._upload_session.open_arrow_writer(block_id, compress=compress) self._blocks_writers[block_id] = writer return writer def _write_contents(self, writer, *args): for arg in args: writer.write(arg) class TableUpsertWriter(ToRecordsMixin): def __init__(self, table, upsert_session, commit=True, on_close=None, **_): self._table = table self._upsert_session = upsert_session self._closed = False self._commit = commit self._on_close = on_close self._upsert = None def _open_upsert(self, compress): self._upsert = self._upsert_session.open_upsert_stream(compress=compress) def _new_record(self, arg): return self._upsert_session.new_record(arg) def _write(self, *args, **kw): compress = kw.pop("compress", None) delete = kw.pop("delete", False) if not self._upsert: self._open_upsert(compress) for record in self._to_records(*args): if delete: self._upsert.delete(record) else: self._upsert.upsert(record) def write(self, *args, **kw): compress = kw.pop("compress", None) self._write(*args, compress=compress, delete=False) def delete(self, *args, **kw): compress = kw.pop("compress", None) self._write(*args, compress=compress, delete=True) def close(self, success=True, commit=True): if not success: self._upsert_session.abort() else: self._upsert.close() if commit and self._commit: self._upsert_session.commit() if callable(self._on_close): self._on_close() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): # if an error occurs inside the with block, we do not commit if exc_val is not None: self.close(success=False, commit=False) return self.close() class TableIOMethods(object): @classmethod def _get_table_obj(cls, odps, name, project=None, schema=None): if isinstance(name, six.string_types) and "." in name: project, schema, name = odps._split_object_dots(name) if not isinstance(name, six.string_types): if name.get_schema(): schema = name.get_schema().name project, name = name.project.name, name.name parent = odps._get_project_or_schema(project, schema) return parent.tables[name] @classmethod def read_table( cls, odps, name, limit=None, start=0, step=None, project=None, schema=None, partition=None, **kw ): """ Read table's records. :param name: table or table name :type name: :class:`odps.models.table.Table` or str :param limit: the records' size, if None will read all records from the table :param start: the record where read starts with :param step: default as 1 :param project: project name, if not provided, will be the default project :param str schema: schema name, if not provided, will be the default schema :param partition: the partition of this table to read :param list columns: the columns' names which are the parts of table's columns :param bool compress: if True, the data will be compressed during downloading :param compress_option: the compression algorithm, level and strategy :type compress_option: :class:`odps.tunnel.CompressOption` :param endpoint: tunnel service URL :param reopen: reading the table will reuse the session which opened last time, if set to True will open a new download session, default as False :return: records :rtype: generator :Example: >>> for record in odps.read_table('test_table', 100): >>> # deal with such 100 records >>> for record in odps.read_table('test_table', partition='pt=test', start=100, limit=100): >>> # read the `pt=test` partition, skip 100 records and read 100 records .. seealso:: :class:`odps.models.Record` """ table = cls._get_table_obj(odps, name, project=project, schema=schema) compress = kw.pop("compress", False) columns = kw.pop("columns", None) with table.open_reader(partition=partition, **kw) as reader: for record in reader.read( start, limit, step=step, compress=compress, columns=columns ): yield record @classmethod def _is_pa_collection(cls, obj): return pa is not None and isinstance(obj, (pa.Table, pa.RecordBatch)) @classmethod def _is_pd_df(cls, obj): return pd is not None and isinstance(obj, pd.DataFrame) @classmethod def _resolve_schema( cls, records_list=None, data_schema=None, unknown_as_string=False, partition=None, partition_cols=None, type_mapping=None, ): from ..df.backends.odpssql.types import df_schema_to_odps_schema from ..df.backends.pd.types import pd_to_df_schema from ..tunnel.io.types import arrow_schema_to_odps_schema type_mapping = type_mapping or {} type_mapping = { k: odps_types.validate_data_type(v) for k, v in type_mapping.items() } if records_list is not None: if cls._is_pa_collection(records_list): data_schema = arrow_schema_to_odps_schema(records_list.schema) elif cls._is_pd_df(records_list): data_schema = df_schema_to_odps_schema( pd_to_df_schema( records_list, unknown_as_string=unknown_as_string, type_mapping=type_mapping, ) ) elif isinstance(records_list, list) and odps_types.is_record( records_list[0] ): data_schema = odps_types.OdpsSchema(records_list[0]._columns) else: raise TypeError( "Inferring schema from provided data not implemented. " "You need to supply a pandas DataFrame or records." ) assert data_schema is not None part_col_names = partition_cols or [] if partition is not None: part_spec = odps_types.PartitionSpec(partition) part_col_names.extend(k for k in part_spec.keys()) if part_col_names: part_col_set = set(part_col_names) simple_cols = [c for c in data_schema.columns if c.name not in part_col_set] part_cols = [ odps_types.Column(n, odps_types.string) for n in part_col_names ] data_schema = odps_types.OdpsSchema(simple_cols, part_cols) if not type_mapping: return data_schema simple_cols, part_cols = [], [] unmapped_cols = set(type_mapping.keys()) for col in data_schema.columns: if col.name not in type_mapping: simple_cols.append(col) else: unmapped_cols.remove(col.name) simple_cols.append(col.replace(type=type_mapping[col.name])) for col in getattr(data_schema, "partitions", None) or (): if col.name not in type_mapping: part_cols.append(col) else: unmapped_cols.remove(col.name) part_cols.append(col.replace(type=type_mapping[col.name])) for col_name in unmapped_cols: simple_cols.append( odps_types.Column(name=col_name, type=type_mapping[col_name]) ) return odps_types.OdpsSchema(simple_cols, part_cols or None) @classmethod def _calc_schema_diff(cls, src_schema, dest_schema, partition_cols=None): if not src_schema or not dest_schema: return [], [] union_cols, diff_cols = [], [] part_col_set = set(partition_cols or []) # collect union columns in the order of dest schema for col in dest_schema.simple_columns: if col.name in src_schema: union_cols.append(col) # collect columns not in dest schema for col in src_schema.simple_columns: if col.name not in dest_schema and col.name not in part_col_set: diff_cols.append(col) return union_cols, diff_cols @classmethod def _check_partition_specified( cls, table_name, table_schema, partition_cols=None, partition=None ): partition_cols = partition_cols or [] no_eval_set = set(n.lower() for n in partition_cols) if partition: no_eval_set.update( c.lower() for c in odps_types.PartitionSpec(partition).keys() ) expr_cols = [ c.name.lower() for c in table_schema.partitions if c.generate_expression and c.name.lower() not in no_eval_set ] partition_cols += expr_cols if not partition_cols and not partition: if table_schema.partitions: raise ValueError( "Partition spec is required for table %s with partitions, " "please specify a partition with `partition` argument or " "specify a list of columns with `partition_cols` argument " "to enable dynamic partitioning." % table_name ) return partition_cols else: if not table_schema.partitions: raise ValueError( "Cannot store into a non-partitioned table %s when `partition` " "or `partition_cols` is specified." % table_name ) all_parts = ( [n.lower() for n in odps_types.PartitionSpec(partition).keys()] if partition else [] ) if partition_cols: all_parts.extend(partition_cols) req_all_parts_set = set(all_parts) table_all_parts = [c.name.lower() for c in table_schema.partitions] no_exist_parts = req_all_parts_set - set(table_all_parts) if no_exist_parts: raise ValueError( "Partitions %s are not in table %s whose partitions are (%s)." % (sorted(no_exist_parts), table_name, table_all_parts) ) no_specified_parts = set(table_all_parts) - req_all_parts_set if no_specified_parts: raise ValueError( "Partitions %s in table %s are not specified in `partition_cols` or " "`partition` argument." % (sorted(no_specified_parts), table_name) ) return partition_cols @classmethod def _get_ordered_col_expressions(cls, table, partition): """ Get column expressions in topological order by variable dependencies """ part_spec = odps_types.PartitionSpec(partition) col_to_expr = { c.name.lower(): table._get_column_generate_expression(c.name) for c in table.table_schema.columns if c.name not in part_spec } col_to_expr = {c: expr for c, expr in col_to_expr.items() if expr} if not col_to_expr: # no columns with expressions, quit return {} col_dag = DAG() for col in col_to_expr: col_dag.add_node(col) for ref in col_to_expr[col].references: ref_col_name = ref.lower() col_dag.add_node(ref_col_name) col_dag.add_edge(ref_col_name, col) out_col_to_expr = OrderedDict() for col in col_dag.topological_sort(): if col not in col_to_expr: continue out_col_to_expr[col] = col_to_expr[col] return out_col_to_expr @classmethod def _fill_missing_expressions(cls, data, col_to_expr): def handle_recordbatch(batch): col_names = list(batch.schema.names) col_arrays = list(batch.columns) for col in missing_cols: col_names.append(col) col_arrays.append(col_to_expr[col].eval(batch)) return pa.RecordBatch.from_arrays(col_arrays, col_names) if pa and isinstance(data, (pa.Table, pa.RecordBatch)): col_name_set = set(c.lower() for c in data.schema.names) missing_cols = [c for c in col_to_expr if c not in col_name_set] if not missing_cols: return data if isinstance(data, pa.Table): batches = [handle_recordbatch(b) for b in data.to_batches()] return pa.Table.from_batches(batches) else: return handle_recordbatch(data) elif pd and isinstance(data, pd.DataFrame): col_name_set = set(c.lower() for c in data.columns) missing_cols = [c for c in col_to_expr if c not in col_name_set] if not missing_cols: return data data = data.copy() for col in missing_cols: data[col] = col_to_expr[col].eval(data) return data else: wrapped = False if odps_types.is_record(data): data = [data] wrapped = True for rec in data: if not odps_types.is_record(rec): continue for c in col_to_expr: if rec[c] is not None: continue rec[c] = col_to_expr[c].eval(rec) return data[0] if wrapped else data @classmethod def _split_block_data_in_partitions( cls, table, block_data, partition_cols=None, partition=None ): from . import Record table_schema = table.table_schema col_to_expr = cls._get_ordered_col_expressions(table, partition) def _fill_cols(data): if col_to_expr: data = cls._fill_missing_expressions(data, col_to_expr) if not pd or not isinstance(data, pd.DataFrame): return data data.columns = [col.lower() for col in data.columns] tb_col_names = [c.name.lower() for c in table_schema.simple_columns] tb_col_set = set(tb_col_names) extra_cols = [col for col in data.columns if col not in tb_col_set] return data.reindex(tb_col_names + extra_cols, axis=1) if not partition_cols: is_arrow = cls._is_pa_collection(block_data) or cls._is_pd_df(block_data) return {(is_arrow, None): [_fill_cols(block_data)]} input_cols = list(table_schema.simple_columns) + [ odps_types.Column(part, odps_types.string) for part in partition_cols ] input_schema = odps_types.OdpsSchema(input_cols) non_generate_idxes = [ idx for idx, c in enumerate(table_schema.simple_columns) if not table._get_column_generate_expression(c.name) ] num_generate_pts = len( [ c for c in (table_schema.partitions or []) if table._get_column_generate_expression(c.name) ] ) parted_data = defaultdict(list) if ( cls._is_pa_collection(block_data) or cls._is_pd_df(block_data) or odps_types.is_record(block_data) or ( isinstance(block_data, list) and block_data and not isinstance(block_data[0], list) ) ): # pd dataframes, arrow RecordBatch, single record or single record-like array block_data = [block_data] for data in block_data: if cls._is_pa_collection(data): data = data.to_pandas() elif isinstance(data, list): if len(data) != len(input_schema): # fill None columns for generate functions to fill them # once size of data matches size of non-generative columns if len(data) < len(table_schema.simple_columns) and len( data ) == len(non_generate_idxes): new_data = [None] * len(table_schema.simple_columns) for idx, d in zip(non_generate_idxes, data): new_data[idx] = d data = new_data # fill None partitions for generate functions to fill them data += [None] * ( num_generate_pts - (len(data) - len(table_schema.simple_columns)) ) if len(data) != len(input_schema): raise ValueError( "Need to specify %d values when writing table " "with dynamic partition." % len(input_schema) ) data = Record(schema=input_schema, values=data) if cls._is_pd_df(data): data = _fill_cols(data) part_set = set(partition_cols) for name, group in data.groupby(partition_cols): name = name if isinstance(name, tuple) else (name,) pt_name = ",".join( "=".join([str(n), str(v)]) for n, v in zip(partition_cols, name) ) parted_data[(True, pt_name)].append( group.drop(part_set, axis=1, errors="ignore") ) elif odps_types.is_record(data): data = _fill_cols(data) pt_name = ",".join( "=".join([str(n), data[str(n)]]) for n in partition_cols ) values = [data[str(c.name)] for c in table_schema.simple_columns] if not parted_data[(False, pt_name)]: parted_data[(False, pt_name)].append([]) parted_data[(False, pt_name)][0].append( Record(schema=table_schema, values=values) ) else: raise ValueError( "Cannot accept data with type %s" % type(data).__name__ ) return parted_data @classmethod def write_table(cls, odps, name, *block_data, **kw): """ Write records or pandas DataFrame into given table. :param name: table or table name :type name: :class:`.models.table.Table` or str :param block_data: records / DataFrame, or block ids and records / DataFrame. If given records or DataFrame only, the block id will be 0 as default. :param str project: project name, if not provided, will be the default project :param str schema: schema name, if not provided, will be the default schema :param partition: the partition of this table to write into :param list partition_cols: columns representing dynamic partitions :param bool append_missing_cols: Whether to append missing columns to the target table. False by default. :param bool overwrite: if True, will overwrite existing data :param bool create_table: if true, the table will be created if not exist :param dict table_kwargs: specify other kwargs for :meth:`~odps.ODPS.create_table` :param dict type_mapping: specify type mapping for columns when creating tables, can be dicts like ``{"column": "bigint"}``. If column does not exist in data, it will be added as an empty column. :param table_schema_callback: a function to accept table schema resolved from data and return a new schema for table to create. Only works when target table does not exist and ``create_table`` is True. :param int lifecycle: specify table lifecycle when creating tables :param bool create_partition: if true, the partition will be created if not exist :param compress_option: the compression algorithm, level and strategy :type compress_option: :class:`odps.tunnel.CompressOption` :param str endpoint: tunnel service URL :param bool reopen: writing the table will reuse the session which opened last time, if set to True will open a new upload session, default as False :return: None :Example: Write records into a specified table. >>> odps.write_table('test_table', data) Write records into multiple blocks. >>> odps.write_table('test_table', 0, records1, 1, records2) Write into a given partition. >>> odps.write_table('test_table', data, partition='pt=test') Write a pandas DataFrame. Create the table if it does not exist. >>> import pandas as pd >>> df = pd.DataFrame([ >>> [111, 'aaa', True], >>> [222, 'bbb', False], >>> [333, 'ccc', True], >>> [444, '中文', False] >>> ], columns=['num_col', 'str_col', 'bool_col']) >>> o.write_table('test_table', df, partition='pt=test', create_table=True, create_partition=True) Passing more arguments when creating table. >>> import pandas as pd >>> df = pd.DataFrame([ >>> [111, 'aaa', True], >>> [222, 'bbb', False], >>> [333, 'ccc', True], >>> [444, '中文', False] >>> ], columns=['num_col', 'str_col', 'bool_col']) >>> # this dict will be passed to `create_table` as kwargs. >>> table_kwargs = {"transactional": True, "primary_key": "num_col"} >>> o.write_table('test_table', df, partition='pt=test', create_table=True, create_partition=True, >>> table_kwargs=table_kwargs) Write with dynamic partitioning. >>> import pandas as pd >>> df = pd.DataFrame([ >>> [111, 'aaa', True, 'p1'], >>> [222, 'bbb', False, 'p1'], >>> [333, 'ccc', True, 'p2'], >>> [444, '中文', False, 'p2'] >>> ], columns=['num_col', 'str_col', 'bool_col', 'pt']) >>> o.write_table('test_part_table', df, partition_cols=['pt'], create_partition=True) :Note: ``write_table`` treats object type of Pandas data as strings as it is often hard to determine their types when creating a new table for your data. To make sure the column type meet your need, you can specify `type_mapping` argument to specify the column types, for instance, ``type_mapping={"col1": "array<struct<id:string>>"}``. .. seealso:: :class:`odps.models.Record` """ project = kw.pop("project", None) schema = kw.pop("schema", None) append_missing_cols = kw.pop("append_missing_cols", False) overwrite = kw.pop("overwrite", False) single_block_types = (Iterable,) if pa is not None: single_block_types += (pa.RecordBatch, pa.Table) if len(block_data) == 1 and isinstance(block_data[0], single_block_types): blocks = [None] data_list = list(block_data) else: blocks = list(block_data[::2]) data_list = list(block_data[1::2]) if len(blocks) != len(data_list): raise ValueError( "Should invoke like odps.write_table(block_id, records, " "block_id2, records2, ..., **kw)" ) unknown_as_string = kw.pop("unknown_as_string", False) create_table = kw.pop("create_table", False) create_partition = kw.pop( "create_partition", kw.pop("create_partitions", False) ) partition = kw.pop("partition", None) partition_cols = kw.pop("partition_cols", None) or kw.pop("partitions", None) lifecycle = kw.pop("lifecycle", None) type_mapping = kw.pop("type_mapping", None) table_schema_callback = kw.pop("table_schema_callback", None) table_kwargs = dict(kw.pop("table_kwargs", None) or {}) if lifecycle: table_kwargs["lifecycle"] = lifecycle if isinstance(partition_cols, six.string_types): partition_cols = [partition_cols] try: data_sample = data_list[0] if isinstance(data_sample, GeneratorType): data_gen = data_sample data_sample = [next(data_gen)] data_list[0] = utils.chain_generator([data_sample[0]], data_gen) table_schema = cls._resolve_schema( data_sample, unknown_as_string=unknown_as_string, partition=partition, partition_cols=partition_cols, type_mapping=type_mapping, ) except TypeError: table_schema = None if not odps.exist_table(name, project=project, schema=schema): if not create_table: raise errors.NoSuchTable( "Target table %s not exist. To create a new table " "you can add an argument `create_table=True`." % name ) if callable(table_schema_callback): table_schema = table_schema_callback(table_schema) target_table = odps.create_table( name, table_schema, project=project, schema=schema, **table_kwargs ) else: target_table = cls._get_table_obj( odps, name, project=project, schema=schema ) union_cols, diff_cols = cls._calc_schema_diff( table_schema, target_table.schema, partition_cols=partition_cols ) if table_schema and not union_cols: warnings.warn( "No columns overlapped between source and target table. If result " "is not as expected, please check if your query provides correct " "column names." ) if diff_cols: if append_missing_cols: target_table.add_columns(diff_cols) else: warnings.warn( "Columns in source data %s are missing in target table %s. " "Specify append_missing_cols=True to append missing columns " "to the target table." % (", ".join(c.name for c in diff_cols), target_table.name) ) partition_cols = cls._check_partition_specified( name, target_table.table_schema, partition_cols=partition_cols, partition=partition, ) data_lists = defaultdict(lambda: defaultdict(list)) for block, data in zip(blocks, data_list): for key, parted_data in cls._split_block_data_in_partitions( target_table, data, partition_cols=partition_cols, partition=partition, ).items(): data_lists[key][block].extend(parted_data) if partition is None or isinstance(partition, six.string_types): partition_str = partition else: partition_str = str(odps_types.PartitionSpec(partition)) # fixme cover up for overwrite failure on table.format.version=2: # only applicable for transactional table with partitions # with generate expressions manual_truncate = ( overwrite and target_table.is_transactional and any( pt_col.generate_expression for pt_col in target_table.table_schema.partitions ) ) for (is_arrow, pt_name), block_to_data in data_lists.items(): if not block_to_data: continue blocks, data_list = [], [] for block, data in block_to_data.items(): blocks.append(block) data_list.extend(data) if len(blocks) == 1 and blocks[0] is None: blocks = None final_pt = ",".join(p for p in (pt_name, partition_str) if p is not None) # fixme cover up for overwrite failure on table.format.version=2 if overwrite and manual_truncate: if not final_pt or target_table.exist_partition(final_pt): target_table.truncate(partition_spec=final_pt or None) with target_table.open_writer( partition=final_pt or None, blocks=blocks, arrow=is_arrow, create_partition=create_partition, reopen=append_missing_cols, overwrite=overwrite, **kw ) as writer: if blocks is None: for data in data_list: writer.write(data) else: for block, data in zip(blocks, data_list): writer.write(block, data) @classmethod def write_sql_result_to_table( cls, odps, table_name, sql, partition=None, partition_cols=None, create_table=False, create_partition=False, append_missing_cols=False, overwrite=False, project=None, schema=None, lifecycle=None, type_mapping=None, table_schema_callback=None, table_kwargs=None, hints=None, running_cluster=None, unique_identifier_id=None, **kwargs ): """ Write SQL query results into a specified table and partition. If the target table does not exist, you may specify the argument create_table=True. Columns are inserted into the target table aligned by column names. Note that column order in the target table will NOT be changed. :param str table_name: The target table name :param str sql: The SQL query to execute :param str partition: Target partition in the format "part=value" or "part1=value1,part2=value2" :param list partition_cols: List of dynamic partition fields. If not provided, all partition fields of the target table are used. :param bool create_table: Whether to create the target table if it does not exist. False by default. :param bool create_partition: Whether to create partitions if they do not exist. False by default. :param bool append_missing_cols: Whether to append missing columns to the target table. False by default. :param bool overwrite: Whether to overwrite existing data. False by default. :param str project: project name, if not provided, will be the default project :param str schema: schema name, if not provided, will be the default schema :param int lifecycle: specify table lifecycle when creating tables :param dict type_mapping: specify type mapping for columns when creating tables, can be dicts like ``{"column": "bigint"}``. If column does not exist in data, it will be added as an empty column. :param table_schema_callback: a function to accept table schema resolved from data and return a new schema for table to create. Only works when target table does not exist and ``create_table`` is True. :param dict table_kwargs: specify other kwargs for :meth:`~odps.ODPS.create_table` :param dict hints: specify hints for SQL statements, will be passed through to execute_sql method :param dict running_cluster: specify running cluster for SQL statements, will be passed through to execute_sql method """ partition_cols = partition_cols or kwargs.pop("partitions", None) if isinstance(partition_cols, six.string_types): partition_cols = [partition_cols] temp_table_name = "_".join( [utils.TEMP_TABLE_PREFIX, utils.md5_hexdigest(table_name), uuid.uuid4().hex] ) insert_mode = "OVERWRITE" if overwrite else "INTO" # move table params in kwargs into table_kwargs table_kwargs = dict(table_kwargs or {}) for extra_table_arg in ( "table_properties", "shard_num", "transactional", "primary_key", "storage_tier", ): if extra_table_arg in kwargs: table_kwargs[extra_table_arg] = kwargs.pop(extra_table_arg) # if extra table kwargs are supported, create table ... as ... # may not work, and the table need to be created first with_extra_table_kw = bool(table_kwargs) table_kwargs.update( {"schema": schema, "project": project, "lifecycle": lifecycle} ) sql_kwargs = kwargs.copy() sql_kwargs.update( { "hints": copy.deepcopy(hints or {}), "running_cluster": running_cluster, "unique_identifier_id": unique_identifier_id, "default_schema": schema, "project": project, } ) def _format_raw_sql(fmt, args): """Add DDLs for existing SQL, multiple statements acceptable""" args = list(args) sql_parts = utils.split_sql_by_semicolon(args[-1]) if len(sql_parts) == 1: return fmt % tuple(args) # need script mode for multiple statements sql_kwargs["hints"]["odps.sql.submit.mode"] = "script" sql_parts[-1] = fmt % tuple(args[:-1] + [sql_parts[-1]]) return "\n".join(sql_parts) # Check if the target table exists if not odps.exist_table(table_name, project=project, schema=schema): if not create_table: raise ValueError( "Table %s does not exist and create_table is set to False." % table_name ) elif ( not partition and not partition_cols and not with_extra_table_kw and table_schema_callback is None ): # return directly when creating table without partitions # and special kwargs if not lifecycle: lifecycle_clause = "" else: lifecycle_clause = "LIFECYCLE %d " % lifecycle sql_stmt = _format_raw_sql( "CREATE TABLE `%s` %sAS %s", (table_name, lifecycle_clause, sql) ) odps.execute_sql(sql_stmt, **sql_kwargs) return else: # create temp table, get result schema and create target table sql_stmt = _format_raw_sql( "CREATE TABLE `%s` LIFECYCLE %d AS %s", (temp_table_name, options.temp_lifecycle, sql), ) odps.execute_sql(sql_stmt, **sql_kwargs) tmp_schema = odps.get_table(temp_table_name).table_schema out_table_schema = cls._resolve_schema( data_schema=tmp_schema, partition=partition, partition_cols=partition_cols, type_mapping=type_mapping, ) if table_schema_callback: out_table_schema = table_schema_callback(out_table_schema) target_table = odps.create_table( table_name, table_schema=out_table_schema, **table_kwargs ) else: target_table = cls._get_table_obj( odps, table_name, project=project, schema=schema ) # for partitioned target, create a temp table and store results sql_stmt = _format_raw_sql( "CREATE TABLE `%s` LIFECYCLE %d AS %s", (temp_table_name, options.temp_lifecycle, sql), ) odps.execute_sql(sql_stmt, **sql_kwargs) try: partition_cols = cls._check_partition_specified( table_name, target_table.table_schema, partition_cols=partition_cols, partition=partition, ) temp_table = odps.get_table(temp_table_name) union_cols, diff_cols = cls._calc_schema_diff( temp_table.table_schema, target_table.table_schema, partition_cols=partition_cols, ) if not union_cols: warnings.warn( "No columns overlapped between source and target table. If result " "is not as expected, please check if your query provides correct " "column names." ) if diff_cols: if append_missing_cols: target_table.add_columns(diff_cols, hints=hints) union_cols += diff_cols else: warnings.warn( "Columns in source query %s are missing in target table %s. " "Specify append_missing_cols=True to append missing columns " "to the target table." % (", ".join(c.name for c in diff_cols), target_table.name) ) target_columns = [col.name for col in union_cols] if partition: static_part_spec = odps_types.PartitionSpec(partition) else: static_part_spec = odps_types.PartitionSpec() if ( target_table.table_schema.partitions and len(static_part_spec) == len(target_table.table_schema.partitions) and not target_table.exist_partition(static_part_spec) ): if create_partition: target_table.create_partition(static_part_spec) else: raise ValueError( "Partition %s does not exist and create_partition is set to False." % static_part_spec ) all_parts, part_specs, dyn_parts = [], [], [] has_dyn_parts = False for col in target_table.table_schema.partitions: if col.name in static_part_spec: spec = "%s='%s'" % ( col.name, utils.escape_odps_string(static_part_spec[col.name]), ) part_specs.append(spec) all_parts.append(spec) elif col.name not in temp_table.table_schema: if col.generate_expression: all_parts.append(col.name) has_dyn_parts = True continue else: raise ValueError( "Partition column %s does not exist in source query." % col.name ) else: has_dyn_parts = True all_parts.append(col.name) dyn_parts.append(col.name) part_specs.append(col.name) if not part_specs: part_clause = "" else: part_clause = "PARTITION (%s) " % ", ".join(part_specs) if overwrite: insert_mode = "INTO" if not has_dyn_parts: target_table.truncate(partition or None) elif any(target_table.partitions): # generate column expressions in topological order col_to_exprs = cls._get_ordered_col_expressions( target_table, partition ) part_expr_map = { col: "`%s`" % col for col in partition_cols if col in temp_table.table_schema } generated_cols = set() for col_name, expr in col_to_exprs.items(): if col_name in part_expr_map: continue generated_cols.add(col_name) part_expr_map[col_name] = expr.to_str(part_expr_map) # add an alias for generated columns part_expr_map = { col: ( v if col not in generated_cols else "%s AS `%s`" % (v, col) ) for col, v in part_expr_map.items() } # query for partitions need to be truncated part_selections = [ part_expr_map[col_name] for col_name in partition_cols ] part_distinct_sql = "SELECT DISTINCT %s FROM `%s`" % ( ", ".join(part_selections), temp_table_name, ) distinct_inst = odps.execute_sql(part_distinct_sql) trunc_part_specs = [] with distinct_inst.open_reader(tunnel=True) as reader: for row in reader: local_part_specs = [ "%s='%s'" % (c, utils.escape_odps_string(row[c])) if c in row else c for c in all_parts ] local_part_str = ",".join(local_part_specs) if target_table.exist_partition(local_part_str): trunc_part_specs.append(local_part_str) target_table.truncate(trunc_part_specs) col_selection = ", ".join("`%s`" % s for s in (target_columns + dyn_parts)) sql_stmt = "INSERT %s `%s` %s (%s) SELECT %s FROM %s" % ( insert_mode, table_name, part_clause, col_selection, col_selection, temp_table_name, ) odps.execute_sql(sql_stmt, **sql_kwargs) finally: odps.delete_table( temp_table_name, project=project, schema=schema, if_exists=True, async_=True, )