fbnet/command_runner/command_handler.py (567 lines of code) (raw):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import asyncio
import functools
import inspect
import random
import re
import sys
import typing
from dataclasses import dataclass
from functools import wraps
from itertools import islice
from uuid import uuid4
from fb303_asyncio.FacebookBase import FacebookBase
from fbnet.command_runner_asyncio.CommandRunner import constants, ttypes
from fbnet.command_runner_asyncio.CommandRunner.Command import Iface as FcrIface
from .command_session import CommandSession, CapturedTimeMS
from .counters import Counters
from .exceptions import ensure_thrift_exception, convert_to_fcr_exception
from .global_namespace import GlobalNamespace
from .options import Option
from .utils import input_fields_validator
@dataclass(frozen=True)
class DeviceResult:
"""
Class for passing information received after sending commands to device
Includes CommandResult(s), captured communication/processing time, and any raised exceptions
(remains None for exceptions that are returned instead)
"""
device_response: typing.Union[
typing.List[typing.Optional[ttypes.CommandResult]],
typing.Optional[ttypes.CommandResult],
] = None
captured_time_ms: CapturedTimeMS = CapturedTimeMS()
exception: typing.Optional[Exception] = None
def _append_debug_info_to_exception(fn):
@wraps(fn)
async def wrapper(self, *args, **kwargs):
try:
return await fn(self, *args, **kwargs)
except Exception as ex:
# Retrieve request uuid from the global namespace
uuid = GlobalNamespace.get_request_uuid()
# Exception defined in command runner thrift spec has attribute 'message'
if hasattr(ex, "message"):
ex.message = await self.add_debug_info_to_error_message( # noqa
error_msg=ex.message, uuid=uuid # noqa
)
raise ex
else:
# Don't pass in ex.message to debug info to avoid duplicate message
debug_info = await self.add_debug_info_to_error_message(
error_msg="", uuid=uuid
)
# Append message as new arg instead of constructing new exception
# to account for exceptions having different required args
ex.args = ex.args + (debug_info,)
raise ex.with_traceback(sys.exc_info()[2])
return wrapper
def _ensure_uuid(fn):
"""Make sure the 'uuid' parameter for both input and return is non-empty"""
@wraps(fn)
async def wrapper(*args, **kwargs):
uuid = ""
callargs = inspect.getcallargs(fn, *args, **kwargs)
if "uuid" in callargs:
uuid = callargs["uuid"] or uuid4().hex[:8]
callargs["uuid"] = uuid
# Note: this won't work for functions that specify positional-only or
# kwarg-only parameters.
result = await fn(**callargs)
# Set UUID on the resulting struct -- if it supports it: Map of, or
# raw, CommandResult and Session
if isinstance(result, (ttypes.CommandResult, ttypes.Session)):
result.uuid = uuid
elif isinstance(result, dict):
for val in result.values():
# If size of result is large, we can refactor to only check the
# first element. (time taken: ~100ns * N)
if not isinstance(val, ttypes.CommandResult):
break
val.uuid = uuid
return result
return wrapper
class CommandHandler(Counters, FacebookBase, FcrIface):
"""
Command implementation for api defined in thrift for command runner
"""
_COUNTER_PREFIX = "fbnet.command_runner"
_LB_THRESHOLD = 100
REMOTE_CALL_OVERHEAD = Option(
"--remote_call_overhead",
help="Overhead for running commands remotely (for bulk calls)",
type=int,
default=20,
)
LB_THRESHOLD = Option(
"--lb_threshold",
help="""Load Balance threashold for bulk_run calls. If number of
devices is greater than this threashold, the requests are broken and
send to other instances using bulk_run_local() api""",
type=int,
default=100,
)
BULK_SESSION_LIMIT = Option(
"--bulk_session_limit",
help="""session limit above which we reject the bulk run local
calls""",
type=int,
default=200,
)
BULK_RETRY_LIMIT = Option(
"--bulk_retry_limit",
help="""number of times to retry bulk call on the remote instances""",
type=int,
default=5,
)
BULK_RUN_JITTER = Option(
"--bulk_run_jitter",
help="""A random delay added for bulk commands to stagger the calls to
distribute the load.""",
type=int,
default=5,
)
BULK_RETRY_DELAY_MIN = Option(
"--bulk_retry_delay_min",
help="""number of seconds to wait before retrying""",
type=int,
default=5,
)
BULK_RETRY_DELAY_MAX = Option(
"--bulk_retry_delay_max",
help="""number of seconds to wait before retrying""",
type=int,
default=10,
)
_bulk_session_count = 0
def __init__(self, service, name=None):
Counters.__init__(self, service, name)
FacebookBase.__init__(self, service.app_name)
FcrIface.__init__(self)
self.service.register_stats_mgr(self)
def cleanup(self):
pass
@classmethod
def register_counters(cls, stats_mgr):
stats_mgr.register_counter("bulk_run.remote")
stats_mgr.register_counter("bulk_run.local")
stats_mgr.register_counter("bulk_run.local.overload_error")
def getCounters(self):
ret = {}
for key, value in self.counters.items():
if not key.startswith(self._COUNTER_PREFIX):
key = self._COUNTER_PREFIX + "." + key
ret[key] = value() if callable(value) else value
return ret
@classmethod
def _set_bulk_session_count(cls, new_count: int) -> None:
"""Method to set the class variable _bulk_session_count"""
cls._bulk_session_count = new_count
async def add_debug_info_to_error_message(self, uuid, error_msg=None):
if not error_msg:
return f"(DebugInfo: thrift_uuid={uuid})"
return f"{error_msg} (DebugInfo: thrift_uuid={uuid})"
@ensure_thrift_exception
@input_fields_validator
@_append_debug_info_to_exception
@_ensure_uuid
async def run(
self, command, device, timeout, open_timeout, client_ip, client_port, uuid
) -> ttypes.CommandResult:
result = await self._run_commands(
[command], device, timeout, open_timeout, client_ip, client_port, uuid
)
GlobalNamespace.set_api_captured_time_ms(
GlobalNamespace.get_api_captured_time_ms() + result.captured_time_ms
)
# Raise exception if needed here since _run_commands does not raise
# and exceptions are not returned
if isinstance(result.exception, Exception):
raise result.exception
cmd_result = result.device_response
if isinstance(cmd_result, list):
cmd_result = result.device_response[0] # pyre-ignore checked is list
return cmd_result
def _bulk_failure(self, device_to_commands, message):
def command_failures(cmds):
return [
ttypes.CommandResult(
output=message, status=constants.FAILURE_STATUS, command=cmd
)
for cmd in cmds
]
return {
self._get_result_key(dev): command_failures(cmds)
for dev, cmds in device_to_commands.items()
}
@ensure_thrift_exception
@input_fields_validator
@_append_debug_info_to_exception
@_ensure_uuid
async def bulk_run_v2(
self, request: ttypes.BulkRunCommandRequest
) -> ttypes.BulkRunCommandResponse:
device_to_commands = {}
for device_commands in request.device_commands_list:
device_details = device_commands.device
# Since this is in the handler, is it okay to have struct as a key?
device_to_commands[device_details] = device_commands.commands
result = await self.bulk_run(
device_to_commands,
request.timeout,
request.open_timeout,
request.client_ip,
request.client_port,
request.uuid,
)
response = ttypes.BulkRunCommandResponse()
response.device_to_result = result
return response
@ensure_thrift_exception
@input_fields_validator
@_append_debug_info_to_exception
@_ensure_uuid
async def bulk_run(
self, device_to_commands, timeout, open_timeout, client_ip, client_port, uuid
) -> typing.Dict[str, typing.List[ttypes.CommandResult]]:
if (len(device_to_commands) < self.LB_THRESHOLD) and (
self._bulk_session_count < self.BULK_SESSION_LIMIT
):
# Run these command locally.
self.incrementCounter("bulk_run.local")
return await self._bulk_run_local(
device_to_commands, timeout, open_timeout, client_ip, client_port, uuid
)
async def _remote_task(chunk):
# Run the chunk of commands on remote instance
self.incrementCounter("bulk_run.remote")
retry_count = 0
while True:
try:
return await self._bulk_run_remote(
chunk, timeout, open_timeout, client_ip, client_port, uuid
)
except Exception as e:
if isinstance(e, ttypes.InstanceOverloaded):
# Instance we ran the call on was overloaded. We can retry
# the command again, hopefully on a different instance
self.incrementCounter("bulk_run.remote.overload_error")
self.logger.error("Instance Overloaded: %d: %s", retry_count, e)
if (
self._remote_task_should_retry(ex=e)
and retry_count <= self.BULK_RETRY_LIMIT
):
# Stagger the retries
delay = random.uniform(
self.BULK_RETRY_DELAY_MIN, self.BULK_RETRY_DELAY_MAX
)
await asyncio.sleep(delay)
retry_count += 1
else:
# Append message as new arg instead of constructing new exception
# to account for exceptions having different required args
e.args = e.args + ("bulk_run_remote failed",)
return self._bulk_failure(chunk, str(e))
# Split the request into chunks and run them on remote hosts
tasks = [
_remote_task(chunk)
for chunk in self._chunked_dict(device_to_commands, self.LB_THRESHOLD)
]
all_results = {}
for task in asyncio.as_completed(tasks, loop=self.loop):
result = await task
all_results.update(result)
return all_results
@ensure_thrift_exception
@_ensure_uuid
async def bulk_run_local(
self, device_to_commands, timeout, open_timeout, client_ip, client_port, uuid
):
return await self._bulk_run_local(
device_to_commands, timeout, open_timeout, client_ip, client_port, uuid
)
@_ensure_uuid
async def _bulk_run_local(
self, device_to_commands, timeout, open_timeout, client_ip, client_port, uuid
) -> typing.Dict[str, typing.List[ttypes.CommandResult]]:
devices = sorted(device_to_commands.keys(), key=lambda d: d.hostname)
session_count = self._bulk_session_count
if session_count + len(device_to_commands) > self.BULK_SESSION_LIMIT:
self.logger.error("Too many session open: %d", session_count)
raise ttypes.InstanceOverloaded(
message="Too many session open: %d" % session_count
)
self._set_bulk_session_count(self._bulk_session_count + len(devices))
async def _run_one_device(device) -> DeviceResult:
if not device_to_commands[device]:
return DeviceResult(device_response=[])
# Instead of running all commands at once, stagger the commands to
# distribute the load
delay = random.uniform(0, self.BULK_RUN_JITTER)
await asyncio.sleep(delay)
return await self._run_commands(
device_to_commands[device],
device,
timeout,
open_timeout,
client_ip,
client_port,
uuid,
return_exceptions=True,
)
captured_time_ms_list = []
try:
commands = []
for device in devices:
commands.append(_run_one_device(device))
# Run commands in parallel
results: typing.List[DeviceResult] = await asyncio.gather(
*commands, loop=self.loop, return_exceptions=True
)
# pyre-ignore: _run_commands always returns list of CommandResults
# when return_exceptions is True
cmd_results: typing.List[typing.List[ttypes.CommandResult]] = [
result.device_response for result in results
]
# Get the captured communication/processing time to save later in finally block
captured_time_ms_list = [result.captured_time_ms for result in results]
finally:
self._set_bulk_session_count(self._bulk_session_count - len(devices))
GlobalNamespace.set_api_captured_time_ms(
functools.reduce(
CapturedTimeMS.__add__,
[GlobalNamespace.get_api_captured_time_ms()]
+ captured_time_ms_list,
)
)
return {
self._get_result_key(dev): res for dev, res in zip(devices, cmd_results)
}
@ensure_thrift_exception
@input_fields_validator
@_append_debug_info_to_exception
@_ensure_uuid
async def open_session(
self, device, open_timeout, idle_timeout, client_ip, client_port, uuid
):
return await self._open_session(
device,
open_timeout,
idle_timeout,
client_ip,
client_port,
uuid,
raw_session=False,
)
@ensure_thrift_exception
@input_fields_validator
@_append_debug_info_to_exception
@_ensure_uuid
async def run_session(
self, session, command, timeout, client_ip, client_port, uuid
):
return await self._run_session(
session, command, timeout, client_ip, client_port, uuid
)
@ensure_thrift_exception
@input_fields_validator
@_append_debug_info_to_exception
@_ensure_uuid
async def close_session(self, session, client_ip, client_port, uuid):
closed_session = None
try:
closed_session = CommandSession.get(session.id, client_ip, client_port)
# Reset captured time field so we don't include
# the time from a previous API call
closed_session.captured_time_ms.reset_time()
await closed_session.close()
except Exception as e:
# Append message as new arg instead of constructing new exception
# to account for exceptions having different required args
e.args = e.args + ("close_session failed",)
raise e
finally:
if closed_session:
GlobalNamespace.set_api_captured_time_ms(
GlobalNamespace.get_api_captured_time_ms()
+ closed_session.captured_time_ms
)
@ensure_thrift_exception
@_append_debug_info_to_exception
@_ensure_uuid
async def open_raw_session(
self, device, open_timeout, idle_timeout, client_ip, client_port, uuid
):
return await self._open_session(
device,
open_timeout,
idle_timeout,
client_ip,
client_port,
uuid,
raw_session=True,
)
@ensure_thrift_exception
@_append_debug_info_to_exception
@_ensure_uuid
async def run_raw_session(
self, tsession, command, timeout, prompt_regex, client_ip, client_port, uuid
):
if not prompt_regex:
raise ttypes.SessionException(message="prompt_regex not specified")
prompt_re = re.compile(prompt_regex.encode("utf8"), re.M)
return await self._run_session(
tsession, command, timeout, client_ip, client_port, uuid, prompt_re
)
@ensure_thrift_exception
@_append_debug_info_to_exception
@_ensure_uuid
async def close_raw_session(self, tsession, client_ip, client_port, uuid):
return await self.close_session(tsession, client_ip, client_port, uuid)
async def _open_session(
self,
device,
open_timeout,
idle_timeout,
client_ip,
client_port,
uuid,
raw_session=False,
):
options = self._get_options(
device,
client_ip,
client_port,
open_timeout,
idle_timeout,
raw_session=raw_session,
)
session = None
try:
devinfo = await self._lookup_device(device)
session = await devinfo.setup_session(
self.service, device, options, loop=self.loop
)
return ttypes.Session(
id=session.id, name=session.hostname, hostname=device.hostname
)
except Exception as e:
# Append message as new arg instead of constructing new exception
# to account for exceptions having different required args
e.args = e.args + ("open_session failed",)
raise e
finally:
if session:
GlobalNamespace.set_api_captured_time_ms(
GlobalNamespace.get_api_captured_time_ms()
+ session.captured_time_ms
)
async def _run_session(
self, tsession, command, timeout, client_ip, client_port, uuid, prompt_re=None
):
session = None
captured_time_ms = CapturedTimeMS()
try:
session = CommandSession.get(tsession.id, client_ip, client_port)
# Reset captured time field so we don't include
# the time from a previous API call
session.captured_time_ms.reset_time()
return await self._run_command(session, command, timeout, uuid, prompt_re)
except Exception as e:
# Append message as new arg instead of constructing new exception
# to account for exceptions having different required args
e.args = e.args + ("run_session failed",)
raise e
finally:
if session:
captured_time_ms = session.captured_time_ms
GlobalNamespace.set_api_captured_time_ms(
GlobalNamespace.get_api_captured_time_ms() + captured_time_ms
)
def _get_result_key(self, device):
# TODO: just returning the hostname for now. Some additional processing
# may be required e.g. using shortnames, adding console info, etc
return device.hostname
async def _run_command(self, session, command, timeout, uuid, prompt_re=None):
self.logger.info(f"[request_id={uuid}]: Run command with session {session.id}")
output = await session.run_command(
command.encode("utf8"), timeout=timeout, prompt_re=prompt_re
)
return session.build_result(
output=output.decode("utf8", errors="ignore"),
status=session.exit_status or constants.SUCCESS_STATUS,
command=command,
)
async def _run_commands(
self,
commands,
device,
timeout,
open_timeout,
client_ip,
client_port,
uuid,
return_exceptions=False,
) -> DeviceResult:
options = self._get_options(
device, client_ip, client_port, open_timeout, timeout
)
if device.command_prompts:
options["command_prompts"] = {
c.encode(): p.encode() for c, p in device.command_prompts.items()
}
command = commands[0]
devinfo = None
session = None
results = []
captured_time_ms = CapturedTimeMS()
exc = None
try:
devinfo = await self._lookup_device(device)
async with devinfo.create_session(
self.service, device, options, loop=self.loop
) as session:
for command in commands:
result = await self._run_command(session, command, timeout, uuid)
results.append(result)
except Exception as e:
await self._record_error(e, command, uuid, options, devinfo, session)
if return_exceptions:
if not isinstance(e, ttypes.SessionException):
e = convert_to_fcr_exception(e)
e = ttypes.SessionException(message=f"{e!s}", code=e._CODE)
e.message = await self.add_debug_info_to_error_message( # noqa
error_msg=e.message, uuid=uuid # noqa
)
results = [
ttypes.CommandResult(output="", status="%r" % e, command=command)
]
else:
# save exc from the original place so we have full stacktrace
# when we raise in API level later
exc = e
if session:
captured_time_ms = session.captured_time_ms
return DeviceResult(
device_response=results,
captured_time_ms=captured_time_ms,
exception=exc,
)
def _chunked_dict(self, data, chunk_size):
"""split the dict into smaller dicts"""
items = iter(data.items()) # get an iterator for items
for _ in range(0, len(data), chunk_size):
yield dict(islice(items, chunk_size))
async def _bulk_run_remote(
self, device_to_commands, timeout, open_timeout, client_ip, client_port, uuid
):
# Determine a timeout for remote call.
call_timeout = open_timeout + timeout
remote_timeout = timeout - self.REMOTE_CALL_OVERHEAD
# Make sure we have a sane timeout value
assert remote_timeout > 10, "timeout: '%d' value too low for bulk_run" % timeout
fcr_client = await self._get_fcr_client(timeout=call_timeout)
async with fcr_client as client:
result = await client.bulk_run_local(
device_to_commands,
remote_timeout,
open_timeout,
client_ip,
client_port,
uuid,
)
return result
async def _lookup_device(self, device):
return await self.service.device_db.get(device)
async def _get_fcr_client(self, timeout):
return await self.service.get_fcr_client(timeout=timeout)
def _get_options(
self,
device,
client_ip,
client_port,
open_timeout,
idle_timeout,
raw_session=False,
):
options = {
"username": self._get_device_username(device),
"password": self._get_device_password(device),
"console": device.console,
"command_prompts": {},
"client_ip": client_ip,
"client_port": client_port,
"mgmt_ip": device.mgmt_ip or False,
"open_timeout": open_timeout,
"idle_timeout": idle_timeout,
"ip_address": device.ip_address,
"session_type": device.session_type,
"device": device,
"raw_session": raw_session,
"clear_command": device.clear_command,
}
if device.command_prompts:
options["command_prompts"] = {
c.encode(): p.encode() for c, p in device.command_prompts.items()
}
return options
def _get_device_username(self, device):
return device.username
def _get_device_password(self, device):
# If the username is specified then password must also be specified.
if device.username:
return self._decrypt(device.password)
def _decrypt(self, data):
return data
async def _record_error(
self, error, command, uuid, options, devinfo, session, **kwargs
):
"""
Subclass can override this method to export the interested error messages
to proper data store
"""
pass
def _remote_task_should_retry(self, ex: Exception) -> bool:
"""
The function that decides whether a remote task (fan out request) should or should not
retry
"""
if not ex:
return False
if isinstance(ex, ttypes.InstanceOverloaded):
return True
return False