fbnet/command_runner/service.py (130 lines of code) (raw):

#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import asyncio import logging import signal import typing from concurrent.futures import ThreadPoolExecutor from fbnet.command_runner_asyncio.CommandRunner.Command import Client as FcrClient from .base_service import ServiceObjMeta, ServiceTask from .command_server import CommandServer from .command_session import CommandSession from .exceptions import ( LookupErrorException, ValueErrorException, NotImplementedErrorException, ) from .options import Option from .thrift_client import AsyncioThriftClient from .utils import IPUtils class FcrServiceBase: """ Main Application object. This manages application resources and provides a common orchestraion point for the application modules. """ ASYNCIO_DEBUG = Option( "--asyncio_debug", help="turn on debug for asyncio", action="store_true", default=False, ) LOG_LEVEL = Option( "--log_level", help="logging level", choices=["debug", "info", "warning", "error", "critical"], default="info", ) MAX_DEFAULT_EXECUTOR_THREADS = Option( "--max_default_executor_threads", help="Max number of worker threads", type=int, default=10, ) EXIT_MAX_WAIT = Option( "--exit_max_wait", help="Max time (seconds) to wait for session to terminate", type=int, default=300, ) def __init__(self, app_name, args=None, loop=None): self._app_name = app_name self._shutting_down = False self._stats_mgr = None Option.parse_args(args) self._loop = loop or asyncio.get_event_loop() self._loop.set_debug(self.ASYNCIO_DEBUG) executor = ThreadPoolExecutor(max_workers=self.MAX_DEFAULT_EXECUTOR_THREADS) self._loop.set_default_executor(executor) self._init_logging() self._loop.add_signal_handler(signal.SIGINT, self.shutdown) self._loop.add_signal_handler(signal.SIGTERM, self.shutdown) self._tasks = {} self.logger = logging.getLogger(self._app_name) def register_stats_mgr(self, stats_mgr): self.logger.info("Registering Counter manager") self._stats_mgr = stats_mgr ServiceObjMeta.register_all_counters(stats_mgr) @property def stats_mgr(self): return self._stats_mgr def incrementCounter(self, counter): self._stats_mgr.incrementCounter(counter) @property def config(self): return Option.config @property def app_name(self): return self._app_name @property def loop(self): return self._loop @property def ip_utils(self) -> typing.Type[IPUtils]: return IPUtils def add_task(self, key, task): if key in self._tasks: raise LookupErrorException(f"Duplicated key: {key}") self._tasks[key] = task def start(self): try: self._loop.run_forever() finally: pending_tasks = asyncio.all_tasks(loop=self._loop) for task in pending_tasks: task.cancel() self._loop.run_until_complete( asyncio.gather(*pending_tasks, return_exceptions=True) ) self._loop.close() async def _clean_shutdown(self): try: coro = CommandSession.wait_sessions("Shutdown", service=self) await asyncio.wait_for(coro, timeout=self.EXIT_MAX_WAIT, loop=self.loop) except asyncio.TimeoutError: self.logger.error("Timeout waiting for sessions, shutting down anyway") finally: self.terminate() def terminate(self): """ Terminate the application. We cancel all the tasks that are currently active """ self.logger.info("Terminating") pending_tasks = asyncio.all_tasks(loop=self.loop) for t in pending_tasks: t.cancel() self.loop.stop() def shutdown(self): """initiate a clean shutdown""" if not self._shutting_down: self._shutting_down = True for name, task in ServiceTask.all_tasks(): self.logger.info("Stopping: %s", name) task.cancel() asyncio.ensure_future(self._clean_shutdown(), loop=self.loop) else: # Forcibly shutdown. self.terminate() def _init_logging(self): level = getattr(logging, self.LOG_LEVEL.upper(), None) if not isinstance(level, int): raise ValueErrorException("Invalid log level: %s" % self.LOG_LEVEL) logging.basicConfig(level=level) def decrypt(self, data): """helper method to decrypt data. The default implementation doesn't do anything. Override this method to implement security according to your needs """ return data async def get_fcr_client(self, timeout=None): """ Get a FCR client for your service. This client is used to distribute requests for bulk calls """ return AsyncioThriftClient( FcrClient, "localhost", CommandServer.PORT, service=self, timeout=timeout ) def check_ip(self, ipaddr): """ Check if ip address is usable. You will likely need to override this function to implement the ip validation logic. For eg. a service could periodically check what ip addresses are reachable. The application can then use this data to filter out non-reachable addresses. The default implementation assumes that everything is reachable """ return True def get_http_proxy_url(self, host): """build a url for http proxy""" raise NotImplementedErrorException("Proxy support not implemented")