fbnet/command_runner/command_session.py (734 lines of code) (raw):

#!/usr/bin/env python3 # -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import abc import asyncio import logging import re import sys import time import traceback import typing from collections import namedtuple from dataclasses import dataclass from functools import wraps from typing import List import asyncssh from fbnet.command_runner.exceptions import ( FcrBaseException, RuntimeErrorException, AssertionErrorException, LookupErrorException, StreamReaderErrorException, CommandExecutionTimeoutErrorException, ConnectionErrorException, ConnectionTimeoutErrorException, ) from fbnet.command_runner_asyncio.CommandRunner import ttypes from .base_service import PeriodicServiceTask, ServiceObj from .device_info import IPInfo from .options import Option if typing.TYPE_CHECKING: from fbnet.command_runner.service import FcrServiceBase from .counters import Counters from .device_info import DeviceInfo # Register additional key exchange algorithms asyncssh.public_key.register_public_key_alg(b"rsa-sha2-256", asyncssh.rsa._RSAKey) log = logging.getLogger("fcr.CommandSession") ResponseMatch = namedtuple("ResponseMatch", ["data", "matched", "groupdict", "match"]) class PeerInfoList(typing.NamedTuple): ip_list: typing.Optional[List[IPInfo]] = [] port: typing.Optional[typing.Union[int, str]] = None def __str__(self) -> str: return f"({self.ip_list}, {self.port})" class PeerInfo(typing.NamedTuple): ip: typing.Optional[str] = None ip_is_pingable: typing.Optional[bool] = True port: typing.Optional[typing.Union[int, str]] = None def __str__(self) -> str: return f"({self.ip}, {self.ip_is_pingable}, {self.port})" @dataclass(frozen=False) class CapturedTimeMS: """ Class for capturing different types of communication and processing times (in ms) during an API call. Currently includes external communication time. Add additional types of captured time as fields as needed along with their relevant increment methods """ # captures external communication time (e.g. establishing SSH connection, # waiting for device to feed bytes to the stream, etc.) external_communication_time_ms: float = 0.0 def __add__(self, other): return CapturedTimeMS( external_communication_time_ms=self.external_communication_time_ms + other.external_communication_time_ms ) def __radd__(self, other): raise RuntimeErrorException("Can only add CapturedTimeMS objects together") def reset_time(self) -> None: """ Resets all captured time to 0.0. Any new added fields must be reset in this method. """ self.external_communication_time_ms = 0.0 def increment_external_communication_time_ms(self, time_ms: float) -> None: self.external_communication_time_ms += time_ms class LogAdapter(logging.LoggerAdapter): def process( self, msg: str, kwargs: typing.MutableMapping[str, typing.Any] ) -> typing.Tuple[typing.Any, typing.MutableMapping[str, typing.Any]]: # pyre-fixme[16]: `object` has no attribute `id`. return f"[session_id={self.extra['session'].id}]: {msg}", kwargs class SessionReaperTask(PeriodicServiceTask): SESSION_REAP_PERIOD_S = Option( "--session_reap_period", help="Interval (in seconds) to cleanup stale or long-idle sessions " "(default: %(default)s)", type=int, default=60, ) MAX_SESSION_IDLE_TIMEOUT_S = Option( "--max_session_idle_timeout", help="Maximal accepted value (in seconds) for session idle timeout " "(default: %(default)s)", type=int, default=30 * 60, ) MAX_SESSION_LAST_ACCESS_TIMEOUT_S = Option( "--max_session_last_access_timeout", help="Max time a session can live since last access" "(default: %(default)s)", type=int, default=60 * 60, ) COUNTER_KEY_REAPED_ALL = "session_reaper.reaped.all" def __init__( self, service: ServiceObj, sessions: typing.Optional[ typing.Dict[typing.Hashable, "CommandSession"] ] = None, ) -> None: super().__init__( service, name=self.__class__.__name__, period=self.SESSION_REAP_PERIOD_S ) self._sessions = sessions or CommandSession._ALL_SESSIONS @classmethod def register_counters(cls, stats_mgr: "Counters") -> None: stats_mgr.add_stats_counter(cls.COUNTER_KEY_REAPED_ALL, ["count"]) def _bump_counters_for_reaped_session(self, session: "CommandSession") -> None: self.inc_counter(self.COUNTER_KEY_REAPED_ALL) async def run(self) -> None: """ A session is accessed when a command begins executing, and is accessed again at the end of execution when it is released. A session is freed if 1) it's idle for 'idle_timeout' sec after the last command execution; OR 2) it exceeds the max session time out since last accessed (this could happend when a command get stuck). This would prevent the thrift service from holding up open/stale connections to network devices. """ try: self.logger.info( f"Session reaper woke up: curr_time={time.time()}, " f"session_count={len(self._sessions)}" ) for key in list(self._sessions.keys()): if key not in self._sessions: # Since this is an async method, it's possible that the session # is closed before being reaped continue session = self._sessions[key] curr_time = time.time() time_since_last_access = curr_time - session.last_access_time idle_timeout = min( session.idle_timeout, self.MAX_SESSION_IDLE_TIMEOUT_S ) if time_since_last_access > self.MAX_SESSION_LAST_ACCESS_TIMEOUT_S or ( not session.in_use and time_since_last_access > idle_timeout ): self.logger.info( f"Reap session {key}, " f"last_access_time={session.last_access_time}, " f"curr_time={curr_time}" ) await session.close() if key in self._sessions: del self._sessions[key] self._bump_counters_for_reaped_session(session) self.logger.info( f"Session reaper finished: session_count={len(self._sessions)}" ) except Exception as ex: self.logger.exception(f"Error when reaping session {ex!r}") def _update_last_access_time_and_in_use(fn: typing.Callable) -> typing.Callable: """ This is a decorator to update the last access time of the session before and after calling the wrapped function NOTE: This is for internal use only within CommandSession """ @wraps(fn) async def wrapper(self, *args, **kwargs): self._in_use_count += 1 self._last_access_time = time.time() try: return await fn(self, *args, **kwargs) finally: self._in_use_count -= 1 self._last_access_time = time.time() return wrapper class CommandSession(ServiceObj): """ A session for running commands on devices. Before running a command a CommandSession needs to be created. The connection to the device is established asynchronously, The user should wait for the session to be connected before trying to send commands to the device. Once a session is established, a set of read and write streams will be associated with the session. """ _ALL_SESSIONS: typing.Dict[typing.Hashable, "CommandSession"] = {} # the prompt is at the end of input. So rather then searching in the entire # buffer, we will only look in the trailing data _MAX_PROMPT_SIZE = 100 def __init__( self, service: "FcrServiceBase", devinfo: "DeviceInfo", options: typing.Dict[str, typing.Any], loop: asyncio.AbstractEventLoop, ) -> None: # Setup devinfo as this is needed to create the logger self._devinfo = devinfo super().__init__(service) self._opts = options self.device = self._opts.get("device") self._extra_options = ( self.device and self.device.session_data and self.device.session_data.extra_options ) or {} self._hostname = devinfo.hostname self._pre_setup_commands: typing.List[str] = ( (self.device.pre_setup_commands or []) if self.device else [] ) self._extra_info = {} self._exit_status = None # use the specified username/password passed in by user self._username = options.get("username") self._password = options.get("password") self._client_ip = options["client_ip"] self._client_port = options["client_port"] self._loop = loop # TODO: remove _cmd_stream from the base class CommandSession (some # session type, e.g., rpc base session, does not need this property) self._cmd_stream = None self._connected = False self._event = asyncio.Condition(loop=self._loop) self.logger.info("Created key=%s", self.key) # Record the session in the cache self._ALL_SESSIONS[self.key] = self self._last_access_time: float = time.time() self._in_use_count: int = 0 self._open_time_ms: int = 0 # captures various types of communication and processing times # including external communication time self._captured_time_ms: CapturedTimeMS = CapturedTimeMS() def get_session_name(self) -> str: return self.objname def get_peer_info(self) -> typing.Optional[PeerInfo]: return self._extra_info.get("peer") def get_peer_info_list(self) -> typing.Optional[PeerInfoList]: return self._extra_info.get("peer_list") def create_logger(self) -> LogAdapter: logger = logging.getLogger( "fcr.{klass}.{dev.vendor_name}.{dev.hostname}".format( klass=self.__class__.__name__, dev=self._devinfo ) ) return LogAdapter(logger, {"session": self}) def build_result( self, output: str, status: str, command: str ) -> ttypes.CommandResult: return ttypes.CommandResult(output=output, status=status, command=command) def __repr__(self) -> str: return f"{self.__class__.__name__} [{self._devinfo.hostname}] [{self.id}]" @classmethod def register_counters(cls, stats_mgr: "Counters") -> None: stats_mgr.register_counter(f"{cls.__name__}.setup") stats_mgr.register_counter(f"{cls.__name__}.connected") stats_mgr.register_counter(f"{cls.__name__}.failed") stats_mgr.register_counter(f"{cls.__name__}.closed") @classmethod def get_session_count(cls) -> int: return len(cls._ALL_SESSIONS) @classmethod async def wait_sessions(cls, req_name: str, service: ServiceObj) -> None: session_count = cls.get_session_count() while session_count != 0: await asyncio.sleep(1, loop=service.loop) session_count = cls.get_session_count() service.logger.info(f"{req_name}: pending sessions: {session_count}") service.logger.info(f"{req_name}: no pending sesison") async def __aenter__(self) -> "CommandSession": try: open_connection_time = time.perf_counter() await self.setup() except Exception as e: await self.close() raise self._build_session_exc(e) from e finally: self._open_time_ms = int( (time.perf_counter() - open_connection_time) * 1000 ) return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: await self.close() if exc_val: raise self._build_session_exc(exc_val) from exc_val def _build_session_exc(self, exc: Exception) -> Exception: """ Builds a new exception of the same type as exc Contains original exception's message plus additional messages. """ peer_info = self.get_peer_info() msg = f"Failed (session: {self.get_session_name()}, peer: {peer_info})" if isinstance(peer_info, PeerInfo) and not peer_info.ip_is_pingable: msg += ", IP used in this connection is not pingable according to NetSonar" # Append message as new arg instead of constructing new exception # to account for exceptions having different required args exc.args = exc.args + (msg,) return exc @classmethod def get(cls, session_id: int, client_ip: str, client_port: int) -> "CommandSession": key = (session_id, client_ip, client_port) try: return cls._ALL_SESSIONS[key] except KeyError as ke: raise LookupErrorException("Session not found", key) from ke @property def hostname(self) -> str: return self._hostname @property def username(self) -> str: return self._username @property def devinfo(self): return self._devinfo @property def id(self) -> int: return id(self) @property def key(self) -> typing.Tuple[int, str, int]: return (self.id, self._client_ip, self._client_port) @property def open_timeout(self) -> int: return self._opts.get("open_timeout") @property def open_time_ms(self) -> int: return self._open_time_ms @property def use_mgmt_ip(self) -> bool: return self._opts.get("mgmt_ip") @property def idle_timeout(self) -> int: return self._opts.get("idle_timeout") @property def connected(self) -> bool: return self._connected @property def last_access_time(self) -> float: return self._last_access_time @property def in_use(self) -> bool: return self._in_use_count > 0 @property def exit_status(self) -> int: return self._exit_status @property def captured_time_ms(self) -> CapturedTimeMS: return self._captured_time_ms async def _create_connection(self) -> None: await self.connect() @_update_last_access_time_and_in_use async def setup(self) -> "CommandSession": self.inc_counter(f"{self.objname}.setup") try: await asyncio.wait_for( self._create_connection(), self.open_timeout, loop=self._loop ) except asyncio.TimeoutError: self.logger.exception("Timeout during connection setup") data = [] # TODO(mzheng): Move the _steam_reader check to subclasses that # define it # pyre-fixme if hasattr(self, "_stream_reader") and self._stream_reader: data = await self._stream_reader.drain() raise ConnectionTimeoutErrorException( "Timeout during connection setup. Currently received data " f"(last 200 char): {data[-200:]}" ) return self async def connect(self) -> None: """ Initiates a connection on the session """ try: self._cmd_stream = await self._connect() self.inc_counter(f"{self.objname}.connected") self.logger.info(f"Connected: {self._extra_info}") except Exception as e: self.logger.error(f"Connect Failed {e!r}") self.inc_counter(f"{self.objname}.failed") if isinstance(e, FcrBaseException): raise raise ConnectionErrorException(repr(e)) from e async def close(self) -> None: """ Close the session. This removes the session from the cache. Also invokes the session specific _close method """ try: self.logger.debug("Closing session") if self.key in self._ALL_SESSIONS: del self._ALL_SESSIONS[self.key] finally: await self._close() if self._cmd_stream is not None: self._cmd_stream.close() self._connected = False self.inc_counter(f"{self.objname}.closed") @_update_last_access_time_and_in_use async def run_command( self, command: bytes, timeout: typing.Optional[int] = None, prompt_re: typing.Optional[typing.Pattern] = None, ) -> bytes: return await self._run_command( command=command, timeout=timeout, prompt_re=prompt_re ) @abc.abstractmethod async def _connect(self) -> None: """ This needs to be implemented by the actual session classes """ pass @abc.abstractmethod async def _close(self) -> None: """ This needs to be implemented by the actual session classes """ pass @abc.abstractmethod async def _run_command( self, command: bytes, timeout: typing.Optional[int] = None, prompt_re: typing.Optional[typing.Pattern] = None, ) -> bytes: """ This needs to be implemented by the actual session classes """ pass async def wait_until_connected(self, timeout: typing.Optional[int] = None) -> None: """ Wait until the session is marked as connected """ try: await self.wait_for(lambda _: self._connected, timeout=timeout) except asyncio.TimeoutError as exc: raise ConnectionTimeoutErrorException( "Timed out before session marked as connected" ) from exc async def _notify(self) -> None: """ notify a change in stream state """ await self._event.acquire() self._event.notify_all() self._event.release() async def wait_for( self, predicate: typing.Callable, timeout: typing.Optional[int] = None ) -> None: """ Wait for condition to become true on the session """ await self._event.acquire() await asyncio.wait_for( self._event.wait_for(lambda: predicate(self)), timeout=timeout, loop=self._loop, ) self._event.release() class CommandStreamReader(asyncio.StreamReader): """ A Reader for commmand responses Extends the asyncio.StreamReader and adds support for waiting for regex match on received data """ QUICK_COMMAND_RUNTIME = 1 COMMAND_DATA_TIMEOUT = 1 def __init__(self, session: CommandSession, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._session = session self._last_feed_data_call_time_s: float = 0.0 @property def logger(self) -> LogAdapter: return self._session.logger def feed_data(self, data: bytes) -> None: feed_data_call_time_s = time.perf_counter() # only increment external time if there was a last call time & session's cmd_stream not None # (i.e. _connect done, so not simultaneously capturing session connection time in _connect) if self._last_feed_data_call_time_s and self._session._cmd_stream: self._session.captured_time_ms.increment_external_communication_time_ms( (feed_data_call_time_s - self._last_feed_data_call_time_s) * 1000 ) # Only update last call time if cmd_stream not None (i.e. _connect done) # to prevent overlapping captured time with _connect (which may also call feed_data). # If feed_data is called for the first time without going through wait_for, # this will also start capturing the time for those non-wait_for feed_data calls. if self._session._cmd_stream: self._last_feed_data_call_time_s = feed_data_call_time_s return super().feed_data(data) async def wait_for( self, predicate: typing.Callable, timeout: typing.Optional[int] = None ) -> typing.Match: """ Wait for the predicate to become true on the stream. As and when new data is available, the predicate will be re-evaluated. """ if self._exception is not None: # pyre-ignore raise StreamReaderErrorException(repr(self._exception)) from self._exception res = predicate(self._buffer) # pyre-ignore start_ts = time.time() # Set an initial time so that first call to feed_data has a # reference from which to capture how much time has passed self._last_feed_data_call_time_s = time.perf_counter() while res is None: now = time.time() # Here we add a protection to avoid this function from doing infinite regex matching # This will ensure that we will eventually break out from the while loop if timeout # is set if timeout and now - start_ts >= timeout: raise CommandExecutionTimeoutErrorException( "FCR timeout during matching regex against current buffer output from device." ) self.logger.debug( f"match failed in: {len(self._buffer)}: {self._limit}: {self._buffer[-100:]}" # pyre-ignore ) self._session.inc_counter("streamreader.wait_for_retry") if len(self._buffer) > self._limit: self._session.inc_counter("streamreader.overrun") raise StreamReaderErrorException( "Reader buffer overrun: %d: %d" % (len(self._buffer), self._limit) ) if now - start_ts > self.QUICK_COMMAND_RUNTIME: # Keep waiting for data till we get a timeout try: while True: fut = self._wait_for_data( # pyre-ignore "CommandStreamReader.wait_for" ) await asyncio.wait_for( fut, timeout=self.COMMAND_DATA_TIMEOUT, loop=self._loop, # pyre-ignore ) except asyncio.TimeoutError: # Now try to match the prompt pass else: # match quickly initially await self._wait_for_data("CommandStreamReader.wait_for") res = predicate(self._buffer) self.logger.debug("match found at: %s", res) # Reset last_feed_data_call_time_s to 0.0 so that in case of a later call to feed_data # that doesn't go through wait_for, we don't accidentally capture the time from now until then self._last_feed_data_call_time_s = 0.0 return res def _search_re( self, regex: typing.Pattern, data: bytes, start: int = 0 ) -> typing.Optional[typing.Match]: self.logger.debug(f"searching for: {regex}") return regex.search(data, start) async def readuntil_re( self, regex: typing.Pattern, timeout: typing.Optional[int] = None, start: int = 0, ) -> ResponseMatch: """ Read data until a regex is matched on the input stream """ self.logger.debug("readuntil_re: %s", regex) try: match = await self.wait_for(lambda data: regex.search(data, start), timeout) m_beg, m_end = match.span() # We are matching against the data stored stored in bytebuffer # The bytebuffer is manipulated in place. After we read the data # the buffer may get overwritten. The match object seems to be # directly referring the data in bytebuffer. This causes a problem # when we try to find the matched groups in match object. # # In [38]: data = bytearray(b"localhost login:") # # In [39]: rex = re.compile(b'(?P<login>.*((?<!Last ).ogin|.sername):)|(?P<passwd>\n.*assword:)|(?P<prompt>\n.*[%#>])|(?P<ignore>( to cli \\])|(who is on this device.\\]\r\n)|(Press R # ...: ETURN to get started\r\n))\\s*$') # # In [40]: m = rex.search(data) # # In [41]: m.groupdict() # Out[41]: {'ignore': None, 'login': b'localhost login:', 'passwd': None, 'prompt': None} # # In [42]: data[:]=b'overwrite' # # In [43]: m.groupdict() # Out[43]: {'ignore': None, 'login': b'overwrite', 'passwd': None, 'prompt': None} # groupdict = match.groupdict() rdata = await self.read(m_end) data = rdata[:m_beg] # Data before the regex match matched = rdata[m_beg:m_end] # portion that matched regex except AssertionError as exc: if self._eof: # pyre-ignore # We are at the EOF. Read the whole buffer and send it back data = await self.read(len(self._buffer)) # pyre-ignore matched = b"" match = None groupdict = None else: # re-raise the exception raise AssertionErrorException(str(exc)) from exc return ResponseMatch(data, matched, groupdict, match) async def drain(self) -> bytes: """ Drain the read buffer. Typically used before sending a new commands to make sure the stream in in sane state """ return await self.read(len(self._buffer)) # pyre-ignore class CommandStream(asyncio.StreamReaderProtocol): # TODO: make this tweakable from configerator _BUFFER_LIMIT = 100 * (2 ** 20) # 100M def __init__( self, session: "CliCommandSession", loop: asyncio.AbstractEventLoop ) -> None: super().__init__( CommandStreamReader(session, limit=self._BUFFER_LIMIT, loop=loop), client_connected_cb=self._on_connect, loop=loop, ) self._session = session self._loop = loop def _on_connect( self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter ) -> None: """ called when transport is connected """ # Sometimes the remote side doesn't send the newline for the first # prompt. This causes our prompt matching to fail. Here we inject a # newline to normalize these cases. This keeps our prompt processing # simple. super().data_received(b"\n") self._session._session_connected(stream_reader, stream_writer) def close(self) -> None: if self._stream_writer: # pyre-ignore self._stream_writer.close() def data_received(self, data: bytes, datatype=None) -> None: # TODO: check if we need to handle stderr data separately # for stderr data: datatype == EXTENDED_DATA_STDERR return super().data_received(data) def session_started(self) -> None: # Not used yet. But needs to be defined pass def exit_status_received(self, status: str) -> None: self._session.exit_status_received(status) class CliCommandSession(CommandSession): """ A command session for CLI commands. Does prompt processing on the command stream. """ _SPECIAL_CHAR_REGEX = re.compile(br".\x08|\x07") _NEWLINE_REPLACE_REGEX = re.compile(br"(\r+\n)|(\n\r+)|\r") def __init__( self, service: "FcrServiceBase", devinfo: "DeviceInfo", options: typing.Dict[str, typing.Any], loop: asyncio.AbstractEventLoop, ) -> None: super().__init__(service, devinfo, options, loop) self._cmd_stream = None self._stream_reader = None # for reading data from device self._stream_writer = None # for writing data to the device # TODO: investigate if we need an error stream @classmethod def register_counters(cls, stats_mgr: "Counters") -> None: super().register_counters(stats_mgr) stats_mgr.register_counter("streamreader.wait_for_retry") stats_mgr.register_counter("streamreader.overrun") stats_mgr.register_counter("streamreader.overrun") async def _setup_connection(self) -> None: # At this point login process should already be complete. If a # particular session needs to send password, it should override this # method and complete the login before calling this method await self.wait_prompt() for cmd in self._pre_setup_commands: self.logger.debug(f"Sending pre setup command: {cmd}") await self.run_command(cmd.encode("utf-8") + b"\n") for cmd in self._devinfo.vendor_data.cli_setup: self.logger.debug(f"Sending setup command: {cmd}") await self.run_command(cmd + b"\n") async def _create_connection(self) -> None: await super()._create_connection() await self.wait_until_connected(self.open_timeout) await self._setup_connection() async def wait_prompt( self, prompt_re: typing.Optional[typing.Pattern] = None, timeout: typing.Optional[int] = None, ) -> ResponseMatch: """ Wait for a prompt """ return await self._stream_reader.readuntil_re( prompt_re or self._devinfo.prompt_re, timeout, -self._MAX_PROMPT_SIZE, ) async def _wait_response( self, prompt_re: typing.Pattern, timeout: int ) -> ResponseMatch: """ Wait for command response from the device """ self.logger.debug("Waiting for prompt") resp = await self.wait_prompt(prompt_re=prompt_re, timeout=timeout) return resp def _fixup_whitespace(self, output: bytes) -> bytes: # we need to sanitize the output to remove '\r' and other chars. # List of chars that will be removed # ' *\x08+': space* followed by backspace characters # '\x07' : BEL(bell) char output = self._SPECIAL_CHAR_REGEX.sub(b"", output) # # We need to apply following transforms # '\r+\n' -> '\n' # '\n\r+' -> '\n' # '\r' -> '\n' standalone \r output = self._NEWLINE_REPLACE_REGEX.sub(b"\n", output) return output.strip() def _format_output(self, cmd: bytes, resp: ResponseMatch) -> bytes: """ Format the output to comply with following format <prompt> <command> command-output ... In addition '\r\n' | '\n\r' | '\r' will be replace with '\n' """ cmd_words = cmd.split() # Fixup the white spaces first, as some devices are inserting backspace # characters in the command echo cmd_output = self._fixup_whitespace(resp.data) # Command regex in the output # [SPACE]{Command string}[SPACE] # The words in the command string can be separated by mulitple spaces. # for e.g regex for matching 'show version' command would be # b'^\s*show\s+version\s*$' # We also need to escape the words to handle characters like '|' cmd_words_esc = (re.escape(w) for w in cmd_words) cmd_re = br"^\s*" + br"\s+".join(cmd_words_esc) + br"([ \t]*\n)*" # Now replace the 'command string' in the output with a sanitized # version (redundant spaces removed) # ' show version ' ==> 'show version' cmd_output = re.sub(cmd_re, b" ".join(cmd_words) + b"\n", cmd_output, 1, re.M) # Now we need to prepend the prompt to the command output. The prompt is # the matched part in the 'resp' output = resp.matched.strip() + b" " + cmd_output return output async def _run_command( self, command: bytes, timeout: typing.Optional[int] = None, prompt_re: typing.Optional[typing.Pattern] = None, ) -> bytes: """ Run a command and return response to user """ if not self._connected: raise RuntimeErrorException( "Not Connected", f"status: {self.exit_status!r}", self.key ) # Ideally there should be no data on the stream. We will in any case # drain any stale data. This is mostly for debugging and making sure # that we are in sane state stale_data = await self._stream_reader.drain() if len(stale_data) != 0: self.logger.warning(f"Stale data on session: {stale_data}") output = [] commands = command.splitlines() for command in commands: cmdinfo = self._devinfo.get_command_info( command, self._opts.get("command_prompts"), self._opts.get("clear_command"), ) self.logger.info(f"RUN: {cmdinfo.cmd!r}") # Send any precmd data (e.g. \x15 to clear the commandline) if cmdinfo.precmd: self._stream_writer.write(cmdinfo.precmd) self._stream_writer.write(cmdinfo.cmd) try: prompt = prompt_re or cmdinfo.prompt_re cmd_timeout = timeout or self._devinfo.vendor_data.cmd_timeout_sec resp = await asyncio.wait_for( self._wait_response(prompt, cmd_timeout), cmd_timeout, loop=self._loop, ) output.append(self._format_output(command, resp)) except asyncio.TimeoutError: self.logger.error("Timeout waiting for command response") data = await self._stream_reader.drain() raise CommandExecutionTimeoutErrorException( "Command Response Timeout", data[-200:] ) return b"\n".join(output).rstrip() def _session_connected( self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter ) -> None: """ This called once the session is connected to the transport. stream_reader and stream_writer are used for receiving and sending data on the session """ self._stream_reader = stream_reader self._stream_writer = stream_writer self._connected = True # Notify anyone waiting for session to be connected asyncio.ensure_future(self._notify(), loop=self._loop) def exit_status_received(self, status: str) -> None: self.logger.info(f"exit status received: {status}") self._connected = False self._exit_status = str(status) class SSHCommandClient(asyncssh.SSHClient): """ The connection objects are leaked if the session timeout while the authentication is in progres. The fix ideally needs to be implemented in asyncssh. For now we are adding a workaround in FCR. We will save the connection object when we get a connection_made callback. This will be used to close the connection when we close the session. """ def __init__(self, session: "SSHCommandSession") -> None: super().__init__() self._session = session def connection_made(self, conn: asyncssh.SSHClientConnection) -> None: super().connection_made(conn) self._session.connection_made(conn) class SSHCommandSession(CliCommandSession): TERM_TYPE: typing.Optional[str] = "vt100" def __init__( self, service: "FcrServiceBase", devinfo: "DeviceInfo", options: typing.Dict[str, typing.Any], loop: asyncio.AbstractEventLoop, ) -> None: super().__init__(service, devinfo, options, loop) self._conn = None self._chan = None def connection_made(self, conn: asyncssh.SSHClientConnection) -> None: s = conn.get_extra_info("socket") self._extra_info["fd"] = s.fileno() self._extra_info["sockname"] = conn.get_extra_info("sockname") self._conn = conn def _client_factory(self) -> SSHCommandClient: return SSHCommandClient(self) async def dest_info(self) -> typing.Tuple[List[IPInfo], int, str, str]: ip_list = self.service.ip_utils.get_ip( options=self._opts, devinfo=self._devinfo, service=self.service ) port = int( self._extra_options.get("port") or self._devinfo.vendor_data.get_port() ) return (ip_list, port, self._username, self._password) # pyre-fixme: Inconsistent override async def _connect( self, subsystem: typing.Optional[str] = None, exec_command: typing.Optional[str] = None, ) -> asyncssh.SSHTCPSession: """ Some session types require us to run a command to start a session. The SSH protocol defines three ways to start a session. 1. shell: this starts a regular user shell on the remote system. This is the most common way of using SSH. If none of 'subsystem' or 'command' is specified, this is the method that we use. 2. exec: Here we specify the command we want to run on remote system. This allows the user start a custom shell. For example run a 'netconf' command to start netconf session 3. subsystem: Here instead of running a comman we specify a subsystems that has been configured on the remote system. These are predefined systems see sec 6.5 https://tools.ietf.org/html/rfc4254 for more details """ ip_list, port, user, passwd = await self.dest_info() self.logger.debug(f"Order in which ips will be tried: {ip_list}") self._extra_info["peer_list"] = PeerInfoList(ip_list, port) if self.device and not self.device.failover_to_backup_ips: # Use the first IP in the list if failover is not enabled ip, ip_is_pingable = ip_list[0] try: return await self._connect_to_ip( ip, port, user, passwd, subsystem, exec_command, ) finally: self._extra_info["peer"] = PeerInfo(ip, ip_is_pingable, port) ips_tried = [] for index, (ip, ip_is_pingable) in enumerate(ip_list): try: return await self._connect_to_ip( ip, port, user, passwd, subsystem, exec_command, ) except Exception as e: self.logger.exception(f"Connection to {ip} failed") ips_tried.append(ip) # Raise the last exception in the iteration if index == len(ip_list) - 1: msg = f"IPs that failed to connect: {ips_tried}" # Gather the information from the original exception: exc_type, exc_value, exc_traceback = sys.exc_info() traceback_string = "".join( traceback.format_exception(exc_type, exc_value, exc_traceback) ) # Re-raise a new exception of the same class as the original one, # using custom message and the original traceback if isinstance(e, asyncssh.misc.DisconnectError): raise type(e)(code=e.code, reason=f"{msg}:{e.reason}") raise type(e)(f"{msg}:{traceback_string}") finally: self._extra_info["peer"] = PeerInfo(ip, ip_is_pingable, port) raise LookupErrorException( f"No Valid IP address was found for the device {self._hostname}: {ip_list}" ) async def _connect_to_ip( self, ip: str, port: int, user: str, passwd: str, subsystem: typing.Optional[str] = None, exec_command: typing.Optional[str] = None, ) -> asyncssh.SSHTCPSession: if self.service.ip_utils.proxy_required(ip): host = self.service.get_http_proxy_url(ip) elif self.service.ip_utils.should_nat(ip, self.service): host = await self.service.ip_utils.translate_address(ip, self.service) else: host = ip self.logger.info("Connecting to: %s: %d", host, port) open_connection_time_s = time.perf_counter() # known_hosts is set to None to disable the host verifications. Without # this the connection setup fails for some devices conn, _ = await asyncssh.create_connection( self._client_factory, host=host, port=port, username=user, password=passwd, client_keys=None, known_hosts=None, ) chan, cmd_stream = await self._conn.create_session( lambda: CommandStream(self, self._loop), encoding=None, term_type=self.TERM_TYPE, subsystem=subsystem, command=exec_command, ) self._chan = chan end_connection_time_s = time.perf_counter() self.captured_time_ms.increment_external_communication_time_ms( (end_connection_time_s - open_connection_time_s) * 1000 ) return cmd_stream async def _close(self) -> None: if self._chan is not None: self._chan.close() if self._conn is not None: self._conn.close()