# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import functools
import itertools
import logging
import multiprocessing
import os
import socket
import struct
import sys
import threading
import uuid
import warnings
from collections import OrderedDict, defaultdict
from types import GeneratorType, MethodType

try:
    import pyarrow as pa
except (AttributeError, ImportError):
    pa = None
try:
    import pandas as pd
except ImportError:
    pd = None

from .. import errors
from .. import types as odps_types
from .. import utils
from ..compat import Iterable, six
from ..config import options
from ..dag import DAG
from ..lib import cloudpickle
from ..lib.tblib import pickling_support
from .readers import TunnelArrowReader, TunnelRecordReader

logger = logging.getLogger(__name__)
pickling_support.install()

_GET_NEXT_BLOCK_CMD = 0x01
_PUT_WRITTEN_BLOCKS_CMD = 0x02
_SERVER_ERROR_CMD = 0xFD
_STOP_SERVER_CMD = 0xFE


if sys.version_info[0] == 2:
    _ord_if_possible = ord

    def _load_classmethod(cls, func_name):
        return getattr(cls, func_name)

    class InstanceMethodWrapper(object):
        """Trick for classmethods under Python 2.7 to be pickleable"""

        def __init__(self, func):
            assert isinstance(func, MethodType)
            assert isinstance(func.im_self, type)
            self._func = func

        def __call__(self, *args, **kw):
            return self._func()

        def __reduce__(self):
            return _load_classmethod, (self._func.im_self, self._func.im_func.__name__)

    _wrap_classmethod = InstanceMethodWrapper
else:

    def _ord_if_possible(x):
        return x

    def _wrap_classmethod(x):
        return x


class SpawnedTableReaderMixin(object):
    @property
    def schema(self):
        return self._parent.table_schema

    @staticmethod
    def _read_table_split(
        conn,
        download_id,
        start,
        count,
        idx,
        rest_client=None,
        project=None,
        table_name=None,
        partition_spec=None,
        tunnel_endpoint=None,
        quota_name=None,
        columns=None,
        arrow=False,
        schema_name=None,
        append_partitions=None,
    ):
        # read part data
        from ..tunnel import TableTunnel

        try:
            tunnel = TableTunnel(
                client=rest_client,
                project=project,
                endpoint=tunnel_endpoint,
                quota_name=quota_name,
            )
            session = utils.call_with_retry(
                tunnel.create_download_session,
                table_name,
                schema=schema_name,
                download_id=download_id,
                partition_spec=partition_spec,
            )

            def _data_to_pandas():
                if not arrow:
                    with session.open_record_reader(
                        start,
                        count,
                        columns=columns,
                        append_partitions=append_partitions,
                    ) as reader:
                        return reader.to_pandas()
                else:
                    with session.open_arrow_reader(
                        start,
                        count,
                        columns=columns,
                        append_partitions=append_partitions,
                    ) as reader:
                        return reader.to_pandas()

            data = utils.call_with_retry(_data_to_pandas)
            conn.send((idx, data, True))
        except:
            try:
                conn.send((idx, sys.exc_info(), False))
            except:
                logger.exception("Failed to write in process %d", idx)
                raise

    def _get_process_split_reader(self, columns=None, append_partitions=None):
        rest_client = self._parent._client
        table_name = self._parent.name
        schema_name = self._parent.get_schema()
        project = self._parent.project.name
        tunnel_endpoint = self._download_session._client.endpoint
        quota_name = self._download_session._quota_name
        partition_spec = self._partition_spec

        return functools.partial(
            self._read_table_split,
            rest_client=rest_client,
            project=project,
            table_name=table_name,
            partition_spec=partition_spec,
            tunnel_endpoint=tunnel_endpoint,
            quota_name=quota_name,
            arrow=isinstance(self, TunnelArrowReader),
            columns=columns or self._column_names,
            schema_name=schema_name,
            append_partitions=append_partitions,
        )


class TableRecordReader(SpawnedTableReaderMixin, TunnelRecordReader):
    def __init__(
        self,
        table,
        download_session,
        partition_spec=None,
        columns=None,
        append_partitions=True,
    ):
        super(TableRecordReader, self).__init__(
            table,
            download_session,
            columns=columns,
            append_partitions=append_partitions,
        )
        self._partition_spec = partition_spec


class TableArrowReader(SpawnedTableReaderMixin, TunnelArrowReader):
    def __init__(
        self,
        table,
        download_session,
        partition_spec=None,
        columns=None,
        append_partitions=False,
    ):
        super(TableArrowReader, self).__init__(
            table,
            download_session,
            columns=columns,
            append_partitions=append_partitions,
        )
        self._partition_spec = partition_spec


class MPBlockServer(object):
    def __init__(self, writer):
        self._writer = writer
        self._sock = None
        self._serve_thread_obj = None
        self._authkey = multiprocessing.current_process().authkey

    @property
    def address(self):
        return self._sock.getsockname() if self._sock else None

    @property
    def authkey(self):
        return self._authkey

    def _serve_thread(self):
        while True:
            data, addr = self._sock.recvfrom(4096)
            try:
                pos = len(self._authkey)
                assert data[:pos] == self._authkey, "Authentication key mismatched!"

                cmd_code = _ord_if_possible(data[pos])
                pos += 1
                if cmd_code == _GET_NEXT_BLOCK_CMD:
                    block_id = self._writer._gen_next_block_id()
                    data = struct.pack("<B", _GET_NEXT_BLOCK_CMD) + struct.pack(
                        "<I", block_id
                    )
                    self._sock.sendto(data, addr)
                elif cmd_code == _PUT_WRITTEN_BLOCKS_CMD:
                    blocks_queue = self._writer._used_block_id_queue
                    (count,) = struct.unpack("<H", data[pos : pos + 2])
                    pos += 2
                    assert 4 * count < len(data), "Data too short for block count!"
                    block_ids = struct.unpack(
                        "<%dI" % count, data[pos : pos + 4 * count]
                    )
                    blocks_queue.put(block_ids)
                    self._sock.sendto(struct.pack("<B", _PUT_WRITTEN_BLOCKS_CMD), addr)
                elif cmd_code == _STOP_SERVER_CMD:
                    assert (
                        addr[0] == self._sock.getsockname()[0]
                    ), "Cannot stop server from other hosts!"
                    break
                else:  # pragma: no cover
                    raise AssertionError("Unrecognized command %x", cmd_code)
            except BaseException:
                pk_exc_info = cloudpickle.dumps(sys.exc_info())
                data = (
                    struct.pack("<B", _SERVER_ERROR_CMD)
                    + struct.pack("<I", len(pk_exc_info))
                    + pk_exc_info
                )
                self._sock.sendto(data, addr)
                logger.exception("Serve thread error.")

        self._sock.close()
        self._sock = None

    def start(self):
        self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self._sock.bind(("127.0.0.1", 0))
        self._serve_thread_obj = threading.Thread(target=self._serve_thread)
        self._serve_thread_obj.daemon = True
        self._serve_thread_obj.start()

    def stop(self):
        stop_sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        stop_data = self._authkey + struct.pack("<B", _STOP_SERVER_CMD)
        stop_sock.sendto(stop_data, self._sock.getsockname())
        stop_sock.close()
        self._serve_thread_obj.join()


