test-runner/adapters/direct_azure_rest/direct_eventhub_api.py (173 lines of code) (raw):

# Copyright (c) Microsoft. All rights reserved. # Licensed under the MIT license. See LICENSE file in the project root for # full license information. import asyncio import ast import datetime import threading from azure.eventhub.aio import EventHubConsumerClient from ..adapter_config import logger from . import eventhub_connection_string def json_is_same(a, b): # If either parameter is a string, convert it to an object. # use ast.literal_eval because they might be single-quote delimited which fails with json.loads. # if ast.literal_eval raises a ValueError, leave it as a string -- it must not be json after all. if isinstance(a, str): try: a = ast.literal_eval(a) except (ValueError, SyntaxError): pass if isinstance(b, str): try: b = ast.literal_eval(b) except (ValueError, SyntaxError): pass return a == b def get_device_id_from_event(event): return event.message.annotations["iothub-connection-device-id".encode()].decode() class EventHubApi: def __init__(self): self.consumer_client = None self.iothub_connection_string = None self.eventhub_connection_string = None self.received_events = None self.listener_future = None self.starting_position = None async def create_from_connection_string(self, connection_string): self.iothub_connection_string = connection_string self.eventhub_connection_string = eventhub_connection_string.convert_iothub_to_eventhub_conn_str( connection_string ) async def connect(self, starting_position=None): logger( "EventHubApi: connect: thread={} {} loop={}".format( threading.current_thread(), id(threading.current_thread()), id(asyncio.get_running_loop()), ) ) self.starting_position = starting_position or ( datetime.datetime.utcnow() - datetime.timedelta(seconds=10) ) self.received_events = asyncio.Queue() # Create a consumer client for the event hub. self.consumer_client = EventHubConsumerClient.from_connection_string( self.eventhub_connection_string, consumer_group="$Default" ) await self.start_new_listener() async def start_new_listener(self): if self.listener_future: logger("EventHubApi: cancelling old listener") self.listener_future.cancel() try: await self.listener_future except asyncio.CancelledError: pass self.listener_future = None async def on_event(partition_context, event): # this receives all events. they get filtered by device_id (if necessary) when # pulled from the queue await self.received_events.put(event) await partition_context.update_checkpoint(event) self.starting_position[partition_context.partition_id] = event.offset async def on_error(partition_context, error): if partition_context: logger( "EventHubApi: An exception: {} occurred during receiving from Partition: {}.".format( partition_context.partition_id, error ) ) else: logger( "EventHubApi: An exception: {} occurred during the load balance process.".format( error ) ) async def on_partition_initialize(partition_context): logger( "EventHubApi: Partition: {} has been initialized.".format( partition_context.partition_id ) ) async def on_partition_close(partition_context, reason): logger( "EventHubApi: Partition: {} has been closed, reason for closing: {}.".format( partition_context.partition_id, reason ) ) async def get_current_position(): positions = {} ids = await self.consumer_client.get_partition_ids() for id in ids: properties = await self.consumer_client.get_partition_properties(id) positions[id] = properties.get("last_enqueued_sequence_number") or "-1" return positions async def listener(): try: if not self.starting_position: # if we don't have a starting position, start at the current one and # save it so we can update it as we receive events. starting_position = await get_current_position() self.starting_position = starting_position elif isinstance(self.starting_position, dict): # if our starting position is a dict, use it and keep updating it as we # receive events. starting_position = self.starting_position else: # if we do have a starting position, but it's not a dict, use it and get # the current position so we can update it as events come in. starting_position = self.starting_position self.starting_position = await get_current_position() logger("EventHubApi: listening at {}".format(starting_position)) logger( "EventHubApi: next starting position = {}".format( self.starting_position ) ) await self.consumer_client.receive( on_event=on_event, starting_position=starting_position, starting_position_inclusive=True, on_error=on_error, on_partition_initialize=on_partition_initialize, on_partition_close=on_partition_close, ) except Exception as e: logger("EventHubApi exception: {}".format(e)) raise self.listener_future = asyncio.ensure_future(listener()) logger("EventHubApi: Listener Created") async def _close_eventhub_client(self): logger( "EventHubApi: close: thread={} {} loop={}".format( threading.current_thread(), id(threading.current_thread()), id(asyncio.get_running_loop()), ) ) if self.consumer_client: logger("EventHubApi: _close_eventhub_client: stopping consumer client") await self.consumer_client.close() logger("EventHubApi: _close_eventhub_client: done stopping consumer client") self.consumer_client = None if self.listener_future: logger("EventHubApi: _close_eventhub_client: cancelling listener") self.listener_future.cancel() logger( "EventHubApi: _close_eventhub_client: waiting for listener to complete" ) try: await self.listener_future except asyncio.CancelledError: pass logger("EventHubApi: _close_eventhub_client: listener is complete") async def disconnect(self): logger("EventHubApi: async disconnect") await self._close_eventhub_client() async def wait_for_next_event(self, device_id, expected=None): # logger("EventHubApi: waiting for next event for {}".format(device_id)) while True: event = await self.received_events.get() if not device_id: logger("EventHubAPI: body = {}".format(event.body_as_str())) try: return event.body_as_json() except TypeError: return event.body_as_str() elif get_device_id_from_event(event) == device_id: logger("EventHubApi: received event: {}".format(event)) received = event.body_as_json() if expected is not None: if json_is_same(expected, received): logger("EventHubApi: message received as expected") return received else: logger("EventHubApi: unexpected message. skipping") else: return received else: pass # logger("EventHubApi: event not for me received")