test-runner/adapters/direct_azure_rest/direct_eventhub_api.py (173 lines of code) (raw):
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE file in the project root for
# full license information.
import asyncio
import ast
import datetime
import threading
from azure.eventhub.aio import EventHubConsumerClient
from ..adapter_config import logger
from . import eventhub_connection_string
def json_is_same(a, b):
# If either parameter is a string, convert it to an object.
# use ast.literal_eval because they might be single-quote delimited which fails with json.loads.
# if ast.literal_eval raises a ValueError, leave it as a string -- it must not be json after all.
if isinstance(a, str):
try:
a = ast.literal_eval(a)
except (ValueError, SyntaxError):
pass
if isinstance(b, str):
try:
b = ast.literal_eval(b)
except (ValueError, SyntaxError):
pass
return a == b
def get_device_id_from_event(event):
return event.message.annotations["iothub-connection-device-id".encode()].decode()
class EventHubApi:
def __init__(self):
self.consumer_client = None
self.iothub_connection_string = None
self.eventhub_connection_string = None
self.received_events = None
self.listener_future = None
self.starting_position = None
async def create_from_connection_string(self, connection_string):
self.iothub_connection_string = connection_string
self.eventhub_connection_string = eventhub_connection_string.convert_iothub_to_eventhub_conn_str(
connection_string
)
async def connect(self, starting_position=None):
logger(
"EventHubApi: connect: thread={} {} loop={}".format(
threading.current_thread(),
id(threading.current_thread()),
id(asyncio.get_running_loop()),
)
)
self.starting_position = starting_position or (
datetime.datetime.utcnow() - datetime.timedelta(seconds=10)
)
self.received_events = asyncio.Queue()
# Create a consumer client for the event hub.
self.consumer_client = EventHubConsumerClient.from_connection_string(
self.eventhub_connection_string, consumer_group="$Default"
)
await self.start_new_listener()
async def start_new_listener(self):
if self.listener_future:
logger("EventHubApi: cancelling old listener")
self.listener_future.cancel()
try:
await self.listener_future
except asyncio.CancelledError:
pass
self.listener_future = None
async def on_event(partition_context, event):
# this receives all events. they get filtered by device_id (if necessary) when
# pulled from the queue
await self.received_events.put(event)
await partition_context.update_checkpoint(event)
self.starting_position[partition_context.partition_id] = event.offset
async def on_error(partition_context, error):
if partition_context:
logger(
"EventHubApi: An exception: {} occurred during receiving from Partition: {}.".format(
partition_context.partition_id, error
)
)
else:
logger(
"EventHubApi: An exception: {} occurred during the load balance process.".format(
error
)
)
async def on_partition_initialize(partition_context):
logger(
"EventHubApi: Partition: {} has been initialized.".format(
partition_context.partition_id
)
)
async def on_partition_close(partition_context, reason):
logger(
"EventHubApi: Partition: {} has been closed, reason for closing: {}.".format(
partition_context.partition_id, reason
)
)
async def get_current_position():
positions = {}
ids = await self.consumer_client.get_partition_ids()
for id in ids:
properties = await self.consumer_client.get_partition_properties(id)
positions[id] = properties.get("last_enqueued_sequence_number") or "-1"
return positions
async def listener():
try:
if not self.starting_position:
# if we don't have a starting position, start at the current one and
# save it so we can update it as we receive events.
starting_position = await get_current_position()
self.starting_position = starting_position
elif isinstance(self.starting_position, dict):
# if our starting position is a dict, use it and keep updating it as we
# receive events.
starting_position = self.starting_position
else:
# if we do have a starting position, but it's not a dict, use it and get
# the current position so we can update it as events come in.
starting_position = self.starting_position
self.starting_position = await get_current_position()
logger("EventHubApi: listening at {}".format(starting_position))
logger(
"EventHubApi: next starting position = {}".format(
self.starting_position
)
)
await self.consumer_client.receive(
on_event=on_event,
starting_position=starting_position,
starting_position_inclusive=True,
on_error=on_error,
on_partition_initialize=on_partition_initialize,
on_partition_close=on_partition_close,
)
except Exception as e:
logger("EventHubApi exception: {}".format(e))
raise
self.listener_future = asyncio.ensure_future(listener())
logger("EventHubApi: Listener Created")
async def _close_eventhub_client(self):
logger(
"EventHubApi: close: thread={} {} loop={}".format(
threading.current_thread(),
id(threading.current_thread()),
id(asyncio.get_running_loop()),
)
)
if self.consumer_client:
logger("EventHubApi: _close_eventhub_client: stopping consumer client")
await self.consumer_client.close()
logger("EventHubApi: _close_eventhub_client: done stopping consumer client")
self.consumer_client = None
if self.listener_future:
logger("EventHubApi: _close_eventhub_client: cancelling listener")
self.listener_future.cancel()
logger(
"EventHubApi: _close_eventhub_client: waiting for listener to complete"
)
try:
await self.listener_future
except asyncio.CancelledError:
pass
logger("EventHubApi: _close_eventhub_client: listener is complete")
async def disconnect(self):
logger("EventHubApi: async disconnect")
await self._close_eventhub_client()
async def wait_for_next_event(self, device_id, expected=None):
# logger("EventHubApi: waiting for next event for {}".format(device_id))
while True:
event = await self.received_events.get()
if not device_id:
logger("EventHubAPI: body = {}".format(event.body_as_str()))
try:
return event.body_as_json()
except TypeError:
return event.body_as_str()
elif get_device_id_from_event(event) == device_id:
logger("EventHubApi: received event: {}".format(event))
received = event.body_as_json()
if expected is not None:
if json_is_same(expected, received):
logger("EventHubApi: message received as expected")
return received
else:
logger("EventHubApi: unexpected message. skipping")
else:
return received
else:
pass
# logger("EventHubApi: event not for me received")