azure-slurm/slurmcc/util.py (216 lines of code) (raw):
import logging
from abc import ABC, abstractmethod
import os
import random
import subprocess as subprocesslib
import tempfile
import sys
import time
import traceback
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
from . import AzureSlurmError, custom_chaos_mode
class SrunExitCodeException(Exception):
def __init__(self, returncode: int, stdout: str, stderr: str, stderr_content: str):
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
self.stderr_content = stderr_content
super().__init__(f"srun command failed with exit code {returncode}")
class SrunOutput:
def __init__(self, returncode: int, stdout: str, stderr: str,):
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
class NativeSlurmCLI(ABC):
@abstractmethod
def scontrol(self, args: List[str], retry: bool = True) -> str:
...
@abstractmethod
def srun(self, hostname: List[str], user_command: str, timeout: int, shell: bool, partition: str) -> SrunOutput:
...
class NativeSlurmCLIImpl(NativeSlurmCLI):
def scontrol(self, args: List[str], retry: bool = True) -> str:
assert args[0] != "scontrol"
full_args = ["scontrol"] + args
if retry:
return retry_subprocess(lambda: check_output(full_args)).strip()
return check_output(full_args).strip()
def srun(self, hostlist: List[str], user_command: str, timeout: int, shell: bool, partition: str) -> SrunOutput:
with tempfile.NamedTemporaryFile(delete=True) as temp_file:
temp_file_path = temp_file.name
try:
command = f"bash -c '{user_command}'" if shell else user_command
partition_flag = f"-p {partition} " if partition else ""
#adding deadline timeout 1 minute more than the srun timeout to avoid deadline timeout before srun can finish running
srun_command = f"srun {partition_flag}-w {','.join(hostlist)} --error {temp_file_path} --deadline=now+{timeout+1}minute --time={timeout} {command}"
logging.debug(srun_command)
#subprocess timeout is in seconds, so we need to convert the timeout to seconds
#add 3 minutes to it so it doesnt timeout before the srun command can kill the job from its own timeout
subp_timeout=timeout*60+180
result = subprocesslib.run(srun_command, check=True, timeout=subp_timeout, shell=True,stdout=subprocesslib.PIPE, stderr=subprocesslib.PIPE, universal_newlines=True)
return SrunOutput(returncode=result.returncode, stdout=result.stdout, stderr=None)
except subprocesslib.CalledProcessError as e:
logging.error(f"Command: {srun_command} failed with return code {e.returncode}")
with open(temp_file_path, 'r') as f:
stderr_content = f.read()
if not stderr_content.strip("\n"):
stderr_content = None
raise SrunExitCodeException(returncode=e.returncode,stdout=e.stdout, stderr=e.stderr, stderr_content=stderr_content)
except subprocesslib.TimeoutExpired:
logging.error("Srun command timed out!")
raise
SLURM_CLI: NativeSlurmCLI = NativeSlurmCLIImpl()
def set_slurm_cli(cli: NativeSlurmCLI) -> None:
global SLURM_CLI
SLURM_CLI = cli
def scontrol(args: List[str], retry: bool = True) -> str:
assert args[0] != "scontrol"
return SLURM_CLI.scontrol(args, retry)
def srun(hostlist: List[str], user_command: str, timeout: int = 2, shell: bool = False, partition: str = None) -> SrunOutput:
#srun --time option is in minutes and needs atleast 1 minute to run
assert timeout >= 1
assert hostlist != None
assert user_command != None
return SLURM_CLI.srun(hostlist, user_command, timeout=timeout, shell=shell, partition=partition)
TEST_MODE = False
def is_slurmctld_up() -> bool:
if TEST_MODE:
return True
try:
SLURM_CLI.scontrol(["ping"], retry=False)
return True
except Exception:
return False
def show_nodes(node_list: Optional[List[str]] = None) -> List[Dict[str, Any]]:
args = ["show", "nodes"]
if not is_autoscale_enabled():
args.append("--future")
if node_list:
args.append(",".join(node_list))
stdout = scontrol(args)
return parse_show_nodes(stdout)
def parse_show_nodes(stdout: str) -> List[Dict[str, Any]]:
ret = []
current_node = None
for line in stdout.splitlines():
line = line.strip()
for sub_expr in line.split():
if "=" not in sub_expr:
continue
key, value = sub_expr.split("=", 1)
if key == "NodeName":
if current_node:
ret.append(current_node)
current_node = {}
assert current_node is not None
current_node[key] = value
if current_node:
ret.append(current_node)
return ret
def to_hostlist(nodes: Union[str, List[str]], scontrol_func: Callable=scontrol) -> str:
"""
convert name-[1-5] into name-1 name-2 name-3 name-4 name-5
"""
assert nodes
for n in nodes:
assert n
if isinstance(nodes, list):
nodes_str = ",".join(nodes)
else:
nodes_str = nodes
# prevent poor sorting of nodes and getting node lists like htc-1,htc-10-19, htc-2, htc-20-29 etc
sorted_nodes = sorted(nodes_str.split(","), key=get_sort_key_func(is_hpc=False))
nodes_str = ",".join(sorted_nodes)
return scontrol_func(["show", "hostlist", nodes_str])
def from_hostlist(hostlist_expr: str) -> List[str]:
"""
convert name-1,name-2,name-3,name-4,name-5 into name-[1-5]
"""
assert isinstance(hostlist_expr, str)
stdout = scontrol(["show", "hostnames", hostlist_expr])
return [x.strip() for x in stdout.split()]
def run(args: list, stdout=subprocesslib.PIPE, stderr=subprocesslib.PIPE, timeout=120, shell=False, check=True, universal_newlines=True, **kwargs):
"""
run arbitrary command through subprocess.run with some sensible defaults.
Standard streams are defaulted to subprocess stdout/stderr pipes.
Encoding defaulted to string
Timeout defaulted to 2 minutes.
"""
try:
output = subprocesslib.run(args=args, stdout=stdout, stderr=stderr, timeout=timeout, shell=shell, check=check, universal_newlines=universal_newlines, **kwargs)
except subprocesslib.CalledProcessError as e:
logging.error(f"cmd: {e.cmd}, rc: {e.returncode}")
logging.error(e.stderr)
raise
except subprocesslib.TimeoutExpired as t:
logging.error("Timeout Expired")
raise
except Exception as e:
logging.error(e)
raise
return output
def retry_rest(func: Callable, attempts: int = 5) -> Any:
attempts = max(1, attempts)
last_exception = None
for attempt in range(1, attempts + 1):
try:
return func()
except Exception as e:
last_exception = e
logging.debug(traceback.format_exc())
time.sleep(attempt * attempt)
raise AzureSlurmError(str(last_exception))
def retry_subprocess(func: Callable, attempts: int = 5) -> Any:
attempts = max(1, attempts)
last_exception: Optional[Exception] = None
for attempt in range(1, attempts + 1):
try:
return func()
except Exception as e:
last_exception = e
logging.debug(traceback.format_exc())
logging.warning("Command failed, retrying: %s", str(e))
time.sleep(attempt * attempt)
raise AzureSlurmError(str(last_exception))
def check_output(args: List[str], **kwargs: Any) -> str:
ret = _SUBPROCESS_MODULE.check_output(args=args, **kwargs)
if not isinstance(ret, str):
ret = ret.decode()
return ret
def _raise_proc_exception() -> Any:
def called_proc_exception(msg: str) -> Any:
raise subprocesslib.CalledProcessError(1, msg)
choice: Callable = random.choice([OSError, called_proc_exception]) # type: ignore
raise choice("Random failure")
class SubprocessModuleWithChaosMode:
@custom_chaos_mode(_raise_proc_exception)
def check_call(self, *args, **kwargs): # type: ignore
return subprocesslib.check_call(*args, **kwargs)
@custom_chaos_mode(_raise_proc_exception)
def check_output(self, *args, **kwargs): # type: ignore
return subprocesslib.check_output(*args, **kwargs)
_SUBPROCESS_MODULE = SubprocessModuleWithChaosMode()
def get_sort_key_func(is_hpc: bool) -> Callable[[str], Union[str, int]]:
return _node_index_and_pg_as_sort_key if is_hpc else _node_index_as_sort_key
def _node_index_as_sort_key(nodename: str) -> Union[str, int]:
"""
Used to get the key to sort names that don't have name-pg#-# format
"""
try:
return int(nodename.split("-")[-1])
except Exception:
return int.from_bytes(nodename.encode(), byteorder="little")
def _node_index_and_pg_as_sort_key(nodename: str) -> Union[str, int]:
"""
Used to get the key to sort names that have name-pg#-# format
"""
try:
node_index = int(nodename.split("-")[-1])
pg = int(nodename.split("-")[-2].replace("pg", "")) * 100000
return pg + node_index
except Exception:
return nodename
_IS_AUTOSCALE_ENABLED = None
def is_autoscale_enabled() -> bool:
global _IS_AUTOSCALE_ENABLED
if _IS_AUTOSCALE_ENABLED is not None:
return _IS_AUTOSCALE_ENABLED
try:
with open("/etc/slurm/slurm.conf") as fr:
lines = fr.readlines()
except Exception:
_IS_AUTOSCALE_ENABLED = True
return _IS_AUTOSCALE_ENABLED
for line in lines:
line = line.strip()
# this can be defined more than once
if line.startswith("SuspendTime ") or line.startswith("SuspendTime="):
suspend_time = line.split("=")[1].strip().split()[0]
try:
if suspend_time in ["NONE", "INFINITE"] or int(suspend_time) < 0:
_IS_AUTOSCALE_ENABLED = False
else:
_IS_AUTOSCALE_ENABLED = True
except Exception:
pass
if _IS_AUTOSCALE_ENABLED is not None:
return _IS_AUTOSCALE_ENABLED
logging.warning("Could not determine if autoscale is enabled. Assuming yes")
return True