src/common/utils.py (234 lines of code) (raw):

# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "LICENSE.txt" file accompanying this file. # This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. # See the License for the specific language governing permissions and limitations under the License. import collections import contextlib import functools import itertools import json import logging import os import pwd import shlex # A nosec comment is appended to the following line in order to disable the B404 check. # In this file the input of the module subprocess is trusted. import subprocess # nosec B404 import sys import time from datetime import datetime, timezone from enum import Enum log = logging.getLogger(__name__) class CriticalError(Exception): """Critical error for the daemon.""" pass class EventType(Enum): ADD = "ADD" REMOVE = "REMOVE" Host = collections.namedtuple("Host", ["instance_id", "hostname", "slots", "gpus"]) UpdateEvent = collections.namedtuple("UpdateEvent", ["action", "message", "host"]) class SlurmCommandError(Exception): def __init__(self, message: str): super().__init__(message) class SlurmCommandErrorHandler: """Handle SlurmCommandError.""" @staticmethod def handle_slurm_command_error(func): """Handle slurm command errors, can be used as a decorator.""" @functools.wraps(func) def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except subprocess.CalledProcessError as e: error = f" Error is: {e.stderr.rstrip()}." if e.stderr else "" output = f" Output is: {e.stdout.rstrip()}." if e.stdout else "" msg = f"Failed to execute slurm command.{error}{output} {e}" log.error(msg) raise SlurmCommandError(msg) return wrapper def load_module(module): """ Load python module. :param module: module path, relative to the caller one. :return: the loaded scheduler module """ # import module __import__(module) # get module from the loaded maps scheduler_module = sys.modules[module] return scheduler_module def check_command_output( command, env=None, raise_on_error=True, execute_as_user=None, log_error=True, timeout=60, shell=False ): """ Execute shell command and retrieve command output. Usage of this function will result in a B604 bandit violation. When building the command string argument, if using an external argument, please validate it using validate_subprocess_argument and/or validate_absolute_path functions based on the argument type. :param command: command to execute :param env: a dictionary containing environment variables :param raise_on_error: True to raise subprocess.CalledProcessError on errors :param execute_as_user: the user executing the command :param log_error: control whether to log or not an error :return: the command output :raise: subprocess.CalledProcessError if the command fails """ if isinstance(command, str) and not shell: command = shlex.split(command) # A nosec B602 comment is appended to the following line in order to disable the B602 check. # This check is disabled for the following reasons: # - Some callers (e.g., common slurm commands) require the use of `shell=True`. # - All values passed as the command arg are constructed from known inputs and are properly validated. result = _run_command( lambda _command, _env, _preexec_fn: subprocess.run( _command, env=_env, preexec_fn=_preexec_fn, timeout=timeout, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8", shell=shell, # nosec B602 ), command, env, raise_on_error, execute_as_user, log_error, ) return result.stdout if hasattr(result, "stdout") else "" def run_command(command, env=None, raise_on_error=True, execute_as_user=None, log_error=True, timeout=60, shell=False): """ Execute shell command. Usage of this function will result in a B604 bandit violation. When building the command string argument, if using an external argument, please validate it using validate_subprocess_argument and/or validate_absolute_path functions based on the argument type. :param command: command to execute :param env: a dictionary containing environment variables :param raise_on_error: True to raise subprocess.CalledProcessError on errors :param log_error: control whether to log or not an error :raise: subprocess.CalledProcessError if the command fails """ if isinstance(command, str) and not shell: command = shlex.split(command) # A nosec B602 comment is appended to the following line in order to disable the B602 check. # This check is disabled for the following reasons: # - Some callers (e.g., common slurm commands) require the use of `shell=True`. # - All values passed as the command arg are constructed from known inputs and are properly validated. _run_command( lambda _command, _env, _preexec_fn: subprocess.run( _command, env=_env, preexec_fn=_preexec_fn, timeout=timeout, check=True, encoding="utf-8", shell=shell, # nosec B602 ), command, env, raise_on_error, execute_as_user, log_error, ) def _demote(user_uid, user_gid): def set_ids(): os.setgid(user_gid) os.setuid(user_uid) return set_ids def _run_command(command_function, command, env=None, raise_on_error=True, execute_as_user=None, log_error=True): try: if env is None: env = {} env.update(os.environ.copy()) if execute_as_user: log.debug("Executing command as user '%s': %s", execute_as_user, command) pw_record = pwd.getpwnam(execute_as_user) user_uid = pw_record.pw_uid user_gid = pw_record.pw_gid preexec_fn = _demote(user_uid, user_gid) return command_function(command, env, preexec_fn) else: log.debug("Executing command: %s", command) return command_function(command, env, None) except subprocess.CalledProcessError as e: # CalledProcessError.__str__ already produces a significant error message if raise_on_error: if log_error: log.error(e) raise else: if log_error: log.warning(e) return e except OSError as e: log.error("Unable to execute the command %s. Failed with exception: %s", command, e) raise def sleep_remaining_loop_time(total_loop_time, loop_start_time=None): wait_remaining_time(time.sleep, total_wait_time=total_loop_time, wait_start_time=loop_start_time) def wait_remaining_time(wait_function, total_wait_time, wait_start_time=None): end_time = datetime.now(tz=timezone.utc) if not wait_start_time: wait_start_time = end_time # Always convert the received loop_start_time to utc timezone. This is so that we never rely on the system local # time and risk to compare naive datatime instances with localized ones wait_start_time = wait_start_time.astimezone(tz=timezone.utc) time_delta = (end_time - wait_start_time).total_seconds() if 0 <= time_delta < total_wait_time: wait_function(total_wait_time - time_delta) def grouper(iterable, n): """Slice iterable into chunks of size n.""" it = iter(iterable) while True: chunk = tuple(itertools.islice(it, n)) if not chunk: return yield chunk def load_additional_instance_types_data(config, section): """Load instance types data from configuration, if set; an empty dict is returned otherwise.""" instance_types_data = {} if config.has_option(section, "instance_types_data"): instance_types_data_str = config.get(section, "instance_types_data") if instance_types_data_str: try: instance_types_data_str = str(instance_types_data_str).strip() # Load json value if not empty if instance_types_data_str: instance_types_data = json.loads(instance_types_data_str) # Fallback to empty dict if value is None if not instance_types_data: instance_types_data = {} log.info( "Additional instance types data loaded for instance types '%s': %s", instance_types_data.keys(), instance_types_data, ) except Exception as e: raise CriticalError("Error loading instance types data from configuration: {0}".format(e)) return instance_types_data def convert_range_to_list(node_range): """ Convert a number range to a list. Example input: Input can be like one of the format: "1-3", "1-2,6", "2, 8" Example output: [1, 2, 3] """ return sum( ( (list(range(*[int(j) + k for k, j in enumerate(i.split("-"))])) if "-" in i else [int(i)]) for i in node_range.split(",") ), [], ) def time_is_up(initial_time: datetime, current_time: datetime, grace_time: float): """Check if timeout is exceeded.""" # Localize datetime objects to UTC if not previously localized # All timestamps used in this function should be already localized # Assume timestamp was taken from system local time if there is no localization info if initial_time is None: return True if not initial_time.tzinfo: logging.warning( "Timestamp %s is not localized. Please double check that this is expected, localizing to UTC.", initial_time ) initial_time = initial_time.astimezone(tz=timezone.utc) if not current_time.tzinfo: logging.warning( "Timestamp %s is not localized. Please double check that this is expected, localizing to UTC", current_time ) current_time = current_time.astimezone(tz=timezone.utc) time_diff = (current_time - initial_time).total_seconds() return time_diff >= grace_time def read_json(file_path, default=None): """Read json file into a dict.""" try: with open(file_path) as mapping_file: return json.load(mapping_file) except Exception as e: if default is None: log.error("Unable to read file from '%s'. Failed with exception: %s", file_path, e) raise else: if not isinstance(e, FileNotFoundError): log.info("Unable to read file '%s' due to an exception: %s. Using default: %s", file_path, e, default) return default def validate_subprocess_argument(argument): """ Validate an argument used to build a subprocess command. The validation is done forcing the encoding to be the standard Python Unicode / UTF-8 and searching for forbidden patterns. :param argument: an argument string to validate :raise: Exception if the argument contains a forbidden pattern :return: True if the argument does not contain forbidden patterns """ forbidden_patterns = ["&", "|", ";", "$", ">", "<", "`", "\\", "!", "#", "\n"] # Forcing the encoding to be the standard Python Unicode / UTF-8 # https://docs.python.org/3/howto/unicode.html # https://docs.python.org/3/library/codecs.html#standard-encodings _argument = (str(argument).encode("utf-8", "ignore")).decode() if any(pattern in _argument for pattern in forbidden_patterns): raise ValueError("Value of provided argument contains at least a forbidden pattern") return True def validate_absolute_path(path): """ Validate if a path string represents is a valid absolute path. :param path: path to validate :raise: Exception if the path is not a valid absolute path :return: True if the path is a valid absolute path """ if not os.path.isabs(path): raise ValueError(f"The path {path} is not a valid absolute path") return True @contextlib.contextmanager def setup_logging_filter(logger: logging.Logger, custom_field: str): """Set up a custom logging filter and remove it at the end of the context.""" class CustomFilter(logging.Filter): def __init__(self, custom_field: str): super().__init__() self.field = custom_field self.value = None def set_custom_value(self, custom_value: str): self.value = custom_value def filter(self, record: logging.LogRecord) -> bool: if self.value: record.msg = f"{self.field} {self.value} - {record.msg}" return True custom_filter = CustomFilter(custom_field) logger.addFilter(custom_filter) try: yield custom_filter finally: # Remove the custom log filter logger.removeFilter(custom_filter) class ApiMocker: """API mocker.""" @staticmethod def mockable(func): """ Try to mock passed function by searching for an overrides.py file in the same path of the given func. This function can be used a decorator and applied any method. The function will check if a function called with the name of the given function exists in the <function-dir>/overrides.py, and if it does, the function will execute it. E.g. if the method with ApiMocker.mockable decorator is defined in Ec2Client class of the ${node_virtualenv_path}/aws/ec2.py module, the mocked function should be defined in the ${node_virtualenv_path}/aws/overrides.py file. """ def wrapper(*args, **kwargs): try: function_name = func.__name__ # retrieve parent module of the given function that has the ApiMocker.mockable decorator func_module = func.__module__ func_parent_module = func_module[: func_module.rindex(".")] # try to import overrides.py module in the same folder of the module to mock overrides_module = __import__(f"{func_parent_module}.overrides", fromlist=function_name) overrided_func = getattr(overrides_module, function_name) log.info("Calling %s override with args: %s and kwargs: %s", function_name, args, kwargs) result = overrided_func(*args, **kwargs) except (ImportError, AttributeError): result = func(*args, **kwargs) return result return wrapper