class MPBlockClient(object):
    _MAX_BLOCK_COUNT = 256

    def __init__(self, address, authkey):
        self._addr = address
        self._authkey = authkey
        self._sock = None

    def __del__(self):
        self.close()

    def close(self):
        if self._sock is not None:
            self._sock.close()
            self._sock = None

    def _get_socket(self):
        if self._sock is None:
            self._sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        return self._sock

    @staticmethod
    def _reraise_remote_error(recv_data):
        if recv_data[0] != _SERVER_ERROR_CMD:
            return
        (pk_len,) = struct.unpack("<I", recv_data[1:5])
        exc_info = cloudpickle.loads(recv_data[5 : 5 + pk_len])
        six.reraise(*exc_info)

    def get_next_block_id(self):
        sock = self._get_socket()
        data = self._authkey + struct.pack("<B", _GET_NEXT_BLOCK_CMD)
        sock.sendto(data, self._addr)
        recv_data, server_addr = sock.recvfrom(1024)
        self._reraise_remote_error(recv_data)
        assert _ord_if_possible(recv_data[0]) == _GET_NEXT_BLOCK_CMD
        assert self._addr == server_addr
        (count,) = struct.unpack("<I", recv_data[1:5])
        return count

    def put_written_blocks(self, block_ids):
        sock = self._get_socket()
        for pos in range(0, len(block_ids), self._MAX_BLOCK_COUNT):
            sub_block_ids = block_ids[pos : pos + self._MAX_BLOCK_COUNT]
            data = (
                self._authkey
                + struct.pack("<B", _PUT_WRITTEN_BLOCKS_CMD)
                + struct.pack("<H", len(sub_block_ids))
                + struct.pack("<%dI" % len(sub_block_ids), *sub_block_ids)
            )
            sock.sendto(data, self._addr)
            recv_data, server_addr = sock.recvfrom(1024)
            self._reraise_remote_error(recv_data)
            assert _ord_if_possible(recv_data[0]) == _PUT_WRITTEN_BLOCKS_CMD
            assert self._addr == server_addr


class AbstractTableWriter(object):
    def __init__(
        self, table, upload_session, blocks=None, commit=True, on_close=None, **kwargs
    ):
        self._table = table
        self._upload_session = upload_session
        self._commit = commit
        self._closed = False
        self._on_close = on_close

        self._use_buffered_writer = None
        if blocks is not None:
            self._use_buffered_writer = False

        # block writer options
        self._blocks = blocks or upload_session.blocks or [0]
        self._blocks_writes = [False] * len(self._blocks)
        self._blocks_writers = [None] * len(self._blocks)

        for block in upload_session.blocks or ():
            self._blocks_writes[self._blocks.index(block)] = True

        # buffered writer options
        self._thread_to_buffered_writers = dict()

        # objects for cross-process sharings
        self._mp_server = None
        self._main_pid = kwargs.get("main_pid") or os.getpid()
        self._mp_fixed = kwargs.get("mp_fixed")
        if kwargs.get("mp_client"):
            self._mp_client = kwargs["mp_client"]
            self._mp_context = self._block_id_counter = self._used_block_id_queue = None
        else:
            self._mp_client = self._mp_authkey = None
            self._mp_context = kwargs.get("mp_context") or multiprocessing
            self._block_id_counter = kwargs.get(
                "block_id_counter"
            ) or self._mp_context.Value("i", 1 + max(upload_session.blocks or [0]))
            self._used_block_id_queue = (
                kwargs.get("used_block_id_queue") or self._mp_context.Queue()
            )

    @classmethod
    def _restore_subprocess_writer(
        cls,
        mp_server_address,
        mp_server_auth,
        upload_id,
        main_pid=None,
        blocks=None,
        rest_client=None,
        project=None,
        table_name=None,
        partition_spec=None,
        tunnel_endpoint=None,
        quota_name=None,
        schema=None,
    ):
        from ..core import ODPS
        from ..tunnel import TableTunnel

        odps_entry = ODPS(
            account=rest_client.account,
            app_account=rest_client.app_account,
            endpoint=rest_client.endpoint,
            overwrite_global=False,
        )
        tunnel = TableTunnel(
            client=rest_client,
            project=project,
            endpoint=tunnel_endpoint,
            quota_name=quota_name,
        )
        table = odps_entry.get_table(table_name, schema=schema, project=project)
        session = utils.call_with_retry(
            tunnel.create_upload_session,
            table_name,
            schema=schema,
            upload_id=upload_id,
            partition_spec=partition_spec,
        )
        mp_client = MPBlockClient(mp_server_address, mp_server_auth)
        writer = cls(
            table,
            session,
            commit=False,
            blocks=blocks,
            main_pid=main_pid,
            mp_fixed=True,
            mp_client=mp_client,
        )
        return writer

    def _start_mp_server(self):
        if self._mp_server is not None:
            return
        self._mp_server = MPBlockServer(self)
        self._mp_server.start()
        # replace mp queue with ordinary queue
        self._used_block_id_queue = six.moves.queue.Queue()

    def __reduce__(self):
        rest_client = self._table._client
        table_name = self._table.name
        schema_name = self._table.get_schema()
        project = self._table.project.name
        tunnel_endpoint = self._upload_session._client.endpoint
        quota_name = self._upload_session._quota_name
        partition_spec = self._upload_session._partition_spec
        blocks = None if self._use_buffered_writer is not False else self._blocks

        self._start_mp_server()

        return _wrap_classmethod(self._restore_subprocess_writer), (
            self._mp_server.address,
            bytes(self._mp_server.authkey),
            self.upload_id,
            self._main_pid,
            blocks,
            rest_client,
            project,
            table_name,
            partition_spec,
            tunnel_endpoint,
            quota_name,
            schema_name,
        )

    @property
    def upload_id(self):
        return self._upload_session.id

    @property
    def schema(self):
        return self._table.table_schema

    @property
    def status(self):
        return self._upload_session.status

    def _open_writer(self, block_id, compress):
        raise NotImplementedError

    def _write_contents(self, writer, *args):
        raise NotImplementedError

    def _gen_next_block_id(self):
        if self._mp_client is not None:
            return self._mp_client.get_next_block_id()
        with self._block_id_counter.get_lock():
            block_id = self._block_id_counter.value
            self._block_id_counter.value += 1
            return block_id

    def _fix_mp_attributes(self):
        if not self._mp_fixed:
            self._mp_fixed = True
            if os.getpid() == self._main_pid:
                return
            self._commit = False
            self._on_close = None

    def write(self, *args, **kwargs):
        if self._closed:
            raise IOError("Cannot write to a closed writer.")
        self._fix_mp_attributes()

        compress = kwargs.get("compress", False)

        block_id = kwargs.get("block_id")
        if block_id is None:
            if type(args[0]) in six.integer_types:
                block_id = args[0]
                args = args[1:]
            else:
                block_id = 0 if options.tunnel.use_block_writer_by_default else None

        use_buffered_writer = block_id is None
        if self._use_buffered_writer is None:
            self._use_buffered_writer = use_buffered_writer
        elif self._use_buffered_writer is not use_buffered_writer:
            raise ValueError(
                "Cannot mix block writing mode with non-block writing mode within a single writer"
            )

        if use_buffered_writer:
            idx = None
            writer = self._thread_to_buffered_writers.get(
                threading.current_thread().ident
            )
        else:
            idx = self._blocks.index(block_id)
            writer = self._blocks_writers[idx]

        if writer is None:
            writer = self._open_writer(block_id, compress)

        self._write_contents(writer, *args)
        if not use_buffered_writer:
            self._blocks_writes[idx] = True

    def close(self):
        if self._closed:
            return

        written_blocks = []
        if self._use_buffered_writer:
            for writer in self._thread_to_buffered_writers.values():
                writer.close()
                written_blocks.extend(writer.get_blocks_written())
        else:
            for writer in self._blocks_writers:
                if writer is not None:
                    writer.close()
            written_blocks = [
                block
                for block, block_write in zip(self._blocks, self._blocks_writes)
                if block_write
            ]

        if written_blocks:
            if self._mp_client is not None:
                self._mp_client.put_written_blocks(written_blocks)
            else:
                self._used_block_id_queue.put(written_blocks)

        if self._commit:
            collected_blocks = []
            # as queue.empty() not reliable, we need to fill local blocks manually
            collected_blocks.extend(written_blocks)
            while not self._used_block_id_queue.empty():
                collected_blocks.extend(self._used_block_id_queue.get())
            collected_blocks.extend(self._upload_session.blocks or [])
            collected_blocks = sorted(set(collected_blocks))
            self._upload_session.commit(collected_blocks)

        if callable(self._on_close):
            self._on_close()

        if self._mp_client is not None:
            self._mp_client.close()
            self._mp_client = None
        if self._mp_server is not None:
            self._mp_server.stop()
            self._mp_server = None
        self._closed = True

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        # if an error occurs inside the with block, we do not commit
        if exc_val is not None:
            return
        self.close()


