elastic_agent_client/service/checkin.py (158 lines of code) (raw):

# # Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one # or more contributor license agreements. Licensed under the Elastic License 2.0; # you may not use this file except in compliance with the Elastic License 2.0. # import asyncio import functools import logging from asyncio import CancelledError, Task, sleep from typing import Any, Optional import elastic_agent_client.generated.elastic_agent_client_pb2 as proto from elastic_agent_client.client import V2 from elastic_agent_client.handler.checkin import BaseCheckinHandler from elastic_agent_client.util.async_tools import AsyncQueueIterator, BaseService from elastic_agent_client.util.logger import convert_agent_log_level, logger, set_logger class CheckinV2Service(BaseService): name = "checkinV2" CHECKIN_INTERVAL = 5 def __init__(self, client: V2, checkin_handler: BaseCheckinHandler): super().__init__(client, "checkinV2") logger.debug(f"Initializing the {self.name} service") self.client = client self.checkin_handler = checkin_handler self._send_checkins_task: Optional[Task[Any]] = None self._receive_checkins_task: Optional[Task[Any]] = None def stop(self): super().stop() if self._send_checkins_task: logger.info(f"Cancelling task: {self._send_checkins_task.get_name()}") self._send_checkins_task.cancel() if self._receive_checkins_task: logger.info(f"Cancelling task: {self._receive_checkins_task.get_name()}") self._receive_checkins_task.cancel() async def _run(self): logger.info(f"Starting {self.name} service") if self.client.client is None: msg = "gRPC client is not yet set" raise RuntimeError(msg) send_queue: asyncio.Queue = asyncio.Queue() checkin_stream = self.client.client.CheckinV2(AsyncQueueIterator(send_queue)) send_checkins_task = asyncio.create_task( self.send_checkins(send_queue), name="Checkin Writer" ) receive_checkins_task = asyncio.create_task( self.receive_checkins(checkin_stream), name="Checkin Reader" ) send_checkins_task.add_done_callback(functools.partial(self._callback)) receive_checkins_task.add_done_callback(functools.partial(self._callback)) self._send_checkins_task = send_checkins_task self._receive_checkins_task = receive_checkins_task logger.debug(f"Running {self.name} service loop") done, pending = await asyncio.wait( [send_checkins_task, receive_checkins_task], return_when=asyncio.FIRST_EXCEPTION, ) for task in pending: task.cancel() try: await task except CancelledError: logger.error("Task did not handle cancellation gracefully") # Separated these two to log all errors if both tasks error out together # Which is unlikely, but it's cheap to do it the way I did for task in done: if not task.cancelled() and task.exception(): logger.error( f"Task {task.get_name()} terminated due to exception:", exc_info=task.exception(), ) for task in done: if not task.cancelled(): task_exception = task.exception() if task_exception: raise task_exception async def send_checkins(self, send_queue): while self.running: if send_queue.empty(): await self.do_checkin(send_queue) # Sleep if still running if self.running: await sleep(self.CHECKIN_INTERVAL) async def receive_checkins(self, checkin_stream): checkin: proto.CheckinExpected logger.info(f"{self.name} service is listening for check-in events") async for checkin in checkin_stream: logger.debug("Received a check-in event from CheckinV2 stream") await self.apply_expected(checkin) async def apply_expected(self, checkin: proto.CheckinExpected): if self.client.units and self.client.component_idx == checkin.component_idx: change_detected = False expected_units = [ (unit.id, unit.config_state_idx, unit.log_level) for unit in checkin.units ] current_units = [ (unit.id, unit.config_idx, unit.log_level) for unit in self.client.units ] for current_unit in current_units: if current_unit not in expected_units: change_detected = True break for expected_unit in expected_units: if expected_unit not in current_units: change_detected = True break if not change_detected: logger.debug("No change detected") return logger.debug("Detected change in units") self.client.agent_info = proto.AgentInfo( id=checkin.agent_info.id, version=checkin.agent_info.version, snapshot=checkin.agent_info.snapshot, ) self.client.sync_component(checkin) self.client.sync_units(checkin) logger.debug("Calling apply_from_client with new units") self.pre_process_units() await self.checkin_handler.apply_from_client() def pre_process_units(self): logger.debug("Pre-processing units") if self.client.units is None: logger.debug("No units found") return outputs = [ unit for unit in self.client.units if unit.unit_type == proto.UnitType.OUTPUT ] if len(outputs): unit = outputs[0] log_level = unit.log_level if log_level: # Convert the UnitLogLevel to the corresponding Python logging level python_log_level = convert_agent_log_level(log_level) logger.info( f"Updating log level to {logging.getLevelName(python_log_level)}" ) set_logger(log_level=python_log_level) else: logger.debug("No log level found for the output unit") else: logger.info("No outputs found") async def do_checkin(self, send_queue): if self.client.units is None: return logger.debug("Doing a check-in") units_observed = [unit.to_observed() for unit in self.client.units] if not self.client.version_info_sent and self.client.version_info: version_info = proto.CheckinObservedVersionInfo( name=self.client.version_info.name, meta=self.client.version_info.meta, build_hash=self.client.version_info.build_hash, ) else: version_info = None supports: list[str] = [] msg = proto.CheckinObserved( token=self.client.token, units=units_observed, version_info=version_info, features_idx=self.client.features_idx, component_idx=self.client.component_idx, supports=supports, ) await send_queue.put(msg) self.client.version_info_sent = True