src/slurm_plugin/instance_manager.py (795 lines of code) (raw):

# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance # with the License. A copy of the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES # OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions and # limitations under the License. import itertools import logging # A nosec comment is appended to the following line in order to disable the B404 check. # In this file the input of the module subprocess is trusted. import subprocess # nosec B404 from collections import defaultdict from typing import Dict, Iterable, List import boto3 from botocore.config import Config from botocore.exceptions import ClientError from common.ec2_utils import get_private_ip_address_and_dns_name from common.schedulers.slurm_commands import get_nodes_info, update_nodes from common.utils import grouper, setup_logging_filter from slurm_plugin.common import ComputeInstanceDescriptor, ScalingStrategy, log_exception, print_with_count from slurm_plugin.fleet_manager import EC2Instance, FleetManagerFactory from slurm_plugin.slurm_resources import ( EC2_HEALTH_STATUS_UNHEALTHY_STATES, EC2_INSTANCE_ALIVE_STATES, EC2_SCHEDULED_EVENT_CODES, EC2InstanceHealthState, InvalidNodenameError, SlurmNode, SlurmResumeData, SlurmResumeJob, parse_nodename, ) logger = logging.getLogger(__name__) # PageSize parameter used for Boto3 paginated calls # Corresponds to MaxResults in describe_instances and describe_instance_status API BOTO3_PAGINATION_PAGE_SIZE = 1000 BOTO3_MAX_BATCH_SIZE = 50 class HostnameTableStoreError(Exception): """Raised when error occurs while writing into hostname table.""" class HostnameDnsStoreError(Exception): """Raised when error occurs while writing into hostname DNS.""" class InstanceToNodeAssignmentError(Exception): """Raised when error occurs while assigning EC2 instance to Slurm node.""" class NodeAddrUpdateError(Exception): """Raised when error occurs while updating NodeAddrs in Slurm node.""" class InstanceManager: """ InstanceManager class. Class implementing instance management actions. Used when launching instance, terminating instance, and retrieving instance info for slurm integration. """ def __init__( self, region: str, cluster_name: str, boto3_config: Config, table_name: str = None, hosted_zone: str = None, dns_domain: str = None, use_private_hostname: bool = False, head_node_private_ip: str = None, head_node_hostname: str = None, fleet_config: Dict[str, any] = None, run_instances_overrides: dict = None, create_fleet_overrides: dict = None, job_level_scaling: bool = False, ): """Initialize InstanceLauncher with required attributes.""" self._region = region self._cluster_name = cluster_name self._boto3_config = boto3_config self.failed_nodes = {} self._ddb_resource = boto3.resource("dynamodb", region_name=region, config=boto3_config) self._table = self._ddb_resource.Table(table_name) if table_name else None self._hosted_zone = hosted_zone self._dns_domain = dns_domain self._use_private_hostname = use_private_hostname self._head_node_private_ip = head_node_private_ip self._head_node_hostname = head_node_hostname self._fleet_config = fleet_config self._run_instances_overrides = run_instances_overrides or {} self._create_fleet_overrides = create_fleet_overrides or {} self._boto3_resource_factory = lambda resource_name: boto3.session.Session().resource( resource_name, region_name=region, config=boto3_config ) self.nodes_assigned_to_instances = {} self.unused_launched_instances = {} self.job_level_scaling = job_level_scaling def _clear_failed_nodes(self): """Clear and reset failed nodes list.""" self.failed_nodes = {} @log_exception( logger, "saving assigned hostnames in DynamoDB", raise_on_error=True, exception_to_raise=HostnameTableStoreError ) def _store_assigned_hostnames(self, nodes): logger.info("Saving assigned hostnames in DynamoDB") if not self._table: raise HostnameTableStoreError("Empty table name configuration parameter.") if nodes: with self._table.batch_writer() as batch_writer: for nodename, instance in nodes.items(): # Note: These items will never be removed, but the put_item method # will replace old items if the hostnames are already associated with an old instance_id. batch_writer.put_item( Item={ "Id": nodename, "InstanceId": instance.id, "HeadNodePrivateIp": self._head_node_private_ip, "HeadNodeHostname": self._head_node_hostname, } ) @log_exception(logger, "updating DNS records", raise_on_error=True, exception_to_raise=HostnameDnsStoreError) def _update_dns_hostnames(self, nodes, update_dns_batch_size=500): logger.info( "Updating DNS records for %s - %s", self._hosted_zone, self._dns_domain, ) if not self._hosted_zone or not self._dns_domain: logger.info( "Empty DNS domain name or hosted zone configuration parameter", ) return changes = [] for hostname, instance in nodes.items(): changes.append( { "Action": "UPSERT", "ResourceRecordSet": { "Name": f"{hostname}.{self._dns_domain}", "ResourceRecords": [{"Value": instance.private_ip}], "Type": "A", "TTL": 120, }, } ) if changes: # Submit calls to change_resource_record_sets in batches of max 500 elements each. # change_resource_record_sets API call has limit of 1000 changes, # but the UPSERT action counts for 2 calls # Also pick the number of retries to be the max between the globally configured one and 4. # This is done to address Route53 API throttling without changing the configured retries for all API calls. configured_retry = self._boto3_config.retries.get("max_attempts", 0) if self._boto3_config.retries else 0 boto3_config = self._boto3_config.merge( Config(retries={"max_attempts": max([configured_retry, 4]), "mode": "standard"}) ) route53_client = boto3.client("route53", region_name=self._region, config=boto3_config) changes_batch_size = min(update_dns_batch_size, 500) for changes_batch in grouper(changes, changes_batch_size): route53_client.change_resource_record_sets( HostedZoneId=self._hosted_zone, ChangeBatch={"Changes": list(changes_batch)} ) def _parse_nodes_resume_list(self, node_list: List[str]) -> defaultdict[str, defaultdict[str, List[str]]]: """ Parse out which launch configurations (queue/compute resource) are requested by slurm nodes from NodeName. Valid NodeName format: {queue_name}-{st/dy}-{compute_resource_name}-{number} Sample NodeName: queue1-st-computeres1-2 """ nodes_to_launch = defaultdict(lambda: defaultdict(list)) logger.debug("Nodes already assigned to running instances: %s", self.nodes_assigned_to_instances) for node in node_list: try: queue_name, node_type, compute_resource_name = parse_nodename(node) if node in self.nodes_assigned_to_instances.get(queue_name, {}).get(compute_resource_name, []): # skip node for which there is already an instance assigned (oversubscribe case) logger.info("Discarding NodeName already assigned to running instance: %s", node) else: nodes_to_launch[queue_name][compute_resource_name].append(node) except (InvalidNodenameError, KeyError): logger.warning("Discarding NodeName with invalid format: %s", node) self._update_failed_nodes({node}, "InvalidNodenameError") logger.debug("Launch configuration requested by nodes = %s", nodes_to_launch) return nodes_to_launch def delete_instances(self, instance_ids_to_terminate, terminate_batch_size): """Terminate corresponding EC2 instances.""" ec2_client = boto3.client("ec2", region_name=self._region, config=self._boto3_config) logger.info("Terminating instances %s", print_with_count(instance_ids_to_terminate)) for instances in grouper(instance_ids_to_terminate, terminate_batch_size): try: # Boto3 clients retries on connection errors only ec2_client.terminate_instances(InstanceIds=list(instances)) except ClientError as e: logger.error( "Failed TerminateInstances request: %s", e.response.get("ResponseMetadata").get("RequestId") ) logger.error("Failed when terminating instances %s with error %s", print_with_count(instances), e) @log_exception(logger, "getting health status for unhealthy EC2 instances", raise_on_error=True) def get_unhealthy_cluster_instance_status(self, cluster_instance_ids): """ Get health status for unhealthy EC2 instances. Retrieve instance status with 3 separate paginated calls filtering on different health check attributes Rather than doing call with instance ids Reason being number of unhealthy instances is in general lower than number of instances in cluster In addition, while specifying instance ids, the max result returned by 1 API call is 100 As opposed to 1000 when not specifying instance ids and using filters """ instance_health_states = {} health_check_filters = { "instance_status": { "Filters": [{"Name": "instance-status.status", "Values": list(EC2_HEALTH_STATUS_UNHEALTHY_STATES)}] }, "system_status": { "Filters": [{"Name": "system-status.status", "Values": list(EC2_HEALTH_STATUS_UNHEALTHY_STATES)}] }, "scheduled_events": {"Filters": [{"Name": "event.code", "Values": EC2_SCHEDULED_EVENT_CODES}]}, } for health_check_type in health_check_filters: ec2_client = boto3.client("ec2", region_name=self._region, config=self._boto3_config) paginator = ec2_client.get_paginator("describe_instance_status") response_iterator = paginator.paginate( PaginationConfig={"PageSize": BOTO3_PAGINATION_PAGE_SIZE}, **health_check_filters[health_check_type] ) filtered_iterator = response_iterator.search("InstanceStatuses") for instance_status in filtered_iterator: instance_id = instance_status.get("InstanceId") if instance_id in cluster_instance_ids and instance_id not in instance_health_states: instance_health_states[instance_id] = EC2InstanceHealthState( instance_id, instance_status.get("InstanceState").get("Name"), instance_status.get("InstanceStatus"), instance_status.get("SystemStatus"), instance_status.get("Events"), ) return list(instance_health_states.values()) @log_exception(logger, "getting cluster instances from EC2", raise_on_error=True) def get_cluster_instances(self, include_head_node=False, alive_states_only=True): """ Get instances that are associated with the cluster. Instances without all the info set are ignored and not returned """ ec2_client = boto3.client("ec2", region_name=self._region, config=self._boto3_config) paginator = ec2_client.get_paginator("describe_instances") args = { "Filters": [{"Name": "tag:parallelcluster:cluster-name", "Values": [self._cluster_name]}], } if alive_states_only: args["Filters"].append({"Name": "instance-state-name", "Values": list(EC2_INSTANCE_ALIVE_STATES)}) if not include_head_node: args["Filters"].append({"Name": "tag:parallelcluster:node-type", "Values": ["Compute"]}) response_iterator = paginator.paginate(PaginationConfig={"PageSize": BOTO3_PAGINATION_PAGE_SIZE}, **args) filtered_iterator = response_iterator.search("Reservations[].Instances[]") instances = [] for instance_info in filtered_iterator: try: private_ip, private_dns_name, all_private_ips = get_private_ip_address_and_dns_name(instance_info) instances.append( EC2Instance( instance_info["InstanceId"], private_ip, private_dns_name.split(".")[0], all_private_ips, instance_info["LaunchTime"], ) ) except Exception as e: logger.warning( "Ignoring instance %s because not all EC2 info are available, exception: %s, message: %s", instance_info["InstanceId"], type(e).__name__, e, ) return instances def terminate_all_compute_nodes(self, terminate_batch_size): try: compute_nodes = self.get_cluster_instances() self.delete_instances( instance_ids_to_terminate=[instance.id for instance in compute_nodes], terminate_batch_size=terminate_batch_size, ) return True except Exception as e: logger.error("Failed when terminating compute fleet with error %s", e) return False def _update_failed_nodes(self, nodeset, error_code="Exception", override=True): """Update failed nodes dict with error code as key and nodeset value.""" if not override: # Remove nodes already present in any failed_nodes key so to not override the error_code if already set for nodes in self.failed_nodes.values(): if nodes: nodeset = nodeset.difference(nodes) if nodeset: self.failed_nodes[error_code] = self.failed_nodes.get(error_code, set()).union(nodeset) def get_compute_node_instances( self, compute_nodes: Iterable[SlurmNode], max_retrieval_count: int ) -> Iterable[ComputeInstanceDescriptor]: """Return an iterable of dicts containing a node name and instance ID for each node in compute_nodes.""" return InstanceManager._get_instances_for_nodes( compute_nodes=compute_nodes, table_name=self._table.table_name, resource_factory=self._boto3_resource_factory, max_retrieval_count=max_retrieval_count if max_retrieval_count > 0 else None, ) @staticmethod def _get_instances_for_nodes( compute_nodes, table_name, resource_factory, max_retrieval_count ) -> Iterable[ComputeInstanceDescriptor]: # Partition compute_nodes into a list of nodes with an instance ID and a list of nodes without an instance ID. nodes_with_instance_id = [] nodes_without_instance_id = [] for node in compute_nodes: (nodes_with_instance_id if node.instance else nodes_without_instance_id).append( { "Name": node.name, "InstanceId": node.instance.id if node.instance else None, } ) # Make sure that we don't return more than max_retrieval_count if set. nodes_with_instance_id = ( nodes_with_instance_id[:max_retrieval_count] if max_retrieval_count else nodes_with_instance_id ) # Determine the remaining number nodes we will need to retrieve from DynamoDB. remaining = ( max(0, max_retrieval_count - len(nodes_with_instance_id)) if max_retrieval_count else len(nodes_without_instance_id) ) # Return instance ids that don't require a DDB lookup first. yield from nodes_with_instance_id # Lookup instance IDs in DynamoDB for nodes that we don't already have the instance ID for; but only # if we haven't already returned max_retrieval_count instances. if remaining: yield from InstanceManager._retrieve_instance_ids_from_dynamo( ddb_resource=resource_factory("dynamodb"), table_name=table_name, compute_nodes=nodes_without_instance_id, max_retrieval_count=remaining, ) @staticmethod def _retrieve_instance_ids_from_dynamo( ddb_resource, table_name, compute_nodes, max_retrieval_count ) -> Iterable[ComputeInstanceDescriptor]: node_name_partitions = InstanceManager._partition_nodes(node.get("Name") for node in compute_nodes) for node_name_partition in node_name_partitions: node_name_partition = itertools.islice(node_name_partition, max_retrieval_count) query = InstanceManager._create_request_for_nodes(table_name=table_name, node_names=node_name_partition) response = ddb_resource.batch_get_item(RequestItems=query) # Because we can't assume that len(partition) equals len(Responses.table_name), e.g. when a node name does # not exist in the DynamoDB table, we only decrement the remaining number of nodes we need when we actually # return an instance ID. for item in response.get("Responses").get(table_name): yield { "Name": item.get("Id"), "InstanceId": item.get("InstanceId"), } max_retrieval_count -= 1 if max_retrieval_count < 1: break @staticmethod def _partition_nodes(node_names, size=BOTO3_MAX_BATCH_SIZE): yield from grouper(node_names, size) @staticmethod def _create_request_for_nodes(table_name, node_names): return { str(table_name): { "Keys": [ { "Id": str(node_name), } for node_name in node_names ], "ProjectionExpression": "Id, InstanceId", } } def _clear_unused_launched_instances(self): """Clear and reset unused launched instances list.""" self.unused_launched_instances = {} def _clear_nodes_assigned_to_instances(self): """Clear and reset nodes assigned to instances list.""" self.nodes_assigned_to_instances = {} def _update_dict(self, target_dict: dict, update: dict) -> dict: logger.debug("Updating target dict (%s) with update (%s)", target_dict, update) for update_key, update_value in update.items(): if isinstance(update_value, dict): target_dict[update_key] = self._update_dict(target_dict.get(update_key, {}), update_value) elif isinstance(update_value, list): target_dict[update_key] = target_dict.get(update_key, []) + update_value elif isinstance(update_value, set): target_dict[update_key] = target_dict.get(update_key, set()) | update_value else: target_dict[update_key] = update_value logger.debug("Updated target dict is (%s)", target_dict) return target_dict def add_instances( self, node_list: List[str], launch_batch_size: int, update_node_address: bool = True, scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT, slurm_resume: Dict[str, any] = None, assign_node_batch_size: int = None, terminate_batch_size: int = None, ): """Add EC2 instances to Slurm nodes.""" # Reset failed nodes pool self._clear_failed_nodes() # Reset unused instances pool self._clear_unused_launched_instances() # Reset nodes assigned to instances pool self._clear_nodes_assigned_to_instances() if self.job_level_scaling: if slurm_resume: logger.debug("Performing job level scaling using Slurm resume file") self._add_instances_for_resume_file( slurm_resume=slurm_resume, node_list=node_list, launch_batch_size=launch_batch_size, assign_node_batch_size=assign_node_batch_size, update_node_address=update_node_address, scaling_strategy=scaling_strategy, ) else: logger.error( "Not possible to perform job level scaling because Slurm resume file content is empty. " "No scaling actions will be taken." ) else: if node_list: logger.debug("Performing node list scaling using Slurm node resume list") self._add_instances_for_nodes( node_list=node_list, launch_batch_size=launch_batch_size, assign_node_batch_size=assign_node_batch_size, update_node_address=update_node_address, scaling_strategy=scaling_strategy, ) else: logger.error( "Not possible to perform node list scaling because Slurm node resume list is empty. " "No scaling actions will be taken." ) self._terminate_unassigned_launched_instances(terminate_batch_size) self._clear_nodes_assigned_to_instances() def _scaling_for_jobs( self, job_list: List[SlurmResumeJob], launch_batch_size: int, assign_node_batch_size: int, update_node_address: bool, scaling_strategy: ScalingStrategy, skip_launch: bool = False, ) -> None: """Scaling for job list.""" # Setup custom logging filter with setup_logging_filter(logger, "JobID") as job_id_logging_filter: for job in job_list: job_id_logging_filter.set_custom_value(job.job_id) logger.debug(f"Job info: {job}") logger.info("The nodes_resume list from Slurm Resume File is %s", print_with_count(job.nodes_resume)) self._add_instances_for_nodes( job=job, node_list=job.nodes_resume, launch_batch_size=launch_batch_size, assign_node_batch_size=assign_node_batch_size, update_node_address=update_node_address, scaling_strategy=scaling_strategy, skip_launch=skip_launch, ) def _terminate_unassigned_launched_instances(self, terminate_batch_size): # If there are remaining unassigned instances, terminate them unassigned_launched_instances = [ instance for compute_resources in self.unused_launched_instances.values() for instance_list in compute_resources.values() for instance in instance_list ] if unassigned_launched_instances: logger.info("Terminating unassigned launched instances: %s", self.unused_launched_instances) self.delete_instances( [instance.id for instance in unassigned_launched_instances], terminate_batch_size, ) self._clear_unused_launched_instances() def _scaling_for_jobs_single_node( self, job_list: List[SlurmResumeJob], launch_batch_size: int, assign_node_batch_size: int, update_node_address: bool, scaling_strategy: ScalingStrategy, ) -> None: """Scaling for job single node list.""" if job_list: if len(job_list) == 1: # call _scaling_for_jobs so that JobID is logged self._scaling_for_jobs( job_list=job_list, launch_batch_size=launch_batch_size, assign_node_batch_size=assign_node_batch_size, update_node_address=update_node_address, scaling_strategy=scaling_strategy, ) else: # Batch all single node jobs in a single best-effort EC2 launch request # This to reduce scaling time and save launch API calls # Remove duplicated node entries (possible in oversubscribe case) single_nodes = list(dict.fromkeys([job.nodes_resume[0] for job in job_list])) self._add_instances_for_nodes( node_list=single_nodes, launch_batch_size=launch_batch_size, assign_node_batch_size=assign_node_batch_size, update_node_address=update_node_address, scaling_strategy=ScalingStrategy.BEST_EFFORT, ) def _add_instances_for_resume_file( self, slurm_resume: Dict[str, any], node_list: List[str], launch_batch_size: int, assign_node_batch_size: int, update_node_address: bool = True, scaling_strategy: ScalingStrategy = ScalingStrategy.BEST_EFFORT, ): """Launch requested EC2 instances for resume file.""" slurm_resume_data = self._get_slurm_resume_data(slurm_resume=slurm_resume, node_list=node_list) self._clear_unused_launched_instances() self._scaling_for_jobs_single_node( job_list=slurm_resume_data.jobs_single_node, launch_batch_size=launch_batch_size, assign_node_batch_size=assign_node_batch_size, update_node_address=update_node_address, scaling_strategy=scaling_strategy, ) self._scaling_for_jobs_multi_node( job_list=slurm_resume_data.jobs_multi_node, node_list=slurm_resume_data.multi_node, launch_batch_size=launch_batch_size, assign_node_batch_size=assign_node_batch_size, update_node_address=update_node_address, scaling_strategy=scaling_strategy, ) def _scaling_for_jobs_multi_node( self, job_list, node_list, launch_batch_size, assign_node_batch_size, update_node_address, scaling_strategy: ScalingStrategy, ): if not (scaling_strategy in [ScalingStrategy.ALL_OR_NOTHING] and len(job_list) <= 1): # Optimize job level scaling with preliminary scale-all nodes attempt # Except for the case all-or-nothing / single job, to avoid scaling twice the same node list self._update_dict( self.unused_launched_instances, self._launch_instances( nodes_to_launch=self._parse_nodes_resume_list(node_list), launch_batch_size=launch_batch_size, scaling_strategy=scaling_strategy, ), ) # Avoid a job level launch if scaling strategy is BEST_EFFORT or GREEDY_ALL_OR_NOTHING # The scale all-in launch has been performed already hence from this point we want to skip the extra # job level launch of instances for jobs that are unable to get the needed capacity from the all-in scaling skip_launch = scaling_strategy in [ScalingStrategy.BEST_EFFORT, ScalingStrategy.GREEDY_ALL_OR_NOTHING] self._scaling_for_jobs( job_list=job_list, launch_batch_size=launch_batch_size, assign_node_batch_size=assign_node_batch_size, update_node_address=update_node_address, scaling_strategy=scaling_strategy, skip_launch=skip_launch, ) def _get_slurm_resume_data(self, slurm_resume: Dict[str, any], node_list: List[str]) -> SlurmResumeData: """ Get SlurmResumeData object. SlurmResumeData object contains the following: * the node list for jobs allocated to single node * the node list for jobs allocated to multiple nodes * the job list with single node allocation * the job list with multi node allocation Example of Slurm Resume File (ref. https://slurm.schedmd.com/elastic_computing.html): { "all_nodes_resume": "cloud[1-3,7-8]", "jobs": [ { "extra": "An arbitrary string from --extra", "features": "c1,c2", "job_id": 140814, "nodes_alloc": "queue1-st-c5xlarge-[4-5]", "nodes_resume": "queue1-st-c5xlarge-[1,3]", "oversubscribe": "OK", "partition": "cloud", "reservation": "resv_1234", }, { "extra": None, "features": "c1,c2", "job_id": 140815, "nodes_alloc": "queue2-st-c5xlarge-[1-2]", "nodes_resume": "queue2-st-c5xlarge-[1-2]", "oversubscribe": "OK", "partition": "cloud", "reservation": None, }, { "extra": None, "features": None, "job_id": 140816, "nodes_alloc": "queue2-st-c5xlarge-[7,8]", "nodes_resume": "queue2-st-c5xlarge-[7,8]", "oversubscribe": "NO", "partition": "cloud_exclusive", "reservation": None, }, ], } """ jobs_single_node = [] jobs_multi_node = [] single_node = [] multi_node = [] slurm_resume_jobs = self._parse_slurm_resume(slurm_resume) for job in slurm_resume_jobs: if len(job.nodes_resume) == 1: jobs_single_node.append(job) single_node.extend(job.nodes_resume) else: jobs_multi_node.append(job) multi_node.extend(job.nodes_resume) nodes_difference = list(set(node_list) - (set(single_node) | set(multi_node))) if nodes_difference: logger.warning( "Discarding NodeNames because of mismatch in Slurm Resume File Vs Nodes passed to Resume Program: %s", ", ".join(nodes_difference), ) self._update_failed_nodes(set(nodes_difference), "InvalidNodenameError") return SlurmResumeData( single_node=list(dict.fromkeys(single_node)), multi_node=list(dict.fromkeys(multi_node)), jobs_single_node=jobs_single_node, jobs_multi_node=jobs_multi_node, ) def _parse_slurm_resume(self, slurm_resume: Dict[str, any]) -> List[SlurmResumeJob]: slurm_resume_jobs = [] for job in slurm_resume.get("jobs", {}): try: slurm_resume_jobs.append(SlurmResumeJob(**job)) except InvalidNodenameError: nodes_resume = job.get("nodes_resume", "") nodes_alloc = job.get("nodes_alloc", "") job_id = job.get("job_id", "") logger.warning( "Discarding NodeNames with invalid format for Job Id (%s): nodes_alloc (%s), nodes_resume (%s)", job_id, nodes_alloc, nodes_resume, ) self._update_failed_nodes( # if NodeNames in nodes_resume cannot be parsed, try to get info directly from Slurm set([node.name for node in get_nodes_info(nodes_resume)]), "InvalidNodenameError", ) return slurm_resume_jobs def _add_instances_for_nodes( self, launch_batch_size: int, assign_node_batch_size: int, update_node_address: bool = True, scaling_strategy: ScalingStrategy = ScalingStrategy.ALL_OR_NOTHING, node_list: List[str] = None, job: SlurmResumeJob = None, skip_launch: bool = False, ): """Launch requested EC2 instances for nodes.""" nodes_resume_mapping = self._parse_nodes_resume_list(node_list=node_list) # nodes in the resume list, mapped for queues and compute resources, e.g. # { # queue_1: {cr_1: [nodes_1, nodes_2, nodes_3], cr_2: [nodes_4]}, # queue_2: {cr_3: [nodes_5]} # } nodes_resume_list = [] for compute_resources in nodes_resume_mapping.values(): for slurm_node_list in compute_resources.values(): nodes_resume_list.extend(slurm_node_list) # nodes in the resume flattened list, e.g. # [nodes_1, nodes_2, nodes_3, nodes_4, nodes_5] instances_launched = self._launch_instances( job=job if job else None, nodes_to_launch=nodes_resume_mapping, launch_batch_size=launch_batch_size, scaling_strategy=scaling_strategy, skip_launch=skip_launch, ) # instances launched, e.g. # { # queue_1: {cr_1: list[EC2Instance], cr_2: list[EC2Instance], # queue_2: {cr_3: list[EC2Instance]} # } successful_launched_nodes = [] failed_launch_nodes = [] for queue, compute_resources in nodes_resume_mapping.items(): for compute_resource, slurm_node_list in compute_resources.items(): q_cr_instances_launched_length = len(instances_launched.get(queue, {}).get(compute_resource, [])) successful_launched_nodes += slurm_node_list[:q_cr_instances_launched_length] failed_launch_nodes += slurm_node_list[q_cr_instances_launched_length:] if scaling_strategy in [ScalingStrategy.ALL_OR_NOTHING, ScalingStrategy.GREEDY_ALL_OR_NOTHING]: logger.info("Assigning nodes with all-or-nothing strategy") self._all_or_nothing_node_assignment( assign_node_batch_size=assign_node_batch_size, instances_launched=instances_launched, nodes_resume_list=nodes_resume_list, nodes_resume_mapping=nodes_resume_mapping, successful_launched_nodes=successful_launched_nodes, update_node_address=update_node_address, ) else: logger.info("Assigning nodes with best-effort strategy") self._best_effort_node_assignment( assign_node_batch_size=assign_node_batch_size, failed_launch_nodes=failed_launch_nodes, instances_launched=instances_launched, nodes_resume_list=nodes_resume_list, nodes_resume_mapping=nodes_resume_mapping, successful_launched_nodes=successful_launched_nodes, update_node_address=update_node_address, ) def _reset_failed_nodes(self, nodeset): """Remove nodeset from failed nodes dict.""" if nodeset: for error_code in self.failed_nodes: self.failed_nodes[error_code] = self.failed_nodes.get(error_code, set()).difference(nodeset) def _best_effort_node_assignment( self, assign_node_batch_size, failed_launch_nodes, instances_launched, nodes_resume_list, nodes_resume_mapping, successful_launched_nodes, update_node_address, ): # best-effort job level scaling if 0 < len(successful_launched_nodes) <= len(nodes_resume_list): # All or partial requested EC2 capacity for the Job has been launched # Assign launched EC2 instances to the requested Slurm nodes self._assign_instances_to_nodes( update_node_address=update_node_address, nodes_to_launch=nodes_resume_mapping, instances_launched=instances_launched, assign_node_batch_size=assign_node_batch_size, raise_on_error=False, ) logger.info( "Successful launched and assigned %s instances for nodes %s", "all" if len(successful_launched_nodes) == len(nodes_resume_list) else "partial", print_with_count(successful_launched_nodes), ) nodes_assigned_mapping = defaultdict(lambda: defaultdict(list)) for queue, compute_resources in nodes_resume_mapping.items(): for compute_resource, slurm_node_list in compute_resources.items(): launched_ec2_instances = instances_launched.get(queue, {}).get(compute_resource, []) # fmt: off nodes_assigned_mapping[queue][compute_resource] = slurm_node_list[:len(launched_ec2_instances)] # fmt: on self._update_dict(self.nodes_assigned_to_instances, nodes_assigned_mapping) self._reset_failed_nodes(set(successful_launched_nodes)) if len(successful_launched_nodes) < len(nodes_resume_list): # set limited capacity on the failed to launch nodes self._update_failed_nodes(set(failed_launch_nodes), "LimitedInstanceCapacity", override=False) else: # No instances launched at all, e.g. CreateFleet API returns no EC2 instances, # or no left instances available from a best-effort EC2 launch logger.info("No launched instances found for nodes %s", print_with_count(nodes_resume_list)) self._update_failed_nodes(set(nodes_resume_list), "InsufficientInstanceCapacity", override=False) def _all_or_nothing_node_assignment( self, assign_node_batch_size, instances_launched, nodes_resume_list, nodes_resume_mapping, successful_launched_nodes, update_node_address, ): # all-or-nothing job level scaling if len(successful_launched_nodes) == len(nodes_resume_list): # All requested EC2 capacity for the Job has been launched # Assign launched EC2 instances to the requested Slurm nodes try: self._assign_instances_to_nodes( update_node_address=update_node_address, nodes_to_launch=nodes_resume_mapping, instances_launched=instances_launched, assign_node_batch_size=assign_node_batch_size, raise_on_error=True, ) logger.info( "Successful launched and assigned all instances for nodes %s", print_with_count(nodes_resume_list), ) self._update_dict(self.nodes_assigned_to_instances, nodes_resume_mapping) self._reset_failed_nodes(set(nodes_resume_list)) except InstanceToNodeAssignmentError: # Failed to assign EC2 instances to nodes # EC2 Instances already assigned, are going to be terminated by # setting the nodes into DOWN. # EC2 instances not yet assigned, are going to fail during bootstrap, # because no entry in the DynamoDB table would be found self._update_failed_nodes(set(nodes_resume_list)) elif 0 < len(successful_launched_nodes) < len(nodes_resume_list): # Try to reuse partial capacity of already launched EC2 instances logger.info( "Releasing launched and booked instances %s", print_with_count( [ (queue, compute_resource, instance) for queue, compute_resources in instances_launched.items() for compute_resource, instances in compute_resources.items() for instance in instances ] ), ) self._update_dict(self.unused_launched_instances, instances_launched) self._update_failed_nodes(set(nodes_resume_list), "LimitedInstanceCapacity", override=False) else: # No instances launched at all, e.g. CreateFleet API returns no EC2 instances, # or no left instances available from a best-effort EC2 launch logger.info("No launched instances found for nodes %s", print_with_count(nodes_resume_list)) self._update_failed_nodes(set(nodes_resume_list), "InsufficientInstanceCapacity", override=False) def _launch_instances( # noqa: C901 self, nodes_to_launch: Dict[str, any], launch_batch_size: int, scaling_strategy: ScalingStrategy, job: SlurmResumeJob = None, skip_launch: bool = False, ): instances_launched = defaultdict(lambda: defaultdict(list)) for queue, compute_resources in nodes_to_launch.items(): for compute_resource, slurm_node_list in compute_resources.items(): slurm_node_list = self._resize_slurm_node_list( queue=queue, compute_resource=compute_resource, instances_launched=instances_launched, slurm_node_list=slurm_node_list, ) if slurm_node_list and not skip_launch: logger.info( "Launching %s instances for nodes %s", "all-or-nothing" if scaling_strategy == ScalingStrategy.ALL_OR_NOTHING else "best-effort", print_with_count(slurm_node_list), ) # At instance launch level, the various scaling strategies can be grouped based on the actual # launch behaviour i.e. all-or-nothing or best-effort all_or_nothing_batch = scaling_strategy in [ScalingStrategy.ALL_OR_NOTHING] fleet_manager = self._get_fleet_manager(all_or_nothing_batch, compute_resource, queue) for batch_nodes in grouper(slurm_node_list, launch_batch_size): try: launched_ec2_instances = self._launch_ec2_instances( batch_nodes, compute_resource, fleet_manager, instances_launched, job, queue ) if job and all_or_nothing_batch and len(launched_ec2_instances) < len(batch_nodes): # When launching instances for a specific Job, # exit fast if not all the requested capacity can be launched, # returning the EC2 instances launched so far, # so that they can be eventually allocated to other Slurm nodes # This path handle the CreateFleet case, which doesn't fail when no capacity is returned return instances_launched except (ClientError, Exception) as e: logger.error( "Encountered exception when launching instances for nodes %s: %s", print_with_count(batch_nodes), e, ) update_failed_nodes_parameters = {"nodeset": set(batch_nodes)} if isinstance(e, ClientError): update_failed_nodes_parameters["error_code"] = e.response.get("Error", {}).get("Code") elif isinstance(e, Exception) and hasattr(e, "code"): update_failed_nodes_parameters["error_code"] = e.code self._update_failed_nodes(**update_failed_nodes_parameters) if job and all_or_nothing_batch: # When launching instances for a specific Job, # exit fast if not all the requested capacity can be launched, # returning the EC2 instances launched so far, # so that they can be eventually allocated to other Slurm nodes # This path handle the RunInstances case, which throw an exc when # no capacity is returned, and handle the CreateFleet case when exc is thrown return instances_launched return instances_launched def _launch_ec2_instances(self, batch_nodes, compute_resource, fleet_manager, instances_launched, job, queue): launched_ec2_instances = fleet_manager.launch_ec2_instances( len(batch_nodes), job_id=job.job_id if job else None ) # launched_ec2_instances e.g. list[EC2Instance] if len(launched_ec2_instances) > 0: instances_launched[queue][compute_resource].extend(launched_ec2_instances) # instances_launched e.g. # { # queue_1: {cr_1: list[EC2Instance], cr_2: list[EC2Instance], # queue_2: {cr_3: list[EC2Instance]} # } else: self._update_failed_nodes(set(batch_nodes), "InsufficientInstanceCapacity") return launched_ec2_instances def _get_fleet_manager(self, all_or_nothing_batch, compute_resource, queue): # Set the number of retries to be the max between the globally configured one and 3. # This is done to try to avoid launch instances API throttling # without changing the configured retries for all API calls. configured_retry = self._boto3_config.retries.get("max_attempts", 0) if self._boto3_config.retries else 0 boto3_config = self._boto3_config.merge( Config(retries={"max_attempts": max([configured_retry, 3]), "mode": "standard"}) ) # Each compute resource can be configured to use create_fleet or run_instances fleet_manager = FleetManagerFactory.get_manager( cluster_name=self._cluster_name, region=self._region, boto3_config=boto3_config, fleet_config=self._fleet_config, queue=queue, compute_resource=compute_resource, all_or_nothing=all_or_nothing_batch, run_instances_overrides=self._run_instances_overrides, create_fleet_overrides=self._create_fleet_overrides, ) return fleet_manager def _resize_slurm_node_list( self, queue: str, compute_resource: str, slurm_node_list: List[str], instances_launched: Dict[str, any] ): reusable_instances = self.unused_launched_instances.get(queue, {}).get(compute_resource, []) if len(reusable_instances) > 0: # Reuse already launched capacity # fmt: off logger.info( "Booking already launched instances for nodes %s", print_with_count(slurm_node_list[:len(reusable_instances)]), ) instances_launched[queue][compute_resource].extend(reusable_instances[:len(slurm_node_list)]) # Remove reusable instances from unused instances self.unused_launched_instances[queue][compute_resource] = reusable_instances[len(slurm_node_list):] # Reduce slurm_node_list slurm_node_list = slurm_node_list[len(reusable_instances):] # fmt: on return slurm_node_list def _assign_instances_to_nodes( self, update_node_address: bool, nodes_to_launch: Dict[str, any], instances_launched: Dict[str, any], assign_node_batch_size: int, raise_on_error: bool, ): if update_node_address: for queue, compute_resources in nodes_to_launch.items(): for compute_resource, slurm_node_list in compute_resources.items(): launched_ec2_instances = instances_launched.get(queue, {}).get(compute_resource, []) for batch in grouper(list(zip(slurm_node_list, launched_ec2_instances)), assign_node_batch_size): batch_nodes = [] try: batch_nodes, batch_launched_ec2_instances = zip(*batch) assigned_nodes = dict(batch) self._store_assigned_hostnames(nodes=assigned_nodes) self._update_dns_hostnames( nodes=assigned_nodes, update_dns_batch_size=assign_node_batch_size ) self._update_slurm_node_addrs( slurm_nodes=list(batch_nodes), launched_instances=batch_launched_ec2_instances ) except (NodeAddrUpdateError, HostnameTableStoreError, HostnameDnsStoreError): if raise_on_error: raise InstanceToNodeAssignmentError # Update the batch of failed node and continue self._update_failed_nodes(set(batch_nodes)) def _update_slurm_node_addrs(self, slurm_nodes: List[str], launched_instances: List[EC2Instance]): """Update node information in slurm with info from launched EC2 instance.""" try: # When using a cluster DNS domain we don't need to pass nodehostnames # because they are equal to node names. # It is possible to force the use of private hostnames by setting # use_private_hostname = "true" as extra json parameter node_hostnames = ( None if not self._use_private_hostname else [instance.hostname for instance in launched_instances] ) update_nodes( slurm_nodes, nodeaddrs=[instance.private_ip for instance in launched_instances], nodehostnames=node_hostnames, ) logger.info( "Nodes are now configured with instances %s", print_with_count(zip(slurm_nodes, launched_instances)), ) except subprocess.CalledProcessError: logger.error( "Encountered error when updating nodes %s with instances %s", print_with_count(slurm_nodes), print_with_count(launched_instances), ) raise NodeAddrUpdateError