python/rocketmq/v5/client/connection/rpc_channel.py (196 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio import time import grpc from grpc import ChannelConnectivity, aio from grpc.aio import AioRpcError from rocketmq.grpc_protocol import (Address, AddressScheme, Code, Endpoints, MessagingServiceStub) from rocketmq.v5.exception import (IllegalArgumentException, UnsupportedException) from rocketmq.v5.log import logger class RpcAddress: def __init__(self, address: Address): self.__host = address.host self.__port = address.port def __hash__(self) -> int: return hash(self.__str__()) def __str__(self) -> str: return self.__host + ":" + str(self.__port) def __eq__(self, other: object) -> bool: if not isinstance(other, RpcAddress): return False return self.__str__() == (other.__str__()) def __lt__(self, other): if not isinstance(other, RpcAddress): return False return self.__str__() < (other.__str__()) def address0(self): address = Address() address.host = self.__host address.port = self.__port return address class RpcEndpoints: def __init__(self, endpoints: Endpoints): self.__endpoints = endpoints self.__scheme = endpoints.scheme self.__addresses = set( map(lambda address: RpcAddress(address), endpoints.addresses) ) if self.__scheme == AddressScheme.DOMAIN_NAME and len(self.__addresses) > 1: raise UnsupportedException( "Multiple addresses not allowed in domain schema" ) self.__facade, self.__endpoint_desc = self.__facade() def __hash__(self) -> int: return hash(str(self.__scheme) + ":" + self.__facade) def __eq__(self, other): if not isinstance(other, RpcEndpoints): return False return self.__facade == other.__facade and self.__scheme == other.__scheme def __str__(self): return self.__endpoint_desc """ private """ def __facade(self): if ( self.__scheme is None or len(self.__addresses) == 0 or self.__scheme == AddressScheme.ADDRESS_SCHEME_UNSPECIFIED ): return "", "" prefix = "dns:" if self.__scheme == AddressScheme.IPv4: prefix = "ipv4:" elif self.__scheme == AddressScheme.IPv6: prefix = "ipv6:" # formatted as: ip:port, ip:port, ip:port sorted_list = sorted(self.__addresses) ret = "" for address in sorted_list: ret = ret + address.__str__() + "," return prefix + ret[0:len(ret) - 1], ret[0:len(ret) - 1] """ property """ @property def endpoints(self): return self.__endpoints @property def facade(self): return self.__facade class RpcStreamStreamCall: def __init__(self, endpoints: RpcEndpoints, stream_stream_call, handler): self.__endpoints = endpoints self.__stream_stream_call = stream_stream_call # grpc stream_stream_call self.__handler = handler # handler responsible for handling data from the server side stream. async def start_stream_read(self): # start reading from a stream, including send setting result, sever check for transaction message if self.__stream_stream_call is not None: try: while True: res = await self.__stream_stream_call.read() if res.HasField("settings"): # read a response for send setting result if res is not None and res.status.code == Code.OK: logger.debug( f"{ self.__handler.__str__()} sync setting success. response status code: {res.status.code}" ) if ( res.settings is not None and res.settings.metric is not None ): # reset metrics if needed self.__handler.reset_metric(res.settings.metric) elif res.HasField("recover_orphaned_transaction_command"): # sever check for a transaction message if self.__handler is not None: transaction_id = ( res.recover_orphaned_transaction_command.transaction_id ) message = res.recover_orphaned_transaction_command.message self.__handler.on_recover_orphaned_transaction_command( self.__endpoints, message, transaction_id ) except AioRpcError as e: logger.warn( f"{ self.__handler.__str__()} read stream from endpoints {self.__endpoints.__str__()} occurred AioRpcError. code: {e.code()}, message: {e.details()}" ) except Exception as e: logger.error( f"{ self.__handler.__str__()} read stream from endpoints {self.__endpoints.__str__()} exception, {e}" ) async def stream_write(self, req): if self.__stream_stream_call is not None: try: await self.__stream_stream_call.write(req) except Exception as e: raise e def close(self): if self.__stream_stream_call is not None: self.__stream_stream_call.cancel() class RpcChannel: def __init__(self, endpoints: RpcEndpoints, tls_enabled=False): self.__async_channel = None self.__async_stub = None self.__telemetry_stream_stream_call = None self.__tls_enabled = tls_enabled self.__endpoints = endpoints self.__update_time = int(time.time()) def create_channel(self, loop): # create grpc channel with the given loop # assert loop == RpcClient._io_loop asyncio.set_event_loop(loop) self.__create_aio_channel() def close_channel(self, loop): if self.__async_channel is not None: # close stream_stream_call if self.__telemetry_stream_stream_call is not None: self.__telemetry_stream_stream_call.close() self.__telemetry_stream_stream_call = None logger.info( f"channel[{self.__endpoints.__str__()}] close stream_stream_call success." ) if self.channel_state() is not ChannelConnectivity.SHUTDOWN: # close grpc channel asyncio.run_coroutine_threadsafe(self.__async_channel.close(), loop) self.__async_channel = None logger.info(f"channel[{self.__endpoints.__str__()}] close success.") self.__async_stub = None self.__endpoints = None self.__update_time = None def channel_state(self, wait_for_ready=True): return self.__async_channel.get_state(wait_for_ready) def register_telemetry_stream_stream_call(self, stream_stream_call, handler): if self.__telemetry_stream_stream_call is not None: self.__telemetry_stream_stream_call.close() self.__telemetry_stream_stream_call = RpcStreamStreamCall( self.__endpoints, stream_stream_call, handler ) """ private """ def __create_aio_channel(self): try: if self.__endpoints is None: raise IllegalArgumentException( "create_aio_channel exception, endpoints is None" ) else: options = [ ("grpc.enable_retries", 0), ("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1), ("grpc.use_local_subchannel_pool", 1), ] if self.__tls_enabled: self.__async_channel = aio.secure_channel( self.__endpoints.facade, grpc.ssl_channel_credentials(), options ) else: self.__async_channel = aio.insecure_channel( self.__endpoints.facade, options ) self.__async_stub = MessagingServiceStub(self.__async_channel) logger.info( f"create_aio_channel to [{self.__endpoints.__str__()}] success. channel state:{self.__async_channel.get_state()}" ) except Exception as e: logger.error( f"create_aio_channel to [{self.__endpoints.__str__()}] exception: {e}" ) raise e # """ property """ @property def async_stub(self): return self.__async_stub @property def telemetry_stream_stream_call(self): return self.__telemetry_stream_stream_call @property def update_time(self): return self.__update_time @update_time.setter def update_time(self, update_time): self.__update_time = update_time