azure-slurm/slurmcc/allocation.py (248 lines of code) (raw):
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
import logging
from typing import Callable, Dict, List, Set, Tuple
from hpc.autoscale import util as hpcutil
from hpc.autoscale import clock
from hpc.autoscale.ccbindings import ClusterBindingInterface
from hpc.autoscale.node.bucket import NodeBucket
from hpc.autoscale.node.node import Node
from hpc.autoscale.node.nodemanager import NodeManager
from hpc.autoscale.results import AllocationResult, BootupResult
from . import AzureSlurmError
from . import partition as partitionlib
from . import util as slutil
def resume(
config: Dict,
node_mgr: NodeManager,
node_list: List[str],
partitions: List[partitionlib.Partition],
) -> BootupResult:
name_to_partition = {}
for partition in partitions:
for name in partition.all_nodes():
name_to_partition[name] = partition
existing_nodes_by_name = hpcutil.partition(node_mgr.get_nodes(), lambda n: n.name)
nodes = []
unknown_node_names = []
for name in node_list:
if name not in name_to_partition:
unknown_node_names.append(name)
if unknown_node_names:
raise AzureSlurmError("Unknown node name(s): %s" % ",".join(unknown_node_names))
for name in node_list:
if name in existing_nodes_by_name:
node = existing_nodes_by_name[name][0]
if node.state != "Deallocated":
logging.info(f"{name} already exists.")
continue
if name not in name_to_partition:
raise AzureSlurmError(
f"Unknown node name: {name}: {list(name_to_partition.keys())}"
)
partition = name_to_partition[name]
bucket = partition.bucket_for_node(name)
def name_hook(bucket: NodeBucket, index: int) -> str:
if index != 1:
raise RuntimeError(f"Could not create node with name {name}. Perhaps the node already exists in a terminating state?")
return name
node_mgr.set_node_name_hook(name_hook)
constraints = {"node.bucket_id": bucket.bucket_id, "exclusive": True}
if partition.is_hpc:
constraints["node.colocated"] = True
result: AllocationResult = node_mgr.allocate(
constraints, node_count=1, allow_existing=False
)
if len(result.nodes) != 1:
raise RuntimeError()
result.nodes[0].name_format = name
nodes.extend(result.nodes)
boot_result = node_mgr.bootup(nodes)
return boot_result
def wait_for_nodes_to_terminate(
bindings: ClusterBindingInterface, node_list: List[str]
) -> None:
attempts = 1800 / 5
waiting_for_nodes = []
while attempts > 0:
attempts -= 1
waiting_for_nodes = []
cc_by_name = hpcutil.partition_single(bindings.get_nodes().nodes, lambda n: n["Name"])
for name in node_list:
if name in cc_by_name:
cc_node = cc_by_name[name]
target_state = cc_node.get("TargetState") or "undefined"
status = cc_node.get("Status") or "undefined"
is_booting = target_state == "Started"
reached_final_state = target_state == status
if is_booting or reached_final_state:
continue
waiting_for_nodes.append(name)
if not waiting_for_nodes:
logging.info("All nodes are available to be started.")
return
logging.info(f"Waiting for nodes to terminate: {waiting_for_nodes}")
clock.sleep(5)
raise AzureSlurmError(f"Timed out waiting for nodes to terminate: {waiting_for_nodes}")
class WaitForResume:
def __init__(self) -> None:
self.failed_node_names: Set[str] = set()
self.ip_already_set: Set[Tuple[str, str]] = set()
def check_nodes(
self, node_list: List[str], latest_nodes: List[Node]
) -> Tuple[Dict, List[Node]]:
ready_nodes = []
states = {}
by_name = hpcutil.partition_single(latest_nodes, lambda node: node.name)
relevant_nodes: List[Node] = []
recovered_node_names: Set[str] = set()
newly_failed_node_names: List[str] = []
deleted_nodes = []
for name in node_list:
node = by_name.get(name)
if not node:
deleted_nodes.append(node)
continue
is_dynamic = node.software_configuration.get("slurm", {}).get("dynamic_config")
relevant_nodes.append(node)
state = node.state
if state and state.lower() == "failed":
states["Failed"] = states.get("Failed", 0) + 1
if name not in self.failed_node_names:
newly_failed_node_names.append(name)
if not is_dynamic:
slutil.scontrol(["update", f"NodeName={name}", f"NodeAddr={name}", f"NodeHostName={name}"])
self.failed_node_names.add(name)
continue
use_nodename_as_hostname = node.software_configuration.get("slurm", {}).get(
"use_nodename_as_hostname", False
)
if not use_nodename_as_hostname:
ip_already_set_key = (name, node.private_ip)
if node.private_ip and ip_already_set_key not in self.ip_already_set:
slutil.scontrol(
[
"update",
"NodeName=%s" % name,
"NodeAddr=%s" % node.private_ip,
"NodeHostName=%s" % node.private_ip,
]
)
self.ip_already_set.add(ip_already_set_key)
if name in self.failed_node_names:
recovered_node_names.add(name)
if node.target_state != "Started":
states["UNKNOWN"] = states.get("UNKNOWN", {})
states["UNKNOWN"][node.state] = states["UNKNOWN"].get(state, 0) + 1
continue
if node.state == "Ready":
if not node.private_ip:
state = "WaitingOnIPAddress"
else:
ready_nodes.append(node)
states[state] = states.get(state, 0) + 1
if newly_failed_node_names:
failed_node_names_str = ",".join(self.failed_node_names)
try:
logging.error(
"The following nodes failed to start: %s", failed_node_names_str
)
for failed_name in self.failed_node_names:
slutil.scontrol(
[
"update",
"NodeName=%s" % failed_name,
"State=down",
"Reason=cyclecloud_node_failure",
]
)
except Exception:
logging.exception(
"Failed to mark the following nodes as down: %s. Will re-attempt next iteration.",
failed_node_names_str,
)
if recovered_node_names:
recovered_node_names_str = ",".join(recovered_node_names)
try:
for recovered_name in recovered_node_names:
logging.error(
"The following nodes have recovered from failure: %s",
recovered_node_names_str,
)
if not is_dynamic:
slutil.scontrol(
[
"update",
"NodeName=%s" % recovered_name,
"State=idle",
"Reason=cyclecloud_node_recovery",
]
)
if recovered_name in self.failed_node_names:
self.failed_node_names.remove(recovered_name)
except Exception:
logging.exception(
"Failed to mark the following nodes as recovered: %s. Will re-attempt next iteration.",
recovered_node_names_str,
)
return (states, ready_nodes)
def wait_for_resume(
config: Dict,
operation_id: str,
node_list: List[str],
get_latest_nodes: Callable[[], List[Node]],
waiter: "WaitForResume" = WaitForResume(),
) -> None:
previous_states = {}
nodes_str = ",".join(node_list[:5])
omega = clock.time() + 3600
ready_nodes: List[Node] = []
while clock.time() < omega:
states, ready_nodes = waiter.check_nodes(node_list, get_latest_nodes())
terminal_states = (
states.get("Ready", 0)
+ sum(states.get("UNKNOWN", {}).values())
+ states.get("Failed", 0)
)
if states != previous_states:
states_messages = []
for key in sorted(states.keys()):
if key != "UNKNOWN":
states_messages.append("{}={}".format(key, states[key]))
else:
for ukey in sorted(states["UNKNOWN"].keys()):
states_messages.append(
"{}={}".format(ukey, states["UNKNOWN"][ukey])
)
states_message = " , ".join(states_messages)
logging.info(
"OperationId=%s NodeList=%s: Number of nodes in each state: %s",
operation_id,
nodes_str,
states_message,
)
if terminal_states == len(node_list):
break
previous_states = states
clock.sleep(5)
logging.info(
"The following nodes reached Ready state: %s",
",".join([x.name for x in ready_nodes]),
)
for node in ready_nodes:
if not hpcutil.is_valid_hostname(config, node):
continue
is_dynamic = node.software_configuration.get("slurm", {}).get("dynamic_config")
if is_dynamic:
continue
slutil.scontrol(
[
"update",
"NodeName=%s" % node.name,
"NodeAddr=%s" % node.private_ip,
"NodeHostName=%s" % node.hostname,
]
)
logging.info(
"OperationId=%s NodeList=%s: all nodes updated with the proper IP address. Exiting",
operation_id,
nodes_str,
)