datahub/client/producer/shard_writer.py (114 lines of code) (raw):

#!/usr/bin/env python # -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import time import atomic import logging import threading from datahub.exceptions import DatahubException from .write_result import WriteResult from .record_pack_queue import RecordPackQueue from ..common.datahub_factory import DatahubFactory class ShardWriter: def __init__(self, project_name, topic_name, sub_id, message_writer, producer_config, shard_id): self._closed = False self._logger = logging.getLogger(ShardWriter.__name__) self._lock = threading.Lock() self._project_name = project_name self._topic_name = topic_name self._sub_id = sub_id self._uniq_key = "{}:{}:{}".format(project_name, topic_name, sub_id) self._message_writer = message_writer self._shard_id = shard_id self._max_retry_times = producer_config.retry_times self._task_num = atomic.AtomicLong(0) self._condition = threading.Condition() self._has_write_count = atomic.AtomicLong(0) self._datahub_client = DatahubFactory.create_datahub_client(producer_config) self._record_package_queue = RecordPackQueue(producer_config.max_async_buffer_size, producer_config.max_async_buffer_records, producer_config.max_async_buffer_time, producer_config.max_record_pack_queue_limit) def close(self): self._closed = True self._logger.info("ShardWriter closed. key: {}, shard_id: {}, write count: {}".format(self._uniq_key, self._shard_id, self._has_write_count.value)) @property def shard_id(self): return self._shard_id def write(self, records): if self._closed: self._logger.warning("ShardWriter closed when write. key: {}, shard_id: {}".format(self._uniq_key, self._shard_id)) raise DatahubException("ShardWriter closed when write") self.__write_once(records) self._logger.debug("Send next write task success. key: {}, record count: {}".format(self._uniq_key, len(records))) def write_async(self, records): if self._closed: self._logger.warning("ShardWriter closed when write async. key: {}, shard_id: {}".format(self._uniq_key, self._shard_id)) raise DatahubException("ShardWriter closed when write async") result = self._record_package_queue.append_record(records) if self._task_num.value == 0: self.__send_next_task() return result def flush(self): if self._closed: self._logger.warning("ShardWriter closed when flush. key: {}, shard_id: {}".format(self._uniq_key, self._shard_id)) raise DatahubException("ShardWriter closed when flush") self._record_package_queue.flush() self.__send_next_task() while self._task_num.value > 0: with self._condition: self._condition.wait() def __send_next_task(self): with self._lock: pack = self._record_package_queue.obtain_ready_record_pack() if pack is not None: if not self._message_writer.send_task(int(self._shard_id), self.__gen_next_write_task, pack): # Add task fail when thread pool full self._logger.warning("Send next task fail. key: {}, shard_id: {}, task num: {}" .format(self._uniq_key, self._shard_id, self._task_num.value)) raise DatahubException("Send next task fail. key: {}, shard_id: {}".format(self._uniq_key, self._shard_id)) self._task_num += 1 self._logger.debug("Send next task once. key: {}, shard_id: {}, task_num: {}" .format(self._uniq_key, self._shard_id, self._task_num.value)) def __write_once(self, records): retry_time = 0 while True: try: self._message_writer.put_record_by_shard(self._shard_id, records) self._has_write_count += len(records) return except DatahubException as e: self._logger.warning("Write records fail. key: {}, shard_id: {}, records size: {}, max retry time: {}, this time: {}, DatahubException: {}" .format(self._uniq_key, self._shard_id, len(records), self._max_retry_times, retry_time, e)) retry_time += 1 if retry_time >= self._max_retry_times: raise e except Exception as e: self._logger.warning("Write records fail. key: {}, shard_id: {}, records size: {}, {}".format(self._uniq_key, self._shard_id, len(records), e)) raise e def __gen_next_write_task(self, record_pack): records = record_pack.records futures = record_pack.write_result_futures init_time = record_pack.init_time try: start_time = time.time() self.__write_once(records) end_time = time.time() self._logger.debug("write async once success. key: {}, shard_id: {}, records size: {}" .format(self._uniq_key, self._shard_id, len(records))) self.__set_result_to_futures(futures, WriteResult(self._shard_id, end_time - init_time, end_time - start_time)) except DatahubException as e: self._logger.warning("write async once fail. key: {}, shard_id: {}, records size: {}, DatahubException: {}" .format(self._uniq_key, self._shard_id, len(records), e)) self.__set_exception_to_futures(futures, e) except Exception as e: self._logger.warning("write async once fail. key: {}, shard_id: {}, records size: {}, Exception: {}" .format(self._uniq_key, self._shard_id, len(records), e)) self.__set_exception_to_futures(futures, e) def __set_result_to_futures(self, futures, target): for future in futures: future.set_result(target) self.__task_done() def __set_exception_to_futures(self, futures, target): for future in futures: future.set_exception(target) self.__task_done() def __task_done(self): self.__send_next_task() self._task_num -= 1 with self._condition: self._condition.notify_all()