class ToRecordsMixin(object):
    def _new_record(self, arg):
        raise NotImplementedError

    def _to_records(self, *args):
        def convert_records(arg, sample_rec):
            if odps_types.is_record(sample_rec):
                return arg
            elif isinstance(sample_rec, (list, tuple)):
                return (self._new_record(vals) for vals in arg)
            else:
                return [self._new_record(arg)]

        if len(args) == 0:
            return
        if len(args) > 1:
            args = [args]

        arg = args[0]
        if odps_types.is_record(arg):
            return [arg]
        elif isinstance(arg, (list, tuple)):
            return convert_records(arg, arg[0])
        elif isinstance(arg, GeneratorType):
            try:
                # peek the first element and then put back
                next_arg = six.next(arg)
                chained = itertools.chain((next_arg,), arg)
                return convert_records(chained, next_arg)
            except StopIteration:
                return ()
        else:
            raise ValueError("Unsupported record type.")


class TableRecordWriter(ToRecordsMixin, AbstractTableWriter):
    def _open_writer(self, block_id, compress):
        if self._use_buffered_writer:
            writer = self._upload_session.open_record_writer(
                compress=compress,
                initial_block_id=self._gen_next_block_id(),
                block_id_gen=self._gen_next_block_id,
            )
            thread_ident = threading.current_thread().ident
            self._thread_to_buffered_writers[thread_ident] = writer
        else:
            writer = self._upload_session.open_record_writer(
                block_id, compress=compress
            )
            self._blocks_writers[block_id] = writer
        return writer

    def _new_record(self, arg):
        return self._upload_session.new_record(arg)

    def _write_contents(self, writer, *args):
        for record in self._to_records(*args):
            writer.write(record)


class TableArrowWriter(AbstractTableWriter):
    def _open_writer(self, block_id, compress):
        if self._use_buffered_writer:
            writer = self._upload_session.open_arrow_writer(
                compress=compress,
                initial_block_id=self._gen_next_block_id(),
                block_id_gen=self._gen_next_block_id,
            )
            thread_ident = threading.current_thread().ident
            self._thread_to_buffered_writers[thread_ident] = writer
        else:
            writer = self._upload_session.open_arrow_writer(block_id, compress=compress)
            self._blocks_writers[block_id] = writer
        return writer

    def _write_contents(self, writer, *args):
        for arg in args:
            writer.write(arg)


class TableUpsertWriter(ToRecordsMixin):
    def __init__(self, table, upsert_session, commit=True, on_close=None, **_):
        self._table = table
        self._upsert_session = upsert_session
        self._closed = False
        self._commit = commit
        self._on_close = on_close

        self._upsert = None

    def _open_upsert(self, compress):
        self._upsert = self._upsert_session.open_upsert_stream(compress=compress)

    def _new_record(self, arg):
        return self._upsert_session.new_record(arg)

    def _write(self, *args, **kw):
        compress = kw.pop("compress", None)
        delete = kw.pop("delete", False)
        if not self._upsert:
            self._open_upsert(compress)

        for record in self._to_records(*args):
            if delete:
                self._upsert.delete(record)
            else:
                self._upsert.upsert(record)

    def write(self, *args, **kw):
        compress = kw.pop("compress", None)
        self._write(*args, compress=compress, delete=False)

    def delete(self, *args, **kw):
        compress = kw.pop("compress", None)
        self._write(*args, compress=compress, delete=True)

    def close(self, success=True, commit=True):
        if not success:
            self._upsert_session.abort()
        else:
            self._upsert.close()
            if commit and self._commit:
                self._upsert_session.commit()

        if callable(self._on_close):
            self._on_close()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        # if an error occurs inside the with block, we do not commit
        if exc_val is not None:
            self.close(success=False, commit=False)
            return
        self.close()


