cookbooks/aws-parallelcluster-computefleet/files/clusterstatusmgtd/clusterstatusmgtd.py (359 lines of code) (raw):
# Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the
# License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=W0719
# FIXME: Fix Code Duplication
# pylint: disable=R0801
import functools
import json
import logging
import os
import shlex
# A nosec comment is appended to the following line in order to disable the B404 check.
# In this file the input of the module subprocess is trusted.
import subprocess # nosec B404
import time
from configparser import ConfigParser
from datetime import datetime, timezone
from enum import Enum
from logging.config import fileConfig
import boto3
from boto3.dynamodb.conditions import Attr
from botocore.config import Config
CONFIG_FILE_DIR = "/etc/parallelcluster"
LOOP_TIME = 60
log = logging.getLogger(__name__)
# Utils
def _seconds(sec):
"""Convert seconds to milliseconds."""
return sec * 1000
def _minutes(min):
"""Convert minutes to seconds."""
return min * 60
def _sleep_remaining_loop_time(total_loop_time, loop_start_time=None):
end_time = datetime.now(tz=timezone.utc)
if not loop_start_time:
loop_start_time = end_time
# Always convert the received loop_start_time to utc timezone. This is so that we never rely on the system local
# time and risk to compare native datatime instances with localized ones
loop_start_time = loop_start_time.astimezone(tz=timezone.utc)
time_delta = (end_time - loop_start_time).total_seconds()
if 0 <= time_delta < total_loop_time:
time.sleep(total_loop_time - time_delta)
def log_exception(
logger,
action_desc,
log_level=logging.ERROR,
catch_exception=Exception,
raise_on_error=True,
exception_to_raise=None,
):
def decorator_log_exception(function):
@functools.wraps(function)
def wrapper_log_exception(*args, **kwargs): # pylint: disable=R1710
try:
return function(*args, **kwargs)
except catch_exception as e:
logger.log(log_level, "Failed when %s with exception %s", action_desc, e)
if raise_on_error:
if exception_to_raise:
raise exception_to_raise
raise
return None
return wrapper_log_exception
return decorator_log_exception
def _run_command( # noqa: C901
command,
capture_output=False,
log_error=True,
env=None,
timeout=None,
raise_on_error=True,
):
"""Execute shell command."""
if isinstance(command, str):
command = shlex.split(command)
log_command = command if isinstance(command, str) else " ".join(str(arg) for arg in command)
log.info("Executing command: %s", log_command)
try:
result = subprocess.run( # nosec - trusted input
command,
capture_output=capture_output,
universal_newlines=True,
encoding="utf-8",
env=env,
timeout=timeout,
check=False,
)
result.check_returncode()
except subprocess.CalledProcessError:
if log_error:
log.error(
"Command %s failed",
log_command,
)
if raise_on_error:
raise
except subprocess.TimeoutExpired:
if log_error:
log.error("Command %s timed out after %s sec", log_command, timeout)
if raise_on_error:
raise
return result
def _write_json_to_file(filename, json_data):
"""Write json to file."""
with open(filename, "w", encoding="utf-8") as file:
file.write(json.dumps(json_data))
class ComputeFleetStatus(Enum):
"""Represents the status of the cluster compute fleet."""
STOPPED = "STOPPED" # Fleet is stopped, partitions are inactive.
RUNNING = "RUNNING" # Fleet is running, partitions are active.
STOPPING = "STOPPING" # clusterstatusmgtd is handling the stop request.
STARTING = "STARTING" # clusterstatusmgtd is handling the start request.
STOP_REQUESTED = "STOP_REQUESTED" # A request to stop the fleet has been submitted.
START_REQUESTED = "START_REQUESTED" # A request to start the fleet has been submitted.
# PROTECTED indicates that some partitions have consistent bootstrap failures. Affected partitions are inactive.
PROTECTED = "PROTECTED"
# cluster compute fleet mapping for status exposed to the update event handler
EVENT_HANDLER_STATUS_MAPPING = {STOPPING: STOP_REQUESTED, STARTING: START_REQUESTED}
UNKNOWN = "UNKNOWN"
def __str__(self): # noqa: D105
return str(self.value)
@staticmethod
def _transform_compute_fleet_data(compute_fleet_data):
try:
compute_fleet_data[ComputeFleetStatusManager.COMPUTE_FLEET_STATUS_ATTRIBUTE] = (
ComputeFleetStatus.EVENT_HANDLER_STATUS_MAPPING.value.get(
compute_fleet_data.get(ComputeFleetStatusManager.COMPUTE_FLEET_STATUS_ATTRIBUTE),
str(ComputeFleetStatus.UNKNOWN),
)
)
return compute_fleet_data
except AttributeError as e:
raise Exception(f"Unable to parse compute fleet status data: {e}")
@staticmethod
def is_start_in_progress(status): # noqa: D102
return status in {ComputeFleetStatus.START_REQUESTED, ComputeFleetStatus.STARTING}
@staticmethod
def is_stop_in_progress(status): # noqa: D102
return status in {ComputeFleetStatus.STOP_REQUESTED, ComputeFleetStatus.STOPPING}
@staticmethod
def is_protected_status(status): # noqa: D102
return status == ComputeFleetStatus.PROTECTED
class ComputeFleetStatusManager:
"""
Manages the compute fleet status store into the DynamoDB table.
The value stored in the table is a json in the following form
{
"status": "STOPPING",
"lastStatusUpdatedTime": "2021-12-21 18:12:07.485674+00:00",
"queues": {
"queue_name1": {
"status": "RUNNING",
"lastStatusUpdatedTime": "2021-12-21 18:10:02.485674+00:00",
}
}
}
"""
DB_KEY = "COMPUTE_FLEET"
DB_DATA = "Data"
COMPUTE_FLEET_STATUS_ATTRIBUTE = "status"
COMPUTE_FLEET_LAST_UPDATED_TIME_ATTRIBUTE = "lastStatusUpdatedTime"
QUEUES_ATTRIBUTE = "queues"
QUEUE_STATUS_ATTRIBUTE = "status"
QUEUE_LAST_UPDATED_TIME_ATTRIBUTE = "lastStatusUpdatedTime"
class ConditionalStatusUpdateFailedError(Exception):
"""Raised when there is a failure in updating the status due to a change occurred after retrieving its value."""
pass
class FleetDataNotFoundError(Exception):
"""Raised when compute fleet data cannot be found in db table."""
pass
def __init__(self, table_name, boto3_config, region):
self._ddb_resource = boto3.resource("dynamodb", region_name=region, config=boto3_config)
self._table = self._ddb_resource.Table(table_name)
def get_status(self): # noqa: D102
compute_fleet_item = self._table.get_item(
ConsistentRead=True,
Key={"Id": self.DB_KEY},
)
if not compute_fleet_item or "Item" not in compute_fleet_item:
raise ComputeFleetStatusManager.FleetDataNotFoundError("COMPUTE_FLEET data not found in db table")
log.debug("Found COMPUTE_FLEET data (%s)", compute_fleet_item)
return compute_fleet_item["Item"].get(self.DB_DATA)
def update_status(self, current_status, next_status): # noqa: D102
try:
updated_attributes = self._table.update_item(
Key={"Id": self.DB_KEY},
UpdateExpression="set #dt.#st=:s, #dt.#lut=:t",
ExpressionAttributeNames={
"#dt": self.DB_DATA,
"#st": self.COMPUTE_FLEET_STATUS_ATTRIBUTE,
"#lut": self.COMPUTE_FLEET_LAST_UPDATED_TIME_ATTRIBUTE,
},
ExpressionAttributeValues={
":s": str(next_status),
":t": str(datetime.now(tz=timezone.utc)),
},
ConditionExpression=Attr(f"{self.DB_DATA}.{self.COMPUTE_FLEET_STATUS_ATTRIBUTE}").eq(
str(current_status)
),
ReturnValues="ALL_NEW",
)
return updated_attributes.get("Attributes").get(f"{self.DB_DATA}")
except self._ddb_resource.meta.client.exceptions.ConditionalCheckFailedException as e:
raise ComputeFleetStatusManager.ConditionalStatusUpdateFailedError(e)
class ClusterstatusmgtdConfig:
"""Represents the cluster status management demon configuration."""
DEFAULTS = {
"max_retry": 5,
"loop_time": LOOP_TIME,
"proxy": "NONE",
"logging_config": os.path.join(os.path.dirname(__file__), "clusterstatusmgtd_logging.conf"),
"update_event_timeout_minutes": 15,
}
def __init__(self, config_file_path):
self._get_config(config_file_path)
def __repr__(self): # noqa: D105
attrs = ", ".join([f"{key}={repr(value)}" for key, value in self.__dict__.items()])
return f"{self.__class__.__name__}({attrs})"
def __eq__(self, other): # noqa: D105
if type(other) is type(self):
return self._config == other._config
return False
def __ne__(self, other): # noqa: D105
return not self.__eq__(other)
@log_exception(
log, "reading cluster status manger configuration file", catch_exception=IOError, raise_on_error=True
)
def _get_config(self, config_file_path):
"""Get clusterstatusmgtd configuration."""
log.info("Reading %s", config_file_path)
self._config = ConfigParser()
with open(config_file_path, "r", encoding="utf-8") as config_file:
self._config.read_file(config_file)
# Get config settings
self._get_basic_config(self._config)
def _get_basic_config(self, config):
"""Get basic config options."""
self.region = config.get("clusterstatusmgtd", "region")
self.cluster_name = config.get("clusterstatusmgtd", "cluster_name")
self.dynamodb_table = config.get("clusterstatusmgtd", "dynamodb_table")
self.computefleet_status_path = config.get("clusterstatusmgtd", "computefleet_status_path")
self.logging_config = config.get(
"clusterstatusmgtd", "logging_config", fallback=self.DEFAULTS.get("logging_config")
)
self.loop_time = config.getint("clusterstatusmgtd", "loop_time", fallback=self.DEFAULTS.get("loop_time"))
self.update_event_timeout_minutes = config.getint(
"clusterstatusmgtd",
"update_event_timeout_minutes",
fallback=self.DEFAULTS.get("update_event_timeout_minutes"),
)
# Configure boto3 to retry 1 times by default
self._boto3_retry = config.getint("clusterstatusmgtd", "boto3_retry", fallback=self.DEFAULTS.get("max_retry"))
self._boto3_config = {"retries": {"max_attempts": self._boto3_retry, "mode": "standard"}}
# Configure proxy
proxy = config.get("clusterstatusmgtd", "proxy", fallback=self.DEFAULTS.get("proxy"))
if proxy != "NONE":
self._boto3_config["proxies"] = {"https": proxy}
self.boto3_config = Config(**self._boto3_config)
class ClusterStatusManager:
"""The cluster status manager."""
def __init__(self, config):
"""Initialize ClusterStatusManager."""
self._config = None
self._current_time = None
self._compute_fleet_status_manager = None
self._compute_fleet_status = ComputeFleetStatus.RUNNING
self._compute_fleet_data = {}
self.set_config(config)
class ClusterStatusUpdateEventError(Exception):
"""Raised when there is a failure in updating the status due to an error on update event handler execution."""
pass
def set_config(self, config): # noqa: D102
if self._config != config:
log.info("Applying new clusterstatusmgtd config: %s", config)
self._config = config
self._compute_fleet_status_manager = self._initialize_compute_fleet_status_manager(config)
@staticmethod
def _initialize_compute_fleet_status_manager(config):
return ComputeFleetStatusManager(
table_name=config.dynamodb_table, boto3_config=config.boto3_config, region=config.region
)
def _get_compute_fleet_status(self, fallback=None):
try:
log.info("Getting compute fleet status")
self._compute_fleet_data = self._compute_fleet_status_manager.get_status()
return ComputeFleetStatus(
self._compute_fleet_data.get(self._compute_fleet_status_manager.COMPUTE_FLEET_STATUS_ATTRIBUTE)
)
except Exception as e:
log.error(
"Failed when retrieving computefleet status from DynamoDB with error %s, using fallback value %s",
e,
fallback,
)
return fallback
def _update_compute_fleet_status(self, new_status):
log.info("Updating compute fleet status from %s to %s", self._compute_fleet_status, new_status)
self._compute_fleet_data = self._compute_fleet_status_manager.update_status(
current_status=self._compute_fleet_status, next_status=new_status
)
self._compute_fleet_status = new_status
def _call_update_event(self):
try:
compute_fleet_data = ComputeFleetStatus._transform_compute_fleet_data( # pylint: disable=W0212
self._compute_fleet_data
)
_write_json_to_file(self._config.computefleet_status_path, compute_fleet_data)
except Exception as e:
log.error("Update event handler failed during fleet status translation: %s", e)
raise ClusterStatusManager.ClusterStatusUpdateEventError(e)
cinc_log_file = "/var/log/chef-client.log"
log.info("Calling update event handler, log can be found at %s", cinc_log_file)
cmd = (
"sudo cinc-client "
"--local-mode "
"--config /etc/chef/client.rb "
"--log_level auto "
f"--logfile {cinc_log_file} "
"--force-formatter "
"--no-color "
"--chef-zero-port 8889 "
"--json-attributes /etc/chef/dna.json "
"--override-runlist aws-parallelcluster-entrypoints::update_computefleet_status"
)
try:
# The command being passed has been built from string literals and local variables and can be trusted.
_run_command(cmd, self._config.update_event_timeout_minutes)
except Exception as e:
log.error("Update event handler failed. Check log file %s", cinc_log_file)
raise ClusterStatusManager.ClusterStatusUpdateEventError(e)
def _update_status(self, request_status, in_progress_status, final_status):
if self._compute_fleet_status == request_status:
self._update_compute_fleet_status(in_progress_status)
self._call_update_event()
if self._compute_fleet_status == in_progress_status:
self._update_compute_fleet_status(final_status)
@log_exception(log, "handling compute fleet status transitions", catch_exception=Exception, raise_on_error=False)
def manage_cluster_status(self):
"""
Manage cluster status.
When running pcluster start/stop command the fleet status is set to START_REQUESTED/STOP_REQUESTED.
The function fetches the current fleet status and performs the following transitions:
- START_REQUESTED -> STARTING -> RUNNING
- STOP_REQUESTED -> STOPPING -> STOPPED
STARTING/STOPPING states are only used to communicate that the request is being processed by clusterstatusmgtd.
On status STARTING|STOPPING, the update event handler baked by the recipe
aws-parallelcluster-entrypoints::update_computefleet_status is called
"""
self._current_time = datetime.now(tz=timezone.utc)
self._compute_fleet_status = self._get_compute_fleet_status(fallback=self._compute_fleet_status)
log.info("Current compute fleet status: %s", self._compute_fleet_status)
try:
if ComputeFleetStatus.is_stop_in_progress(self._compute_fleet_status):
self._update_status(
ComputeFleetStatus.STOP_REQUESTED, ComputeFleetStatus.STOPPING, ComputeFleetStatus.STOPPED
)
elif ComputeFleetStatus.is_start_in_progress(self._compute_fleet_status):
self._update_status(
ComputeFleetStatus.START_REQUESTED, ComputeFleetStatus.STARTING, ComputeFleetStatus.RUNNING
)
except ComputeFleetStatusManager.ConditionalStatusUpdateFailedError:
log.warning(
"Cluster status was updated while handling a transition from %s. "
"Status transition will be retried at the next iteration",
self._compute_fleet_status,
)
def _run_clusterstatusmgtd(config_file):
config = ClusterstatusmgtdConfig(config_file)
cluster_status_manager = ClusterStatusManager(config=config)
while True:
# Get loop start time
start_time = datetime.now(tz=timezone.utc)
# Get program config
try:
config = ClusterstatusmgtdConfig(config_file)
cluster_status_manager.set_config(config)
except Exception as e:
log.warning("Unable to reload daemon config from %s, using previous one.\nException: %s", config_file, e)
# Configure root logger
try:
fileConfig(config.logging_config, disable_existing_loggers=False)
except Exception as e:
log.warning(
"Unable to configure logging from %s, using default logging settings.\nException: %s",
config.logging_config,
e,
)
# Manage cluster
cluster_status_manager.manage_cluster_status()
_sleep_remaining_loop_time(config.loop_time, start_time)
def retry(delay):
def decorator_retry(func):
@functools.wraps(func)
def wrapper_retry(*args, **kwargs):
while True:
try:
return func(*args, **kwargs)
except Exception:
time.sleep(delay)
return wrapper_retry
return decorator_retry
@retry(LOOP_TIME)
def main():
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - [%(module)s:%(funcName)s] - %(levelname)s - %(message)s"
)
log.info("Clusterstatusmgtd Startup")
try:
clusterstatusmgtd_config_file = os.environ.get(
"CONFIG_FILE", os.path.join(CONFIG_FILE_DIR, "clusterstatusmgtd.conf")
)
_run_clusterstatusmgtd(clusterstatusmgtd_config_file)
except Exception as e:
log.exception("An unexpected error occurred: %s.\nRestarting in %s seconds...", e, LOOP_TIME)
raise
if __name__ == "__main__":
main()