fbnet/command_runner/service.py (130 lines of code) (raw):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import asyncio
import logging
import signal
import typing
from concurrent.futures import ThreadPoolExecutor
from fbnet.command_runner_asyncio.CommandRunner.Command import Client as FcrClient
from .base_service import ServiceObjMeta, ServiceTask
from .command_server import CommandServer
from .command_session import CommandSession
from .exceptions import (
LookupErrorException,
ValueErrorException,
NotImplementedErrorException,
)
from .options import Option
from .thrift_client import AsyncioThriftClient
from .utils import IPUtils
class FcrServiceBase:
"""
Main Application object.
This manages application resources and provides a common orchestraion point
for the application modules.
"""
ASYNCIO_DEBUG = Option(
"--asyncio_debug",
help="turn on debug for asyncio",
action="store_true",
default=False,
)
LOG_LEVEL = Option(
"--log_level",
help="logging level",
choices=["debug", "info", "warning", "error", "critical"],
default="info",
)
MAX_DEFAULT_EXECUTOR_THREADS = Option(
"--max_default_executor_threads",
help="Max number of worker threads",
type=int,
default=10,
)
EXIT_MAX_WAIT = Option(
"--exit_max_wait",
help="Max time (seconds) to wait for session to terminate",
type=int,
default=300,
)
def __init__(self, app_name, args=None, loop=None):
self._app_name = app_name
self._shutting_down = False
self._stats_mgr = None
Option.parse_args(args)
self._loop = loop or asyncio.get_event_loop()
self._loop.set_debug(self.ASYNCIO_DEBUG)
executor = ThreadPoolExecutor(max_workers=self.MAX_DEFAULT_EXECUTOR_THREADS)
self._loop.set_default_executor(executor)
self._init_logging()
self._loop.add_signal_handler(signal.SIGINT, self.shutdown)
self._loop.add_signal_handler(signal.SIGTERM, self.shutdown)
self._tasks = {}
self.logger = logging.getLogger(self._app_name)
def register_stats_mgr(self, stats_mgr):
self.logger.info("Registering Counter manager")
self._stats_mgr = stats_mgr
ServiceObjMeta.register_all_counters(stats_mgr)
@property
def stats_mgr(self):
return self._stats_mgr
def incrementCounter(self, counter):
self._stats_mgr.incrementCounter(counter)
@property
def config(self):
return Option.config
@property
def app_name(self):
return self._app_name
@property
def loop(self):
return self._loop
@property
def ip_utils(self) -> typing.Type[IPUtils]:
return IPUtils
def add_task(self, key, task):
if key in self._tasks:
raise LookupErrorException(f"Duplicated key: {key}")
self._tasks[key] = task
def start(self):
try:
self._loop.run_forever()
finally:
pending_tasks = asyncio.all_tasks(loop=self._loop)
for task in pending_tasks:
task.cancel()
self._loop.run_until_complete(
asyncio.gather(*pending_tasks, return_exceptions=True)
)
self._loop.close()
async def _clean_shutdown(self):
try:
coro = CommandSession.wait_sessions("Shutdown", service=self)
await asyncio.wait_for(coro, timeout=self.EXIT_MAX_WAIT, loop=self.loop)
except asyncio.TimeoutError:
self.logger.error("Timeout waiting for sessions, shutting down anyway")
finally:
self.terminate()
def terminate(self):
"""
Terminate the application. We cancel all the tasks that are currently active
"""
self.logger.info("Terminating")
pending_tasks = asyncio.all_tasks(loop=self.loop)
for t in pending_tasks:
t.cancel()
self.loop.stop()
def shutdown(self):
"""initiate a clean shutdown"""
if not self._shutting_down:
self._shutting_down = True
for name, task in ServiceTask.all_tasks():
self.logger.info("Stopping: %s", name)
task.cancel()
asyncio.ensure_future(self._clean_shutdown(), loop=self.loop)
else:
# Forcibly shutdown.
self.terminate()
def _init_logging(self):
level = getattr(logging, self.LOG_LEVEL.upper(), None)
if not isinstance(level, int):
raise ValueErrorException("Invalid log level: %s" % self.LOG_LEVEL)
logging.basicConfig(level=level)
def decrypt(self, data):
"""helper method to decrypt data.
The default implementation doesn't do anything. Override this method to
implement security according to your needs
"""
return data
async def get_fcr_client(self, timeout=None):
"""
Get a FCR client for your service.
This client is used to distribute requests for bulk calls
"""
return AsyncioThriftClient(
FcrClient, "localhost", CommandServer.PORT, service=self, timeout=timeout
)
def check_ip(self, ipaddr):
"""
Check if ip address is usable.
You will likely need to override this function to implement the ip
validation logic. For eg. a service could periodically check what ip
addresses are reachable. The application can then use this data to
filter out non-reachable addresses.
The default implementation assumes that everything is reachable
"""
return True
def get_http_proxy_url(self, host):
"""build a url for http proxy"""
raise NotImplementedErrorException("Proxy support not implemented")