class TableIOMethods(object):
    @classmethod
    def _get_table_obj(cls, odps, name, project=None, schema=None):
        if isinstance(name, six.string_types) and "." in name:
            project, schema, name = odps._split_object_dots(name)

        if not isinstance(name, six.string_types):
            if name.get_schema():
                schema = name.get_schema().name
            project, name = name.project.name, name.name

        parent = odps._get_project_or_schema(project, schema)
        return parent.tables[name]

    @classmethod
    def read_table(
        cls,
        odps,
        name,
        limit=None,
        start=0,
        step=None,
        project=None,
        schema=None,
        partition=None,
        **kw
    ):
        """
        Read table's records.

        :param name: table or table name
        :type name: :class:`odps.models.table.Table` or str
        :param limit:  the records' size, if None will read all records from the table
        :param start:  the record where read starts with
        :param step:  default as 1
        :param project: project name, if not provided, will be the default project
        :param str schema: schema name, if not provided, will be the default schema
        :param partition: the partition of this table to read
        :param list columns: the columns' names which are the parts of table's columns
        :param bool compress: if True, the data will be compressed during downloading
        :param compress_option: the compression algorithm, level and strategy
        :type compress_option: :class:`odps.tunnel.CompressOption`
        :param endpoint: tunnel service URL
        :param reopen: reading the table will reuse the session which opened last time,
                       if set to True will open a new download session, default as False
        :return: records
        :rtype: generator

        :Example:

        >>> for record in odps.read_table('test_table', 100):
        >>>     # deal with such 100 records
        >>> for record in odps.read_table('test_table', partition='pt=test', start=100, limit=100):
        >>>     # read the `pt=test` partition, skip 100 records and read 100 records

        .. seealso:: :class:`odps.models.Record`
        """

        table = cls._get_table_obj(odps, name, project=project, schema=schema)

        compress = kw.pop("compress", False)
        columns = kw.pop("columns", None)

        with table.open_reader(partition=partition, **kw) as reader:
            for record in reader.read(
                start, limit, step=step, compress=compress, columns=columns
            ):
                yield record

    @classmethod
    def _is_pa_collection(cls, obj):
        return pa is not None and isinstance(obj, (pa.Table, pa.RecordBatch))

    @classmethod
    def _is_pd_df(cls, obj):
        return pd is not None and isinstance(obj, pd.DataFrame)

    @classmethod
    def _resolve_schema(
        cls,
        records_list=None,
        data_schema=None,
        unknown_as_string=False,
        partition=None,
        partition_cols=None,
        type_mapping=None,
    ):
        from ..df.backends.odpssql.types import df_schema_to_odps_schema
        from ..df.backends.pd.types import pd_to_df_schema
        from ..tunnel.io.types import arrow_schema_to_odps_schema

        type_mapping = type_mapping or {}
        type_mapping = {
            k: odps_types.validate_data_type(v) for k, v in type_mapping.items()
        }
        if records_list is not None:
            if cls._is_pa_collection(records_list):
                data_schema = arrow_schema_to_odps_schema(records_list.schema)
            elif cls._is_pd_df(records_list):
                data_schema = df_schema_to_odps_schema(
                    pd_to_df_schema(
                        records_list,
                        unknown_as_string=unknown_as_string,
                        type_mapping=type_mapping,
                    )
                )
            elif isinstance(records_list, list) and odps_types.is_record(
                records_list[0]
            ):
                data_schema = odps_types.OdpsSchema(records_list[0]._columns)
            else:
                raise TypeError(
                    "Inferring schema from provided data not implemented. "
                    "You need to supply a pandas DataFrame or records."
                )
        assert data_schema is not None

        part_col_names = partition_cols or []
        if partition is not None:
            part_spec = odps_types.PartitionSpec(partition)
            part_col_names.extend(k for k in part_spec.keys())
        if part_col_names:
            part_col_set = set(part_col_names)
            simple_cols = [c for c in data_schema.columns if c.name not in part_col_set]
            part_cols = [
                odps_types.Column(n, odps_types.string) for n in part_col_names
            ]
            data_schema = odps_types.OdpsSchema(simple_cols, part_cols)

        if not type_mapping:
            return data_schema

        simple_cols, part_cols = [], []
        unmapped_cols = set(type_mapping.keys())
        for col in data_schema.columns:
            if col.name not in type_mapping:
                simple_cols.append(col)
            else:
                unmapped_cols.remove(col.name)
                simple_cols.append(col.replace(type=type_mapping[col.name]))
        for col in getattr(data_schema, "partitions", None) or ():
            if col.name not in type_mapping:
                part_cols.append(col)
            else:
                unmapped_cols.remove(col.name)
                part_cols.append(col.replace(type=type_mapping[col.name]))

        for col_name in unmapped_cols:
            simple_cols.append(
                odps_types.Column(name=col_name, type=type_mapping[col_name])
            )
        return odps_types.OdpsSchema(simple_cols, part_cols or None)

    @classmethod
    def _calc_schema_diff(cls, src_schema, dest_schema, partition_cols=None):
        if not src_schema or not dest_schema:
            return [], []
        union_cols, diff_cols = [], []
        part_col_set = set(partition_cols or [])
        # collect union columns in the order of dest schema
        for col in dest_schema.simple_columns:
            if col.name in src_schema:
                union_cols.append(col)
        # collect columns not in dest schema
        for col in src_schema.simple_columns:
            if col.name not in dest_schema and col.name not in part_col_set:
                diff_cols.append(col)
        return union_cols, diff_cols

    @classmethod
    def _check_partition_specified(
        cls, table_name, table_schema, partition_cols=None, partition=None
    ):
        partition_cols = partition_cols or []
        no_eval_set = set(n.lower() for n in partition_cols)
        if partition:
            no_eval_set.update(
                c.lower() for c in odps_types.PartitionSpec(partition).keys()
            )
        expr_cols = [
            c.name.lower()
            for c in table_schema.partitions
            if c.generate_expression and c.name.lower() not in no_eval_set
        ]
        partition_cols += expr_cols
        if not partition_cols and not partition:
            if table_schema.partitions:
                raise ValueError(
                    "Partition spec is required for table %s with partitions, "
                    "please specify a partition with `partition` argument or "
                    "specify a list of columns with `partition_cols` argument "
                    "to enable dynamic partitioning." % table_name
                )
            return partition_cols
        else:
            if not table_schema.partitions:
                raise ValueError(
                    "Cannot store into a non-partitioned table %s when `partition` "
                    "or `partition_cols` is specified." % table_name
                )
            all_parts = (
                [n.lower() for n in odps_types.PartitionSpec(partition).keys()]
                if partition
                else []
            )
            if partition_cols:
                all_parts.extend(partition_cols)
            req_all_parts_set = set(all_parts)
            table_all_parts = [c.name.lower() for c in table_schema.partitions]
            no_exist_parts = req_all_parts_set - set(table_all_parts)
            if no_exist_parts:
                raise ValueError(
                    "Partitions %s are not in table %s whose partitions are (%s)."
                    % (sorted(no_exist_parts), table_name, table_all_parts)
                )
            no_specified_parts = set(table_all_parts) - req_all_parts_set
            if no_specified_parts:
                raise ValueError(
                    "Partitions %s in table %s are not specified in `partition_cols` or "
                    "`partition` argument." % (sorted(no_specified_parts), table_name)
                )
            return partition_cols

    @classmethod
    def _get_ordered_col_expressions(cls, table, partition):
        """
        Get column expressions in topological order
        by variable dependencies
        """
        part_spec = odps_types.PartitionSpec(partition)
        col_to_expr = {
            c.name.lower(): table._get_column_generate_expression(c.name)
            for c in table.table_schema.columns
            if c.name not in part_spec
        }
        col_to_expr = {c: expr for c, expr in col_to_expr.items() if expr}
        if not col_to_expr:
            # no columns with expressions, quit
            return {}
        col_dag = DAG()
        for col in col_to_expr:
            col_dag.add_node(col)
            for ref in col_to_expr[col].references:
                ref_col_name = ref.lower()
                col_dag.add_node(ref_col_name)
                col_dag.add_edge(ref_col_name, col)

        out_col_to_expr = OrderedDict()
        for col in col_dag.topological_sort():
            if col not in col_to_expr:
                continue
            out_col_to_expr[col] = col_to_expr[col]
        return out_col_to_expr

    @classmethod
    def _fill_missing_expressions(cls, data, col_to_expr):
        def handle_recordbatch(batch):
            col_names = list(batch.schema.names)
            col_arrays = list(batch.columns)
            for col in missing_cols:
                col_names.append(col)
                col_arrays.append(col_to_expr[col].eval(batch))
            return pa.RecordBatch.from_arrays(col_arrays, col_names)

        if pa and isinstance(data, (pa.Table, pa.RecordBatch)):
            col_name_set = set(c.lower() for c in data.schema.names)
            missing_cols = [c for c in col_to_expr if c not in col_name_set]
            if not missing_cols:
                return data
            if isinstance(data, pa.Table):
                batches = [handle_recordbatch(b) for b in data.to_batches()]
                return pa.Table.from_batches(batches)
            else:
                return handle_recordbatch(data)
        elif pd and isinstance(data, pd.DataFrame):
            col_name_set = set(c.lower() for c in data.columns)
            missing_cols = [c for c in col_to_expr if c not in col_name_set]
            if not missing_cols:
                return data
            data = data.copy()
            for col in missing_cols:
                data[col] = col_to_expr[col].eval(data)
            return data
        else:
            wrapped = False
            if odps_types.is_record(data):
                data = [data]
                wrapped = True
            for rec in data:
                if not odps_types.is_record(rec):
                    continue
                for c in col_to_expr:
                    if rec[c] is not None:
                        continue
                    rec[c] = col_to_expr[c].eval(rec)
            return data[0] if wrapped else data

    @classmethod
    def _split_block_data_in_partitions(
        cls, table, block_data, partition_cols=None, partition=None
    ):
        from . import Record

        table_schema = table.table_schema
        col_to_expr = cls._get_ordered_col_expressions(table, partition)

        def _fill_cols(data):
            if col_to_expr:
                data = cls._fill_missing_expressions(data, col_to_expr)
            if not pd or not isinstance(data, pd.DataFrame):
                return data
            data.columns = [col.lower() for col in data.columns]
            tb_col_names = [c.name.lower() for c in table_schema.simple_columns]
            tb_col_set = set(tb_col_names)
            extra_cols = [col for col in data.columns if col not in tb_col_set]
            return data.reindex(tb_col_names + extra_cols, axis=1)

        if not partition_cols:
            is_arrow = cls._is_pa_collection(block_data) or cls._is_pd_df(block_data)
            return {(is_arrow, None): [_fill_cols(block_data)]}

        input_cols = list(table_schema.simple_columns) + [
            odps_types.Column(part, odps_types.string) for part in partition_cols
        ]
        input_schema = odps_types.OdpsSchema(input_cols)

        non_generate_idxes = [
            idx
            for idx, c in enumerate(table_schema.simple_columns)
            if not table._get_column_generate_expression(c.name)
        ]
        num_generate_pts = len(
            [
                c
                for c in (table_schema.partitions or [])
                if table._get_column_generate_expression(c.name)
            ]
        )

        parted_data = defaultdict(list)
        if (
            cls._is_pa_collection(block_data)
            or cls._is_pd_df(block_data)
            or odps_types.is_record(block_data)
            or (
                isinstance(block_data, list)
                and block_data
                and not isinstance(block_data[0], list)
            )
        ):
            # pd dataframes, arrow RecordBatch, single record or single record-like array
            block_data = [block_data]
        for data in block_data:
            if cls._is_pa_collection(data):
                data = data.to_pandas()
            elif isinstance(data, list):
                if len(data) != len(input_schema):
                    # fill None columns for generate functions to fill them
                    #  once size of data matches size of non-generative columns
                    if len(data) < len(table_schema.simple_columns) and len(
                        data
                    ) == len(non_generate_idxes):
                        new_data = [None] * len(table_schema.simple_columns)
                        for idx, d in zip(non_generate_idxes, data):
                            new_data[idx] = d
                        data = new_data
                    # fill None partitions for generate functions to fill them
                    data += [None] * (
                        num_generate_pts
                        - (len(data) - len(table_schema.simple_columns))
                    )
                    if len(data) != len(input_schema):
                        raise ValueError(
                            "Need to specify %d values when writing table "
                            "with dynamic partition." % len(input_schema)
                        )
                data = Record(schema=input_schema, values=data)

            if cls._is_pd_df(data):
                data = _fill_cols(data)
                part_set = set(partition_cols)
                for name, group in data.groupby(partition_cols):
                    name = name if isinstance(name, tuple) else (name,)
                    pt_name = ",".join(
                        "=".join([str(n), str(v)]) for n, v in zip(partition_cols, name)
                    )
                    parted_data[(True, pt_name)].append(
                        group.drop(part_set, axis=1, errors="ignore")
                    )
            elif odps_types.is_record(data):
                data = _fill_cols(data)
                pt_name = ",".join(
                    "=".join([str(n), data[str(n)]]) for n in partition_cols
                )
                values = [data[str(c.name)] for c in table_schema.simple_columns]
                if not parted_data[(False, pt_name)]:
                    parted_data[(False, pt_name)].append([])
                parted_data[(False, pt_name)][0].append(
                    Record(schema=table_schema, values=values)
                )
            else:
                raise ValueError(
                    "Cannot accept data with type %s" % type(data).__name__
                )
        return parted_data

    @classmethod
    def write_table(cls, odps, name, *block_data, **kw):
        """
        Write records or pandas DataFrame into given table.

        :param name: table or table name
        :type name: :class:`.models.table.Table` or str
        :param block_data: records / DataFrame, or block ids and records / DataFrame.
            If given records or DataFrame only, the block id will be 0 as default.
        :param str project: project name, if not provided, will be the default project
        :param str schema: schema name, if not provided, will be the default schema
        :param partition: the partition of this table to write into
        :param list partition_cols: columns representing dynamic partitions
        :param bool append_missing_cols: Whether to append missing columns to the target
            table. False by default.
        :param bool overwrite: if True, will overwrite existing data
        :param bool create_table: if true, the table will be created if not exist
        :param dict table_kwargs: specify other kwargs for :meth:`~odps.ODPS.create_table`
        :param dict type_mapping: specify type mapping for columns when creating tables,
            can be dicts like ``{"column": "bigint"}``. If column does not exist in data,
            it will be added as an empty column.
        :param table_schema_callback: a function to accept table schema resolved from data
            and return a new schema for table to create. Only works when target table does
            not exist and ``create_table`` is True.
        :param int lifecycle: specify table lifecycle when creating tables
        :param bool create_partition: if true, the partition will be created if not exist
        :param compress_option: the compression algorithm, level and strategy
        :type compress_option: :class:`odps.tunnel.CompressOption`
        :param str endpoint:  tunnel service URL
        :param bool reopen: writing the table will reuse the session which opened last time,
            if set to True will open a new upload session, default as False
        :return: None

        :Example:

        Write records into a specified table.

        >>> odps.write_table('test_table', data)

        Write records into multiple blocks.

        >>> odps.write_table('test_table', 0, records1, 1, records2)

        Write into a given partition.

        >>> odps.write_table('test_table', data, partition='pt=test')

        Write a pandas DataFrame. Create the table if it does not exist.

        >>> import pandas as pd
        >>> df = pd.DataFrame([
        >>>     [111, 'aaa', True],
        >>>     [222, 'bbb', False],
        >>>     [333, 'ccc', True],
        >>>     [444, '中文', False]
        >>> ], columns=['num_col', 'str_col', 'bool_col'])
        >>> o.write_table('test_table', df, partition='pt=test', create_table=True, create_partition=True)

        Passing more arguments when creating table.

        >>> import pandas as pd
        >>> df = pd.DataFrame([
        >>>     [111, 'aaa', True],
        >>>     [222, 'bbb', False],
        >>>     [333, 'ccc', True],
        >>>     [444, '中文', False]
        >>> ], columns=['num_col', 'str_col', 'bool_col'])
        >>> # this dict will be passed to `create_table` as kwargs.
        >>> table_kwargs = {"transactional": True, "primary_key": "num_col"}
        >>> o.write_table('test_table', df, partition='pt=test', create_table=True, create_partition=True,
        >>>               table_kwargs=table_kwargs)

        Write with dynamic partitioning.

        >>> import pandas as pd
        >>> df = pd.DataFrame([
        >>>     [111, 'aaa', True, 'p1'],
        >>>     [222, 'bbb', False, 'p1'],
        >>>     [333, 'ccc', True, 'p2'],
        >>>     [444, '中文', False, 'p2']
        >>> ], columns=['num_col', 'str_col', 'bool_col', 'pt'])
        >>> o.write_table('test_part_table', df, partition_cols=['pt'], create_partition=True)

        :Note:

        ``write_table`` treats object type of Pandas data as strings as it is often hard to determine their
        types when creating a new table for your data. To make sure the column type meet your need, you can
        specify `type_mapping` argument to specify the column types, for instance,
        ``type_mapping={"col1": "array<struct<id:string>>"}``.

        .. seealso:: :class:`odps.models.Record`
        """
        project = kw.pop("project", None)
        schema = kw.pop("schema", None)
        append_missing_cols = kw.pop("append_missing_cols", False)
        overwrite = kw.pop("overwrite", False)

        single_block_types = (Iterable,)
        if pa is not None:
            single_block_types += (pa.RecordBatch, pa.Table)

        if len(block_data) == 1 and isinstance(block_data[0], single_block_types):
            blocks = [None]
            data_list = list(block_data)
        else:
            blocks = list(block_data[::2])
            data_list = list(block_data[1::2])

            if len(blocks) != len(data_list):
                raise ValueError(
                    "Should invoke like odps.write_table(block_id, records, "
                    "block_id2, records2, ..., **kw)"
                )

        unknown_as_string = kw.pop("unknown_as_string", False)
        create_table = kw.pop("create_table", False)
        create_partition = kw.pop(
            "create_partition", kw.pop("create_partitions", False)
        )
        partition = kw.pop("partition", None)
        partition_cols = kw.pop("partition_cols", None) or kw.pop("partitions", None)
        lifecycle = kw.pop("lifecycle", None)
        type_mapping = kw.pop("type_mapping", None)
        table_schema_callback = kw.pop("table_schema_callback", None)
        table_kwargs = dict(kw.pop("table_kwargs", None) or {})
        if lifecycle:
            table_kwargs["lifecycle"] = lifecycle

        if isinstance(partition_cols, six.string_types):
            partition_cols = [partition_cols]

        try:
            data_sample = data_list[0]
            if isinstance(data_sample, GeneratorType):
                data_gen = data_sample
                data_sample = [next(data_gen)]
                data_list[0] = utils.chain_generator([data_sample[0]], data_gen)
            table_schema = cls._resolve_schema(
                data_sample,
                unknown_as_string=unknown_as_string,
                partition=partition,
                partition_cols=partition_cols,
                type_mapping=type_mapping,
            )
        except TypeError:
            table_schema = None

        if not odps.exist_table(name, project=project, schema=schema):
            if not create_table:
                raise errors.NoSuchTable(
                    "Target table %s not exist. To create a new table "
                    "you can add an argument `create_table=True`." % name
                )
            if callable(table_schema_callback):
                table_schema = table_schema_callback(table_schema)
            target_table = odps.create_table(
                name, table_schema, project=project, schema=schema, **table_kwargs
            )
        else:
            target_table = cls._get_table_obj(
                odps, name, project=project, schema=schema
            )

        union_cols, diff_cols = cls._calc_schema_diff(
            table_schema, target_table.schema, partition_cols=partition_cols
        )
        if table_schema and not union_cols:
            warnings.warn(
                "No columns overlapped between source and target table. If result "
                "is not as expected, please check if your query provides correct "
                "column names."
            )
        if diff_cols:
            if append_missing_cols:
                target_table.add_columns(diff_cols)
            else:
                warnings.warn(
                    "Columns in source data %s are missing in target table %s. "
                    "Specify append_missing_cols=True to append missing columns "
                    "to the target table."
                    % (", ".join(c.name for c in diff_cols), target_table.name)
                )

        partition_cols = cls._check_partition_specified(
            name,
            target_table.table_schema,
            partition_cols=partition_cols,
            partition=partition,
        )

        data_lists = defaultdict(lambda: defaultdict(list))
        for block, data in zip(blocks, data_list):
            for key, parted_data in cls._split_block_data_in_partitions(
                target_table,
                data,
                partition_cols=partition_cols,
                partition=partition,
            ).items():
                data_lists[key][block].extend(parted_data)

        if partition is None or isinstance(partition, six.string_types):
            partition_str = partition
        else:
            partition_str = str(odps_types.PartitionSpec(partition))

        # fixme cover up for overwrite failure on table.format.version=2:
        #  only applicable for transactional table with partitions
        #  with generate expressions
        manual_truncate = (
            overwrite
            and target_table.is_transactional
            and any(
                pt_col.generate_expression
                for pt_col in target_table.table_schema.partitions
            )
        )

        for (is_arrow, pt_name), block_to_data in data_lists.items():
            if not block_to_data:
                continue

            blocks, data_list = [], []
            for block, data in block_to_data.items():
                blocks.append(block)
                data_list.extend(data)

            if len(blocks) == 1 and blocks[0] is None:
                blocks = None

            final_pt = ",".join(p for p in (pt_name, partition_str) if p is not None)
            # fixme cover up for overwrite failure on table.format.version=2
            if overwrite and manual_truncate:
                if not final_pt or target_table.exist_partition(final_pt):
                    target_table.truncate(partition_spec=final_pt or None)
            with target_table.open_writer(
                partition=final_pt or None,
                blocks=blocks,
                arrow=is_arrow,
                create_partition=create_partition,
                reopen=append_missing_cols,
                overwrite=overwrite,
                **kw
            ) as writer:
                if blocks is None:
                    for data in data_list:
                        writer.write(data)
                else:
                    for block, data in zip(blocks, data_list):
                        writer.write(block, data)

    @classmethod
    def write_sql_result_to_table(
        cls,
        odps,
        table_name,
        sql,
        partition=None,
        partition_cols=None,
        create_table=False,
        create_partition=False,
        append_missing_cols=False,
        overwrite=False,
        project=None,
        schema=None,
        lifecycle=None,
        type_mapping=None,
        table_schema_callback=None,
        table_kwargs=None,
        hints=None,
        running_cluster=None,
        unique_identifier_id=None,
        **kwargs
    ):
        """
        Write SQL query results into a specified table and partition. If the target
        table does not exist, you may specify the argument create_table=True. Columns
        are inserted into the target table aligned by column names. Note that column
        order in the target table will NOT be changed.

        :param str table_name: The target table name
        :param str sql: The SQL query to execute
        :param str partition: Target partition in the format "part=value" or
            "part1=value1,part2=value2"
        :param list partition_cols: List of dynamic partition fields. If not provided,
            all partition fields of the target table are used.
        :param bool create_table: Whether to create the target table if it does not exist.
            False by default.
        :param bool create_partition: Whether to create partitions if they do not exist.
            False by default.
        :param bool append_missing_cols: Whether to append missing columns to the target
            table. False by default.
        :param bool overwrite: Whether to overwrite existing data. False by default.
        :param str project: project name, if not provided, will be the default project
        :param str schema: schema name, if not provided, will be the default schema
        :param int lifecycle: specify table lifecycle when creating tables
        :param dict type_mapping: specify type mapping for columns when creating tables,
            can be dicts like ``{"column": "bigint"}``. If column does not exist in data,
            it will be added as an empty column.
        :param table_schema_callback: a function to accept table schema resolved from data
            and return a new schema for table to create. Only works when target table does
            not exist and ``create_table`` is True.
        :param dict table_kwargs: specify other kwargs for :meth:`~odps.ODPS.create_table`
        :param dict hints: specify hints for SQL statements, will be passed through
            to execute_sql method
        :param dict running_cluster: specify running cluster for SQL statements, will
            be passed through to execute_sql method
        """
        partition_cols = partition_cols or kwargs.pop("partitions", None)
        if isinstance(partition_cols, six.string_types):
            partition_cols = [partition_cols]

        temp_table_name = "_".join(
            [utils.TEMP_TABLE_PREFIX, utils.md5_hexdigest(table_name), uuid.uuid4().hex]
        )
        insert_mode = "OVERWRITE" if overwrite else "INTO"

        # move table params in kwargs into table_kwargs
        table_kwargs = dict(table_kwargs or {})
        for extra_table_arg in (
            "table_properties",
            "shard_num",
            "transactional",
            "primary_key",
            "storage_tier",
        ):
            if extra_table_arg in kwargs:
                table_kwargs[extra_table_arg] = kwargs.pop(extra_table_arg)
        # if extra table kwargs are supported, create table ... as ...
        #  may not work, and the table need to be created first
        with_extra_table_kw = bool(table_kwargs)

        table_kwargs.update(
            {"schema": schema, "project": project, "lifecycle": lifecycle}
        )
        sql_kwargs = kwargs.copy()
        sql_kwargs.update(
            {
                "hints": copy.deepcopy(hints or {}),
                "running_cluster": running_cluster,
                "unique_identifier_id": unique_identifier_id,
                "default_schema": schema,
                "project": project,
            }
        )

        def _format_raw_sql(fmt, args):
            """Add DDLs for existing SQL, multiple statements acceptable"""
            args = list(args)
            sql_parts = utils.split_sql_by_semicolon(args[-1])
            if len(sql_parts) == 1:
                return fmt % tuple(args)
            # need script mode for multiple statements
            sql_kwargs["hints"]["odps.sql.submit.mode"] = "script"
            sql_parts[-1] = fmt % tuple(args[:-1] + [sql_parts[-1]])
            return "\n".join(sql_parts)

        # Check if the target table exists
        if not odps.exist_table(table_name, project=project, schema=schema):
            if not create_table:
                raise ValueError(
                    "Table %s does not exist and create_table is set to False."
                    % table_name
                )
            elif (
                not partition
                and not partition_cols
                and not with_extra_table_kw
                and table_schema_callback is None
            ):
                # return directly when creating table without partitions
                #  and special kwargs
                if not lifecycle:
                    lifecycle_clause = ""
                else:
                    lifecycle_clause = "LIFECYCLE %d " % lifecycle
                sql_stmt = _format_raw_sql(
                    "CREATE TABLE `%s` %sAS %s", (table_name, lifecycle_clause, sql)
                )
                odps.execute_sql(sql_stmt, **sql_kwargs)
                return
            else:
                # create temp table, get result schema and create target table
                sql_stmt = _format_raw_sql(
                    "CREATE TABLE `%s` LIFECYCLE %d AS %s",
                    (temp_table_name, options.temp_lifecycle, sql),
                )
                odps.execute_sql(sql_stmt, **sql_kwargs)
                tmp_schema = odps.get_table(temp_table_name).table_schema
                out_table_schema = cls._resolve_schema(
                    data_schema=tmp_schema,
                    partition=partition,
                    partition_cols=partition_cols,
                    type_mapping=type_mapping,
                )
                if table_schema_callback:
                    out_table_schema = table_schema_callback(out_table_schema)
                target_table = odps.create_table(
                    table_name, table_schema=out_table_schema, **table_kwargs
                )
        else:
            target_table = cls._get_table_obj(
                odps, table_name, project=project, schema=schema
            )
            # for partitioned target, create a temp table and store results
            sql_stmt = _format_raw_sql(
                "CREATE TABLE `%s` LIFECYCLE %d AS %s",
                (temp_table_name, options.temp_lifecycle, sql),
            )
            odps.execute_sql(sql_stmt, **sql_kwargs)

        try:
            partition_cols = cls._check_partition_specified(
                table_name,
                target_table.table_schema,
                partition_cols=partition_cols,
                partition=partition,
            )

            temp_table = odps.get_table(temp_table_name)
            union_cols, diff_cols = cls._calc_schema_diff(
                temp_table.table_schema,
                target_table.table_schema,
                partition_cols=partition_cols,
            )
            if not union_cols:
                warnings.warn(
                    "No columns overlapped between source and target table. If result "
                    "is not as expected, please check if your query provides correct "
                    "column names."
                )
            if diff_cols:
                if append_missing_cols:
                    target_table.add_columns(diff_cols, hints=hints)
                    union_cols += diff_cols
                else:
                    warnings.warn(
                        "Columns in source query %s are missing in target table %s. "
                        "Specify append_missing_cols=True to append missing columns "
                        "to the target table."
                        % (", ".join(c.name for c in diff_cols), target_table.name)
                    )

            target_columns = [col.name for col in union_cols]

            if partition:
                static_part_spec = odps_types.PartitionSpec(partition)
            else:
                static_part_spec = odps_types.PartitionSpec()

            if (
                target_table.table_schema.partitions
                and len(static_part_spec) == len(target_table.table_schema.partitions)
                and not target_table.exist_partition(static_part_spec)
            ):
                if create_partition:
                    target_table.create_partition(static_part_spec)
                else:
                    raise ValueError(
                        "Partition %s does not exist and create_partition is set to False."
                        % static_part_spec
                    )

            all_parts, part_specs, dyn_parts = [], [], []
            has_dyn_parts = False
            for col in target_table.table_schema.partitions:
                if col.name in static_part_spec:
                    spec = "%s='%s'" % (
                        col.name,
                        utils.escape_odps_string(static_part_spec[col.name]),
                    )
                    part_specs.append(spec)
                    all_parts.append(spec)
                elif col.name not in temp_table.table_schema:
                    if col.generate_expression:
                        all_parts.append(col.name)
                        has_dyn_parts = True
                        continue
                    else:
                        raise ValueError(
                            "Partition column %s does not exist in source query."
                            % col.name
                        )
                else:
                    has_dyn_parts = True
                    all_parts.append(col.name)
                    dyn_parts.append(col.name)
                    part_specs.append(col.name)

            if not part_specs:
                part_clause = ""
            else:
                part_clause = "PARTITION (%s) " % ", ".join(part_specs)

            if overwrite:
                insert_mode = "INTO"
                if not has_dyn_parts:
                    target_table.truncate(partition or None)
                elif any(target_table.partitions):
                    # generate column expressions in topological order
                    col_to_exprs = cls._get_ordered_col_expressions(
                        target_table, partition
                    )
                    part_expr_map = {
                        col: "`%s`" % col
                        for col in partition_cols
                        if col in temp_table.table_schema
                    }
                    generated_cols = set()
                    for col_name, expr in col_to_exprs.items():
                        if col_name in part_expr_map:
                            continue
                        generated_cols.add(col_name)
                        part_expr_map[col_name] = expr.to_str(part_expr_map)
                    # add an alias for generated columns
                    part_expr_map = {
                        col: (
                            v if col not in generated_cols else "%s AS `%s`" % (v, col)
                        )
                        for col, v in part_expr_map.items()
                    }

                    # query for partitions need to be truncated
                    part_selections = [
                        part_expr_map[col_name] for col_name in partition_cols
                    ]
                    part_distinct_sql = "SELECT DISTINCT %s FROM `%s`" % (
                        ", ".join(part_selections),
                        temp_table_name,
                    )
                    distinct_inst = odps.execute_sql(part_distinct_sql)
                    trunc_part_specs = []
                    with distinct_inst.open_reader(tunnel=True) as reader:
                        for row in reader:
                            local_part_specs = [
                                "%s='%s'" % (c, utils.escape_odps_string(row[c]))
                                if c in row
                                else c
                                for c in all_parts
                            ]
                            local_part_str = ",".join(local_part_specs)
                            if target_table.exist_partition(local_part_str):
                                trunc_part_specs.append(local_part_str)
                    target_table.truncate(trunc_part_specs)

            col_selection = ", ".join("`%s`" % s for s in (target_columns + dyn_parts))
            sql_stmt = "INSERT %s `%s` %s (%s) SELECT %s FROM %s" % (
                insert_mode,
                table_name,
                part_clause,
                col_selection,
                col_selection,
                temp_table_name,
            )
            odps.execute_sql(sql_stmt, **sql_kwargs)
        finally:
            odps.delete_table(
                temp_table_name,
                project=project,
                schema=schema,
                if_exists=True,
                async_=True,
            )
