azure-slurm-install/installlib.py (502 lines of code) (raw):

import base64 import grp from hashlib import md5 import json import logging import os import pwd import re import shutil from ssl import SSLContext import ssl import subprocess import tempfile from time import sleep as _sleep from datetime import datetime, timezone from types import TracebackType from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import urllib import urllib.parse import urllib.request class ConvergeError(RuntimeError): pass class ConvergeRetry(RuntimeError): pass class Clock: def time(self) -> float: return datetime.now(timezone.utc).timestamp() def sleep(self, n: float) -> None: return _sleep(n) class MockClock: def __init__(self, now: float = 1000.0) -> None: self.now = now def time(self) -> float: return self.now def sleep(self, n: float) -> None: self.now += n _CLOCK = Clock() def use_mock_clock() -> MockClock: global _CLOCK _CLOCK = MockClock() return _CLOCK def time() -> float: return _CLOCK.time() def sleep(n: float) -> None: return _CLOCK.sleep(n) def blob_download(filename: str, project: str, node: Dict) -> str: downloads_dir = node["blobs"].get("downloads", "/opt/azurehpc/blobs") if not os.path.exists(downloads_dir): os.makedirs(downloads_dir) dest = os.path.join(downloads_dir, filename) if node["blobs"]["type"] == "simple": return dest # src = os.path.join(node["blobs"]["url"], filename) # shutil.copyfile(src=src, dst=dest) elif node["blobs"]["type"] == "jetpack": subprocess.check_call( ["jetpack", "download", filename, f"--project={project}", dest] ) return dest else: raise ConvergeError("Only blobs.type==simple or jetpack is valid at this time") def link( src: str, dst: str, owner: Optional[str] = None, group: Optional[str] = None ) -> None: if not os.path.islink(dst): logging.info("Linking {dst} to {src}".format(**locals())) os.symlink(src, dst) # chown(to, owner, group) else: logging.info("Link {dst} already exists".format(**locals())) def chown( dest: str, owner: Optional[str] = None, group: Optional[str] = None, recursive: bool = False, ) -> None: pwd_record = uid = gid = None if owner: pwd_record = pwd.getpwnam(owner) uid = pwd_record.pw_uid gid = pwd_record.pw_gid if group: gid = grp.getgrnam(group).gr_gid elif pwd_record: group = pwd_record.pw_name if uid and gid: recursive_arg = "-R" if recursive else "" logging.info(f"chown {recursive_arg} {dest} with {owner}({uid}):{group}({gid})") os.chown(dest, uid=uid, gid=gid) if recursive and os.path.isdir(dest): # TODO should probably use OS version for fil in os.listdir(dest): chown(os.path.join(dest, fil), owner, group, recursive=recursive) def chmod(dest: str, mode: Optional[Union[str, int]], recursive: bool = False) -> None: if mode is not None: # if isinstance(mode, str): # mode = int(mode) logging.info(f"chmod {mode} {dest}") if recursive: cmd = ["chmod", "-R", str(mode), dest] else: cmd = ["chmod", str(mode), dest] logging.info(" ".join(cmd)) subprocess.check_call(cmd) # os.chmod(dest, mode) # if recursive and os.path.isdir(dest): # for fil in os.listdir(dest): # chmod(os.path.join(dest, fil), mode, recursive=recursive) def copy_file( source: str, dest: str, owner: str, group: str, mode: Union[str, int] ) -> None: shutil.copyfile(src=source, dst=dest) chown(dest, owner=owner, group=group) chmod(dest, mode) def file( dest: str, content: Union[bytes, str] = "", owner: Optional[str] = None, group: Optional[str] = None, mode: Optional[Union[str, int]] = None, ) -> None: io_mode = "a" if isinstance(content, str) else "ab" tmp_dest = dest + ".tmp" with open(tmp_dest, io_mode) as fw: fw.write(content) chown(tmp_dest, owner, group) chmod(tmp_dest, mode) move(tmp_dest, dest) def append_file(dest: str, content: str, comment_prefix: str) -> None: """ provides monotonic appending of content to a file. This relies on the fact that we can append "md5 = <hash>" to the end of the comment to prevent duplicate appends. """ hash = md5(content.encode()).hexdigest() with open(dest, "r") as fr: already_written = hash in fr.read() if not already_written: logging.info(f"Appending to {dest}: content='{content}'") with open(dest, "a") as fa: fa.write(f"{comment_prefix} md5 = {hash}\n") fa.write(content) def move(src: str, dest: str) -> None: logging.info(f"mv {src} {dest}") shutil.move(src, dest) # TODO!!! def cookbook_file( dest: str, source: str, owner: str, group: str, mode: Union[str, int] ) -> None: full_source = os.path.abspath(os.path.join("cookbook_files", source)) if isinstance(mode, str): mode = int(mode) copy_file(full_source, dest, owner, group, mode) def template( dest: str, owner: str, group: str, source: str, mode: Union[str, int] = 600, variables: Optional[Dict] = None, ) -> None: if os.path.exists(dest): shutil.move(dest, f"{dest}.backup") variables = variables or {} if isinstance(mode, str): mode = int(mode) if not os.path.exists(source): raise ConvergeError(f"Template {source} does not exist!") with open(source) as fr: contents = fr.read() with open(dest, "w") as fw: fw.write(contents.format(**variables)) chmod(dest, mode) if owner and group: chown(dest, owner, group) def group(group_name: str, gid: Optional[int]) -> None: groups = dict([(g.gr_name, g.gr_gid) for g in grp.getgrall()]) if group_name in groups: # group already exists # TODO logging return if gid is not None: cmd = ["groupadd", "-g", str(gid), group_name] else: cmd = ["groupadd", group_name] subprocess.check_call(cmd) def group_members(group_name: str, members: List[str], append: bool = True) -> None: assert append for member in members: subprocess.check_call(["usermod", "-a", "-G", group_name, member]) def user( user_name: str, comment: str, shell: Optional[str] = None, uid: Optional[int] = None, gid: Optional[int] = None, ) -> None: users = dict([(p.pw_name, p.pw_uid) for p in pwd.getpwall()]) if user_name in users: return logging.info(comment) cmd = ["useradd"] if uid: cmd += ["-u", str(uid)] if gid: cmd += ["-g", str(gid)] if shell: cmd += ["-s", shell] subprocess.check_call(cmd + [user_name]) class guard: def __init__(self, path: str, content: str = "") -> None: self.path = path self.content = content def __enter__(self) -> "guard": return self def __exit__( self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType], ) -> bool: if not exctype: with open(self.path, "w") as fw: fw.write(self.content) return False return True def directory( path: str, owner: Optional[str] = None, group: Optional[str] = None, mode: Optional[int] = None, recursive: bool = False, ) -> None: if not os.path.exists(path): os.makedirs(path) chown(path, owner, group, recursive) chmod(path, mode, recursive) def create_service( name: str, exec_start: str, working_dir: str = "", user: str = "root", ) -> None: service_desc = f""" [Unit] Description={name} [Service] User={user} {"WorkingDirectory="+working_dir if working_dir else ""} ExecStart={exec_start} Restart=always [Install] WantedBy=multi-user.target""" with open(f"/etc/systemd/system/{name}.service", "w") as fw: fw.write(service_desc) def enable_service(name: str) -> None: execute(f"enable service {name}", command=["systemctl", "enable", name]) def start_service(name: str) -> None: execute(f"start service {name}", command=["systemctl", "start", name]) def restart_service(name: str) -> None: execute(f"restart service {name}", command=["systemctl", "restart", name]) def cron(desc: str, minute: str, command: str) -> None: temp_name = tempfile.NamedTemporaryFile(delete=False).name try: with open(temp_name, "w") as fw: fw.write(f"# {desc}\n") fw.write(f"{minute} * * * * {command}\n") with open(temp_name) as fr: logging.info("Adding crontab:") logging.info(fr.read()) subprocess.check_call(["crontab", temp_name]) finally: if os.path.exists(temp_name): os.remove(temp_name) def _merge_dict(a: Dict, b: Dict) -> Dict: for akey, avalue in a.items(): if isinstance(avalue, dict): bvalue = b.setdefault(akey, {}) _merge_dict(avalue, bvalue) else: b[akey] = avalue return b def read_node(path: str, initializer: "Initializer") -> Dict: with open(path) as fr: node = json.load(fr) defaults = initializer.defaults() if node.pop("_", {}): logging.warning("Purging top level key '_'") node["_"] = {} initializer.initialize(node) return _merge_dict(node, defaults) class Initializer: def initialize(self, node: Dict) -> None: pass def defaults(self) -> Dict: return {} def execute( desc: str, command: Union[str, List[str]], stdout: Optional[str] = None, retries: int = 0, retry_delay: int = 0, guard_file: Optional[str] = None, # ) -> None: if guard_file and os.path.exists(guard_file): logging.info(f"Skipping '{desc}' because {guard_file} exists.") return logging.getLogger("audit").info(f"execute: {desc}") logging.info(f"execute: {desc}") if stdout and os.path.exists(stdout): return for attempt in range(min(0, retries) + 1): try: stdout_content = subprocess.check_output(command) if stdout: with open(stdout, "w") as fw: fw.write(stdout_content.decode()) except: if retries and attempt < retries: logging.exception( f"Attempt {attempt + 1}. Sleeping {retry_delay} seconds" ) sleep(retry_delay) else: raise if guard_file: with open(guard_file, "w") as fw: fw.write("") def _waagent_service_name(platform_family: str) -> str: if platform_family in ["ubuntu", "debian"]: waagent_service_name = "walinuxagent" else: waagent_service_name = "waagent" return waagent_service_name def _ensure_monitoring(platform_family: str) -> None: with open("/etc/waagent.conf") as fr: lines = fr.readlines() modified = False for i in range(len(lines)): line = lines[i].strip().lower() if re.match("^provisioning.monitorhostname=n$", line): lines[i] = "Provisioning.MonitorHostName=y\n" modified = True if modified: dest_waagent = "/etc/waagent.conf" temp_waagent = dest_waagent + ".tmp" with open(temp_waagent, "w") as fw: for line in lines: fw.write(line) move(temp_waagent, dest_waagent) restart_service(_waagent_service_name(platform_family)) def _wait_for_hostname(hostname: str) -> None: attempts = 12 retry_delay = 10 for a in range(attempts): nslookup_stdout = _unchecked_output(["nslookup", hostname]) if hostname in nslookup_stdout: return logging.info(f"{a}/{attempts} waiting for hostname to register in dns.") sleep(retry_delay) raise RuntimeError("Could not register hostname in DNS") def _unchecked_output(cmd: List[str]) -> str: try: return subprocess.check_output(cmd).decode() except Exception as e: logging.debug(f"attempt to run {' '.join(cmd)} failed: {e}") return "" def set_hostname( hostname: str, platform_family: str, monitor_hostname: bool = True ) -> None: if monitor_hostname: _ensure_monitoring(platform_family) pub_hostname_path = "/var/lib/waagent/published_hostname" nslookup_stdout = _unchecked_output(["nslookup", hostname]) hostname_stdout = _unchecked_output(["hostname"]) pub_hostname_exists = os.path.exists(pub_hostname_path) if ( hostname not in nslookup_stdout and hostname not in hostname_stdout and pub_hostname_exists ): os.remove(pub_hostname_path) logging.warning("Restarting waagent service to force re-registration of hostname") restart_service(_waagent_service_name(platform_family)) execute("set hostname", command=["hostnamectl", "set-hostname", hostname]) execute( "update hostname via jetpack", command=[ "/opt/cycle/jetpack/system/embedded/bin/python", "-c", "import jetpack.converge as jc; jc._send_installation_status('warning')", ], ) class CCNode: """ Simple CycleCloud node representation. """ def __init__( self, name: str, nodearray_name: str, hostname: str, private_ipv4: str, status: str, software_configuration: Dict, ) -> None: self.name = name self.nodearray_name = nodearray_name self.hostname = hostname self.private_ipv4 = private_ipv4 self.status = status self.software_configuration = software_configuration def to_dict(self) -> Dict: ret = {} for attr in dir(self): if attr.startswith("_"): continue if attr == "to_dict": continue val = getattr(self, attr) if hasattr(val, "__call__"): continue ret[attr] = val return ret def is_failed(self) -> bool: return self.status == "Failed" def is_ready(self) -> bool: return self.status == "Ready" def is_booting(self) -> bool: return self.status != "Ready" def __repr__(self) -> str: return f"CCNode({self.to_dict()})" def __str__(self) -> str: return repr(self) def __eq__(self, other: object) -> bool: if hasattr(other, "to_dict"): return self.to_dict() == getattr(other, "to_dict")() return False def cluster_status(config: Dict) -> Dict: """ Makes a REST call to clusters/{cluster_name}/status """ if config.get("mock_provider"): return config["mock_provider"]["nodes"] cc_config = config["cyclecloud"]["config"] urlbase = cc_config["web_server"].rstrip("/") username = cc_config["username"] password = cc_config["password"] context = SSLContext(ssl.PROTOCOL_TLSv1_2) cluster_name = urllib.parse.quote(config["cyclecloud"]["cluster"]["name"]) url = f"{urlbase}/clusters/{cluster_name}/nodes" auth_token = base64.b64encode(f"{username}:{password}".encode("utf-8")).decode( "ascii" ) request = urllib.request.Request( url=url, headers={"Authorization": f"Basic {auth_token}"}, method="GET" ) response = urllib.request.urlopen( request, context=context, timeout=30, ) if response.getcode() != 200: raise RuntimeError(f"Error getting cluster status: {response.status}") return json.loads(response.read().decode("utf-8")) def await_node_hostname( config: Dict, node_name: str, timeout=300, cluster_status_func: Callable[[Dict], Dict] = cluster_status, ) -> CCNode: """ Blocks until the nodename has a valid registered hostname in CycleCloud. Handles ip-XXXXXXXX nodes, nodes with a node prefix and if someones specifies config["valid_hostnames"] = ["^myregex$"] """ omega = timeout + time() while time() < omega: referenced_node = get_ccnode(config, node_name, cluster_status_func) if referenced_node.hostname: if is_valid_hostname(config, referenced_node): return referenced_node else: logging.warning( "Invalid hostname detected, waiting for valid hostname %s", referenced_node.hostname, ) sleep(5) raise RuntimeError( f"Node {node_name} did not register hostname in {timeout} seconds" ) def is_valid_hostname(config: Dict, node: CCNode) -> bool: """ See await_node_hostname for details. """ if not node.hostname: return False valid_hostnames: Optional[List[str]] = config.get("valid_hostnames") if not valid_hostnames: if is_standalone_dns(node): valid_hostnames = ["^ip-[0-9A-Za-z]{8}$"] else: unescaped_nodename_prefix = ( node.software_configuration.get("slurm", {}).get("node_prefix") or "" ) nodename_prefix = re.sub("[^a-zA-Z0-9-]", "-", unescaped_nodename_prefix) if nodename_prefix: valid_hostnames = [f"^{nodename_prefix}{node.name}$".lower()] else: valid_hostnames = [f"^{node.name}$".lower()] for valid_hostname in valid_hostnames: # assert "D" not in node.hostname, f"{node.hostname} with {valid_hostname}" if re.match(valid_hostname, node.hostname): return True logging.warning( "Rejecting invalid hostname '%s': Did not match any of the following patterns: %s", node.hostname, valid_hostnames, ) return False def is_standalone_dns(node: CCNode) -> bool: return ( node.software_configuration.get("cyclecloud", {}) .get("hosts", {}) .get("standalone_dns", {}) .get("enabled", True) ) def get_ccnode( config: Dict, node_name: str, cluster_status_func: Callable[[Dict], Dict] = cluster_status, ) -> CCNode: status = cluster_status_func(config) for node in status["nodes"]: if node["Name"] == node_name: return CCNode( name=node["Name"], nodearray_name=node["Template"], hostname=node["Hostname"], private_ipv4=node["PrivateIp"], status=node["Status"], software_configuration=node.get("Configuration") or {}, ) raise RuntimeError(f"Node {node_name} not found in cluster status!")