src/slurm_plugin/instance_manager.py (795 lines of code) (raw):
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
# with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
# A nosec comment is appended to the following line in order to disable the B404 check.
# In this file the input of the module subprocess is trusted.
import subprocess # nosec B404
from collections import defaultdict
from typing import Dict, Iterable, List
import boto3
from botocore.config import Config
from botocore.exceptions import ClientError
from common.ec2_utils import get_private_ip_address_and_dns_name
from common.schedulers.slurm_commands import get_nodes_info, update_nodes
from common.utils import grouper, setup_logging_filter
from slurm_plugin.common import ComputeInstanceDescriptor, ScalingStrategy, log_exception, print_with_count
from slurm_plugin.fleet_manager import EC2Instance, FleetManagerFactory
from slurm_plugin.slurm_resources import (
EC2_HEALTH_STATUS_UNHEALTHY_STATES,
EC2_INSTANCE_ALIVE_STATES,
EC2_SCHEDULED_EVENT_CODES,
EC2InstanceHealthState,
InvalidNodenameError,
SlurmNode,
SlurmResumeData,
SlurmResumeJob,
parse_nodename,
)
logger = logging.getLogger(__name__)
# PageSize parameter used for Boto3 paginated calls
# Corresponds to MaxResults in describe_instances and describe_instance_status API
BOTO3_PAGINATION_PAGE_SIZE = 1000
BOTO3_MAX_BATCH_SIZE = 50
class HostnameTableStoreError(Exception):
"""Raised when error occurs while writing into hostname table."""
class HostnameDnsStoreError(Exception):
"""Raised when error occurs while writing into hostname DNS."""
class InstanceToNodeAssignmentError(Exception):
"""Raised when error occurs while assigning EC2 instance to Slurm node."""
class NodeAddrUpdateError(Exception):
"""Raised when error occurs while updating NodeAddrs in Slurm node."""
class InstanceManager:
"""
InstanceManager class.
Class implementing instance management actions.
Used when launching instance, terminating instance, and retrieving instance info for slurm integration.
"""
def __init__(
self,
region: str,
cluster_name: str,
boto3_config: Config,
table_name: str = None,
hosted_zone: str = None,
dns_domain: str = None,
use_private_hostname: bool = False,
head_node_private_ip: str = None,
head_node_hostname: str = None,
fleet_config: Dict[str, any] = None,
run_instances_overrides: dict = None,
create_fleet_overrides: dict = None,
job_level_scaling: bool = False,
):
"""Initialize InstanceLauncher with required attributes."""
self._region = region
self._cluster_name = cluster_name
self._boto3_config = boto3_config
self.failed_nodes = {}
self._ddb_resource = boto3.resource("dynamodb", region_name=region, config=boto3_config)
self._table = self._ddb_resource.Table(table_name) if table_name else None
self._hosted_zone = hosted_zone
self._dns_domain = dns_domain
self._use_private_hostname = use_private_hostname
self._head_node_private_ip = head_node_private_ip
self._head_node_hostname = head_node_hostname
self._fleet_config = fleet_config
self._run_instances_overrides = run_instances_overrides or {}
self._create_fleet_overrides = create_fleet_overrides or {}
self._boto3_resource_factory = lambda resource_name: boto3.session.Session().resource(
resource_name, region_name=region, config=boto3_config
)
self.nodes_assigned_to_instances = {}
self.unused_launched_instances = {}
self.job_level_scaling = job_level_scaling
def _clear_failed_nodes(self):
"""Clear and reset failed nodes list."""
self.failed_nodes = {}
@log_exception(
logger, "saving assigned hostnames in DynamoDB", raise_on_error=True, exception_to_raise=HostnameTableStoreError
)
def _store_assigned_hostnames(self, nodes):
logger.info("Saving assigned hostnames in DynamoDB")
if not self._table:
raise HostnameTableStoreError("Empty table name configuration parameter.")
if nodes:
with self._table.batch_writer() as batch_writer:
for nodename, instance in nodes.items():
# Note: These items will never be removed, but the put_item method
# will replace old items if the hostnames are already associated with an old instance_id.
batch_writer.put_item(
Item={
"Id": nodename,
"InstanceId": instance.id,
"HeadNodePrivateIp": self._head_node_private_ip,
"HeadNodeHostname": self._head_node_hostname,
}
)
@log_exception(logger, "updating DNS records", raise_on_error=True, exception_to_raise=HostnameDnsStoreError)
def _update_dns_hostnames(self, nodes, update_dns_batch_size=500):
logger.info(
"Updating DNS records for %s - %s",
self._hosted_zone,
self._dns_domain,
)
if not self._hosted_zone or not self._dns_domain:
logger.info(
"Empty DNS domain name or hosted zone configuration parameter",
)
return
changes = []
for hostname, instance in nodes.items():
changes.append(
{
"Action": "UPSERT",
"ResourceRecordSet": {
"Name": f"{hostname}.{self._dns_domain}",
"ResourceRecords": [{"Value": instance.private_ip}],
"Type": "A",
"TTL": 120,
},
}
)
if changes:
# Submit calls to change_resource_record_sets in batches of max 500 elements each.
# change_resource_record_sets API call has limit of 1000 changes,
# but the UPSERT action counts for 2 calls
# Also pick the number of retries to be the max between the globally configured one and 4.
# This is done to address Route53 API throttling without changing the configured retries for all API calls.
configured_retry = self._boto3_config.retries.get("max_attempts", 0) if self._boto3_config.retries else 0
boto3_config = self._boto3_config.merge(
Config(retries={"max_attempts": max([configured_retry, 4]), "mode": "standard"})
)
route53_client = boto3.client("route53", region_name=self._region, config=boto3_config)
changes_batch_size = min(update_dns_batch_size, 500)
for changes_batch in grouper(changes, changes_batch_size):
route53_client.change_resource_record_sets(
HostedZoneId=self._hosted_zone, ChangeBatch={"Changes": list(changes_batch)}
)
def _parse_nodes_resume_list(self, node_list: List[str]) -> defaultdict[str, defaultdict[str, List[str]]]:
"""
Parse out which launch configurations (queue/compute resource) are requested by slurm nodes from NodeName.
Valid NodeName format: {queue_name}-{st/dy}-{compute_resource_name}-{number}
Sample NodeName: queue1-st-computeres1-2
"""
nodes_to_launch = defaultdict(lambda: defaultdict(list))
logger.debug("Nodes already assigned to running instances: %s", self.nodes_assigned_to_instances)
for node in node_list:
try:
queue_name, node_type, compute_resource_name = parse_nodename(node)
if node in self.nodes_assigned_to_instances.get(queue_name, {}).get(compute_resource_name, []):
# skip node for which there is already an instance assigned (oversubscribe case)
logger.info("Discarding NodeName already assigned to running instance: %s", node)
else:
nodes_to_launch[queue_name][compute_resource_name].append(node)
except (InvalidNodenameError, KeyError):
logger.warning("Discarding NodeName with invalid format: %s", node)
self._update_failed_nodes({node}, "InvalidNodenameError")
logger.debug("Launch configuration requested by nodes = %s", nodes_to_launch)
return nodes_to_launch
def delete_instances(self, instance_ids_to_terminate, terminate_batch_size):
"""Terminate corresponding EC2 instances."""
ec2_client = boto3.client("ec2", region_name=self._region, config=self._boto3_config)
logger.info("Terminating instances %s", print_with_count(instance_ids_to_terminate))
for instances in grouper(instance_ids_to_terminate, terminate_batch_size):
try:
# Boto3 clients retries on connection errors only
ec2_client.terminate_instances(InstanceIds=list(instances))
except ClientError as e:
logger.error(
"Failed TerminateInstances request: %s", e.response.get("ResponseMetadata").get("RequestId")
)
logger.error("Failed when terminating instances %s with error %s", print_with_count(instances), e)
@log_exception(logger, "getting health status for unhealthy EC2 instances", raise_on_error=True)
def get_unhealthy_cluster_instance_status(self, cluster_instance_ids):
"""
Get health status for unhealthy EC2 instances.
Retrieve instance status with 3 separate paginated calls filtering on different health check attributes
Rather than doing call with instance ids
Reason being number of unhealthy instances is in general lower than number of instances in cluster
In addition, while specifying instance ids, the max result returned by 1 API call is 100
As opposed to 1000 when not specifying instance ids and using filters
"""
instance_health_states = {}
health_check_filters = {
"instance_status": {
"Filters": [{"Name": "instance-status.status", "Values": list(EC2_HEALTH_STATUS_UNHEALTHY_STATES)}]
},
"system_status": {
"Filters": [{"Name": "system-status.status", "Values": list(EC2_HEALTH_STATUS_UNHEALTHY_STATES)}]
},
"scheduled_events": {"Filters": [{"Name": "event.code", "Values": EC2_SCHEDULED_EVENT_CODES}]},
}
for health_check_type in health_check_filters:
ec2_client = boto3.client("ec2", region_name=self._region, config=self._boto3_config)
paginator = ec2_client.get_paginator("describe_instance_status")
response_iterator = paginator.paginate(
PaginationConfig={"PageSize": BOTO3_PAGINATION_PAGE_SIZE}, **health_check_filters[health_check_type]
)
filtered_iterator = response_iterator.search("InstanceStatuses")
for instance_status in filtered_iterator:
instance_id = instance_status.get("InstanceId")
if instance_id in cluster_instance_ids and instance_id not in instance_health_states:
instance_health_states[instance_id] = EC2InstanceHealthState(
instance_id,
instance_status.get("InstanceState").get("Name"),
instance_status.get("InstanceStatus"),
instance_status.get("SystemStatus"),
instance_status.get("Events"),
)
return list(instance_health_states.values())
@log_exception(logger, "getting cluster instances from EC2", raise_on_error=True)
def get_cluster_instances(self, include_head_node=False, alive_states_only=True):
"""
Get instances that are associated with the cluster.
Instances without all the info set are ignored and not returned
"""
ec2_client = boto3.client("ec2", region_name=self._region, config=self._boto3_config)
paginator = ec2_client.get_paginator("describe_instances")
args = {
"Filters": [{"Name": "tag:parallelcluster:cluster-name", "Values": [self._cluster_name]}],
}
if alive_states_only:
args["Filters"].append({"Name": "instance-state-name", "Values": list(EC2_INSTANCE_ALIVE_STATES)})
if not include_head_node:
args["Filters"].append({"Name": "tag:parallelcluster:node-type", "Values": ["Compute"]})
response_iterator = paginator.paginate(PaginationConfig={"PageSize": BOTO3_PAGINATION_PAGE_SIZE}, **args)
filtered_iterator = response_iterator.search("Reservations[].Instances[]")
instances = []
for instance_info in filtered_iterator:
try:
private_ip, private_dns_name, all_private_ips = get_private_ip_address_and_dns_name(instance_info)
instances.append(
EC2Instance(
instance_info["InstanceId"],
private_ip,
private_dns_name.split(".")[0],
all_private_ips,
instance_info["LaunchTime"],
)
)
except Exception as e:
logger.warning(
"Ignoring instance %s because not all EC2 info are available, exception: %s, message: %s",
instance_info["InstanceId"],
type(e).__name__,
e,
)
return instances
def terminate_all_compute_nodes(self, terminate_batch_size):
try:
compute_nodes = self.get_cluster_instances()
self.delete_instances(
instance_ids_to_terminate=[instance.id for instance in compute_nodes],
terminate_batch_size=terminate_batch_size,
)
return True
except Exception as e:
logger.error("Failed when terminating compute fleet with error %s", e)
return False
def _update_failed_nodes(self, nodeset, error_code="Exception", override=True):
"""Update failed nodes dict with error code as key and nodeset value."""
if not override:
# Remove nodes already present in any failed_nodes key so to not override the error_code if already set
for nodes in self.failed_nodes.values():
if nodes:
nodeset = nodeset.difference(nodes)
if nodeset:
self.failed_nodes[error_code] = self.failed_nodes.get(error_code, set()).union(nodeset)
def get_compute_node_instances(
self, compute_nodes: Iterable[SlurmNode], max_retrieval_count: int
) -> Iterable[ComputeInstanceDescriptor]:
"""Return an iterable of dicts containing a node name and instance ID for each node in compute_nodes."""
return InstanceManager._get_instances_for_nodes(
compute_nodes=compute_nodes,
table_name=self._table.table_name,
resource_factory=self._boto3_resource_factory,
max_retrieval_count=max_retrieval_count if max_retrieval_count > 0 else None,
)
@staticmethod
def _get_instances_for_nodes(
compute_nodes, table_name, resource_factory, max_retrieval_count
) -> Iterable[ComputeInstanceDescriptor]:
# Partition compute_nodes into a list of nodes with an instance ID and a list of nodes without an instance ID.
nodes_with_instance_id = []
nodes_without_instance_id = []
for node in compute_nodes:
(nodes_with_instance_id if node.instance else nodes_without_instance_id).append(
{
"Name": node.name,
"InstanceId": node.instance.id if node.instance else None,
}
)
# Make sure that we don't return more than max_retrieval_count if set.
nodes_with_instance_id = (
nodes_with_instance_id[:max_retrieval_count] if max_retrieval_count else nodes_with_instance_id
)
# Determine the remaining number nodes we will need to retrieve from DynamoDB.
remaining = (
max(0, max_retrieval_count - len(nodes_with_instance_id))
if max_retrieval_count
else len(nodes_without_instance_id)
)
# Return instance ids that don't require a DDB lookup first.
yield from nodes_with_instance_id
# Lookup instance IDs in DynamoDB for nodes that we don't already have the instance ID for; but only
# if we haven't already returned max_retrieval_count instances.
if remaining:
yield from InstanceManager._retrieve_instance_ids_from_dynamo(
ddb_resource=resource_factory("dynamodb"),
table_name=table_name,
compute_nodes=nodes_without_instance_id,
max_retrieval_count=remaining,
)
@staticmethod
def _retrieve_instance_ids_from_dynamo(
ddb_resource, table_name, compute_nodes, max_retrieval_count
) -> Iterable[ComputeInstanceDescriptor]:
node_name_partitions = InstanceManager._partition_nodes(node.get("Name") for node in compute_nodes)
for node_name_partition in node_name_partitions:
node_name_partition = itertools.islice(node_name_partition, max_retrieval_count)
query = InstanceManager._create_request_for_nodes(table_name=table_name, node_names=node_name_partition)
response = ddb_resource.batch_get_item(RequestItems=query)
# Because we can't assume that len(partition) equals len(Responses.table_name), e.g. when a node name does
# not exist in the DynamoDB table, we only decrement the remaining number of nodes we need when we actually
# return an instance ID.
for item in response.get("Responses").get(table_name):
yield {
"Name": item.get("Id"),
"InstanceId": item.get("InstanceId"),
}
max_retrieval_count -= 1
if max_retrieval_count < 1:
break
@staticmethod
def _partition_nodes(node_names, size=BOTO3_MAX_BATCH_SIZE):
yield from grouper(node_names, size)
@staticmethod
def _create_request_for_nodes(table_name, node_names):
return {
str(table_name): {
"Keys": [
{
"Id": str(node_name),
}
for node_name in node_names
],
"ProjectionExpression": "Id, InstanceId",
}
}
def _clear_unused_launched_instances(self):
"""Clear and reset unused launched instances list."""
self.unused_launched_instances = {}
def _clear_nodes_assigned_to_instances(self):
"""Clear and reset nodes assigned to instances list."""
self.nodes_assigned_to_instances = {}
def _update_dict(self, target_dict: dict, update: dict) -> dict:
logger.debug("Updating target dict (%s) with update (%s)", target_dict, update)
for update_key, update_value in update.items():
if isinstance(update_value, dict):
target_dict[update_key] = self._update_dict(target_dict.get(update_key, {}), update_value)
elif isinstance(update_value, list):
target_dict[update_key] = target_dict.get(update_key, []) + update_value
elif isinstance(update_value, set):
target_dict[update_key] = target_dict.get(update_key, set()) | update_value
else:
target_dict[update_key] = update_value
logger.debug("Updated target dict is (%s)", target_dict)
return target_dict
def add_instances(
self,
node_list: List[str],
launch_batch_size: int,
update_node_address: bool = True,
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
slurm_resume: Dict[str, any] = None,
assign_node_batch_size: int = None,
terminate_batch_size: int = None,
):
"""Add EC2 instances to Slurm nodes."""
# Reset failed nodes pool
self._clear_failed_nodes()
# Reset unused instances pool
self._clear_unused_launched_instances()
# Reset nodes assigned to instances pool
self._clear_nodes_assigned_to_instances()
if self.job_level_scaling:
if slurm_resume:
logger.debug("Performing job level scaling using Slurm resume file")
self._add_instances_for_resume_file(
slurm_resume=slurm_resume,
node_list=node_list,
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
scaling_strategy=scaling_strategy,
)
else:
logger.error(
"Not possible to perform job level scaling because Slurm resume file content is empty. "
"No scaling actions will be taken."
)
else:
if node_list:
logger.debug("Performing node list scaling using Slurm node resume list")
self._add_instances_for_nodes(
node_list=node_list,
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
scaling_strategy=scaling_strategy,
)
else:
logger.error(
"Not possible to perform node list scaling because Slurm node resume list is empty. "
"No scaling actions will be taken."
)
self._terminate_unassigned_launched_instances(terminate_batch_size)
self._clear_nodes_assigned_to_instances()
def _scaling_for_jobs(
self,
job_list: List[SlurmResumeJob],
launch_batch_size: int,
assign_node_batch_size: int,
update_node_address: bool,
scaling_strategy: ScalingStrategy,
skip_launch: bool = False,
) -> None:
"""Scaling for job list."""
# Setup custom logging filter
with setup_logging_filter(logger, "JobID") as job_id_logging_filter:
for job in job_list:
job_id_logging_filter.set_custom_value(job.job_id)
logger.debug(f"Job info: {job}")
logger.info("The nodes_resume list from Slurm Resume File is %s", print_with_count(job.nodes_resume))
self._add_instances_for_nodes(
job=job,
node_list=job.nodes_resume,
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
scaling_strategy=scaling_strategy,
skip_launch=skip_launch,
)
def _terminate_unassigned_launched_instances(self, terminate_batch_size):
# If there are remaining unassigned instances, terminate them
unassigned_launched_instances = [
instance
for compute_resources in self.unused_launched_instances.values()
for instance_list in compute_resources.values()
for instance in instance_list
]
if unassigned_launched_instances:
logger.info("Terminating unassigned launched instances: %s", self.unused_launched_instances)
self.delete_instances(
[instance.id for instance in unassigned_launched_instances],
terminate_batch_size,
)
self._clear_unused_launched_instances()
def _scaling_for_jobs_single_node(
self,
job_list: List[SlurmResumeJob],
launch_batch_size: int,
assign_node_batch_size: int,
update_node_address: bool,
scaling_strategy: ScalingStrategy,
) -> None:
"""Scaling for job single node list."""
if job_list:
if len(job_list) == 1:
# call _scaling_for_jobs so that JobID is logged
self._scaling_for_jobs(
job_list=job_list,
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
scaling_strategy=scaling_strategy,
)
else:
# Batch all single node jobs in a single best-effort EC2 launch request
# This to reduce scaling time and save launch API calls
# Remove duplicated node entries (possible in oversubscribe case)
single_nodes = list(dict.fromkeys([job.nodes_resume[0] for job in job_list]))
self._add_instances_for_nodes(
node_list=single_nodes,
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
scaling_strategy=ScalingStrategy.BEST_EFFORT,
)
def _add_instances_for_resume_file(
self,
slurm_resume: Dict[str, any],
node_list: List[str],
launch_batch_size: int,
assign_node_batch_size: int,
update_node_address: bool = True,
scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT,
):
"""Launch requested EC2 instances for resume file."""
slurm_resume_data = self._get_slurm_resume_data(slurm_resume=slurm_resume, node_list=node_list)
self._clear_unused_launched_instances()
self._scaling_for_jobs_single_node(
job_list=slurm_resume_data.jobs_single_node,
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
scaling_strategy=scaling_strategy,
)
self._scaling_for_jobs_multi_node(
job_list=slurm_resume_data.jobs_multi_node,
node_list=slurm_resume_data.multi_node,
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
scaling_strategy=scaling_strategy,
)
def _scaling_for_jobs_multi_node(
self,
job_list,
node_list,
launch_batch_size,
assign_node_batch_size,
update_node_address,
scaling_strategy: ScalingStrategy,
):
if not (scaling_strategy in [ScalingStrategy.ALL_OR_NOTHING] and len(job_list) <= 1):
# Optimize job level scaling with preliminary scale-all nodes attempt
# Except for the case all-or-nothing / single job, to avoid scaling twice the same node list
self._update_dict(
self.unused_launched_instances,
self._launch_instances(
nodes_to_launch=self._parse_nodes_resume_list(node_list),
launch_batch_size=launch_batch_size,
scaling_strategy=scaling_strategy,
),
)
# Avoid a job level launch if scaling strategy is BEST_EFFORT or GREEDY_ALL_OR_NOTHING
# The scale all-in launch has been performed already hence from this point we want to skip the extra
# job level launch of instances for jobs that are unable to get the needed capacity from the all-in scaling
skip_launch = scaling_strategy in [ScalingStrategy.BEST_EFFORT, ScalingStrategy.GREEDY_ALL_OR_NOTHING]
self._scaling_for_jobs(
job_list=job_list,
launch_batch_size=launch_batch_size,
assign_node_batch_size=assign_node_batch_size,
update_node_address=update_node_address,
scaling_strategy=scaling_strategy,
skip_launch=skip_launch,
)
def _get_slurm_resume_data(self, slurm_resume: Dict[str, any], node_list: List[str]) -> SlurmResumeData:
"""
Get SlurmResumeData object.
SlurmResumeData object contains the following:
* the node list for jobs allocated to single node
* the node list for jobs allocated to multiple nodes
* the job list with single node allocation
* the job list with multi node allocation
Example of Slurm Resume File (ref. https://slurm.schedmd.com/elastic_computing.html):
{
"all_nodes_resume": "cloud[1-3,7-8]",
"jobs": [
{
"extra": "An arbitrary string from --extra",
"features": "c1,c2",
"job_id": 140814,
"nodes_alloc": "queue1-st-c5xlarge-[4-5]",
"nodes_resume": "queue1-st-c5xlarge-[1,3]",
"oversubscribe": "OK",
"partition": "cloud",
"reservation": "resv_1234",
},
{
"extra": None,
"features": "c1,c2",
"job_id": 140815,
"nodes_alloc": "queue2-st-c5xlarge-[1-2]",
"nodes_resume": "queue2-st-c5xlarge-[1-2]",
"oversubscribe": "OK",
"partition": "cloud",
"reservation": None,
},
{
"extra": None,
"features": None,
"job_id": 140816,
"nodes_alloc": "queue2-st-c5xlarge-[7,8]",
"nodes_resume": "queue2-st-c5xlarge-[7,8]",
"oversubscribe": "NO",
"partition": "cloud_exclusive",
"reservation": None,
},
],
}
"""
jobs_single_node = []
jobs_multi_node = []
single_node = []
multi_node = []
slurm_resume_jobs = self._parse_slurm_resume(slurm_resume)
for job in slurm_resume_jobs:
if len(job.nodes_resume) == 1:
jobs_single_node.append(job)
single_node.extend(job.nodes_resume)
else:
jobs_multi_node.append(job)
multi_node.extend(job.nodes_resume)
nodes_difference = list(set(node_list) - (set(single_node) | set(multi_node)))
if nodes_difference:
logger.warning(
"Discarding NodeNames because of mismatch in Slurm Resume File Vs Nodes passed to Resume Program: %s",
", ".join(nodes_difference),
)
self._update_failed_nodes(set(nodes_difference), "InvalidNodenameError")
return SlurmResumeData(
single_node=list(dict.fromkeys(single_node)),
multi_node=list(dict.fromkeys(multi_node)),
jobs_single_node=jobs_single_node,
jobs_multi_node=jobs_multi_node,
)
def _parse_slurm_resume(self, slurm_resume: Dict[str, any]) -> List[SlurmResumeJob]:
slurm_resume_jobs = []
for job in slurm_resume.get("jobs", {}):
try:
slurm_resume_jobs.append(SlurmResumeJob(**job))
except InvalidNodenameError:
nodes_resume = job.get("nodes_resume", "")
nodes_alloc = job.get("nodes_alloc", "")
job_id = job.get("job_id", "")
logger.warning(
"Discarding NodeNames with invalid format for Job Id (%s): nodes_alloc (%s), nodes_resume (%s)",
job_id,
nodes_alloc,
nodes_resume,
)
self._update_failed_nodes(
# if NodeNames in nodes_resume cannot be parsed, try to get info directly from Slurm
set([node.name for node in get_nodes_info(nodes_resume)]),
"InvalidNodenameError",
)
return slurm_resume_jobs
def _add_instances_for_nodes(
self,
launch_batch_size: int,
assign_node_batch_size: int,
update_node_address: bool = True,
scaling_strategy: ScalingStrategy = ScalingStrategy.ALL_OR_NOTHING,
node_list: List[str] = None,
job: SlurmResumeJob = None,
skip_launch: bool = False,
):
"""Launch requested EC2 instances for nodes."""
nodes_resume_mapping = self._parse_nodes_resume_list(node_list=node_list)
# nodes in the resume list, mapped for queues and compute resources, e.g.
# {
# queue_1: {cr_1: [nodes_1, nodes_2, nodes_3], cr_2: [nodes_4]},
# queue_2: {cr_3: [nodes_5]}
# }
nodes_resume_list = []
for compute_resources in nodes_resume_mapping.values():
for slurm_node_list in compute_resources.values():
nodes_resume_list.extend(slurm_node_list)
# nodes in the resume flattened list, e.g.
# [nodes_1, nodes_2, nodes_3, nodes_4, nodes_5]
instances_launched = self._launch_instances(
job=job if job else None,
nodes_to_launch=nodes_resume_mapping,
launch_batch_size=launch_batch_size,
scaling_strategy=scaling_strategy,
skip_launch=skip_launch,
)
# instances launched, e.g.
# {
# queue_1: {cr_1: list[EC2Instance], cr_2: list[EC2Instance],
# queue_2: {cr_3: list[EC2Instance]}
# }
successful_launched_nodes = []
failed_launch_nodes = []
for queue, compute_resources in nodes_resume_mapping.items():
for compute_resource, slurm_node_list in compute_resources.items():
q_cr_instances_launched_length = len(instances_launched.get(queue, {}).get(compute_resource, []))
successful_launched_nodes += slurm_node_list[:q_cr_instances_launched_length]
failed_launch_nodes += slurm_node_list[q_cr_instances_launched_length:]
if scaling_strategy in [ScalingStrategy.ALL_OR_NOTHING, ScalingStrategy.GREEDY_ALL_OR_NOTHING]:
logger.info("Assigning nodes with all-or-nothing strategy")
self._all_or_nothing_node_assignment(
assign_node_batch_size=assign_node_batch_size,
instances_launched=instances_launched,
nodes_resume_list=nodes_resume_list,
nodes_resume_mapping=nodes_resume_mapping,
successful_launched_nodes=successful_launched_nodes,
update_node_address=update_node_address,
)
else:
logger.info("Assigning nodes with best-effort strategy")
self._best_effort_node_assignment(
assign_node_batch_size=assign_node_batch_size,
failed_launch_nodes=failed_launch_nodes,
instances_launched=instances_launched,
nodes_resume_list=nodes_resume_list,
nodes_resume_mapping=nodes_resume_mapping,
successful_launched_nodes=successful_launched_nodes,
update_node_address=update_node_address,
)
def _reset_failed_nodes(self, nodeset):
"""Remove nodeset from failed nodes dict."""
if nodeset:
for error_code in self.failed_nodes:
self.failed_nodes[error_code] = self.failed_nodes.get(error_code, set()).difference(nodeset)
def _best_effort_node_assignment(
self,
assign_node_batch_size,
failed_launch_nodes,
instances_launched,
nodes_resume_list,
nodes_resume_mapping,
successful_launched_nodes,
update_node_address,
):
# best-effort job level scaling
if 0 < len(successful_launched_nodes) <= len(nodes_resume_list):
# All or partial requested EC2 capacity for the Job has been launched
# Assign launched EC2 instances to the requested Slurm nodes
self._assign_instances_to_nodes(
update_node_address=update_node_address,
nodes_to_launch=nodes_resume_mapping,
instances_launched=instances_launched,
assign_node_batch_size=assign_node_batch_size,
raise_on_error=False,
)
logger.info(
"Successful launched and assigned %s instances for nodes %s",
"all" if len(successful_launched_nodes) == len(nodes_resume_list) else "partial",
print_with_count(successful_launched_nodes),
)
nodes_assigned_mapping = defaultdict(lambda: defaultdict(list))
for queue, compute_resources in nodes_resume_mapping.items():
for compute_resource, slurm_node_list in compute_resources.items():
launched_ec2_instances = instances_launched.get(queue, {}).get(compute_resource, [])
# fmt: off
nodes_assigned_mapping[queue][compute_resource] = slurm_node_list[:len(launched_ec2_instances)]
# fmt: on
self._update_dict(self.nodes_assigned_to_instances, nodes_assigned_mapping)
self._reset_failed_nodes(set(successful_launched_nodes))
if len(successful_launched_nodes) < len(nodes_resume_list):
# set limited capacity on the failed to launch nodes
self._update_failed_nodes(set(failed_launch_nodes), "LimitedInstanceCapacity", override=False)
else:
# No instances launched at all, e.g. CreateFleet API returns no EC2 instances,
# or no left instances available from a best-effort EC2 launch
logger.info("No launched instances found for nodes %s", print_with_count(nodes_resume_list))
self._update_failed_nodes(set(nodes_resume_list), "InsufficientInstanceCapacity", override=False)
def _all_or_nothing_node_assignment(
self,
assign_node_batch_size,
instances_launched,
nodes_resume_list,
nodes_resume_mapping,
successful_launched_nodes,
update_node_address,
):
# all-or-nothing job level scaling
if len(successful_launched_nodes) == len(nodes_resume_list):
# All requested EC2 capacity for the Job has been launched
# Assign launched EC2 instances to the requested Slurm nodes
try:
self._assign_instances_to_nodes(
update_node_address=update_node_address,
nodes_to_launch=nodes_resume_mapping,
instances_launched=instances_launched,
assign_node_batch_size=assign_node_batch_size,
raise_on_error=True,
)
logger.info(
"Successful launched and assigned all instances for nodes %s",
print_with_count(nodes_resume_list),
)
self._update_dict(self.nodes_assigned_to_instances, nodes_resume_mapping)
self._reset_failed_nodes(set(nodes_resume_list))
except InstanceToNodeAssignmentError:
# Failed to assign EC2 instances to nodes
# EC2 Instances already assigned, are going to be terminated by
# setting the nodes into DOWN.
# EC2 instances not yet assigned, are going to fail during bootstrap,
# because no entry in the DynamoDB table would be found
self._update_failed_nodes(set(nodes_resume_list))
elif 0 < len(successful_launched_nodes) < len(nodes_resume_list):
# Try to reuse partial capacity of already launched EC2 instances
logger.info(
"Releasing launched and booked instances %s",
print_with_count(
[
(queue, compute_resource, instance)
for queue, compute_resources in instances_launched.items()
for compute_resource, instances in compute_resources.items()
for instance in instances
]
),
)
self._update_dict(self.unused_launched_instances, instances_launched)
self._update_failed_nodes(set(nodes_resume_list), "LimitedInstanceCapacity", override=False)
else:
# No instances launched at all, e.g. CreateFleet API returns no EC2 instances,
# or no left instances available from a best-effort EC2 launch
logger.info("No launched instances found for nodes %s", print_with_count(nodes_resume_list))
self._update_failed_nodes(set(nodes_resume_list), "InsufficientInstanceCapacity", override=False)
def _launch_instances( # noqa: C901
self,
nodes_to_launch: Dict[str, any],
launch_batch_size: int,
scaling_strategy: ScalingStrategy,
job: SlurmResumeJob = None,
skip_launch: bool = False,
):
instances_launched = defaultdict(lambda: defaultdict(list))
for queue, compute_resources in nodes_to_launch.items():
for compute_resource, slurm_node_list in compute_resources.items():
slurm_node_list = self._resize_slurm_node_list(
queue=queue,
compute_resource=compute_resource,
instances_launched=instances_launched,
slurm_node_list=slurm_node_list,
)
if slurm_node_list and not skip_launch:
logger.info(
"Launching %s instances for nodes %s",
"all-or-nothing" if scaling_strategy == ScalingStrategy.ALL_OR_NOTHING else "best-effort",
print_with_count(slurm_node_list),
)
# At instance launch level, the various scaling strategies can be grouped based on the actual
# launch behaviour i.e. all-or-nothing or best-effort
all_or_nothing_batch = scaling_strategy in [ScalingStrategy.ALL_OR_NOTHING]
fleet_manager = self._get_fleet_manager(all_or_nothing_batch, compute_resource, queue)
for batch_nodes in grouper(slurm_node_list, launch_batch_size):
try:
launched_ec2_instances = self._launch_ec2_instances(
batch_nodes, compute_resource, fleet_manager, instances_launched, job, queue
)
if job and all_or_nothing_batch and len(launched_ec2_instances) < len(batch_nodes):
# When launching instances for a specific Job,
# exit fast if not all the requested capacity can be launched,
# returning the EC2 instances launched so far,
# so that they can be eventually allocated to other Slurm nodes
# This path handle the CreateFleet case, which doesn't fail when no capacity is returned
return instances_launched
except (ClientError, Exception) as e:
logger.error(
"Encountered exception when launching instances for nodes %s: %s",
print_with_count(batch_nodes),
e,
)
update_failed_nodes_parameters = {"nodeset": set(batch_nodes)}
if isinstance(e, ClientError):
update_failed_nodes_parameters["error_code"] = e.response.get("Error", {}).get("Code")
elif isinstance(e, Exception) and hasattr(e, "code"):
update_failed_nodes_parameters["error_code"] = e.code
self._update_failed_nodes(**update_failed_nodes_parameters)
if job and all_or_nothing_batch:
# When launching instances for a specific Job,
# exit fast if not all the requested capacity can be launched,
# returning the EC2 instances launched so far,
# so that they can be eventually allocated to other Slurm nodes
# This path handle the RunInstances case, which throw an exc when
# no capacity is returned, and handle the CreateFleet case when exc is thrown
return instances_launched
return instances_launched
def _launch_ec2_instances(self, batch_nodes, compute_resource, fleet_manager, instances_launched, job, queue):
launched_ec2_instances = fleet_manager.launch_ec2_instances(
len(batch_nodes), job_id=job.job_id if job else None
)
# launched_ec2_instances e.g. list[EC2Instance]
if len(launched_ec2_instances) > 0:
instances_launched[queue][compute_resource].extend(launched_ec2_instances)
# instances_launched e.g.
# {
# queue_1: {cr_1: list[EC2Instance], cr_2: list[EC2Instance],
# queue_2: {cr_3: list[EC2Instance]}
# }
else:
self._update_failed_nodes(set(batch_nodes), "InsufficientInstanceCapacity")
return launched_ec2_instances
def _get_fleet_manager(self, all_or_nothing_batch, compute_resource, queue):
# Set the number of retries to be the max between the globally configured one and 3.
# This is done to try to avoid launch instances API throttling
# without changing the configured retries for all API calls.
configured_retry = self._boto3_config.retries.get("max_attempts", 0) if self._boto3_config.retries else 0
boto3_config = self._boto3_config.merge(
Config(retries={"max_attempts": max([configured_retry, 3]), "mode": "standard"})
)
# Each compute resource can be configured to use create_fleet or run_instances
fleet_manager = FleetManagerFactory.get_manager(
cluster_name=self._cluster_name,
region=self._region,
boto3_config=boto3_config,
fleet_config=self._fleet_config,
queue=queue,
compute_resource=compute_resource,
all_or_nothing=all_or_nothing_batch,
run_instances_overrides=self._run_instances_overrides,
create_fleet_overrides=self._create_fleet_overrides,
)
return fleet_manager
def _resize_slurm_node_list(
self, queue: str, compute_resource: str, slurm_node_list: List[str], instances_launched: Dict[str, any]
):
reusable_instances = self.unused_launched_instances.get(queue, {}).get(compute_resource, [])
if len(reusable_instances) > 0:
# Reuse already launched capacity
# fmt: off
logger.info(
"Booking already launched instances for nodes %s",
print_with_count(slurm_node_list[:len(reusable_instances)]),
)
instances_launched[queue][compute_resource].extend(reusable_instances[:len(slurm_node_list)])
# Remove reusable instances from unused instances
self.unused_launched_instances[queue][compute_resource] = reusable_instances[len(slurm_node_list):]
# Reduce slurm_node_list
slurm_node_list = slurm_node_list[len(reusable_instances):]
# fmt: on
return slurm_node_list
def _assign_instances_to_nodes(
self,
update_node_address: bool,
nodes_to_launch: Dict[str, any],
instances_launched: Dict[str, any],
assign_node_batch_size: int,
raise_on_error: bool,
):
if update_node_address:
for queue, compute_resources in nodes_to_launch.items():
for compute_resource, slurm_node_list in compute_resources.items():
launched_ec2_instances = instances_launched.get(queue, {}).get(compute_resource, [])
for batch in grouper(list(zip(slurm_node_list, launched_ec2_instances)), assign_node_batch_size):
batch_nodes = []
try:
batch_nodes, batch_launched_ec2_instances = zip(*batch)
assigned_nodes = dict(batch)
self._store_assigned_hostnames(nodes=assigned_nodes)
self._update_dns_hostnames(
nodes=assigned_nodes, update_dns_batch_size=assign_node_batch_size
)
self._update_slurm_node_addrs(
slurm_nodes=list(batch_nodes), launched_instances=batch_launched_ec2_instances
)
except (NodeAddrUpdateError, HostnameTableStoreError, HostnameDnsStoreError):
if raise_on_error:
raise InstanceToNodeAssignmentError
# Update the batch of failed node and continue
self._update_failed_nodes(set(batch_nodes))
def _update_slurm_node_addrs(self, slurm_nodes: List[str], launched_instances: List[EC2Instance]):
"""Update node information in slurm with info from launched EC2 instance."""
try:
# When using a cluster DNS domain we don't need to pass nodehostnames
# because they are equal to node names.
# It is possible to force the use of private hostnames by setting
# use_private_hostname = "true" as extra json parameter
node_hostnames = (
None if not self._use_private_hostname else [instance.hostname for instance in launched_instances]
)
update_nodes(
slurm_nodes,
nodeaddrs=[instance.private_ip for instance in launched_instances],
nodehostnames=node_hostnames,
)
logger.info(
"Nodes are now configured with instances %s",
print_with_count(zip(slurm_nodes, launched_instances)),
)
except subprocess.CalledProcessError:
logger.error(
"Encountered error when updating nodes %s with instances %s",
print_with_count(slurm_nodes),
print_with_count(launched_instances),
)
raise NodeAddrUpdateError