azure-slurm/slurmcc/cli.py (809 lines of code) (raw):
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
import argparse
import json
import logging
import os
import shutil
import sys
import time
import traceback
from argparse import ArgumentParser
from datetime import date, datetime, time, timedelta
from math import ceil
from subprocess import SubprocessError, check_output
from typing import Any, Callable, Dict, Iterable, List, Optional, TextIO, Union
from hpc.autoscale.cost.azurecost import azurecost
from hpc.autoscale.ccbindings import new_cluster_bindings
from hpc.autoscale.hpctypes import Memory
from hpc.autoscale import clock
from hpc.autoscale import util as hpcutil
from hpc.autoscale.cli import GenericDriver
from hpc.autoscale.clilib import CommonCLI, ShellDict, disablecommand
from hpc.autoscale.clilib import main as clilibmain
from hpc.autoscale.job.demandprinter import OutputFormat
from hpc.autoscale.job.driver import SchedulerDriver
from hpc.autoscale.node.node import Node
from hpc.autoscale.node.nodemanager import NodeManager
from hpc.autoscale.results import ShutdownResult
from slurmcc import allocation
from . import AzureSlurmError
from . import partition as partitionlib
from . import util as slutil
from .util import is_autoscale_enabled, scontrol
from . import cost
from . import topology
VERSION = "4.0.0"
def csv_list(x: str) -> List[str]:
# used in argument parsing
return [x.strip() for x in x.split(",")]
def init_power_saving_log(function: Callable) -> Callable:
def wrapped(*args: Any, **kwargs: Any) -> Any:
root_logger = logging.getLogger()
for handler in root_logger.handlers:
if hasattr(handler, "baseFilename"):
fname = getattr(handler, "baseFilename")
if fname and fname.endswith(f"{function.__name__}.log"):
handler.setLevel(logging.INFO)
logging.info(f"initialized {function.__name__}.log")
return function(*args, **kwargs)
wrapped.__doc__ = function.__doc__
return wrapped
class SlurmDriver(GenericDriver):
def __init__(self) -> None:
super().__init__("slurm")
def preprocess_node_mgr(self, config: Dict, node_mgr: NodeManager) -> None:
def default_dampened_memory(node: Node) -> Memory:
return min(node.memory - Memory.value_of("1g"), node.memory * 0.95)
node_mgr.add_default_resource(
selection={},
resource_name="slurm_memory",
default_value=default_dampened_memory,
)
for b in node_mgr.get_buckets():
if "nodearrays" not in config:
config["nodearrays"] = {}
if b.nodearray not in config["nodearrays"]:
config["nodearrays"][b.nodearray] = {}
if "generated_placement_group_buffer" in config["nodearrays"][b.nodearray]:
continue
is_hpc = (
str(
b.software_configuration.get("slurm", {}).get("hpc") or "false"
).lower()
== "true"
)
if is_hpc:
buffer = 1
max_pgs = 1
else:
buffer = 0
max_pgs = 0
config["nodearrays"][b.nodearray][
"generated_placement_group_buffer"
] = buffer
config["nodearrays"][b.nodearray][
"max_placement_groups"
] = max_pgs
super().preprocess_node_mgr(config, node_mgr)
class SlurmCLI(CommonCLI):
def __init__(self) -> None:
super().__init__(project_name="slurm")
self.slurm_node_names = []
@disablecommand
def create_nodes(self, *args: Any, **kwargs: Dict) -> None:
assert False
@disablecommand
def delete_nodes(
self,
config: Dict,
hostnames: List[str],
node_names: List[str],
do_delete: bool = True,
force: bool = False,
permanent: bool = False,
) -> None:
assert False
def _add_completion_data(self, completion_json: Dict) -> None:
node_names = slutil.check_output(["sinfo", "-N", "-h", "-o", "%N"]).splitlines(
keepends=False
)
node_lists = slutil.check_output(["sinfo", "-h", "-o", "%N"]).strip().split(",")
completion_json["slurm_node_names"] = node_names + node_lists
def _read_completion_data(self, completion_json: Dict) -> None:
self.slurm_node_names = completion_json.get("slurm_node_names", [])
def _slurm_node_name_completer(
self,
prefix: str,
action: argparse.Action,
parser: ArgumentParser,
parsed_args: argparse.Namespace,
) -> List[str]:
self._get_example_nodes(parsed_args.config)
output_prefix = ""
if prefix.endswith(","):
output_prefix = prefix
return [output_prefix + x + "," for x in self.slurm_node_names]
def cost_parser(self, parser: ArgumentParser) -> None:
parser.add_argument("-s", "--start", type=lambda s: datetime.strptime(s, '%Y-%m-%d'),
default=date.today().isoformat(),
help="Start time period (yyyy-mm-dd), defaults to current day.")
parser.add_argument("-e", "--end", type=lambda s: datetime.strptime(s, '%Y-%m-%d'),
default=date.today().isoformat(),
help="End time period (yyyy-mm-dd), defaults to current day.")
parser.add_argument("-o", "--out", required=True, help="Directory name for output CSV")
#parser.add_argument("-p", "--partition", action='store_true', help="Show costs aggregated by partitions")
parser.add_argument("-f", "--fmt", type=str,
help="Comma separated list of SLURM formatting options. Otherwise defaults are applied")
def cost(self, config: Dict, start, end, out, fmt=None):
"""
Cost analysis and reporting tool that maps Azure costs
to SLURM Job Accounting data. This is an experimental
feature.
"""
curr = datetime.today()
delta = timedelta(days=365)
if (curr - start) >= delta:
raise ValueError("Start date cannot be more than 1 year back from today")
if start > end:
raise ValueError("Start date cannot be after end date")
if end > curr:
raise ValueError("End date cannot be in the future")
if start == end:
end = datetime.combine(end.date(), time(hour=23,minute=59,second=59))
azcost = azurecost(config)
driver = cost.CostDriver(azcost, config)
driver.run(start, end, out, fmt)
def topology_parser(self, parser: ArgumentParser) -> None:
group = parser.add_mutually_exclusive_group(required=False)
parser.add_argument('-p,','--partition', type=str, help="Specify the parititon")
parser.add_argument('-o', '--output', type=str, help="Specify slurm topology file output")
group.add_argument('-v', '--use_vmss', action='store_true', default=True, help='Use VMSS (default: True)')
group.add_argument('-f', '--use_fabric_manager', action='store_true', default=False, help='Use Fabric Manager (default: False)')
def topology(self, config: Dict, partition, output, use_vmss, use_fabric_manager):
"""
Generates Topology Plugin Configuration
"""
if use_fabric_manager:
if not partition:
raise ValueError("--partition is required when using --use_fabric_manager")
config_dir = config.get("config_dir")
topo = topology.Topology(partition,output,config_dir)
topo.run()
elif use_vmss:
if output:
with open(output, 'w', encoding='utf-8') as file_writer:
return _generate_topology(self._get_node_manager(config), file_writer)
else:
return _generate_topology(self._get_node_manager(config), sys.stdout)
else:
raise ValueError("Please specify either --use_vmss or --use_fabric_manager")
def partitions_parser(self, parser: ArgumentParser) -> None:
parser.add_argument("--allow-empty", action="store_true", default=False)
def partitions(self, config: Dict, allow_empty: bool = False) -> None:
"""
Generates partition configuration
"""
node_mgr = self._get_node_manager(config)
partitions = partitionlib.fetch_partitions(node_mgr, include_dynamic=True) # type: ignore
_partitions(
partitions,
sys.stdout,
allow_empty=allow_empty,
autoscale=is_autoscale_enabled(),
)
def resume_parser(self, parser: ArgumentParser) -> None:
parser.set_defaults(read_only=False)
parser.add_argument(
"--node-list", type=hostlist, required=True
).completer = self._slurm_node_name_completer # type: ignore
parser.add_argument("--no-wait", action="store_true", default=False)
@init_power_saving_log
def resume(self, config: Dict, node_list: List[str], no_wait: bool = False) -> None:
"""
Equivalent to ResumeProgram, starts and waits for a set of nodes.
"""
bindings = new_cluster_bindings(config)
allocation.wait_for_nodes_to_terminate(bindings, node_list)
node_mgr = self._get_node_manager(config)
partitions = partitionlib.fetch_partitions(node_mgr, include_dynamic=True)
bootup_result = allocation.resume(config, node_mgr, node_list, partitions)
if not bootup_result:
raise AzureSlurmError(
f"Failed to boot {node_list} - {bootup_result.message}"
)
if no_wait:
return
def get_latest_nodes() -> List[Node]:
node_mgr = self._get_node_manager(config, force=True)
return node_mgr.get_nodes()
booted_node_list = [n.name for n in (bootup_result.nodes or [])]
allocation.wait_for_resume(
config, bootup_result.operation_id, booted_node_list, get_latest_nodes
)
def wait_for_resume_parser(self, parser: ArgumentParser) -> None:
parser.set_defaults(read_only=False)
parser.add_argument(
"--node-list", type=hostlist, required=True
).completer = self._slurm_node_name_completer # type: ignore
def wait_for_resume(self, config: Dict, node_list: List[str]) -> None:
"""
Wait for a set of nodes to converge.
"""
def get_latest_nodes() -> List[Node]:
node_mgr = self._get_node_manager(config, force=True)
return node_mgr.get_nodes()
allocation.wait_for_resume(config, "noop", node_list, get_latest_nodes)
def _shutdown(self, config: Dict, node_list: List[str], node_mgr: NodeManager) -> None:
by_name = hpcutil.partition_single(node_mgr.get_nodes(), lambda node: node.name)
node_list_filtered = []
to_keep_alive = []
for node_name in node_list:
if node_name in by_name:
node = by_name[node_name]
if node.keep_alive:
to_keep_alive.append(node_name)
logging.warning(f"{node_name} has KeepAlive=true in CycleCloud. Cannot terminate.")
else:
node_list_filtered.append(node_name)
else:
logging.info(f"{node_name} does not exist. Skipping.")
if to_keep_alive:
# This will prevent the node from falsely being resume/resume_fail over and over again.
logging.warning(f"Nodes {to_keep_alive} have KeepAlive=true in CycleCloud. Cannot terminate." +
" Setting state to down reason=keep_alive")
to_keep_alive_str = slutil.to_hostlist(to_keep_alive)
scontrol(["update", f"nodename={to_keep_alive_str}", "state=down", "reason=keep_alive"])
if not node_list_filtered:
logging.warning(f"No nodes out of node list {node_list} could be shutdown." +
" Post-processing the nodes only.")
else:
result = _safe_shutdown(node_list_filtered, node_mgr)
if not result:
raise AzureSlurmError(f"Failed to shutdown {node_list_filtered} - {result.message}")
if slutil.is_autoscale_enabled():
# undo internal DNS
for node_name in node_list:
_undo_internal_dns(node_name)
else:
# set states back to future and set NodeAddr/NodeHostName to node name
_update_future_states(self._get_node_manager(config, force=True), node_list)
def suspend_parser(self, parser: ArgumentParser) -> None:
parser.set_defaults(read_only=False)
parser.add_argument(
"--node-list", type=hostlist, required=True
).completer = self._slurm_node_name_completer # type: ignore
@init_power_saving_log
def suspend(self, config: Dict, node_list: List[str]) -> None:
"""
Equivalent to SuspendProgram, shutsdown nodes
"""
return self._shutdown(config, node_list, self._node_mgr(config))
def resume_fail_parser(self, parser: ArgumentParser) -> None:
self.suspend_parser(parser)
@init_power_saving_log
def resume_fail(
self, config: Dict, node_list: List[str], drain_timeout: int = 300
) -> None:
"""
Equivalent to SuspendFailProgram, shutsdown nodes
"""
node_mgr = self._node_mgr(config, self._driver(config))
self._shutdown(config, node_list=node_list, node_mgr=node_mgr)
def return_to_idle_parser(self, parser: ArgumentParser) -> None:
parser.set_defaults(read_only=False)
parser.add_argument("--terminate-zombie-nodes", action="store_true", default=False)
def return_to_idle(
self, config: Dict, terminate_zombie_nodes: bool = False
) -> None:
"""
Nodes that fail to resume in ResumeTimeout seconds will be left
in a down~ state - i.e. down and powered_down. It is also possible
the nodes will be in a drained~ state, if the node was drained during
resume. This command will set those nodes to idle~.
The one exception is for nodes that have KeepAlive set in CycleCloud.
Those nodes will be left as down~ and will be logged. When the user
unclicks the KeepAlive, the node can be automatically shutdown if --terminate-zombie-nodes
is set, or config["return-to-idle"]["terminate-zombie-nodes"] is true.
"""
if not slutil.is_autoscale_enabled():
return
# this is always run as root, so bump up the loglevel to info
stream_handlers = [
x
for x in logging.getLogger().handlers
if isinstance(x, logging.StreamHandler)
]
for sh in stream_handlers:
sh.setLevel(logging.INFO)
node_mgr = self._node_mgr(config)
ccnodes = node_mgr.get_nodes()
ccnodes_by_name = hpcutil.partition_single(ccnodes, lambda node: node.name)
snodes = slutil.show_nodes()
if terminate_zombie_nodes:
if "return_to_idle" not in config:
config["return-to-idle"] = {}
config["return-to-idle"]["terminate-zombie-nodes"] = True
SlurmCLI._return_to_idle(config, snodes, ccnodes_by_name, scontrol, node_mgr)
@staticmethod
def _return_to_idle(
config: Dict,
snodes: List[Dict],
ccnodes_by_name: Dict[str, Node],
scontrol_func: Callable,
node_mgr: NodeManager,
) -> None:
to_set_to_idle = []
to_power_down = []
for snode in snodes:
slurm_states = set(snode["State"].split("+"))
# ignore non-cloud nodes, as they aren't our responsibility
if "CLOUD" not in slurm_states:
continue
power_down_states = set(["POWERED_DOWN"])
if not power_down_states.intersection(slurm_states):
continue
node_name = snode["NodeName"]
if "DOWN" in slurm_states or "DRAINED" in slurm_states:
# Only nodes that do not exist in CycleCloud can be set to idle~
if node_name not in ccnodes_by_name:
to_set_to_idle.append(node_name)
continue
# keepalive nodes must be left alone, and obviously cannot be terminated.
ccnode = ccnodes_by_name[node_name]
if ccnode.keep_alive:
logging.warning(
f"{node_name} exists and has KeepAlive=true in CycleCloud. Cannot set to idle."
)
continue
terminate_zombie_nodes = config.get("return-to-idle", {}).get(
"terminate-zombie-nodes", False
)
if terminate_zombie_nodes:
logging.warning(
f"Found zombie node {node_name}. Will set to power_down because terminate-zombie-nodes is set."
)
# Terminate the node - note that the next round the node will be set to idle.
to_power_down.append(node_name)
else:
logging.warning(
f"Node {node_name} is in DOWN~ state but exists in CycleCloud. To terminate the node"
+ ", shutdown the node manually (via azslurm suspend or the UI) or, if you want the node"
+ " to join the cluster, login to it and restart slurmd."
)
if to_power_down:
to_power_down_idle_str = slutil.to_hostlist(to_power_down, scontrol_func=scontrol_func)
slutil.scontrol(["update", f"nodename={to_power_down_idle_str}", "state=power_down"])
if to_set_to_idle:
to_set_to_idle_str = slutil.to_hostlist(to_set_to_idle, scontrol_func=scontrol_func)
logging.warning(f"Setting nodes {to_set_to_idle} to idle.")
scontrol_func(["update", f"nodename={to_set_to_idle_str}", "state=idle"])
def _get_node_manager(self, config: Dict, force: bool = False) -> NodeManager:
return self._node_mgr(config, self._driver(config), force=force)
def _setup_shell_locals(self, config: Dict) -> Dict:
# TODO
shell = {}
partitions = partitionlib.fetch_partitions(self._get_node_manager(config)) # type: ignore
shell["partitions"] = ShellDict(
hpcutil.partition_single(partitions, lambda p: p.name)
)
shell["node_mgr"] = node_mgr = self._get_node_manager(config)
nodes = {}
for node in node_mgr.get_nodes():
node.shellify()
nodes[node.name] = node
if node.hostname:
nodes[node.hostname] = node
shell["nodes"] = ShellDict(nodes)
def slurmhelp() -> None:
def _print(key: str, desc: str) -> None:
print("%-20s %s" % (key, desc))
_print("partitions", "partition information")
_print("node_mgr", "NodeManager")
_print(
"nodes",
"Current nodes according to the provider. May include nodes that have not joined yet.",
)
shell["slurmhelp"] = slurmhelp
return shell
def _driver(self, config: Dict) -> SchedulerDriver:
return SlurmDriver()
def _default_output_columns(
self, config: Dict, cmd: Optional[str] = None
) -> List[str]:
if hpcutil.LEGACY:
return ["nodearray", "name", "hostname", "private_ip", "status"]
return ["pool", "name", "hostname", "private_ip", "status"]
def _initconfig_parser(self, parser: ArgumentParser) -> None:
# TODO
parser.add_argument("--accounting-tag-name", dest="accounting__tag_name")
parser.add_argument("--accounting-tag-value", dest="accounting__tag_value")
parser.add_argument(
"--accounting-subscription-id", dest="accounting__subscription_id"
)
parser.add_argument("--cost-cache-root", dest="cost__cache_root")
parser.add_argument("--config-dir", required=True)
def _initconfig(self, config: Dict) -> None:
# TODO
...
@disablecommand
def analyze(self, config: Dict, job_id: str, long: bool = False) -> None:
...
@disablecommand
def validate_constraint(
self,
config: Dict,
constraint_expr: List[str],
writer: TextIO = sys.stdout,
quiet: bool = False,
) -> Union[List, Dict]:
return super().validate_constraint(
config, constraint_expr, writer=writer, quiet=quiet
)
@disablecommand
def join_nodes(
self, config: Dict, hostnames: List[str], node_names: List[str]
) -> None:
return super().join_nodes(config, hostnames, node_names)
@disablecommand
def jobs(self, config: Dict) -> None:
return super().jobs(config)
@disablecommand
def demand(
self,
config: Dict,
output_columns: Optional[List[str]],
output_format: OutputFormat,
long: bool = False,
) -> None:
return super().demand(config, output_columns, output_format, long=long)
@disablecommand
def autoscale(
self,
config: Dict,
output_columns: Optional[List[str]],
output_format: OutputFormat,
dry_run: bool = False,
long: bool = False,
) -> None:
return super().autoscale(
config, output_columns, output_format, dry_run=dry_run, long=long
)
def scale_parser(self, parser: ArgumentParser) -> None:
parser.add_argument("--no-restart", action="store_true", default=False, help="Don't restart slurm controller")
return
def scale(
self,
config: Dict,
no_restart=False,
backup_dir="/etc/slurm/.backups",
slurm_conf_dir="/etc/slurm",
config_only=False,
):
"""
Create or update slurm partition and/or gres information
"""
sched_dir = config.get("config_dir")
node_mgr = self._get_node_manager(config)
# make sure .backups exists
now = clock.time()
backup_dir = os.path.join(backup_dir, str(now))
logging.debug(
"Using backup directory %s for azure.conf and gres.conf", backup_dir
)
os.makedirs(backup_dir)
azure_conf = os.path.join(sched_dir, "azure.conf")
gres_conf = os.path.join(sched_dir, "gres.conf")
linked_gres_conf = os.path.join(slurm_conf_dir, "gres.conf")
if os.path.isfile(linked_gres_conf) and not os.path.islink(linked_gres_conf):
msg = f"{linked_gres_conf} should be a symlink to {gres_conf}! Changes will not take effect locally."
print("WARNING: " + msg, file=sys.stderr)
logging.warning(msg)
if not os.path.exists(linked_gres_conf):
msg = f"please run 'ln -fs {gres_conf} {linked_gres_conf} && chown slurm:slurm {linked_gres_conf}'"
print("WARNING: " + msg, file=sys.stderr)
logging.warning(msg)
if os.path.exists(azure_conf):
shutil.copyfile(azure_conf, os.path.join(backup_dir, "azure.conf"))
if os.path.exists(gres_conf):
shutil.copyfile(gres_conf, os.path.join(backup_dir, "gres.conf"))
partition_dict = partitionlib.fetch_partitions(node_mgr)
with open(azure_conf + ".tmp", "w") as fw:
_partitions(
partition_dict,
fw,
allow_empty=False,
autoscale=is_autoscale_enabled(),
)
# Issue #193 - failure to maintain ownership/permissions when
# rewriting azure.conf and gres.conf
_move_with_permissions(azure_conf + ".tmp", azure_conf)
_update_future_states(node_mgr)
with open(gres_conf + ".tmp", "w") as fw:
_generate_gres_conf(partition_dict, fw)
_move_with_permissions(gres_conf + ".tmp", gres_conf)
if not no_restart:
logging.info("Restarting slurmctld...")
check_output(["systemctl", "restart", "slurmctld"])
logging.info("")
logging.info("Re-scaling cluster complete.")
def keep_alive_parser(self, parser: ArgumentParser) -> None:
parser.set_defaults(read_only=False)
parser.add_argument(
"--node-list", type=hostlist, required=True
).completer = self._slurm_node_name_completer # type: ignore
parser.add_argument("--remove", "-r", action="store_true", default=False)
parser.add_argument(
"--set", "-s", action="store_true", default=False, dest="set_nodes"
)
def keep_alive(
self,
config: Dict,
node_list: List[str],
remove: bool = False,
set_nodes: bool = False,
) -> None:
"""
Add, remove or set which nodes should be prevented from being shutdown.
"""
config_dir = config.get("config_dir")
if remove and set_nodes:
raise AzureSlurmError("Please define only --set or --remove, not both.")
lines = slutil.check_output(["scontrol", "show", "config"]).splitlines()
filtered = [
line for line in lines if line.lower().startswith("suspendexcnodes")
]
current_susp_nodes = []
if filtered:
current_susp_nodes_expr = filtered[0].split("=")[-1].strip()
if current_susp_nodes_expr != "(null)":
current_susp_nodes = slutil.from_hostlist(current_susp_nodes_expr)
if set_nodes:
hostnames = list(set(node_list))
elif remove:
hostnames = list(set(current_susp_nodes) - set(node_list))
else:
hostnames = current_susp_nodes + node_list
all_susp_hostnames = (
slutil.check_output(
[
"scontrol",
"show",
"hostnames",
",".join(hostnames),
]
)
.strip()
.split()
)
all_susp_hostnames = sorted(
list(set(all_susp_hostnames)), key=slutil.get_sort_key_func(False)
)
all_susp_hostlist = slutil.check_output(
["scontrol", "show", "hostlist", ",".join(all_susp_hostnames)]
).strip()
with open(f"{config_dir}/keep_alive.conf.tmp", "w") as fw:
if all_susp_hostlist:
fw.write(f"SuspendExcNodes = {all_susp_hostlist}")
else:
fw.write("# SuspendExcNodes = ")
shutil.move(f"{config_dir}/keep_alive.conf.tmp", f"{config_dir}/keep_alive.conf")
slutil.check_output(["scontrol", "reconfig"])
def accounting_info_parser(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument("--node-name", required=True)
def accounting_info(self, config: Dict, node_name: str) -> None:
node_mgr = self._get_node_manager(config)
nodes = node_mgr.get_nodes()
nodes_filtered = [n for n in nodes if n.name == node_name]
if not nodes_filtered:
json.dump([], sys.stdout)
return
assert len(nodes_filtered) == 1
node = nodes_filtered[0]
toks = check_output(["scontrol", "show", "node", node_name]).decode().split()
cpus = -1
for tok in toks:
tok = tok.lower()
if tok.startswith("cputot"):
cpus = int(tok.split("=")[1])
json.dump(
[
{
"name": node.name,
"location": node.location,
"vm_size": node.vm_size,
"spot": node.spot,
"nodearray": node.nodearray,
"cpus": cpus,
"pcpu_count": node.pcpu_count,
"vcpu_count": node.vcpu_count,
"gpu_count": node.gpu_count,
"memgb": node.memory.value,
}
],
sys.stdout
)
def _move_with_permissions(src: str, dst: str) -> None:
if os.path.exists(dst):
st = os.stat(dst)
os.chmod(src, st.st_mode)
os.chown(src, st.st_uid, st.st_gid)
logging.debug("Moving %s to %s", src, dst)
shutil.move(src, dst)
if not slutil.is_autoscale_enabled():
def sync_future_states_parser(self, parser: ArgumentParser) -> None:
parser.add_argument(
"--node-list", type=hostlist_null_star, help="Optional subset of nodes to sync. Default is all."
)
def sync_future_states(self, config: Dict, node_list: Optional[List[str]] = None) -> None:
_update_future_states(self._get_node_manager(config), node_list)
def _dynamic_partition(partition: partitionlib.Partition, writer: TextIO) -> None:
assert partition.dynamic_feature
writer.write(
"# Creating dynamic nodeset and partition using slurm.dynamic_feature=%s\n"
% partition.dynamic_feature
)
if not partition.features:
logging.error(
f"slurm.dynamic_feature was set for {partition.name}"
+ "but it did not include a feature declaration. Slurm requires this! Skipping for now.ß"
)
return
writer.write(f"Nodeset={partition.name}ns Feature={partition.features[0]}\n")
writer.write(f"PartitionName={partition.name} Nodes={partition.name}ns")
if partition.is_default:
writer.write(" Default=YES")
writer.write("\n")
def _partitions(
partitions: List[partitionlib.Partition],
writer: TextIO,
allow_empty: bool = False,
autoscale: bool = True,
) -> None:
written_dynamic_partitions = set()
writer.write(
f"# Note: To account for OS/VM overhead, by default we reduce the reported memory from CycleCloud by 5%.\n"
)
writer.write(
"# We do this because Slurm will reject a node that reports less than what is defined in this config.\n"
)
writer.write(
"# There are two ways to change this:\n" +
"# 1) edit slurm.dampen_memory=X in the nodearray's Configuration where X is percentage (5 = 5%).\n"
"# 2) Edit the slurm_memory value defined in /opt/azurehpc/slurm/autosacle.json.\n" +
"# Note that slurm.dampen_memory will take precedence.\n"
)
for partition in partitions:
if partition.dynamic_feature:
if partition.name in written_dynamic_partitions:
logging.warning("Duplicate partition found mapped to the same name." +
" Using first Feature= declaration and ignoring the rest!")
continue
_dynamic_partition(partition, writer)
written_dynamic_partitions.add(partition.name)
continue
node_list = partition.node_list or []
max_count = min(partition.max_vm_count, partition.max_scaleset_size)
default_yn = "YES" if partition.is_default else "NO"
memory = max(1024, partition.memory)
if partition.use_pcpu:
cpus = partition.pcpu_count
threads = max(1, partition.vcpu_count // partition.pcpu_count)
else:
cpus = partition.vcpu_count
threads = 1
def_mem_per_cpu = memory // cpus
comment_out = ""
if max_count <= 0:
writer.write(f"# The following partition has no capacity! {partition.name} - {partition.nodearray} - {partition.machine_type} \n")
comment_out = "# "
writer.write(
f"{comment_out}PartitionName={partition.name} Nodes={partition.node_list} Default={default_yn} DefMemPerCPU={def_mem_per_cpu} MaxTime=INFINITE State=UP\n"
)
state = "CLOUD" if autoscale else "FUTURE"
writer.write(
f"{comment_out}Nodename={node_list} Feature=cloud STATE={state} CPUs={cpus} ThreadsPerCore={threads} RealMemory={memory}"
)
if partition.gpu_count:
writer.write(" Gres=gpu:{}".format(partition.gpu_count))
writer.write("\n")
def _generate_topology(node_mgr: NodeManager, writer: TextIO) -> None:
partitions = partitionlib.fetch_partitions(node_mgr)
nodes_by_pg = {}
for partition in partitions:
for pg, node_list in partition.node_list_by_pg.items():
if pg not in nodes_by_pg:
nodes_by_pg[pg] = []
nodes_by_pg[pg].extend(node_list)
if not nodes_by_pg:
raise AzureSlurmError(
"No nodes found to create topology! Do you need to run create_nodes first?"
)
for pg in sorted(nodes_by_pg.keys(), key=lambda x: x if x is not None else ""):
nodes = nodes_by_pg[pg]
if not nodes:
continue
nodes = sorted(nodes, key=slutil.get_sort_key_func(bool(pg)))
slurm_node_expr = ",".join(nodes) # slutil.to_hostlist(",".join(nodes))
writer.write("SwitchName={} Nodes={}\n".format(pg or "htc", slurm_node_expr))
def _generate_nvidia_devices(gpu_count: int) -> str:
if gpu_count == 1:
return "/dev/nvidia0"
return "/dev/nvidia[0-{}]".format(gpu_count - 1)
def _generate_amd_devices(gpu_count: int) -> str:
if gpu_count == 1:
return "/dev/dri/renderD128"
# NOTE: AMD GPU devices should be comma-separated and may not have spaces in the list
amd_gpu_list = ",".join([f"{128+8*index}" for index in range(0, gpu_count)])
return "/dev/dri/renderD[{}]".format(amd_gpu_list)
def _generate_gpu_devices(partition: partitionlib.Partition) -> str:
# Override from node configuration
if partition.gpu_device_config:
return partition.gpu_device_config
try:
# TODO: azure sku should eventually indicate the GPU type
# (currently, attribute is not available, so gpu_family is not available to the partition)
# We'll fall back to parsing machine type until it is available
has_amd_gpu = "amd" in partition.gpu_family.lower()
except AttributeError:
has_amd_gpu = "mi300" in partition.machine_type.lower()
if has_amd_gpu:
gpu_devices = _generate_amd_devices(partition.gpu_count)
else:
gpu_devices = _generate_nvidia_devices(partition.gpu_count)
return gpu_devices
def _generate_gres_conf(partitions: List[partitionlib.Partition], writer: TextIO):
for partition in partitions:
if partition.node_list is None:
raise RuntimeError(
"No nodes found for nodearray %s. Please run 'azslurm create_nodes' first!"
% partition.nodearray
)
num_placement_groups = int(
ceil(float(partition.max_vm_count) / partition.max_scaleset_size)
)
all_nodes = sorted(
slutil.from_hostlist(partition.node_list),
key=slutil.get_sort_key_func(partition.is_hpc),
)
for pg_index in range(num_placement_groups):
start = pg_index * partition.max_scaleset_size
end = min(
partition.max_vm_count, (pg_index + 1) * partition.max_scaleset_size
)
subset_of_nodes = all_nodes[start:end]
if not subset_of_nodes:
continue
node_list = slutil.to_hostlist(",".join((subset_of_nodes)))
# cut out 1gb so that the node reports at least this amount of memory. - recommended by schedmd
if partition.gpu_count:
gpu_devices = _generate_gpu_devices(partition)
writer.write(
"Nodename={} Name=gpu Count={} File={}".format(
node_list, partition.gpu_count, gpu_devices
)
)
writer.write("\n")
def _update_future_states(node_mgr: NodeManager, node_list: Optional[List[str]] = None) -> None:
autoscale_enabled = is_autoscale_enabled()
if autoscale_enabled:
return
nodes = node_mgr.get_nodes()
for node in nodes:
if node_list and node.name not in node_list:
continue
if node.target_state != "Started":
name = node.name
try:
cmd = [
"scontrol",
"update",
f"NodeName={name}",
f"NodeAddr={name}",
f"NodeHostName={name}",
"state=FUTURE",
]
check_output(cmd)
except SubprocessError:
logging.warning(f"Could not set {node.get('Name')} state=FUTURE")
def _undo_internal_dns(node_name: str) -> None:
try:
cmd = [
"scontrol",
"update",
f"NodeName={node_name}",
f"NodeAddr={node_name}",
f"NodeHostName={node_name}",
]
check_output(cmd)
except SubprocessError:
logging.warning(f"Could not set {node_name}'s nodeaddr/nodehostname!")
def _retry_rest(func: Callable, attempts: int = 5) -> Any:
attempts = max(1, attempts)
last_exception = None
for attempt in range(1, attempts + 1):
try:
return func()
except Exception as e:
last_exception = e
logging.debug(traceback.format_exc())
clock.sleep(attempt * attempt)
raise AzureSlurmError(str(last_exception))
def hostlist(hostlist_expr: str) -> List[str]:
if hostlist_expr == "*":
all_node_names = slutil.check_output(
["sinfo", "-O", "nodelist", "-h", "-N"]
).split()
return all_node_names
return slutil.from_hostlist(hostlist_expr)
def hostlist_null_star(hostlist_expr: str) -> Optional[List[str]]:
if hostlist_expr == "*":
return None
return slutil.from_hostlist(hostlist_expr)
def _safe_shutdown(node_list: List[str], node_mgr: NodeManager) -> ShutdownResult:
assert node_list
logging.info(f"Shutting down nodes {node_list}")
nodes = _as_nodes(node_list, node_mgr)
ret = _retry_rest(lambda: node_mgr.shutdown_nodes(nodes))
if ret:
logging.info(str(ret))
else:
logging.error(str(ret))
return ret
def _as_nodes(node_list: List[str], node_mgr: NodeManager) -> List[Node]:
nodes: List[Node] = []
by_name = hpcutil.partition_single(node_mgr.get_nodes(), lambda node: node.name)
for node_name in node_list:
# TODO error handling on missing node names
if node_name not in by_name:
raise AzureSlurmError(f"Unknown node - {node_name}")
nodes.append(by_name[node_name])
return nodes
def main(argv: Optional[Iterable[str]] = None) -> None:
try:
clilibmain(argv or sys.argv[1:], "slurm", SlurmCLI())
except AzureSlurmError as e:
logging.error(e.message)
sys.exit(1)
except Exception:
log_files = [x.baseFilename for x in logging.getLogger().handlers if hasattr(x, "baseFilename")]
logging.exception(f"Unexpected error. See {','.join(log_files)} for more information.")
raise
if __name__ == "__main__":
main()