elastic_agent_client/util/async_tools.py (128 lines of code) (raw):

# # Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one # or more contributor license agreements. Licensed under the Elastic License 2.0; # you may not use this file except in compliance with the Elastic License 2.0. # import asyncio import time from elastic_agent_client.util.logger import logger class CancellableSleeps: def __init__(self): self._sleeps = set() async def sleep(self, delay, result=None, *, loop=None): async def _sleep(delay, result=None, *, loop=None): coro = asyncio.sleep(delay, result=result) task = asyncio.ensure_future(coro) self._sleeps.add(task) try: return await task except asyncio.CancelledError: logger.debug("Sleep canceled") return result finally: self._sleeps.remove(task) await _sleep(delay, result=result, loop=loop) def cancel(self, sig=None): if sig: logger.debug(f"Caught {sig}. Cancelling sleeps...") else: logger.debug("Cancelling sleeps...") for task in self._sleeps: task.cancel() sleeps_for_retryable = CancellableSleeps() def _get_uvloop(): import uvloop return uvloop def get_event_loop(): # activate uvloop if lib is present try: asyncio.set_event_loop_policy(_get_uvloop().EventLoopPolicy()) except Exception as e: logger.warning(f"Unable to enable uvloop: {e}. Running with default event loop") pass try: loop = asyncio.get_running_loop() except RuntimeError: loop = asyncio.get_event_loop_policy().get_event_loop() if loop is None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) return loop class BaseService: """Base class for creating a service. Any class deriving from this class will get added to the registry, given its `name` class attribute (unless it's not set). A concrete service class needs to implement `_run`. """ name: str def __init__(self, client, service_name): self.running = False self._sleeps = CancellableSleeps() self.errors = [0, time.time()] def stop(self): self.running = False self._sleeps.cancel() async def _run(self): raise NotImplementedError() async def run(self): """Runs the service""" if self.running: msg = f"{self.__class__.__name__} is already running." raise Exception(msg) self.running = True try: await self._run() finally: self.stop() def _callback(self, task): if task.cancelled(): logger.error( f"Task {task.get_name()} was cancelled", ) elif task.exception(): logger.exception( f"Exception found for task {task.get_name()}: {task.exception()}", exc_info=task.exception(), ) class MultiService: """Wrapper class to run multiple services against the same client.""" def __init__(self, *services): self._services = services async def run(self): """Runs every service in a task and wait for all tasks.""" tasks = [ asyncio.create_task(service.run(), name=service.name) for service in self._services ] done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) exception = None for task in done: if task.done() and not task.cancelled(): if task.exception(): logger.exception( f"Exception found for task {task.get_name()}: {task.exception()}", exc_info=task.exception(), ) exception = task.exception() for task in pending: task.cancel() try: await task except asyncio.CancelledError: logger.error("Service did not handle cancellation gracefully.") if exception: raise exception def shutdown(self, sig): logger.info(f"Caught {sig}. Graceful shutdown.") for service in self._services: logger.debug(f"Shutting down {service.__class__.__name__}...") service.stop() logger.debug(f"Done shutting down {service.__class__.__name__}...") class AsyncQueueIterator: def __init__(self, queue): self.queue = queue def __aiter__(self): return self async def __anext__(self): try: item = await self.queue.get() except Exception as e: raise StopAsyncIteration() from e else: return item class AsyncIterator: def __init__(self, seq): self.iter = iter(seq) def __aiter__(self): return self async def __anext__(self): try: return next(self.iter) except StopIteration as ex: raise StopAsyncIteration from ex