ansible/roles/slurm/files/scripts/util.py (1,587 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (C) SchedMD LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Iterable, List, Tuple, Optional import argparse import base64 import collections import hashlib import importlib.util import inspect import json import logging import logging.config import math import os import re import shelve import shlex import shutil import socket import subprocess import sys import tempfile from enum import Enum from collections import defaultdict, namedtuple from concurrent.futures import ThreadPoolExecutor, as_completed from contextlib import contextmanager from functools import lru_cache, reduce, wraps from itertools import chain, compress, islice from pathlib import Path from time import sleep, time import slurm_gcp_plugins required_modules = [ ("googleapiclient", "google-api-python-client"), ("requests", "requests"), ("yaml", "yaml"), ("addict", "addict"), ("httplib2", "httplib2"), ("google.cloud.tpu_v2", "google-cloud-tpu"), ] missing_imports = False can_tpu = True for module, name in required_modules: if importlib.util.find_spec(module) is None: if module == "google.cloud.tpu_v2": can_tpu = False print( f"WARNING: Missing Python module '{module} (pip:{name})', TPU support will not work." ) else: missing_imports = True print(f"ERROR: Missing Python module '{module} (pip:{name})'") if missing_imports: print("Aborting due to missing Python modules") exit(1) import google.auth # noqa: E402 from google.oauth2 import service_account # noqa: E402 import googleapiclient.discovery # noqa: E402 import google_auth_httplib2 # noqa: E402 from googleapiclient.http import set_user_agent # noqa: E402 from google.api_core.client_options import ClientOptions # noqa: E402 import httplib2 # noqa: E402 if can_tpu: from google.cloud import tpu_v2 as tpu # noqa: E402 import google.api_core.exceptions as gExceptions # noqa: E402 from requests import get as get_url # noqa: E402 from requests.exceptions import RequestException # noqa: E402 import yaml # noqa: E402 from addict import Dict as NSDict # noqa: E402 optional_modules = [ ("google.cloud.secretmanager", "google-cloud-secret-manager"), ] for module, name in optional_modules: if importlib.util.find_spec(module) is None: print(f"WARNING: Missing Python module '{module}' (pip:{name}) ") USER_AGENT = "Slurm_GCP_Scripts/1.5 (GPN:SchedMD)" ENV_CONFIG_YAML = os.getenv("SLURM_CONFIG_YAML") if ENV_CONFIG_YAML: CONFIG_FILE = Path(ENV_CONFIG_YAML) else: CONFIG_FILE = Path(__file__).with_name("config.yaml") API_REQ_LIMIT = 2000 URI_REGEX = r"[a-z]([-a-z0-9]*[a-z0-9])?" def mkdirp(path: Path) -> None: path.mkdir(parents=True, exist_ok=True) scripts_dir = next( p for p in (Path(__file__).parent, Path("/slurm/scripts")) if p.is_dir() ) # readily available compute api handle compute = None # slurm-gcp config object, could be empty if not available cfg = NSDict() # caching Lookup object lkp = None # load all directories as Paths into a dict-like namespace dirs = NSDict( { n: Path(p) for n, p in dict.items( { "home": "/home", "apps": "/opt/apps", "slurm": "/slurm", "scripts": scripts_dir, "custom_scripts": "/slurm/custom_scripts", "munge": "/etc/munge", "secdisk": "/mnt/disks/sec", "log": "/var/log/slurm", } ) } ) slurmdirs = NSDict( { n: Path(p) for n, p in dict.items( { "prefix": "/usr/local", "etc": "/usr/local/etc/slurm", "state": "/var/spool/slurm", } ) } ) yaml.SafeDumper.yaml_representers[ None ] = lambda self, data: yaml.representer.SafeRepresenter.represent_str(self, str(data)) class ApiEndpoint(Enum): COMPUTE = "compute" BQ = "bq" STORAGE = "storage" TPU = "tpu" SECRET = "secret_manager" @lru_cache(maxsize=1) def default_credentials(): return google.auth.default()[0] @lru_cache(maxsize=1) def authentication_project(): return google.auth.default()[1] DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" def universe_domain() -> str: try: return instance_metadata("attributes/universe_domain") except Exception: return DEFAULT_UNIVERSE_DOMAIN def endpoint_version(api: ApiEndpoint) -> Optional[str]: if api and api.value in lkp.endpoint_versions: return lkp.endpoint_versions[api.value] return None @lru_cache(maxsize=1) def get_credentials() -> Optional[service_account.Credentials]: """Get credentials for service account""" key_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") if key_path is not None: credentials = service_account.Credentials.from_service_account_file( key_path, scopes=[f"https://www.{universe_domain()}/auth/cloud-platform"] ) else: credentials = default_credentials() return credentials def create_client_options(api: ApiEndpoint = None) -> ClientOptions: """Create client options for cloud endpoints""" ver = endpoint_version(api) ud = universe_domain() options = {} if ud and ud != DEFAULT_UNIVERSE_DOMAIN: options["universe_domain"] = ud if ver: options["api_endpoint"] = f"https://{api.value}.{ud}/{ver}/" co = ClientOptions(**options) log.debug(f"Using ClientOptions = {co} for API: {api.value}") return co class LogFormatter(logging.Formatter): """adds logging flags to the levelname in log records""" def format(self, record): new_fmt = self._fmt flag = getattr(record, "flag", None) if flag is not None: start, level, end = new_fmt.partition("%(levelname)s") if level: new_fmt = f"{start}{level}(%(flag)s){end}" # insert function name if record level is DEBUG if record.levelno < logging.INFO: prefix, msg, suffix = new_fmt.partition("%(message)s") new_fmt = f"{prefix}%(funcName)s: {msg}{suffix}" self._style._fmt = new_fmt return super().format(record) class FlagLogAdapter(logging.LoggerAdapter): """creates log adapters that add a flag to the log record, allowing it to be filtered""" def __init__(self, logger, flag, extra=None): if extra is None: extra = {} self.flag = flag super().__init__(logger, extra) @property def enabled(self): return cfg.extra_logging_flags.get(self.flag, False) def process(self, msg, kwargs): extra = kwargs.setdefault("extra", {}) extra.update(self.extra) extra["flag"] = self.flag return msg, kwargs logging.basicConfig(level=logging.INFO, stream=sys.stdout) log = logging.getLogger(__name__) logging_flags = [ "trace_api", "subproc", "hostlists", ] log_trace_api = FlagLogAdapter(log, "trace_api") log_subproc = FlagLogAdapter(log, "subproc") log_hostlists = FlagLogAdapter(log, "hostlists") def access_secret_version(project_id, secret_id, version_id="latest"): """ Access the payload for the given secret version if one exists. The version can be a version number as a string (e.g. "5") or an alias (e.g. "latest"). """ from google.cloud import secretmanager from google.api_core import exceptions co = create_client_options(ApiEndpoint.SECRET) client = secretmanager.SecretManagerServiceClient(client_options=co) name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}" try: response = client.access_secret_version(request={"name": name}) log.debug(f"Secret '{name}' was found.") payload = response.payload.data.decode("UTF-8") except exceptions.NotFound: log.debug(f"Secret '{name}' was not found!") payload = None return payload def parse_self_link(self_link: str): """Parse a selfLink url, extracting all useful values https://.../v1/projects/<project>/regions/<region>/... {'project': <project>, 'region': <region>, ...} can also extract zone, instance (name), image, etc """ link_patt = re.compile(r"(?P<key>[^\/\s]+)s\/(?P<value>[^\s\/]+)") return NSDict(link_patt.findall(self_link)) def parse_bucket_uri(uri: str): """ Parse a bucket url E.g. gs://<bucket_name>/<path> """ pattern = re.compile(r"gs://(?P<bucket>[^/\s]+)/(?P<path>([^/\s]+)(/[^/\s]+)*)") matches = pattern.match(uri) return matches.group("bucket"), matches.group("path") def trim_self_link(link: str): """get resource name from self link url, eg. https://.../v1/projects/<project>/regions/<region> -> <region> """ try: return link[link.rindex("/") + 1 :] except ValueError: raise Exception(f"'/' not found, not a self link: '{link}' ") def execute_with_futures(func, seq): with ThreadPoolExecutor() as exe: futures = [] for i in seq: future = exe.submit(func, i) futures.append(future) for future in as_completed(futures): result = future.exception() if result is not None: raise result def map_with_futures(func, seq): with ThreadPoolExecutor() as exe: futures = [] for i in seq: future = exe.submit(func, i) futures.append(future) for future in futures: # Will be result or raise Exception res = None try: res = future.result() except Exception as e: res = e yield res def blob_get(file, project=None): from google.cloud import storage if project is None: project = lkp.project uri = instance_metadata("attributes/slurm_bucket_path") bucket_name, path = parse_bucket_uri(uri) blob_name = f"{path}/{file}" co = create_client_options(ApiEndpoint.STORAGE) storage_client = storage.Client(project=project, client_options=co) return storage_client.get_bucket(bucket_name).blob(blob_name) def blob_list(prefix="", delimiter=None, project=None): from google.cloud import storage if project is None: project = lkp.project uri = instance_metadata("attributes/slurm_bucket_path") bucket_name, path = parse_bucket_uri(uri) blob_prefix = f"{path}/{prefix}" co = create_client_options(ApiEndpoint.STORAGE) storage_client = storage.Client(project=project, client_options=co) # Note: The call returns a response only when the iterator is consumed. blobs = storage_client.list_blobs( bucket_name, prefix=blob_prefix, delimiter=delimiter ) return [blob for blob in blobs] def _hash_file(fullpath): with open(fullpath, "rb") as f: file_hash = hashlib.md5() chunk = f.read(8192) while chunk: file_hash.update(chunk) chunk = f.read(8192) return base64.b64encode(file_hash.digest()).decode("utf-8") def install_custom_scripts(check_hash=False): """download custom scripts from gcs bucket""" compute_tokens = ["compute", "prolog", "epilog"] if lkp.instance_role == "compute": try: compute_tokens.append(f"nodeset-{lkp.node_nodeset_name()}") except Exception as e: log.error(f"Failed to lookup nodeset: {e}") prefix_tokens = dict.get( { "login": ["login"], "compute": compute_tokens, "controller": ["controller", "prolog", "epilog"], }, lkp.instance_role, [], ) prefixes = [f"slurm-{tok}-script" for tok in prefix_tokens] blobs = list(chain.from_iterable(blob_list(prefix=p) for p in prefixes)) script_pattern = re.compile(r"slurm-(?P<path>\S+)-script-(?P<name>\S+)") for blob in blobs: m = script_pattern.match(Path(blob.name).name) if not m: log.warning(f"found blob that doesn't match expected pattern: {blob.name}") continue path_parts = m["path"].split("-") path_parts[0] += ".d" stem, _, ext = m["name"].rpartition("_") filename = ".".join((stem, ext)) path = Path(*path_parts, filename) fullpath = (dirs.custom_scripts / path).resolve() mkdirp(fullpath.parent) for par in path.parents: chown_slurm(dirs.custom_scripts / par) need_update = True if check_hash and fullpath.exists(): need_update = _hash_file(fullpath) != blob.md5_hash if need_update: log.info(f"installing custom script: {path} from {blob.name}") with fullpath.open("wb") as f: blob.download_to_file(f) chown_slurm(fullpath, mode=0o755) def reservation_resource_policies(reservation): """ Inspects reservation object, returns list of resource policies names. Converts policy URLs to names, e.g.: projects/111111/regions/us-central1/resourcePolicies/zebra -> zebra """ return [u.split("/")[-1] for u in reservation.get("resourcePolicies", {}).values()] def compute_service(credentials=None, user_agent=USER_AGENT, version="beta"): """Make thread-safe compute service handle creates a new Http for each request """ credentials = get_credentials() def build_request(http, *args, **kwargs): new_http = httplib2.Http() if user_agent is not None: new_http = set_user_agent(new_http, user_agent) if credentials is not None: new_http = google_auth_httplib2.AuthorizedHttp(credentials, http=new_http) return googleapiclient.http.HttpRequest(new_http, *args, **kwargs) ver = endpoint_version(ApiEndpoint.COMPUTE) disc_url = googleapiclient.discovery.DISCOVERY_URI if ver: version = ver disc_url = disc_url.replace(DEFAULT_UNIVERSE_DOMAIN, universe_domain()) log.debug(f"Using version={version} of Google Compute Engine API") return googleapiclient.discovery.build( "compute", version, requestBuilder=build_request, credentials=credentials, discoveryServiceUrl=disc_url, ) def load_config_data(config): """load dict-like data into a config object""" cfg = NSDict(config) if not cfg.slurm_log_dir: cfg.slurm_log_dir = dirs.log if not cfg.slurm_bin_dir: cfg.slurm_bin_dir = slurmdirs.prefix / "bin" if not cfg.slurm_control_host: cfg.slurm_control_host = f"{cfg.slurm_cluster_name}-controller" if not cfg.slurm_control_host_port: cfg.slurm_control_host_port = "6820-6830" if not cfg.munge_mount: # NOTE: should only happen with cloud controller cfg.munge_mount = NSDict( { "server_ip": cfg.slurm_control_addr or cfg.slurm_control_host, "remote_mount": "/etc/munge", "fs_type": "nfs", "mount_options": "defaults,hard,intr,_netdev", } ) if not cfg.enable_debug_logging and isinstance(cfg.enable_debug_logging, NSDict): cfg.enable_debug_logging = False cfg.extra_logging_flags = NSDict( {flag: cfg.extra_logging_flags.get(flag, False) for flag in logging_flags} ) return cfg def new_config(config): """initialize a new config object necessary defaults are handled here """ cfg = load_config_data(config) network_storage_iter = filter( None, ( *cfg.network_storage, *cfg.login_network_storage, *chain.from_iterable(ns.network_storage for ns in cfg.nodeset.values()), *chain.from_iterable(ns.network_storage for ns in cfg.nodeset_dyn.values()), *chain.from_iterable(ns.network_storage for ns in cfg.nodeset_tpu.values()), ), ) for netstore in network_storage_iter: if netstore != "gcsfuse" and ( netstore.server_ip is None or netstore.server_ip == "$controller" ): netstore.server_ip = cfg.slurm_control_host return cfg def fetch_config_yaml(): """Fetch config.yaml from bucket""" config_yaml = blob_get("config.yaml").download_as_text() cfg = new_config(yaml.safe_load(config_yaml)) return cfg def fetch_config_yaml_md5(): """Fetch config.yaml blob md5 from bucket""" import hashlib blob = blob_get("config.yaml") blob.reload() # Populate blob with metadata hash_str = str(blob.md5_hash).encode(encoding="utf-8") return hashlib.md5(hash_str) def load_config_file(path): """load config from file""" content = None try: content = yaml.safe_load(Path(path).read_text()) except FileNotFoundError: log.warning(f"config file not found: {path}") return NSDict() return load_config_data(content) def save_config(cfg, path): """save given config to file at path""" Path(path).write_text(yaml.dump(cfg, Dumper=Dumper)) def filter_logging_flags(record): """logging filter for flags if there are no flags, always pass. If there are flags, only pass if a flag matches an enabled flag in cfg.extra_logging_flags""" flag = getattr(record, "flag", None) if flag is None: return True return cfg.extra_logging_flags.get(flag, False) def owned_file_handler(filename): """create file handler""" if filename is None: return None chown_slurm(filename) return logging.handlers.WatchedFileHandler(filename, delay=True) def config_root_logger(caller_logger, level="DEBUG", stdout=True, logfile=None): """configure the root logger, disabling all existing loggers""" handlers = list(compress(("stdout_handler", "file_handler"), (stdout, logfile))) config = { "version": 1, "disable_existing_loggers": True, "formatters": { "standard": { "()": LogFormatter, "fmt": "%(levelname)s: %(message)s", }, "stamp": { "()": LogFormatter, "fmt": "%(asctime)s %(levelname)s: %(message)s", }, }, "filters": { "logging_flags": {"()": lambda: filter_logging_flags}, }, "handlers": { "stdout_handler": { "level": logging.DEBUG, "formatter": "standard", "class": "logging.StreamHandler", "stream": sys.stdout, "filters": ["logging_flags"], }, "file_handler": { "()": owned_file_handler, "level": logging.DEBUG, "formatter": "stamp", "filters": ["logging_flags"], "filename": logfile, }, }, "root": { "handlers": handlers, "level": level, }, } if not logfile: del config["handlers"]["file_handler"] logging.config.dictConfig(config) loggers = ( __name__, "resume", "suspend", "slurmsync", "setup", caller_logger, ) for logger in map(logging.getLogger, loggers): logger.disabled = False def log_api_request(request): """log.trace info about a compute API request""" if log_trace_api.enabled: # output the whole request object as pretty yaml # the body is nested json, so load it as well rep = json.loads(request.to_json()) if rep.get("body", None) is not None: rep["body"] = json.loads(rep["body"]) pretty_req = yaml.safe_dump(rep).rstrip() # label log message with the calling function log_trace_api.debug(f"{inspect.stack()[1].function}:\n{pretty_req}") def handle_exception(exc_type, exc_value, exc_trace): """log exceptions other than KeyboardInterrupt""" # TODO does this work? if not issubclass(exc_type, KeyboardInterrupt): log.exception("Fatal exception", exc_info=(exc_type, exc_value, exc_trace)) sys.__excepthook__(exc_type, exc_value, exc_trace) def run( args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False, timeout=None, check=True, universal_newlines=True, **kwargs, ): """Wrapper for subprocess.run() with convenient defaults""" if isinstance(args, list): args = list(filter(lambda x: x is not None, args)) args = " ".join(args) if not shell and isinstance(args, str): args = shlex.split(args) log_subproc.debug(f"run: {args}") result = subprocess.run( args, stdout=stdout, stderr=stderr, shell=shell, timeout=timeout, check=check, universal_newlines=universal_newlines, **kwargs, ) return result def spawn(cmd, quiet=False, shell=False, **kwargs): """nonblocking spawn of subprocess""" if not quiet: log_subproc.debug(f"spawn: {cmd}") args = cmd if shell else shlex.split(cmd) return subprocess.Popen(args, shell=shell, **kwargs) def chown_slurm(path: Path, mode=None) -> None: if path.exists(): if mode: path.chmod(mode) else: mkdirp(path.parent) if mode: path.touch(mode=mode) else: path.touch() try: shutil.chown(path, user="slurm", group="slurm") except LookupError: log.warning(f"User 'slurm' does not exist. Cannot 'chown slurm:slurm {path}'.") except PermissionError: log.warning(f"Not authorized to 'chown slurm:slurm {path}'.") except Exception as err: log.error(err) @contextmanager def cd(path): """Change working directory for context""" prev = Path.cwd() os.chdir(path) try: yield finally: os.chdir(prev) def cached_property(f): return property(lru_cache()(f)) def retry(max_retries: int, init_wait_time: float, warn_msg: str, exc_type: Exception): """Retries functions that raises the exception exc_type. Retry time is increased by a factor of two for every iteration. Args: max_retries (int): Maximum number of retries init_wait_time (float): Initial wait time in secs warn_msg (str): Message to print during retries exc_type (Exception): Exception type to check for """ if max_retries <= 0: raise ValueError("Incorrect value for max_retries, must be >= 1") if init_wait_time <= 0.0: raise ValueError("Invalid value for init_wait_time, must be > 0.0") def decorator(f): @wraps(f) def wrapper(*args, **kwargs): retry = 0 secs = init_wait_time captured_exc = None while retry < max_retries: try: return f(*args, **kwargs) except exc_type as e: captured_exc = e log.warn(f"{warn_msg}, retrying in {secs}") sleep(secs) retry += 1 secs *= 2 raise captured_exc return wrapper return decorator def separate(pred, coll): """filter into 2 lists based on pred returning True or False returns ([False], [True]) """ return reduce(lambda acc, el: acc[pred(el)].append(el) or acc, coll, ([], [])) def chunked(iterable, n=API_REQ_LIMIT): """group iterator into chunks of max size n""" it = iter(iterable) while True: chunk = list(islice(it, n)) if not chunk: return yield chunk def groupby_unsorted(seq, key): indices = defaultdict(list) for i, el in enumerate(seq): indices[key(el)].append(i) for k, idxs in indices.items(): yield k, (seq[i] for i in idxs) @lru_cache(maxsize=32) def find_ratio(a, n, s, r0=None): """given the start (a), count (n), and sum (s), find the ratio required""" if n == 2: return s / a - 1 an = a * n if n == 1 or s == an: return 1 if r0 is None: # we only need to know which side of 1 to guess, and the iteration will work r0 = 1.1 if an < s else 0.9 # geometric sum formula def f(r): return a * (1 - r**n) / (1 - r) - s # derivative of f def df(r): rm1 = r - 1 rn = r**n return (a * (rn * (n * rm1 - r) + r)) / (r * rm1**2) MIN_DR = 0.0001 # negligible change r = r0 # print(f"r(0)={r0}") MAX_TRIES = 64 for i in range(1, MAX_TRIES + 1): try: dr = f(r) / df(r) except ZeroDivisionError: log.error(f"Failed to find ratio due to zero division! Returning r={r0}") return r0 r = r - dr # print(f"r({i})={r}") # if the change in r is small, we are close enough if abs(dr) < MIN_DR: break else: log.error(f"Could not find ratio after {MAX_TRIES}! Returning r={r0}") return r0 return r def backoff_delay(start, timeout=None, ratio=None, count: int = 0): """generates `count` waits starting at `start` sum of waits is `timeout` or each one is `ratio` bigger than the last the last wait is always 0""" # timeout or ratio must be set but not both assert (timeout is None) ^ (ratio is None) assert ratio is None or ratio > 0 assert timeout is None or timeout >= start assert (count > 1 or timeout is not None) and isinstance(count, int) assert start > 0 if count == 0: # Equation for auto-count is tuned to have a max of # ~int(timeout) counts with a start wait of <0.01. # Increasing start wait decreases count eg. # backoff_delay(10, timeout=60) -> count = 5 count = int( (timeout / ((start + 0.05) ** (1 / 2)) + 2) // math.log(timeout + 2) ) yield start # if ratio is set: # timeout = start * (1 - ratio**(count - 1)) / (1 - ratio) if ratio is None: ratio = find_ratio(start, count - 1, timeout) wait = start # we have start and 0, so we only need to generate count - 2 for _ in range(count - 2): wait *= ratio yield wait yield 0 return ROOT_URL = "http://metadata.google.internal/computeMetadata/v1" def get_metadata(path, root=ROOT_URL): """Get metadata relative to metadata/computeMetadata/v1""" HEADERS = {"Metadata-Flavor": "Google"} url = f"{root}/{path}" try: resp = get_url(url, headers=HEADERS) resp.raise_for_status() return resp.text except RequestException: log.debug(f"metadata not found ({url})") raise Exception(f"failed to get_metadata from {url}") @lru_cache(maxsize=None) def instance_metadata(path): """Get instance metadata""" return get_metadata(path, root=f"{ROOT_URL}/instance") @lru_cache(maxsize=None) def project_metadata(key): """Get project metadata project/attributes/<slurm_cluster_name>-<path>""" return get_metadata(key, root=f"{ROOT_URL}/project/attributes") def bucket_blob_download(bucket_name, blob_name): from google.cloud import storage co = create_client_options("storage") storage_client = storage.Client(client_options=co) bucket = storage_client.bucket(bucket_name) blob = bucket.blob(blob_name) contents = None with tempfile.NamedTemporaryFile(mode="w+t") as tmp: blob.download_to_filename(tmp.name) with open(tmp.name, "r") as f: contents = f.read() return contents def natural_sort(text): def atoi(text): return int(text) if text.isdigit() else text return [atoi(w) for w in re.split(r"(\d+)", text)] # TODO: replace with to_hostlist_fast def to_hostlist(nodenames) -> str: """make hostlist from list of node names""" # use tmp file because list could be large tmp_file = tempfile.NamedTemporaryFile(mode="w+t", delete=False) tmp_file.writelines("\n".join(sorted(nodenames, key=natural_sort))) tmp_file.close() hostlist = run(f"{lkp.scontrol} show hostlist {tmp_file.name}").stdout.rstrip() log_hostlists.debug(f"hostlist({len(nodenames)}): {hostlist}".format(hostlist)) os.remove(tmp_file.name) return hostlist def to_hostlist_fast(names: Iterable[str]) -> str: """ Fast implementation of to_hostlist that doesn't invoke `scontrol` IMPORTANT: * Acts as `scontrol show hostlistsorted`, i.e. original order is not preserved * Achieves worse compression than `to_hostlist` for some cases """ pref = defaultdict(list) tokenizer = re.compile(r"^(.*?)(\d*)$") for name in filter(None, names): p, s = tokenizer.match(name).groups() pref[p].append(s) def _compress_suffixes(ss: List[str]) -> List[str]: cur, res = None, [] def cur_repr(): nums, strs = cur if nums[0] == nums[1]: return strs[0] return f"{strs[0]}-{strs[1]}" for s in sorted(ss, key=int): n = int(s) if cur is None: cur = ((n, n), (s, s)) continue nums, strs = cur if n == nums[1] + 1: cur = ((nums[0], n), (strs[0], s)) else: res.append(cur_repr()) cur = ((n, n), (s, s)) if cur: res.append(cur_repr()) return res res = [] for p in sorted(pref.keys()): sl = defaultdict(list) for s in pref[p]: sl[len(s)].append(s) cs = [] for ln in sorted(sl.keys()): if ln == 0: res.append(p) else: cs.extend(_compress_suffixes(sl[ln])) if not cs: continue if len(cs) == 1 and "-" not in cs[0]: res.append(f"{p}{cs[0]}") else: res.append(f"{p}[{','.join(cs)}]") return ",".join(res) def part_is_tpu(part): """check if partition with name part contains a nodeset of type tpu""" return len(lkp.cfg.partitions[part].partition_nodeset_tpu) > 0 def get_vmcount_of_tpu_part(part): res = 0 for ns in lkp.cfg.partitions[part].partition_nodeset_tpu: tpu_obj = TPU(lkp.cfg.nodeset_tpu[ns]) if res == 0: res = tpu_obj.vmcount else: if res != tpu_obj.vmcount: # this should not happen, that in the same partition there are different vmcount nodesets return -1 return res def to_hostnames(nodelist: str) -> List[str]: """make list of hostnames from hostlist expression""" if not nodelist: return [] # avoid degenerate invocation of scontrol if isinstance(nodelist, str): hostlist = nodelist else: hostlist = ",".join(nodelist) hostnames = run(f"{lkp.scontrol} show hostnames {hostlist}").stdout.splitlines() log_hostlists.debug(f"hostnames({len(hostnames)}) from {hostlist}") return hostnames def retry_exception(exc): """return true for exceptions that should always be retried""" retry_errors = ( "Rate Limit Exceeded", "Quota Exceeded", ) return any(e in str(exc) for e in retry_errors) def ensure_execute(request): """Handle rate limits and socket time outs""" for retry, wait in enumerate(backoff_delay(0.5, timeout=10 * 60, count=20)): try: return request.execute() except googleapiclient.errors.HttpError as e: if retry_exception(e): log.error(f"retry:{retry} '{e}'") sleep(wait) continue raise except socket.timeout as e: # socket timed out, try again log.debug(e) except Exception as e: log.error(e, exc_info=True) raise break def batch_execute(requests, retry_cb=None, log_err=log.error): """execute list or dict<req_id, request> as batch requests retry if retry_cb returns true """ compute = globals()["compute"] BATCH_LIMIT = 1000 if not isinstance(requests, dict): requests = {str(k): v for k, v in enumerate(requests)} # rid generated here done = {} failed = {} timestamps = [] rate_limited = False def batch_callback(rid, resp, exc): nonlocal rate_limited if exc is not None: log_err(f"compute request exception {rid}: {exc}") if retry_exception(exc): rate_limited = True else: req = requests.pop(rid) failed[rid] = (req, exc) else: # if retry_cb is set, don't move to done until it returns false if retry_cb is None or not retry_cb(resp): requests.pop(rid) done[rid] = resp def batch_request(reqs): batch = compute.new_batch_http_request(callback=batch_callback) for rid, req in reqs: batch.add(req, request_id=rid) return batch while requests: if timestamps: timestamps = [stamp for stamp in timestamps if stamp > time()] if rate_limited and timestamps: stamp = next(iter(timestamps)) sleep(max(stamp - time(), 0)) rate_limited = False # up to API_REQ_LIMIT (2000) requests # in chunks of up to BATCH_LIMIT (1000) batches = [ batch_request(chunk) for chunk in chunked(islice(requests.items(), API_REQ_LIMIT), BATCH_LIMIT) ] timestamps.append(time() + 100) with ThreadPoolExecutor() as exe: futures = [] for batch in batches: future = exe.submit(ensure_execute, batch) futures.append(future) for future in futures: result = future.exception() if result is not None: raise result return done, failed def wait_request(operation, project=None, compute=None): """makes the appropriate wait request for a given operation""" if not compute: compute = globals()["compute"] if project is None: project = lkp.project if "zone" in operation: req = compute.zoneOperations().wait( project=project, zone=trim_self_link(operation["zone"]), operation=operation["name"], ) elif "region" in operation: req = compute.regionOperations().wait( project=project, region=trim_self_link(operation["region"]), operation=operation["name"], ) else: req = compute.globalOperations().wait( project=project, operation=operation["name"] ) return req def wait_for_operation(operation, project=None, compute=None): """wait for given operation""" if not compute: compute = globals()["compute"] if project is None: project = parse_self_link(operation["selfLink"]).project wait_req = wait_request(operation, project=project, compute=compute) while True: result = ensure_execute(wait_req) if result["status"] == "DONE": log_errors = " with errors" if "error" in result else "" log.debug( f"operation complete{log_errors}: type={result['operationType']}, name={result['name']}" ) return result def wait_for_operations(operations, project=None, compute=None): if not compute: compute = globals()["compute"] return [ wait_for_operation(op, project=project, compute=compute) for op in operations ] def get_filtered_operations( op_filter, zone=None, region=None, only_global=False, project=None, compute=None, ): """get list of operations associated with group id""" if not compute: compute = globals()["compute"] if project is None: project = lkp.project operations = [] def get_aggregated_operations(items): # items is a dict of location key to value: dict(operations=<list of operations>) or an empty dict operations.extend( chain.from_iterable( ops["operations"] for ops in items.values() if "operations" in ops ) ) def get_list_operations(items): operations.extend(items) handle_items = get_list_operations if only_global: act = compute.globalOperations() op = act.list(project=project, filter=op_filter) nxt = act.list_next elif zone is not None: act = compute.zoneOperations() op = act.list(project=project, zone=zone, filter=op_filter) nxt = act.list_next elif region is not None: act = compute.regionOperations() op = act.list(project=project, region=region, filter=op_filter) nxt = act.list_next else: act = compute.globalOperations() op = act.aggregatedList( project=project, filter=op_filter, fields="items.*.operations,nextPageToken" ) nxt = act.aggregatedList_next handle_items = get_aggregated_operations while op is not None: result = ensure_execute(op) handle_items(result["items"]) op = nxt(op, result) return operations def get_insert_operations(group_ids, flt=None, project=None, compute=None): """get all insert operations from a list of operationGroupId""" if not compute: compute = globals()["compute"] if project is None: project = lkp.project if isinstance(group_ids, str): group_ids = group_ids.split(",") filters = [ "operationType=insert", flt, " OR ".join(f"(operationGroupId={id})" for id in group_ids), ] return get_filtered_operations(" AND ".join(f"({f})" for f in filters if f)) def machine_type_sockets(template): pattern = re.compile("^(?P<family>[^-]+)") m = pattern.match(template.machineType) if not m: raise Exception(f"template {template} does not match expected regex") family = m.group("family") guestCpus: int = int(template.machine_info.guestCpus) socket_count = dict.get( { "h3": 2, "c2d": 2 if guestCpus > 56 else 1, "a3": 2, }, family, 1, # assume 1 socket for all other families ) return socket_count def isSmt(template): machineType: str = template.machineType guestCpus: int = int(template.machine_info.guestCpus) pattern = re.compile("^(?P<family>[^-]+)") matches = pattern.match(machineType) machineTypeFamily: str = matches["family"] # https://cloud.google.com/compute/docs/cpu-platforms noSmtFamily = [ "t2a", "t2d", "h3", ] if machineTypeFamily in noSmtFamily: return False elif guestCpus == 1: return False return True def getThreadsPerCore(template): threadsPerCore: int = template.advancedMachineFeatures.threadsPerCore if not isSmt(template): return 1 elif threadsPerCore: return threadsPerCore else: return 2 @retry( max_retries=9, init_wait_time=1, warn_msg="Temporary failure in name resolution", exc_type=socket.gaierror, ) def host_lookup(host_name: str) -> str: return socket.gethostbyname(host_name) class Dumper(yaml.SafeDumper): """Add representers for pathlib.Path and NSDict for yaml serialization""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.add_representer(NSDict, self.represent_nsdict) self.add_multi_representer(Path, self.represent_path) @staticmethod def represent_nsdict(dumper, data): return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) @staticmethod def represent_path(dumper, path): return dumper.represent_scalar("tag:yaml.org,2002:str", str(path)) class TPU: """Class for handling the TPU-vm nodes""" if can_tpu: State = tpu.types.cloud_tpu.Node.State TPUS_PER_VM = 4 __expected_states = { "create": State.READY, "start": State.READY, "stop": State.STOPPED, } __tpu_version_mapping = { "V2": tpu.AcceleratorConfig().Type.V2, "V3": tpu.AcceleratorConfig().Type.V3, "V4": tpu.AcceleratorConfig().Type.V4, } def __init__(self, nodeset): if not can_tpu: raise Exception("TPU pip package not installed") self._nodeset = nodeset self._parent = f"projects/{lkp.project}/locations/{nodeset.zone}" co = create_client_options(ApiEndpoint.TPU) self._client = tpu.TpuClient(client_options=co) self.data_disks = [] for data_disk in nodeset.data_disks: ad = tpu.AttachedDisk() ad.source_disk = data_disk ad.mode = tpu.AttachedDisk.DiskMode.DISK_MODE_UNSPECIFIED self.data_disks.append(ad) ns_ac = nodeset.accelerator_config if ns_ac.topology != "" and ns_ac.version != "": ac = tpu.AcceleratorConfig() ac.topology = ns_ac.topology ac.type_ = self.__tpu_version_mapping[ns_ac.version] self.ac = ac else: req = tpu.GetAcceleratorTypeRequest( name=f"{self._parent}/acceleratorTypes/{nodeset.node_type}" ) self.ac = self._client.get_accelerator_type(req).accelerator_configs[0] self.vmcount = self.__calc_vm_from_topology(self.ac.topology) @property def nodeset(self): return self._nodeset @property def preserve_tpu(self): return self._nodeset.preserve_tpu @property def node_type(self): return self._nodeset.node_type @property def tf_version(self): return self._nodeset.tf_version @property def enable_public_ip(self): return self._nodeset.enable_public_ip @property def preemptible(self): return self._nodeset.preemptible @property def reserved(self): return self._nodeset.reserved @property def service_account(self): return self._nodeset.service_account @property def zone(self): return self._nodeset.zone def check_node_type(self): if self.node_type is None: return False try: request = tpu.GetAcceleratorTypeRequest( name=f"{self._parent}/acceleratorTypes/{self.node_type}" ) return self._client.get_accelerator_type(request=request) is not None except Exception: return False def check_tf_version(self): try: request = tpu.GetRuntimeVersionRequest( name=f"{self._parent}/runtimeVersions/{self.tf_version}" ) return self._client.get_runtime_version(request=request) is not None except Exception: return False def __calc_vm_from_topology(self, topology): topo = topology.split("x") tot = 1 for num in topo: tot = tot * int(num) return tot // self.TPUS_PER_VM def __check_resp(self, response, op_name): des_state = self.__expected_states.get(op_name) # If the state is not in the table just print the response if des_state is None: return False if response.__class__.__name__ != "Node": # If the response is not a node fail return False if response.state == des_state: return True return False def list_nodes(self): try: request = tpu.ListNodesRequest(parent=self._parent) res = self._client.list_nodes(request=request) except gExceptions.NotFound: res = None return res def list_node_names(self): return [node.name.split("/")[-1] for node in self.list_nodes()] def start_node(self, nodename): request = tpu.StartNodeRequest(name=f"{self._parent}/nodes/{nodename}") resp = self._client.start_node(request=request).result() return self.__check_resp(resp, "start") def stop_node(self, nodename): request = tpu.StopNodeRequest(name=f"{self._parent}/nodes/{nodename}") resp = self._client.stop_node(request=request).result() return self.__check_resp(resp, "stop") def get_node(self, nodename): try: request = tpu.GetNodeRequest(name=f"{self._parent}/nodes/{nodename}") res = self._client.get_node(request=request) except gExceptions.NotFound: res = None return res def _register_node(self, nodename, ip_addr): dns_name = socket.getnameinfo((ip_addr, 0), 0)[0] run( f"{lkp.scontrol} update nodename={nodename} nodeaddr={ip_addr} nodehostname={dns_name}" ) def create_node(self, nodename): if self.vmcount > 1 and not isinstance(nodename, list): log.error( f"Tried to create a {self.vmcount} node TPU on nodeset {self._nodeset.nodeset_name} but only received one nodename {nodename}" ) return False if self.vmcount > 1 and ( isinstance(nodename, list) and len(nodename) != self.vmcount ): log.error( f"Expected to receive a list of {self.vmcount} nodenames for TPU node creation in nodeset {self._nodeset.nodeset_name}, but received this list {nodename}" ) return False node = tpu.Node() node.accelerator_config = self.ac node.runtime_version = f"tpu-vm-tf-{self.tf_version}" startup_script = """ #!/bin/bash echo "startup script not found > /var/log/startup_error.log" """ with open( Path(cfg.slurm_scripts_dir or dirs.scripts) / "startup.sh", "r" ) as script: startup_script = script.read() if isinstance(nodename, list): node_id = nodename[0] slurm_names = [] wid = 0 for node_wid in nodename: slurm_names.append(f"WORKER_{wid}:{node_wid}") wid += 1 else: node_id = nodename slurm_names = [f"WORKER_0:{nodename}"] node.metadata = { "slurm_docker_image": self.nodeset.docker_image, "startup-script": startup_script, "slurm_instance_role": "compute", "slurm_cluster_name": lkp.cfg.slurm_cluster_name, "slurm_bucket_path": lkp.cfg.bucket_path, "slurm_names": ";".join(slurm_names), "universe_domain": universe_domain(), } node.tags = [lkp.cfg.slurm_cluster_name] if self.nodeset.service_account: node.service_account.email = self.nodeset.service_account.email node.service_account.scope = self.nodeset.service_account.scopes node.scheduling_config.preemptible = self.preemptible node.scheduling_config.reserved = self.reserved node.network_config.subnetwork = self.nodeset.subnetwork node.network_config.enable_external_ips = self.enable_public_ip if self.data_disks: node.data_disks = self.data_disks request = tpu.CreateNodeRequest(parent=self._parent, node=node, node_id=node_id) resp = self._client.create_node(request=request).result() if not self.__check_resp(resp, "create"): return False if isinstance(nodename, list): for node_id, net_endpoint in zip(nodename, resp.network_endpoints): self._register_node(node_id, net_endpoint.ip_address) else: ip_add = resp.network_endpoints[0].ip_address self._register_node(nodename, ip_add) return True def delete_node(self, nodename): request = tpu.DeleteNodeRequest(name=f"{self._parent}/nodes/{nodename}") try: resp = self._client.delete_node(request=request).result() if resp: return self.get_node(nodename=nodename) is None return False except gExceptions.NotFound: # log only error if vmcount is 1 as for other tpu vm count, this could be "phantom" nodes if self.vmcount == 1: log.error(f"Tpu single node {nodename} not found") else: # for the TPU nodes that consist in more than one vm, only the first node of the TPU a.k.a. the master node will # exist as real TPU nodes, so the other ones are expected to not be found, check the hostname of the node that has # not been found, and if it ends in 0, it means that is the master node and it should have been found, and in consequence # log an error nodehostname = yaml.safe_load( run(f"{lkp.scontrol} --yaml show node {nodename}").stdout.rstrip() )["nodes"][0]["hostname"] if nodehostname.split("-")[-1] == "0": log.error(f"TPU master node {nodename} not found") else: log.info(f"Deleted TPU 'phantom' node {nodename}") # If the node is not found it is tecnichally deleted, so return success. return True class Lookup: """Wrapper class for cached data access""" def __init__(self, cfg=None): self._cfg = cfg or NSDict() self.template_cache_path = Path(__file__).parent / "template_info.cache" @property def cfg(self): return self._cfg @property def project(self): return self.cfg.project or authentication_project() @property def control_addr(self): return self.cfg.slurm_control_addr @property def control_host(self): return self.cfg.slurm_control_host @cached_property def control_host_addr(self): return host_lookup(self.cfg.slurm_control_host) @property def control_host_port(self): return self.cfg.slurm_control_host_port @property def endpoint_versions(self): return self.cfg.endpoint_versions @property def scontrol(self): return Path(self.cfg.slurm_bin_dir if cfg else "") / "scontrol" @cached_property def instance_role(self): return instance_metadata("attributes/slurm_instance_role") @cached_property def instance_role_safe(self): try: role = self.instance_role except Exception as e: log.error(e) role = None return role @cached_property def compute(self): # TODO evaluate when we need to use google_app_cred_path if self.cfg.google_app_cred_path: os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = self.cfg.google_app_cred_path return compute_service() @cached_property def hostname(self): return socket.gethostname() @cached_property def hostname_fqdn(self): return socket.getfqdn() @cached_property def zone(self): return instance_metadata("zone") node_desc_regex = re.compile( r"^(?P<prefix>(?P<cluster>[^\s\-]+)-(?P<nodeset>\S+))-(?P<node>(?P<suffix>\w+)|(?P<range>\[[\d,-]+\]))$" ) @lru_cache(maxsize=None) def _node_desc(self, node_name): """Get parts from node name""" if not node_name: node_name = self.hostname # workaround below is for VMs whose hostname is FQDN node_name_short = node_name.split(".")[0] m = self.node_desc_regex.match(node_name_short) if not m: raise Exception(f"node name {node_name} is not valid") return m.groupdict() def node_prefix(self, node_name=None): return self._node_desc(node_name)["prefix"] def node_nodeset_name(self, node_name=None): return self._node_desc(node_name)["nodeset"] def node_nodeset(self, node_name=None): nodeset_name = self.node_nodeset_name(node_name) ns = self.cfg.nodeset.get(nodeset_name) if ns: return ns return self.cfg.nodeset_tpu.get(nodeset_name) def node_is_tpu(self, node_name=None): nodeset_name = self.node_nodeset_name(node_name) return self.cfg.nodeset_tpu.get(nodeset_name) is not None def node_is_dyn(self, node_name=None) -> bool: nodeset = self.node_nodeset_name(node_name) return self.cfg.nodeset_dyn.get(nodeset) is not None def chunk_tpu_nodes(self, tpu_nodes): model = tpu_nodes[0] tpu = TPU(self.node_nodeset(model)) return chunked(tpu_nodes, n=tpu.vmcount) def node_template(self, node_name=None): return self.node_nodeset(node_name).instance_template def node_template_info(self, node_name=None): return self.template_info(self.node_template(node_name)) def node_region(self, node_name=None): nodeset = self.node_nodeset(node_name) return parse_self_link(nodeset.subnetwork).region def nodeset_prefix(self, nodeset_name): return f"{self.cfg.slurm_cluster_name}-{nodeset_name}" def nodelist_range(self, nodeset_name: str, start: int, count: int) -> str: assert 0 <= start and 0 < count pref = self.nodeset_prefix(nodeset_name) if count == 1: return f"{pref}-{start}" return f"{pref}-[{start}-{start + count - 1}]" def static_dynamic_sizes(self, nodeset: object) -> int: return (nodeset.node_count_static or 0, nodeset.node_count_dynamic_max or 0) def nodelist(self, nodeset) -> str: cnt = sum(self.static_dynamic_sizes(nodeset)) if cnt == 0: return "" return self.nodelist_range(nodeset.nodeset_name, 0, cnt) def nodenames(self, nodeset) -> Tuple[Iterable[str], Iterable[str]]: pref = self.nodeset_prefix(nodeset.nodeset_name) s_count, d_count = self.static_dynamic_sizes(nodeset) return ( (f"{pref}-{i}" for i in range(s_count)), (f"{pref}-{i}" for i in range(s_count, s_count + d_count)), ) def power_managed_nodesets(self) -> Iterable[object]: return chain(self.cfg.nodeset.values(), self.cfg.nodeset_tpu.values()) def is_power_managed_node(self, node_name: str) -> bool: try: ns = self.node_nodeset(node_name) if ns is None: return False idx = int(self._node_desc(node_name)["suffix"]) return idx < sum(self.static_dynamic_sizes(ns)) except Exception: return False def is_static_node(self, node_name: str) -> bool: if not self.is_power_managed_node(node_name): return False idx = int(self._node_desc(node_name)["suffix"]) return idx < self.node_nodeset(node_name).node_count_static @lru_cache(maxsize=None) def slurm_nodes(self): StateTuple = namedtuple("StateTuple", "base,flags") def make_node_tuple(node_line): """turn node,state line to (node, StateTuple(state))""" # state flags include: CLOUD, COMPLETING, DRAIN, FAIL, POWERED_DOWN, # POWERING_DOWN node, fullstate = node_line.split(",") state = fullstate.split("+") state_tuple = StateTuple(state[0], set(state[1:])) return (node, state_tuple) cmd = ( f"{self.scontrol} show nodes | " r"grep -oP '^NodeName=\K(\S+)|\s+State=\K(\S+)' | " r"paste -sd',\n'" ) node_lines = run(cmd, shell=True).stdout.rstrip().splitlines() nodes = { node: state for node, state in map(make_node_tuple, node_lines) if "CLOUD" in state.flags or "DYNAMIC_NORM" in state.flags } return nodes def slurm_node(self, nodename): return self.slurm_nodes().get(nodename) @lru_cache(maxsize=1) def instances(self, project=None, slurm_cluster_name=None): slurm_cluster_name = slurm_cluster_name or self.cfg.slurm_cluster_name project = project or self.project instance_information_fields = [ "advancedMachineFeatures", "cpuPlatform", "creationTimestamp", "disks", "disks", "fingerprint", "guestAccelerators", "hostname", "id", "kind", "labelFingerprint", "labels", "lastStartTimestamp", "lastStopTimestamp", "lastSuspendedTimestamp", "machineType", "metadata", "name", "networkInterfaces", "resourceStatus", "scheduling", "selfLink", "serviceAccounts", "shieldedInstanceConfig", "shieldedInstanceIntegrityPolicy", "sourceMachineImage", "status", "statusMessage", "tags", "zone", # "deletionProtection", # "startRestricted", ] if lkp.cfg.enable_slurm_gcp_plugins: slurm_gcp_plugins.register_instance_information_fields( lkp=lkp, project=project, slurm_cluster_name=slurm_cluster_name, instance_information_fields=instance_information_fields, ) instance_information_fields = sorted(set(instance_information_fields)) instance_fields = ",".join(instance_information_fields) fields = f"items.zones.instances({instance_fields}),nextPageToken" flt = f"labels.slurm_cluster_name={slurm_cluster_name} AND name:{slurm_cluster_name}-*" act = self.compute.instances() op = act.aggregatedList(project=project, fields=fields, filter=flt) def properties(inst): """change instance properties to a preferred format""" inst["zone"] = trim_self_link(inst["zone"]) inst["machineType"] = trim_self_link(inst["machineType"]) # metadata is fetched as a dict of dicts like: # {'key': key, 'value': value}, kinda silly metadata = {i["key"]: i["value"] for i in inst["metadata"].get("items", [])} if "slurm_instance_role" not in metadata: return None inst["role"] = metadata["slurm_instance_role"] inst["metadata"] = metadata # del inst["metadata"] # no need to store all the metadata return NSDict(inst) instances = {} while op is not None: result = ensure_execute(op) instance_iter = ( (inst["name"], properties(inst)) for inst in chain.from_iterable( m["instances"] for m in result.get("items", {}).values() ) ) instances.update( {name: props for name, props in instance_iter if props is not None} ) op = act.aggregatedList_next(op, result) return instances def instance(self, instance_name, project=None, slurm_cluster_name=None): instances = self.instances( project=project, slurm_cluster_name=slurm_cluster_name ) return instances.get(instance_name) @lru_cache() def reservation(self, name: str, zone: str) -> object: """See https://cloud.google.com/compute/docs/reference/rest/v1/reservations""" try: _, project, _, short_name = name.split("/") except ValueError: raise ValueError( f"Invalid reservation name: '{name}', expected format is 'projects/PROJECT/reservations/NAME'" ) return ( self.compute.reservations() .get(project=project, zone=zone, reservation=short_name) .execute() ) @lru_cache(maxsize=1) def machine_types(self, project=None): project = project or self.project field_names = "name,zone,guestCpus,memoryMb,accelerators" fields = f"items.zones.machineTypes({field_names}),nextPageToken" machines = defaultdict(dict) act = self.compute.machineTypes() op = act.aggregatedList(project=project, fields=fields) while op is not None: result = ensure_execute(op) machine_iter = chain.from_iterable( m["machineTypes"] for m in result["items"].values() if "machineTypes" in m ) for machine in machine_iter: name = machine["name"] zone = machine["zone"] machines[name][zone] = machine op = act.aggregatedList_next(op, result) return machines def machine_type(self, machine_type, project=None, zone=None): """ """ custom_patt = re.compile( r"((?P<family>\w+)-)?custom-(?P<cpus>\d+)-(?P<mem>\d+)" ) custom_match = custom_patt.match(machine_type) if zone: project = project or self.project machine_info = ensure_execute( self.compute.machineTypes().get( project=project, zone=zone, machineType=machine_type ) ) elif custom_match is not None: groups = custom_match.groupdict() cpus, mem = (groups[k] for k in ["cpus", "mem"]) machine_info = { "guestCpus": int(cpus), "memoryMb": int(mem), } else: machines = self.machine_types(project=project) machine_info = next(iter(machines[machine_type].values()), None) if machine_info is None: raise Exception(f"machine type {machine_type} not found") return NSDict(machine_info) def template_machine_conf(self, template_link, project=None, zone=None): template = self.template_info(template_link) if not template.machineType: temp_name = trim_self_link(template_link) raise Exception(f"instance template {temp_name} has no machine type") template.machine_info = self.machine_type(template.machineType, zone=zone) machine = template.machine_info machine_conf = NSDict() machine_conf.boards = 1 # No information, assume 1 machine_conf.sockets = machine_type_sockets(template) # the value below for SocketsPerBoard must be type int machine_conf.sockets_per_board = machine_conf.sockets // machine_conf.boards machine_conf.threads_per_core = 1 _div = 2 if getThreadsPerCore(template) == 1 else 1 machine_conf.cpus = ( int(machine.guestCpus / _div) if isSmt(template) else machine.guestCpus ) machine_conf.cores_per_socket = int(machine_conf.cpus / machine_conf.sockets) # Because the actual memory on the host will be different than # what is configured (e.g. kernel will take it). From # experiments, about 16 MB per GB are used (plus about 400 MB # buffer for the first couple of GB's. Using 30 MB to be safe. gb = machine.memoryMb // 1024 machine_conf.memory = machine.memoryMb - (400 + (30 * gb)) return machine_conf @contextmanager def template_cache(self, writeback=False): flag = "c" if writeback else "r" err = None for wait in backoff_delay(0.125, timeout=60, count=20): try: cache = shelve.open( str(self.template_cache_path), flag=flag, writeback=writeback ) break except OSError as e: err = e log.debug(f"Failed to access template info cache: {e}") sleep(wait) continue else: # reached max_count of waits raise Exception(f"Failed to access cache file. latest error: {err}") try: yield cache finally: cache.close() @lru_cache(maxsize=None) def template_info(self, template_link, project=None): project = project or self.project template_name = trim_self_link(template_link) # split read and write access to minimize write-lock. This might be a # bit slower? TODO measure if self.template_cache_path.exists(): with self.template_cache() as cache: if template_name in cache: return NSDict(cache[template_name]) template = ensure_execute( self.compute.instanceTemplates().get( project=project, instanceTemplate=template_name ) ).get("properties") template = NSDict(template) # name and link are not in properties, so stick them in template.name = template_name template.link = template_link # TODO delete metadata to reduce memory footprint? # del template.metadata # translate gpus into an easier-to-read format machine_info = self.machine_type(template.machineType, project=project) if machine_info.accelerators: template.gpu_type = machine_info.accelerators[0].guestAcceleratorType template.gpu_count = machine_info.accelerators[0].guestAcceleratorCount elif template.guestAccelerators: template.gpu_type = template.guestAccelerators[0].acceleratorType template.gpu_count = template.guestAccelerators[0].acceleratorCount else: template.gpu_type = None template.gpu_count = 0 # keep write access open for minimum time with self.template_cache(writeback=True) as cache: cache[template_name] = template.to_dict() # cache should be owned by slurm chown_slurm(self.template_cache_path) return template def nodeset_map(self, hostnames: list): """Convert a list of nodes into a map of nodeset_name to hostnames""" nodeset_map = collections.defaultdict(list) for node in hostnames: nodeset_map[self.node_nodeset_name(node)].append(node) return nodeset_map # Define late globals lkp = Lookup() cfg = load_config_file(CONFIG_FILE) if not cfg: try: cfg = fetch_config_yaml() except Exception as e: log.warning(f"config not found in bucket: {e}") if cfg: save_config(cfg, CONFIG_FILE) lkp = Lookup(cfg) # Needs to be run after the lookup is complete to get endpoint versions compute = compute_service() if __name__ == "__main__": parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter ) parser.add_argument( "--partitions", "-p", help="The partition(s) to retrieve the TPU vmcount value for.", ) args = parser.parse_args() if args.partitions: # useful exit code # partition does not exists in config.yaml, thus do not exist in slurm PART_INVALID = -1 # in the same partition there are nodesets with different vmcounts DIFF_VMCOUNTS_SAME_PART = -2 # partition is a list of partitions in which at least two of them have different vmcount DIFF_PART_DIFFERENT_VMCOUNTS = -3 vmcounts = [] # valid equals to 0 means that we are ok, otherwise it will be set to one of the previously defined exit codes valid = 0 for part in args.partitions.split(","): if part not in lkp.cfg.partitions: valid = PART_INVALID break else: if part_is_tpu(part): vmcount = get_vmcount_of_tpu_part(part) if vmcount == -1: valid = DIFF_VMCOUNTS_SAME_PART break vmcounts.append(vmcount) else: vmcounts.append(0) # this means that there are different vmcounts for these partitions if valid == 0 and len(set(vmcounts)) != 1: valid = DIFF_PART_DIFFERENT_VMCOUNTS if valid != 0: print(f"VMCOUNT:{valid}") else: print(f"VMCOUNT:{vmcounts[0]}")